diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6a9b951ca61d6..9b1c84ee66156 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -95,6 +95,13 @@ static LogicalResult checkAndUpdateCapabilityRequirements( return success(); } +static void addAllImpliedCapabilities(SetVector &caps) { + for (spirv::Capability cap : caps) { + ArrayRef impliedCaps = getDirectImpliedCapabilities(cap); + caps.insert_range(impliedCaps); + } +} + void UpdateVCEPass::runOnOperation() { spirv::ModuleOp module = getOperation(); @@ -168,6 +175,8 @@ void UpdateVCEPass::runOnOperation() { return WalkResult::interrupt(); } + addAllImpliedCapabilities(deducedCapabilities); + return WalkResult::advance(); }); diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2b237665ffc4a..d657633665876 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -7,7 +7,7 @@ // Test deducing minimal version. // spirv.IAdd is available from v1.0. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal version. // spirv.GroupNonUniformBallot is available since v1.3. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -32,7 +32,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { @@ -48,7 +48,7 @@ spirv.module Logical GLSL450 attributes { // Test minimal capabilities. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -61,10 +61,10 @@ spirv.module Logical GLSL450 attributes { // Test Physical Storage Buffers are deduced correctly. -// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce +// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce spirv.module PhysicalStorageBuffer64 GLSL450 attributes { spirv.target_env = #spirv.target_env< - #spirv.vce, #spirv.resource_limits<>> + #spirv.vce, #spirv.resource_limits<>> } { spirv.func @physical_ptr(%val : !spirv.ptr { spirv.decoration = #spirv.decoration }) "None" { spirv.Return @@ -74,7 +74,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes { // Test deducing implied capability. // AtomicStorage implies Shader. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes { // * GroupNonUniformArithmetic // * GroupNonUniformBallot -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes { // Test type required capabilities // Using 8-bit integers in non-interface storage class requires Int8. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-bit floats in non-interface storage class requires Float16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-element vectors requires Vector16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal extensions. // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes { // Complicated nested types // * Buffer requires ImageBuffer or SampledBuffer. // * Rg32f requires StorageImageExtendedFormats. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, @@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes { } // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce,