Skip to content

Commit 0b37de2

Browse files
authored
[MLIR][SCF] Propagate loop annotation during while op lowering (#151746)
This is expanding on #102562 This allows also propagating attributes for scf.while lowering
1 parent c188e1d commit 0b37de2

File tree

2 files changed

+63
-17
lines changed

2 files changed

+63
-17
lines changed

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
312312

313313
} // namespace
314314

315+
static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
316+
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
317+
// llvm.loop_annotation attribute.
318+
// LLVM requires the loop metadata to be attached on the "latch" block. Which
319+
// is the back-edge to the header block (conditionBlock)
320+
SmallVector<NamedAttribute> llvmAttrs;
321+
llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs),
322+
[](auto attr) {
323+
return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
324+
});
325+
brOp->setDiscardableAttrs(llvmAttrs);
326+
}
327+
315328
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
316329
PatternRewriter &rewriter) const {
317330
Location loc = forOp.getLoc();
@@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
350363
auto branchOp =
351364
cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
352365

353-
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
354-
// llvm.loop_annotation attribute.
355-
// LLVM requires the loop metadata to be attached on the "latch" block. Which
356-
// is the back-edge to the header block (conditionBlock)
357-
SmallVector<NamedAttribute> llvmAttrs;
358-
llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
359-
[](auto attr) {
360-
return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
361-
});
362-
branchOp->setDiscardableAttrs(llvmAttrs);
363-
366+
propagateLoopAttrs(forOp, branchOp);
364367
rewriter.eraseOp(terminator);
365368

366369
// Compute loop bounds before branching to the condition.
@@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
589592

590593
rewriter.setInsertionPointToEnd(after);
591594
auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
592-
rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
593-
yieldOp.getResults());
595+
auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
596+
yieldOp.getResults());
594597

598+
propagateLoopAttrs(whileOp, latch);
595599
// Replace the op with values "yielded" from the "before" region, which are
596600
// visible by dominance.
597601
rewriter.replaceOp(whileOp, args);
@@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
631635
// Loop around the "before" region based on condition.
632636
rewriter.setInsertionPointToEnd(before);
633637
auto condOp = cast<ConditionOp>(before->getTerminator());
634-
cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(),
635-
before, condOp.getArgs(), continuation,
636-
ValueRange());
638+
auto latch = cf::CondBranchOp::create(
639+
rewriter, condOp.getLoc(), condOp.getCondition(), before,
640+
condOp.getArgs(), continuation, ValueRange());
637641

642+
propagateLoopAttrs(whileOp, latch);
638643
// Replace the op with values "yielded" from the "before" region, which are
639644
// visible by dominance.
640645
rewriter.replaceOp(whileOp, condOp.getArgs());

mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,4 +708,45 @@ func.func @simple_std_for_loops_annotation(%arg0 : index, %arg1 : index, %arg2 :
708708
} {llvm.loop_annotation = #full_unroll}
709709
} {llvm.loop_annotation = #no_unroll}
710710
return
711-
}
711+
}
712+
713+
// -----
714+
715+
// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll<disable = true>
716+
// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL_DISABLE]]>
717+
// CHECK: func @simple_while_loops_annotation
718+
// CHECK: cf.br
719+
// CHECK: cf.cond_br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]}
720+
// CHECK: return
721+
#no_unroll = #llvm.loop_annotation<unroll = <disable = true>>
722+
func.func @simple_while_loops_annotation(%arg0 : i1) {
723+
scf.while : () -> () {
724+
scf.condition(%arg0)
725+
} do {
726+
scf.yield
727+
} attributes {llvm.loop_annotation = #no_unroll}
728+
return
729+
}
730+
731+
// -----
732+
733+
// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll<disable = true>
734+
// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL_DISABLE]]>
735+
// CHECK: func @do_while_loops_annotation
736+
// CHECK: cf.br
737+
// CHECK: cf.cond_br
738+
// CHECK: cf.br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]}
739+
// CHECK: return
740+
#no_unroll = #llvm.loop_annotation<unroll = <disable = true>>
741+
func.func @do_while_loops_annotation() {
742+
%c0_i32 = arith.constant 0 : i32
743+
scf.while (%arg2 = %c0_i32) : (i32) -> (i32) {
744+
%0 = "test.make_condition"() : () -> i1
745+
scf.condition(%0) %c0_i32 : i32
746+
} do {
747+
^bb0(%arg2: i32):
748+
scf.yield %c0_i32: i32
749+
} attributes {llvm.loop_annotation = #no_unroll}
750+
return
751+
}
752+

0 commit comments

Comments
 (0)