Skip to content

Commit 3d4f1fe

Browse files
[mlir][spirv] Fix UpdateVCEPass to deduce the correct set of capabilities (#151108)
When deducing capabilities implied capabilities are not considered, which causes generation of incorrect SPIR-V modules. This commit fixes that by pulling in the capability set for all the implied ones. --------- Signed-off-by: Davide Grohmann <[email protected]>
1 parent bae8f13 commit 3d4f1fe

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
9595
return success();
9696
}
9797

98+
static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
99+
for (spirv::Capability cap : caps) {
100+
ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
101+
caps.insert_range(impliedCaps);
102+
}
103+
}
104+
98105
void UpdateVCEPass::runOnOperation() {
99106
spirv::ModuleOp module = getOperation();
100107

@@ -168,6 +175,8 @@ void UpdateVCEPass::runOnOperation() {
168175
return WalkResult::interrupt();
169176
}
170177

178+
addAllImpliedCapabilities(deducedCapabilities);
179+
171180
return WalkResult::advance();
172181
});
173182

mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// Test deducing minimal version.
88
// spirv.IAdd is available from v1.0.
99

10-
// CHECK: requires #spirv.vce<v1.0, [Shader], []>
10+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
1111
spirv.module Logical GLSL450 attributes {
1212
spirv.target_env = #spirv.target_env<
1313
#spirv.vce<v1.5, [Shader], []>, #spirv.resource_limits<>>
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
2121
// Test deducing minimal version.
2222
// spirv.GroupNonUniformBallot is available since v1.3.
2323

24-
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, Shader], []>
24+
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
2525
spirv.module Logical GLSL450 attributes {
2626
spirv.target_env = #spirv.target_env<
2727
#spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -32,7 +32,7 @@ spirv.module Logical GLSL450 attributes {
3232
}
3333
}
3434

35-
// CHECK: requires #spirv.vce<v1.4, [Shader], []>
35+
// CHECK: requires #spirv.vce<v1.4, [Shader, Matrix], []>
3636
spirv.module Logical GLSL450 attributes {
3737
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, #spirv.resource_limits<>>
3838
} {
@@ -48,7 +48,7 @@ spirv.module Logical GLSL450 attributes {
4848

4949
// Test minimal capabilities.
5050

51-
// CHECK: requires #spirv.vce<v1.0, [Shader], []>
51+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
5252
spirv.module Logical GLSL450 attributes {
5353
spirv.target_env = #spirv.target_env<
5454
#spirv.vce<v1.0, [Shader, Float16, Float64, Int16, Int64, VariablePointers], []>, #spirv.resource_limits<>>
@@ -61,10 +61,10 @@ spirv.module Logical GLSL450 attributes {
6161

6262
// Test Physical Storage Buffers are deduced correctly.
6363

64-
// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader], [SPV_EXT_physical_storage_buffer]>
64+
// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
6565
spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
6666
spirv.target_env = #spirv.target_env<
67-
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
67+
#spirv.vce<v1.0, [PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
6868
} {
6969
spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
7070
spirv.Return
@@ -74,7 +74,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
7474
// Test deducing implied capability.
7575
// AtomicStorage implies Shader.
7676

77-
// CHECK: requires #spirv.vce<v1.0, [Shader], []>
77+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
7878
spirv.module Logical GLSL450 attributes {
7979
spirv.target_env = #spirv.target_env<
8080
#spirv.vce<v1.0, [AtomicStorage], []>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
9595
// * GroupNonUniformArithmetic
9696
// * GroupNonUniformBallot
9797

98-
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, Shader], []>
98+
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
9999
spirv.module Logical GLSL450 attributes {
100100
spirv.target_env = #spirv.target_env<
101101
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
106106
}
107107
}
108108

109-
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, Shader], []>
109+
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
110110
spirv.module Logical GLSL450 attributes {
111111
spirv.target_env = #spirv.target_env<
112112
#spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
120120
// Test type required capabilities
121121

122122
// Using 8-bit integers in non-interface storage class requires Int8.
123-
// CHECK: requires #spirv.vce<v1.0, [Int8, Shader], []>
123+
// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
124124
spirv.module Logical GLSL450 attributes {
125125
spirv.target_env = #spirv.target_env<
126126
#spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
132132
}
133133

134134
// Using 16-bit floats in non-interface storage class requires Float16.
135-
// CHECK: requires #spirv.vce<v1.0, [Float16, Shader], []>
135+
// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
136136
spirv.module Logical GLSL450 attributes {
137137
spirv.target_env = #spirv.target_env<
138138
#spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
144144
}
145145

146146
// Using 16-element vectors requires Vector16.
147-
// CHECK: requires #spirv.vce<v1.0, [Vector16, Shader], []>
147+
// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
148148
spirv.module Logical GLSL450 attributes {
149149
spirv.target_env = #spirv.target_env<
150150
#spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
162162
// Test deducing minimal extensions.
163163
// spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
164164

165-
// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader], [SPV_KHR_shader_ballot]>
165+
// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
166166
spirv.module Logical GLSL450 attributes {
167167
spirv.target_env = #spirv.target_env<
168168
#spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
193193

194194
// Using 8-bit integers in interface storage class requires additional
195195
// extensions and capabilities.
196-
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
196+
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
197197
spirv.module Logical GLSL450 attributes {
198198
spirv.target_env = #spirv.target_env<
199199
#spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
208208
// Complicated nested types
209209
// * Buffer requires ImageBuffer or SampledBuffer.
210210
// * Rg32f requires StorageImageExtendedFormats.
211-
// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
211+
// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
212212
spirv.module Logical GLSL450 attributes {
213213
spirv.target_env = #spirv.target_env<
214214
#spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
219219
}
220220

221221
// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
222-
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
222+
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
223223
spirv.module Logical GLSL450 attributes {
224224
spirv.target_env = #spirv.target_env<
225225
#spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,

0 commit comments

Comments
 (0)