@@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
312
312
313
313
} // namespace
314
314
315
+ static void propagateLoopAttrs (Operation *scfOp, Operation *brOp) {
316
+ // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
317
+ // llvm.loop_annotation attribute.
318
+ // LLVM requires the loop metadata to be attached on the "latch" block. Which
319
+ // is the back-edge to the header block (conditionBlock)
320
+ SmallVector<NamedAttribute> llvmAttrs;
321
+ llvm::copy_if (scfOp->getAttrs (), std::back_inserter (llvmAttrs),
322
+ [](auto attr) {
323
+ return isa<LLVM::LLVMDialect>(attr.getValue ().getDialect ());
324
+ });
325
+ brOp->setDiscardableAttrs (llvmAttrs);
326
+ }
327
+
315
328
LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
316
329
PatternRewriter &rewriter) const {
317
330
Location loc = forOp.getLoc ();
@@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
350
363
auto branchOp =
351
364
cf::BranchOp::create (rewriter, loc, conditionBlock, loopCarried);
352
365
353
- // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
354
- // llvm.loop_annotation attribute.
355
- // LLVM requires the loop metadata to be attached on the "latch" block. Which
356
- // is the back-edge to the header block (conditionBlock)
357
- SmallVector<NamedAttribute> llvmAttrs;
358
- llvm::copy_if (forOp->getAttrs (), std::back_inserter (llvmAttrs),
359
- [](auto attr) {
360
- return isa<LLVM::LLVMDialect>(attr.getValue ().getDialect ());
361
- });
362
- branchOp->setDiscardableAttrs (llvmAttrs);
363
-
366
+ propagateLoopAttrs (forOp, branchOp);
364
367
rewriter.eraseOp (terminator);
365
368
366
369
// Compute loop bounds before branching to the condition.
@@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
589
592
590
593
rewriter.setInsertionPointToEnd (after);
591
594
auto yieldOp = cast<scf::YieldOp>(after->getTerminator ());
592
- rewriter.replaceOpWithNewOp <cf::BranchOp>(yieldOp, before,
593
- yieldOp.getResults ());
595
+ auto latch = rewriter.replaceOpWithNewOp <cf::BranchOp>(yieldOp, before,
596
+ yieldOp.getResults ());
594
597
598
+ propagateLoopAttrs (whileOp, latch);
595
599
// Replace the op with values "yielded" from the "before" region, which are
596
600
// visible by dominance.
597
601
rewriter.replaceOp (whileOp, args);
@@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
631
635
// Loop around the "before" region based on condition.
632
636
rewriter.setInsertionPointToEnd (before);
633
637
auto condOp = cast<ConditionOp>(before->getTerminator ());
634
- cf::CondBranchOp::create (rewriter, condOp. getLoc (), condOp. getCondition (),
635
- before , condOp.getArgs (), continuation ,
636
- ValueRange ());
638
+ auto latch = cf::CondBranchOp::create (
639
+ rewriter, condOp. getLoc () , condOp.getCondition (), before ,
640
+ condOp. getArgs (), continuation, ValueRange ());
637
641
642
+ propagateLoopAttrs (whileOp, latch);
638
643
// Replace the op with values "yielded" from the "before" region, which are
639
644
// visible by dominance.
640
645
rewriter.replaceOp (whileOp, condOp.getArgs ());
0 commit comments