Skip to content

Commit 04b5274

Browse files
committed
[MLIR] Introduce applyOpPatternsAndFold for op local rewrites
Introduce mlir::applyOpPatternsAndFold which applies patterns as well as any folding only on a specified op (in contrast to applyPatternsAndFoldGreedily which applies patterns only on the regions of an op isolated from above). The caller is made aware of the op being folded away or erased. Depends on D77485. Differential Revision: https://reviews.llvm.org/D77487
1 parent bd47c47 commit 04b5274

File tree

8 files changed

+258
-67
lines changed

8 files changed

+258
-67
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,15 @@ bool applyPatternsAndFoldGreedily(Operation *op,
455455
/// Rewrite the given regions, which must be isolated from above.
456456
bool applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
457457
const OwningRewritePatternList &patterns);
458+
459+
/// Applies the specified patterns on `op` alone while also trying to fold it,
460+
/// by selecting the highest benefits patterns in a greedy manner. Returns true
461+
/// if no more patterns can be matched. `erased` is set to true if `op` was
462+
/// folded away or erased as a result of becoming dead. Note: This does not
463+
/// apply any patterns recursively to the regions of `op`.
464+
bool applyOpPatternsAndFold(Operation *op,
465+
const OwningRewritePatternList &patterns,
466+
bool *erased = nullptr);
458467
} // end namespace mlir
459468

460469
#endif // MLIR_PATTERN_MATCH_H

mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -211,20 +211,25 @@ void AffineDataCopyGeneration::runOnFunction() {
211211
for (auto &block : f)
212212
runOnBlock(&block, copyNests);
213213

214-
// Promote any single iteration loops in the copy nests.
214+
// Promote any single iteration loops in the copy nests and collect
215+
// load/stores to simplify.
216+
SmallVector<Operation *, 4> copyOps;
215217
for (auto nest : copyNests)
216-
nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
218+
// With a post order walk, the erasure of loops does not affect
219+
// continuation of the walk or the collection of load/store ops.
220+
nest->walk([&](Operation *op) {
221+
if (auto forOp = dyn_cast<AffineForOp>(op))
222+
promoteIfSingleIteration(forOp);
223+
else if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
224+
copyOps.push_back(op);
225+
});
217226

218227
// Promoting single iteration loops could lead to simplification of
219-
// load's/store's. We will run canonicalization patterns on load/stores.
220-
// TODO: this whole function load/store canonicalization should be replaced by
221-
// canonicalization that is limited to only the load/store ops
222-
// introduced/touched by this pass (those inside 'copyNests'). This would be
223-
// possible once the necessary support is available in the pattern rewriter.
224-
if (!copyNests.empty()) {
225-
OwningRewritePatternList patterns;
226-
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
227-
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
228-
applyPatternsAndFoldGreedily(f, std::move(patterns));
229-
}
228+
// contained load's/store's, and the latter could anyway also be
229+
// canonicalized.
230+
OwningRewritePatternList patterns;
231+
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
232+
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
233+
for (auto op : copyOps)
234+
applyOpPatternsAndFold(op, std::move(patterns));
230235
}

mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements a pass to simplify affine structures.
9+
// This file implements a pass to simplify affine structures in operations.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "PassDetail.h"
1414
#include "mlir/Analysis/Utils.h"
15+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Affine/Passes.h"
1617
#include "mlir/IR/IntegerSet.h"
18+
#include "mlir/IR/PatternMatch.h"
1719
#include "mlir/Transforms/Utils.h"
1820

1921
#define DEBUG_TYPE "simplify-affine-structure"
@@ -77,13 +79,22 @@ mlir::createSimplifyAffineStructuresPass() {
7779
void SimplifyAffineStructures::runOnFunction() {
7880
auto func = getFunction();
7981
simplifiedAttributes.clear();
80-
func.walk([&](Operation *opInst) {
81-
for (auto attr : opInst->getAttrs()) {
82+
OwningRewritePatternList patterns;
83+
AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
84+
AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
85+
AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
86+
func.walk([&](Operation *op) {
87+
for (auto attr : op->getAttrs()) {
8288
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
83-
simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
89+
simplifyAndUpdateAttribute(op, attr.first, mapAttr);
8490
else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>())
85-
simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
91+
simplifyAndUpdateAttribute(op, attr.first, setAttr);
8692
}
93+
94+
// The simplification of the attribute will likely simplify the op. Try to
95+
// fold / apply canonicalization patterns when we have affine dialect ops.
96+
if (isa<AffineForOp>(op) || isa<AffineIfOp>(op) || isa<AffineApplyOp>(op))
97+
applyOpPatternsAndFold(op, patterns);
8798
});
8899

89100
// Turn memrefs' non-identity layouts maps into ones with identity. Collect

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ using namespace mlir;
2626
/// The max number of iterations scanning for pattern match.
2727
static unsigned maxPatternMatchIterations = 10;
2828

29+
//===----------------------------------------------------------------------===//
30+
// GreedyPatternRewriteDriver
31+
//===----------------------------------------------------------------------===//
32+
2933
namespace {
3034
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3135
/// applies the locally optimal patterns in a roughly "bottom up" way.
@@ -37,8 +41,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
3741
worklist.reserve(64);
3842
}
3943

40-
/// Perform the rewrites while folding and erasing any dead ops. Return true
41-
/// if the rewrite converges in `maxIterations`.
4244
bool simplify(MutableArrayRef<Region> regions, int maxIterations);
4345

4446
void addToWorklist(Operation *op) {
@@ -248,3 +250,112 @@ bool mlir::applyPatternsAndFoldGreedily(
248250
});
249251
return converged;
250252
}
253+
254+
//===----------------------------------------------------------------------===//
255+
// OpPatternRewriteDriver
256+
//===----------------------------------------------------------------------===//
257+
258+
namespace {
259+
/// This is a simple driver for the PatternMatcher to apply patterns and perform
260+
/// folding on a single op. It repeatedly applies locally optimal patterns.
261+
class OpPatternRewriteDriver : public PatternRewriter {
262+
public:
263+
explicit OpPatternRewriteDriver(MLIRContext *ctx,
264+
const OwningRewritePatternList &patterns)
265+
: PatternRewriter(ctx), matcher(patterns), folder(ctx) {}
266+
267+
bool simplifyLocally(Operation *op, int maxIterations, bool &erased);
268+
269+
/// No additional action needed other than inserting the op.
270+
Operation *insert(Operation *op) override { return OpBuilder::insert(op); }
271+
272+
// These are hooks implemented for PatternRewriter.
273+
protected:
274+
/// If an operation is about to be removed, mark it so that we can let clients
275+
/// know.
276+
void notifyOperationRemoved(Operation *op) override {
277+
opErasedViaPatternRewrites = true;
278+
}
279+
280+
// When a root is going to be replaced, its removal will be notified as well.
281+
// So there is nothing to do here.
282+
void notifyRootReplaced(Operation *op) override {}
283+
284+
private:
285+
/// The low-level pattern matcher.
286+
RewritePatternMatcher matcher;
287+
288+
/// Non-pattern based folder for operations.
289+
OperationFolder folder;
290+
291+
/// Set to true if the operation has been erased via pattern rewrites.
292+
bool opErasedViaPatternRewrites = false;
293+
};
294+
295+
} // anonymous namespace
296+
297+
/// Performs the rewrites and folding only on `op`. The simplification converges
298+
/// if the op is erased as a result of being folded, replaced, or dead, or no
299+
/// more changes happen in an iteration. Returns true if the rewrite converges
300+
/// in `maxIterations`. `erased` is set to true if `op` gets erased.
301+
bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations,
302+
bool &erased) {
303+
bool changed = false;
304+
erased = false;
305+
opErasedViaPatternRewrites = false;
306+
int i = 0;
307+
// Iterate until convergence or until maxIterations. Deletion of the op as
308+
// a result of being dead or folded is convergence.
309+
do {
310+
// If the operation is trivially dead - remove it.
311+
if (isOpTriviallyDead(op)) {
312+
op->erase();
313+
erased = true;
314+
return true;
315+
}
316+
317+
// Try to fold this op.
318+
bool inPlaceUpdate;
319+
if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
320+
/*preReplaceAction=*/nullptr,
321+
&inPlaceUpdate))) {
322+
changed = true;
323+
if (!inPlaceUpdate) {
324+
erased = true;
325+
return true;
326+
}
327+
}
328+
329+
// Make sure that any new operations are inserted at this point.
330+
setInsertionPoint(op);
331+
332+
// Try to match one of the patterns. The rewriter is automatically
333+
// notified of any necessary changes, so there is nothing else to do here.
334+
changed |= matcher.matchAndRewrite(op, *this);
335+
if ((erased = opErasedViaPatternRewrites))
336+
return true;
337+
} while (changed && ++i < maxIterations);
338+
339+
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
340+
return !changed;
341+
}
342+
343+
/// Rewrites only `op` using the supplied canonicalization patterns and
344+
/// folding. `erased` is set to true if the op is erased as a result of being
345+
/// folded, replaced, or dead.
346+
bool mlir::applyOpPatternsAndFold(Operation *op,
347+
const OwningRewritePatternList &patterns,
348+
bool *erased) {
349+
// Start the pattern driver.
350+
OpPatternRewriteDriver driver(op->getContext(), patterns);
351+
bool opErased;
352+
bool converged =
353+
driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
354+
if (erased)
355+
*erased = opErased;
356+
LLVM_DEBUG(if (!converged) {
357+
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
358+
<< maxPatternMatchIterations << " times";
359+
});
360+
return converged;
361+
}

mlir/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/BlockAndValueMapping.h"
2424
#include "mlir/IR/Function.h"
2525
#include "mlir/IR/IntegerSet.h"
26+
#include "mlir/IR/PatternMatch.h"
2627
#include "mlir/Transforms/RegionUtils.h"
2728
#include "mlir/Transforms/Utils.h"
2829
#include "llvm/ADT/DenseMap.h"
@@ -312,9 +313,19 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
312313
opGroupQueue, /*offset=*/0, forOp, b);
313314
lbShift = d * step;
314315
}
315-
if (!prologue && res)
316-
prologue = res;
317-
epilogue = res;
316+
317+
if (res) {
318+
// Simplify/canonicalize the affine.for.
319+
OwningRewritePatternList patterns;
320+
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
321+
bool erased;
322+
applyOpPatternsAndFold(res, std::move(patterns), &erased);
323+
324+
if (!erased && !prologue)
325+
prologue = res;
326+
if (!erased)
327+
epilogue = res;
328+
}
318329
} else {
319330
// Start of first interval.
320331
lbShift = d * step;
@@ -694,7 +705,8 @@ bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
694705
}
695706

696707
/// Return true if `loops` is a perfect nest.
697-
static bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops) {
708+
static bool LLVM_ATTRIBUTE_UNUSED
709+
isPerfectlyNested(ArrayRef<AffineForOp> loops) {
698710
auto outerLoop = loops.front();
699711
for (auto loop : loops.drop_front()) {
700712
auto parentForOp = dyn_cast<AffineForOp>(loop.getParentOp());

mlir/test/Dialect/Affine/affine-data-copy.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
216216
return %A : memref<4096xf32>
217217
}
218218
// CHECK: affine.for %[[IV1:.*]] = 0 to 4096 step 100
219-
// CHECK-NEXT: %[[BUF:.*]] = alloc() : memref<100xf32>
219+
// CHECK: %[[BUF:.*]] = alloc() : memref<100xf32>
220220
// CHECK-NEXT: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) {
221221
// CHECK-NEXT: affine.load %{{.*}}[%[[IV2]]] : memref<4096xf32>
222222
// CHECK-NEXT: affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32>
@@ -226,7 +226,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
226226
// CHECK-NEXT: mulf
227227
// CHECK-NEXT: affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32>
228228
// CHECK-NEXT: }
229-
// CHECK-NEXT: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) {
229+
// CHECK: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) {
230230
// CHECK-NEXT: affine.load %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32>
231231
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[IV2]]] : memref<4096xf32>
232232
// CHECK-NEXT: }
@@ -239,8 +239,8 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
239239
// with multi-level tiling when the tile sizes used don't divide loop trip
240240
// counts.
241241

242-
#lb = affine_map<(d0, d1) -> (d0 * 512, d1 * 6)>
243-
#ub = affine_map<(d0, d1) -> (d0 * 512 + 512, d1 * 6 + 6)>
242+
#lb = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)>
243+
#ub = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)>
244244

245245
// CHECK-DAG: #[[LB:.*]] = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)>
246246
// CHECK-DAG: #[[UB:.*]] = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)>
@@ -250,7 +250,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
250250
// CHECK-SAME: [[j:arg[0-9]+]]
251251
func @max_lower_bound(%M: memref<2048x516xf64>, %i : index, %j : index) {
252252
affine.for %ii = 0 to 2048 {
253-
affine.for %jj = max #lb(%i, %j) to min #ub(%i, %j) {
253+
affine.for %jj = max #lb()[%i, %j] to min #ub()[%i, %j] {
254254
affine.load %M[%ii, %jj] : memref<2048x516xf64>
255255
}
256256
}

0 commit comments

Comments
 (0)