Skip to content

Commit 103461f

Browse files
oojahoookuhar
andauthored
[mlir][spirv] Fix lookup logic spirv.target_env for gpu.module (#147262)
The `gpu.module` operation can contain `spirv.target_env` attributes within an array attribute named `"targets"`. So it accounts for that case by iterating over the `"targets"` attribute, if present, and looking up `spirv.target_env`. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent ceb2b9c commit 103461f

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,14 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
385385
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
386386
spirv::getTargetEnvAttrName()))
387387
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
388+
if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
389+
for (Attribute targetAttr : targets)
390+
if (auto spirvTargetEnvAttr =
391+
dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
392+
spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
393+
break;
394+
}
395+
}
388396

389397
rewriter.eraseOp(moduleOp);
390398
return success();

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,45 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
4848
void runOnOperation() override;
4949

5050
private:
51+
/// Queries the target environment from 'targets' attribute of the given
52+
/// `moduleOp`.
53+
spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp);
54+
55+
/// Queries the target environment from 'targets' attribute of the given
56+
/// `moduleOp` or returns target environment as returned by
57+
/// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'.
58+
spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp);
5159
bool mapMemorySpace;
5260
};
5361

62+
spirv::TargetEnvAttr
63+
GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) {
64+
if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
65+
for (Attribute targetAttr : targets)
66+
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
67+
return spirvTargetEnvAttr;
68+
}
69+
70+
return {};
71+
}
72+
73+
spirv::TargetEnvAttr
74+
GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) {
75+
if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp))
76+
return targetEnvAttr;
77+
78+
return spirv::lookupTargetEnvOrDefault(moduleOp);
79+
}
80+
5481
void GPUToSPIRVPass::runOnOperation() {
5582
MLIRContext *context = &getContext();
5683
ModuleOp module = getOperation();
5784

5885
SmallVector<Operation *, 1> gpuModules;
5986
OpBuilder builder(context);
6087

61-
auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) {
62-
Operation *gpuModule = moduleOp.getOperation();
63-
auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
88+
auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) {
89+
auto targetAttr = lookupTargetEnvOrDefault(moduleOp);
6490
spirv::TargetEnv targetEnv(targetAttr);
6591
return targetEnv.allows(spirv::Capability::Kernel);
6692
};
@@ -86,7 +112,7 @@ void GPUToSPIRVPass::runOnOperation() {
86112
// TargetEnv attributes.
87113
for (Operation *gpuModule : gpuModules) {
88114
spirv::TargetEnvAttr targetAttr =
89-
spirv::lookupTargetEnvOrDefault(gpuModule);
115+
lookupTargetEnvOrDefault(cast<gpu::GPUModuleOp>(gpuModule));
90116

91117
// Map MemRef memory space to SPIR-V storage class first if requested.
92118
if (mapMemorySpace) {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt --split-input-file --convert-gpu-to-spirv %s | FileCheck %s
2+
3+
module attributes {gpu.container_module} {
4+
// CHECK-LABEL: spirv.module @{{.*}} GLSL450
5+
gpu.module @kernels [#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>] {
6+
// CHECK: spirv.func @load_kernel
7+
// CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
8+
gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
9+
%c0 = arith.constant 0 : index
10+
// CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
11+
// CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32
12+
%0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32>
13+
// CHECK: spirv.Return
14+
gpu.return
15+
}
16+
}
17+
}
18+
19+
// -----
20+
// Checks that the `-convert-gpu-to-spirv` pass selects the first
21+
// `spirv.target_env` from the `targets` array attribute attached to `gpu.module`.
22+
module attributes {gpu.container_module} {
23+
// CHECK-LABEL: spirv.module @{{.*}} GLSL450
24+
// CHECK-SAME: #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>
25+
gpu.module @kernels [
26+
#spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>,
27+
#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>,
28+
#spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>] {
29+
// CHECK: spirv.func @load_kernel
30+
// CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
31+
gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
32+
%c0 = arith.constant 0 : index
33+
// CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
34+
// CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32
35+
%0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32>
36+
// CHECK: spirv.Return
37+
gpu.return
38+
}
39+
}
40+
}

0 commit comments

Comments
 (0)