Skip to content

Commit 0d21522

Browse files
authored
[mlir][gpu] Make offset and width in gpu.rotate as attributes (#150901)
`offset` and `width` must be constants and there are constraints on their values. Update the operation definition to use attributes instead of operands.
1 parent d598590 commit 0d21522

File tree

6 files changed

+35
-123
lines changed

6 files changed

+35
-123
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,12 +1368,14 @@ def GPU_ShuffleOp : GPU_Op<
13681368

13691369
def GPU_RotateOp : GPU_Op<
13701370
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
1371-
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
1371+
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value,
1372+
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$offset,
1373+
ConfinedAttr<I32Attr, [IntPowerOf2]>:$width)>,
13721374
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> {
13731375
let summary = "Rotate values within a subgroup.";
13741376
let description = [{
13751377
The "rotate" op moves values across lanes in a subgroup (a.k.a., local
1376-
invocations) within the same subgroup. The `width` argument specifies the
1378+
invocations) within the same subgroup. The `width` attribute specifies the
13771379
number of lanes that participate in the rotation, and must be uniform across
13781380
all participating lanes. Further, the first `width` lanes of the subgroup
13791381
must be active.
@@ -1394,9 +1396,7 @@ def GPU_RotateOp : GPU_Op<
13941396
example:
13951397

13961398
```mlir
1397-
%offset = arith.constant 1 : i32
1398-
%width = arith.constant 16 : i32
1399-
%1, %2 = gpu.rotate %0, %offset, %width : f32
1399+
%1, %2 = gpu.rotate %0, 1, 16 : f32
14001400
```
14011401

14021402
For lane `k`, returns the value from lane `(k + cst1) % width`.
@@ -1406,11 +1406,6 @@ def GPU_RotateOp : GPU_Op<
14061406
$value `,` $offset `,` $width attr-dict `:` type($value)
14071407
}];
14081408

1409-
let builders = [
1410-
// Helper function that creates a rotate with constant offset/width.
1411-
OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
1412-
];
1413-
14141409
let hasVerifier = 1;
14151410
}
14161411

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite(
507507
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
508508
unsigned subgroupSize =
509509
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
510-
IntegerAttr widthAttr;
511-
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
512-
widthAttr.getValue().getZExtValue() > subgroupSize)
510+
unsigned width = rotateOp.getWidth();
511+
if (width > subgroupSize)
513512
return rewriter.notifyMatchFailure(
514-
rotateOp,
515-
"rotate width is not a constant or larger than target subgroup size");
513+
rotateOp, "rotate width is larger than target subgroup size");
516514

517515
Location loc = rotateOp.getLoc();
518516
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
517+
Value offsetVal =
518+
arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
519+
Value widthVal =
520+
arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
519521
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
520-
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
521-
adaptor.getWidth());
522+
rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
522523
Value validVal;
523-
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
524+
if (width == subgroupSize) {
524525
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
525526
} else {
527+
IntegerAttr widthAttr = adaptor.getWidthAttr();
526528
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
527529
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
528-
laneId, adaptor.getWidth());
530+
laneId, widthVal);
529531
}
530532

531533
rewriter.replaceOp(rotateOp, {rotateResult, validVal});

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
13951395
// RotateOp
13961396
//===----------------------------------------------------------------------===//
13971397

1398-
void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
1399-
int32_t offset, int32_t width) {
1400-
build(builder, result, value,
1401-
arith::ConstantOp::create(builder, result.___location,
1402-
builder.getI32IntegerAttr(offset)),
1403-
arith::ConstantOp::create(builder, result.___location,
1404-
builder.getI32IntegerAttr(width)));
1405-
}
1406-
14071398
LogicalResult RotateOp::verify() {
1408-
auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
1409-
if (!offsetConstOp)
1410-
return emitOpError() << "offset is not a constant value";
1411-
1412-
auto offsetIntAttr =
1413-
llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
1414-
1415-
auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
1416-
if (!widthConstOp)
1417-
return emitOpError() << "width is not a constant value";
1418-
1419-
auto widthIntAttr =
1420-
llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
1421-
1422-
llvm::APInt offsetValue = offsetIntAttr.getValue();
1423-
llvm::APInt widthValue = widthIntAttr.getValue();
1424-
1425-
if (!widthValue.isPowerOf2())
1426-
return emitOpError() << "width must be a power of two";
1399+
uint32_t offset = getOffset();
1400+
uint32_t width = getWidth();
14271401

1428-
if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
1429-
int64_t widthValueInt = widthValue.getSExtValue();
1430-
return emitOpError() << "offset must be in the range [0, " << widthValueInt
1431-
<< ")";
1402+
if (offset >= width) {
1403+
return emitOpError() << "offset must be in the range [0, " << width << ")";
14321404
}
14331405

14341406
return success();

mlir/test/Conversion/GPUToSPIRV/rotate.mlir

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@ gpu.module @kernels {
1010
// CHECK-LABEL: spirv.func @rotate()
1111
gpu.func @rotate() kernel
1212
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
13-
%offset = arith.constant 4 : i32
14-
%width = arith.constant 16 : i32
1513
%val = arith.constant 42.0 : f32
1614

15+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
1716
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
1817
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
19-
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
2018
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
2119
// CHECK: %{{.+}} = spirv.Constant true
22-
%result, %valid = gpu.rotate %val, %offset, %width : f32
20+
%result, %valid = gpu.rotate %val, 4, 16 : f32
2321
gpu.return
2422
}
2523
}
@@ -38,18 +36,16 @@ gpu.module @kernels {
3836
// CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size()
3937
gpu.func @rotate_width_less_than_subgroup_size() kernel
4038
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
41-
%offset = arith.constant 4 : i32
42-
%width = arith.constant 8 : i32
4339
%val = arith.constant 42.0 : f32
4440

41+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
4542
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
4643
// CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
47-
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
4844
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
4945
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
5046
// CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
5147
// CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
52-
%result, %valid = gpu.rotate %val, %offset, %width : f32
48+
%result, %valid = gpu.rotate %val, 4, 8 : f32
5349
gpu.return
5450
}
5551
}
@@ -67,34 +63,10 @@ module attributes {
6763
gpu.module @kernels {
6864
gpu.func @rotate_with_bigger_than_subgroup_size() kernel
6965
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
70-
%offset = arith.constant 4 : i32
71-
%width = arith.constant 32 : i32
7266
%val = arith.constant 42.0 : f32
7367

7468
// expected-error @+1 {{failed to legalize operation 'gpu.rotate'}}
75-
%result, %valid = gpu.rotate %val, %offset, %width : f32
76-
gpu.return
77-
}
78-
}
79-
80-
}
81-
82-
// -----
83-
84-
module attributes {
85-
gpu.container_module,
86-
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
87-
#spirv.resource_limits<subgroup_size = 16>>
88-
} {
89-
90-
gpu.module @kernels {
91-
gpu.func @rotate_non_const_width(%width: i32) kernel
92-
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
93-
%offset = arith.constant 4 : i32
94-
%val = arith.constant 42.0 : f32
95-
96-
// expected-error @+1 {{'gpu.rotate' op width is not a constant value}}
97-
%result, %valid = gpu.rotate %val, %offset, %width : f32
69+
%result, %valid = gpu.rotate %val, 4, 32 : f32
9870
gpu.return
9971
}
10072
}

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
479479
// -----
480480

481481
func.func @rotate_mismatching_type(%arg0 : f32) {
482-
%offset = arith.constant 4 : i32
483-
%width = arith.constant 16 : i32
484482
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
485-
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
483+
%rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1)
486484
return
487485
}
488486

489487
// -----
490488

491489
func.func @rotate_unsupported_type(%arg0 : index) {
492-
%offset = arith.constant 4 : i32
493-
%width = arith.constant 16 : i32
494490
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
495-
%rotate, %valid = gpu.rotate %arg0, %offset, %width : index
491+
%rotate, %valid = gpu.rotate %arg0, 4, 16 : index
496492
return
497493
}
498494

@@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
502498
%offset = arith.constant 4 : i32
503499
%width = arith.constant 16 : i32
504500
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
505-
%rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
501+
%rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32>
506502
return
507503
}
508504

509505
// -----
510506

511507
func.func @rotate_unsupported_width(%arg0 : f32) {
512-
%offset = arith.constant 4 : i32
513-
%width = arith.constant 15 : i32
514-
// expected-error@+1 {{op width must be a power of two}}
515-
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
508+
// expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}}
509+
%rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1)
516510
return
517511
}
518512

519513
// -----
520514

521515
func.func @rotate_unsupported_offset(%arg0 : f32) {
522-
%offset = arith.constant 16 : i32
523-
%width = arith.constant 16 : i32
524516
// expected-error@+1 {{op offset must be in the range [0, 16)}}
525-
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
517+
%rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1)
526518
return
527519
}
528520

529521
// -----
530522

531523
func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
532-
%offset = arith.constant -1 : i32
533-
%width = arith.constant 16 : i32
534-
// expected-error@+1 {{op offset must be in the range [0, 16)}}
535-
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
536-
return
537-
}
538-
539-
// -----
540-
541-
func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
542-
%width = arith.constant 16 : i32
543-
// expected-error@+1 {{op offset is not a constant value}}
544-
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
545-
return
546-
}
547-
548-
// -----
549-
550-
func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
551-
%offset = arith.constant 0 : i32
552-
// expected-error@+1 {{op width is not a constant value}}
553-
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
524+
// expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}}
525+
%rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1)
554526
return
555527
}
556528

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,8 @@ module attributes {gpu.container_module} {
140140
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
141141
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
142142

143-
// CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
144-
%rotate_width = arith.constant 16 : i32
145-
%rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32
143+
// CHECK: gpu.rotate %{{.*}}, 3, 16 : f32
144+
%rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32
146145

147146
"gpu.barrier"() : () -> ()
148147

0 commit comments

Comments
 (0)