Skip to content

Commit 5851671

Browse files
committed
[MLIR] Constant fold multiplies in deriveStaticUpperBound.
Summary: This operation occurs during collapseParallelLoops, so we constant fold them also to allow more situations of determining a loop invariant upper bound when lowering to the GPU dialect from the Loop dialect. Differential Revision: https://reviews.llvm.org/D77723
1 parent 9657385 commit 5851671

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -506,16 +506,36 @@ struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
506506
/// `upperBound`.
507507
static Value deriveStaticUpperBound(Value upperBound,
508508
PatternRewriter &rewriter) {
509-
if (AffineMinOp minOp =
510-
dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
509+
if (auto op = dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp())) {
510+
return op;
511+
}
512+
513+
if (auto minOp = dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
511514
for (const AffineExpr &result : minOp.map().getResults()) {
512-
if (AffineConstantExpr constExpr =
513-
result.dyn_cast<AffineConstantExpr>()) {
515+
if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) {
514516
return rewriter.create<ConstantIndexOp>(minOp.getLoc(),
515517
constExpr.getValue());
516518
}
517519
}
518520
}
521+
522+
if (auto multiplyOp = dyn_cast_or_null<MulIOp>(upperBound.getDefiningOp())) {
523+
if (auto lhs = dyn_cast_or_null<ConstantIndexOp>(
524+
deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter)
525+
.getDefiningOp()))
526+
if (auto rhs = dyn_cast_or_null<ConstantIndexOp>(
527+
deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter)
528+
.getDefiningOp())) {
529+
// Assumptions about the upper bound of minimum computations no longer
530+
// work if multiplied by a negative value, so abort in this case.
531+
if (lhs.getValue() < 0 || rhs.getValue() < 0)
532+
return {};
533+
534+
return rewriter.create<ConstantIndexOp>(
535+
multiplyOp.getLoc(), lhs.getValue() * rhs.getValue());
536+
}
537+
}
538+
519539
return {};
520540
}
521541

mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,10 @@ module {
213213
loop.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c2, %c3) {
214214
%2 = dim %arg0, 0 : memref<?x?xf32, #map0>
215215
%3 = affine.min #map1(%arg3)[%2]
216+
%squared_min = muli %3, %3 : index
216217
%4 = dim %arg0, 1 : memref<?x?xf32, #map0>
217218
%5 = affine.min #map2(%arg4)[%4]
218-
%6 = std.subview %arg0[%arg3, %arg4][%3, %5][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
219+
%6 = std.subview %arg0[%arg3, %arg4][%squared_min, %5][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
219220
%7 = dim %arg1, 0 : memref<?x?xf32, #map0>
220221
%8 = affine.min #map1(%arg3)[%7]
221222
%9 = dim %arg1, 1 : memref<?x?xf32, #map0>
@@ -226,7 +227,7 @@ module {
226227
%14 = dim %arg2, 1 : memref<?x?xf32, #map0>
227228
%15 = affine.min #map2(%arg4)[%14]
228229
%16 = std.subview %arg2[%arg3, %arg4][%13, %15][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
229-
loop.parallel (%arg5, %arg6) = (%c0, %c0) to (%3, %5) step (%c1, %c1) {
230+
loop.parallel (%arg5, %arg6) = (%c0, %c0) to (%squared_min, %5) step (%c1, %c1) {
230231
%17 = load %6[%arg5, %arg6] : memref<?x?xf32, #map3>
231232
%18 = load %11[%arg5, %arg6] : memref<?x?xf32, #map3>
232233
%19 = load %16[%arg5, %arg6] : memref<?x?xf32, #map3>
@@ -259,7 +260,7 @@ module {
259260
// CHECK: [[VAL_9:%.*]] = constant 1 : index
260261
// CHECK: [[VAL_10:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_7]], [[VAL_4]], [[VAL_6]]]
261262
// CHECK: [[VAL_11:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_8]], [[VAL_4]], [[VAL_5]]]
262-
// CHECK: [[VAL_12:%.*]] = constant 2 : index
263+
// CHECK: [[VAL_12:%.*]] = constant 4 : index
263264
// CHECK: [[VAL_13:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_12]], [[VAL_4]], [[VAL_3]]]
264265
// CHECK: [[VAL_14:%.*]] = constant 3 : index
265266
// CHECK: [[VAL_15:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_14]], [[VAL_4]], [[VAL_3]]]
@@ -268,9 +269,10 @@ module {
268269
// CHECK: [[VAL_29:%.*]] = affine.apply #[[MAP2]]([[VAL_17]]){{\[}}[[VAL_5]], [[VAL_4]]]
269270
// CHECK: [[VAL_30:%.*]] = dim [[VAL_0]], 0 : memref<?x?xf32, #[[MAP0]]>
270271
// CHECK: [[VAL_31:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_30]]]
272+
// CHECK: [[VAL_31_SQUARED:%.*]] = muli [[VAL_31]], [[VAL_31]] : index
271273
// CHECK: [[VAL_32:%.*]] = dim [[VAL_0]], 1 : memref<?x?xf32, #[[MAP0]]>
272274
// CHECK: [[VAL_33:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_32]]]
273-
// CHECK: [[VAL_34:%.*]] = subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31]], [[VAL_33]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
275+
// CHECK: [[VAL_34:%.*]] = subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31_SQUARED]], [[VAL_33]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
274276
// CHECK: [[VAL_35:%.*]] = dim [[VAL_1]], 0 : memref<?x?xf32, #[[MAP0]]>
275277
// CHECK: [[VAL_36:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_35]]]
276278
// CHECK: [[VAL_37:%.*]] = dim [[VAL_1]], 1 : memref<?x?xf32, #[[MAP0]]>
@@ -282,7 +284,7 @@ module {
282284
// CHECK: [[VAL_43:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_42]]]
283285
// CHECK: [[VAL_44:%.*]] = subview [[VAL_2]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_41]], [[VAL_43]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
284286
// CHECK: [[VAL_45:%.*]] = affine.apply #[[MAP2]]([[VAL_22]]){{\[}}[[VAL_3]], [[VAL_4]]]
285-
// CHECK: [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31]] : index
287+
// CHECK: [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31_SQUARED]] : index
286288
// CHECK: loop.if [[VAL_46]] {
287289
// CHECK: [[VAL_47:%.*]] = affine.apply #[[MAP2]]([[VAL_23]]){{\[}}[[VAL_3]], [[VAL_4]]]
288290
// CHECK: [[VAL_48:%.*]] = cmpi "slt", [[VAL_47]], [[VAL_33]] : index

0 commit comments

Comments
 (0)