From b5156d6735280ddb074558582c1fd11b6004a3c8 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Tue, 29 Jul 2025 10:23:34 +0200 Subject: [PATCH 1/4] [mlir][spirv] Fix UpdateVCEPass to deduce the correct set of capabilities When deducing capabilities implied capabilities are not considered, which causes generation of incorrect SPIR-V modules. This commit fixes that by pulling in the capability set all the implied ones. Signed-off-by: Davide Grohmann Change-Id: Ia30149fb35bbf0071010cb7bc92b86d2e5b6a6af --- .../SPIRV/Transforms/UpdateVCEPass.cpp | 12 ++++++++ .../SPIRV/Transforms/vce-deduction.mlir | 30 +++++++++---------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6a9b951ca61d6..da316b98c2b20 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -95,6 +95,16 @@ static LogicalResult checkAndUpdateCapabilityRequirements( return success(); } +static SetVector +withImpliedCapabilities(SetVector &caps) { + SetVector allCaps(caps.begin(), caps.end()); + for (auto cap : caps) { + ArrayRef directCaps = getDirectImpliedCapabilities(cap); + allCaps.insert(directCaps.begin(), directCaps.end()); + } + return allCaps; +} + void UpdateVCEPass::runOnOperation() { spirv::ModuleOp module = getOperation(); @@ -168,6 +178,8 @@ void UpdateVCEPass::runOnOperation() { return WalkResult::interrupt(); } + deducedCapabilities = withImpliedCapabilities(deducedCapabilities); + return WalkResult::advance(); }); diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2b237665ffc4a..b536b8e4003f9 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -7,7 +7,7 @@ // Test deducing minimal version. // spirv.IAdd is available from v1.0. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal version. // spirv.GroupNonUniformBallot is available since v1.3. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -32,7 +32,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { @@ -48,7 +48,7 @@ spirv.module Logical GLSL450 attributes { // Test minimal capabilities. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes { // Test Physical Storage Buffers are deduced correctly. -// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce +// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce spirv.module PhysicalStorageBuffer64 GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -74,7 +74,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes { // Test deducing implied capability. // AtomicStorage implies Shader. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes { // * GroupNonUniformArithmetic // * GroupNonUniformBallot -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes { // Test type required capabilities // Using 8-bit integers in non-interface storage class requires Int8. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-bit floats in non-interface storage class requires Float16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-element vectors requires Vector16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal extensions. // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes { // Complicated nested types // * Buffer requires ImageBuffer or SampledBuffer. // * Rg32f requires StorageImageExtendedFormats. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, @@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes { } // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, From 02dac82f88b8ae590942d53cb850e163b45b50be Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 30 Jul 2025 10:19:00 +0200 Subject: [PATCH 2/4] Resolve code review comments Signed-off-by: Davide Grohmann Change-Id: Ib58ef4d1d24e395678c9527abdd7e96a9b1df9eb --- .../SPIRV/Transforms/UpdateVCEPass.cpp | 17 +++++++++----- .../SPIRV/Transforms/vce-deduction.mlir | 22 +++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index da316b98c2b20..ae79c39c29b46 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -96,11 +96,16 @@ static LogicalResult checkAndUpdateCapabilityRequirements( } static SetVector -withImpliedCapabilities(SetVector &caps) { - SetVector allCaps(caps.begin(), caps.end()); - for (auto cap : caps) { - ArrayRef directCaps = getDirectImpliedCapabilities(cap); - allCaps.insert(directCaps.begin(), directCaps.end()); +addAllImpliedCapabilities(SetVector &caps) { + SetVector allCaps; + while (!caps.empty()) { + spirv::Capability cap = caps.pop_back_val(); + allCaps.insert(cap); + ArrayRef impliedCaps = getDirectImpliedCapabilities(cap); + for (spirv::Capability impliedCap : impliedCaps) { + if (!allCaps.contains(impliedCap)) + caps.insert(impliedCap); + } } return allCaps; } @@ -178,7 +183,7 @@ void UpdateVCEPass::runOnOperation() { return WalkResult::interrupt(); } - deducedCapabilities = withImpliedCapabilities(deducedCapabilities); + deducedCapabilities = addAllImpliedCapabilities(deducedCapabilities); return WalkResult::advance(); }); diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index b536b8e4003f9..9410435bbea99 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal version. // spirv.GroupNonUniformBallot is available since v1.3. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes { // Test Physical Storage Buffers are deduced correctly. -// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce +// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce spirv.module PhysicalStorageBuffer64 GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes { // * GroupNonUniformArithmetic // * GroupNonUniformBallot -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes { // Test type required capabilities // Using 8-bit integers in non-interface storage class requires Int8. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-bit floats in non-interface storage class requires Float16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-element vectors requires Vector16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal extensions. // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes { // Complicated nested types // * Buffer requires ImageBuffer or SampledBuffer. // * Rg32f requires StorageImageExtendedFormats. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, @@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes { } // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, From f0e913952a9673146ff5ed9a442e530917e11c69 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 30 Jul 2025 12:53:17 +0200 Subject: [PATCH 3/4] More improvements from code review Signed-off-by: Davide Grohmann Change-Id: I34150644e4bcf559597b3c3b3dbb668e5c828faf --- .../SPIRV/Transforms/UpdateVCEPass.cpp | 16 ++++---------- .../SPIRV/Transforms/vce-deduction.mlir | 22 +++++++++---------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index ae79c39c29b46..9b1c84ee66156 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -95,19 +95,11 @@ static LogicalResult checkAndUpdateCapabilityRequirements( return success(); } -static SetVector -addAllImpliedCapabilities(SetVector &caps) { - SetVector allCaps; - while (!caps.empty()) { - spirv::Capability cap = caps.pop_back_val(); - allCaps.insert(cap); +static void addAllImpliedCapabilities(SetVector &caps) { + for (spirv::Capability cap : caps) { ArrayRef impliedCaps = getDirectImpliedCapabilities(cap); - for (spirv::Capability impliedCap : impliedCaps) { - if (!allCaps.contains(impliedCap)) - caps.insert(impliedCap); - } + caps.insert_range(impliedCaps); } - return allCaps; } void UpdateVCEPass::runOnOperation() { @@ -183,7 +175,7 @@ void UpdateVCEPass::runOnOperation() { return WalkResult::interrupt(); } - deducedCapabilities = addAllImpliedCapabilities(deducedCapabilities); + addAllImpliedCapabilities(deducedCapabilities); return WalkResult::advance(); }); diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 9410435bbea99..b536b8e4003f9 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal version. // spirv.GroupNonUniformBallot is available since v1.3. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes { // Test Physical Storage Buffers are deduced correctly. -// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce +// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce spirv.module PhysicalStorageBuffer64 GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes { // * GroupNonUniformArithmetic // * GroupNonUniformBallot -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes { // Test type required capabilities // Using 8-bit integers in non-interface storage class requires Int8. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-bit floats in non-interface storage class requires Float16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-element vectors requires Vector16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal extensions. // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes { // Complicated nested types // * Buffer requires ImageBuffer or SampledBuffer. // * Rg32f requires StorageImageExtendedFormats. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, @@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes { } // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, From 54a95a4ed89a5d31144e29adb2389e2f5d4da390 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 30 Jul 2025 13:41:32 +0200 Subject: [PATCH 4/4] Tweak a test to make sure fix point computation works Signed-off-by: Davide Grohmann Change-Id: Ic8b4f2035d110a659368726be4e4504934541c5c --- mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index b536b8e4003f9..d657633665876 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -64,7 +64,7 @@ spirv.module Logical GLSL450 attributes { // CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce spirv.module PhysicalStorageBuffer64 GLSL450 attributes { spirv.target_env = #spirv.target_env< - #spirv.vce, #spirv.resource_limits<>> + #spirv.vce, #spirv.resource_limits<>> } { spirv.func @physical_ptr(%val : !spirv.ptr { spirv.decoration = #spirv.decoration }) "None" { spirv.Return