Skip to content

Commit 3b7dd9f

Browse files
:WIP
update erase immediately update 2 fix fix some tests
1 parent 1fad57e commit 3b7dd9f

19 files changed

+570
-123
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
5252
"Test conversion patterns of only the specified dialects">,
5353
Option<"useDynamic", "dynamic", "bool", "false",
5454
"Use op conversion attributes to configure the conversion">,
55+
Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "false",
56+
"Experimental performance flag to disallow pattern rollback">
5557
];
5658
}
5759

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
728728
public:
729729
~ConversionPatternRewriter() override;
730730

731+
/// Return the configuration of the current dialect conversion.
732+
const ConversionConfig &getConfig() const;
733+
731734
/// Apply a signature conversion to given block. This replaces the block with
732735
/// a new block containing the updated signature. The operations of the given
733736
/// block are inlined into the newly-created block, which is returned.
@@ -1228,18 +1231,18 @@ struct ConversionConfig {
12281231
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
12291232
/// and cannot be legalized by subsequent foldings / pattern applications.
12301233
///
1231-
/// If set to "false", the conversion driver will produce an LLVM fatal error
1232-
/// instead of rolling back IR modifications. Moreover, in case of a failed
1233-
/// conversion, the original IR is not restored. The resulting IR may be a
1234-
/// mix of original and rewritten IR. (Same as a failed greedy pattern
1235-
/// rewrite.)
1234+
/// Experimental: If set to "false", the conversion driver will produce an
1235+
/// LLVM fatal error instead of rolling back IR modifications. Moreover, in
1236+
/// case of a failed conversion, the original IR is not restored. The
1237+
/// resulting IR may be a mix of original and rewritten IR. (Same as a failed
1238+
/// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1239+
/// with ASAN to detect invalid pattern API usage.
12361240
///
1237-
/// Note: This flag was added in preparation of the One-Shot Dialect
1238-
/// Conversion refactoring, which will remove the ability to roll back IR
1239-
/// modifications from the conversion driver. Use this flag to ensure that
1240-
/// your patterns do not trigger any IR rollbacks. For details, see
1241+
/// When pattern rollback is disabled, the conversion driver has to maintain
1242+
/// less internal state. This is more efficient, but not supported by all
1243+
/// lowering patterns. For details, see
12411244
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
1242-
bool allowPatternRollback = true;
1245+
bool allowPatternRollback = false;
12431246
};
12441247

12451248
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ namespace {
3131
class ConvertToLLVMPassInterface {
3232
public:
3333
ConvertToLLVMPassInterface(MLIRContext *context,
34-
ArrayRef<std::string> filterDialects);
34+
ArrayRef<std::string> filterDialects,
35+
bool allowPatternRollback = true);
3536
virtual ~ConvertToLLVMPassInterface() = default;
3637

3738
/// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface {
6061
MLIRContext *context;
6162
/// List of dialects names to use as filters.
6263
ArrayRef<std::string> filterDialects;
64+
/// An experimental flag to disallow pattern rollback. This is more efficient
65+
/// but not supported by all lowering patterns.
66+
bool allowPatternRollback;
6367
};
6468

6569
/// This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
128132

129133
/// Apply the conversion driver.
130134
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
131-
if (failed(applyPartialConversion(op, *target, *patterns)))
135+
ConversionConfig config;
136+
config.allowPatternRollback = allowPatternRollback;
137+
if (failed(applyPartialConversion(op, *target, *patterns, config)))
132138
return failure();
133139
return success();
134140
}
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
179185
patterns);
180186

181187
// Apply the conversion.
182-
if (failed(applyPartialConversion(op, target, std::move(patterns))))
188+
ConversionConfig config;
189+
config.allowPatternRollback = allowPatternRollback;
190+
if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
183191
return failure();
184192
return success();
185193
}
@@ -206,9 +214,11 @@ class ConvertToLLVMPass
206214
std::shared_ptr<ConvertToLLVMPassInterface> impl;
207215
// Choose the pass implementation.
208216
if (useDynamic)
209-
impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
217+
impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
218+
allowPatternRollback);
210219
else
211-
impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
220+
impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
221+
allowPatternRollback);
212222
if (failed(impl->initialize()))
213223
return failure();
214224
this->impl = impl;
@@ -228,8 +238,10 @@ class ConvertToLLVMPass
228238
//===----------------------------------------------------------------------===//
229239

230240
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
231-
MLIRContext *context, ArrayRef<std::string> filterDialects)
232-
: context(context), filterDialects(filterDialects) {}
241+
MLIRContext *context, ArrayRef<std::string> filterDialects,
242+
bool allowPatternRollback)
243+
: context(context), filterDialects(filterDialects),
244+
allowPatternRollback(allowPatternRollback) {}
233245

234246
void ConvertToLLVMPassInterface::getDependentDialects(
235247
DialectRegistry &registry) {

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,22 @@ struct LinalgDetensorize
458458
}
459459
};
460460

461+
/// A listener that forwards notifyBlockErased and notifyOperationErased to
462+
/// the given callbacks.
463+
struct CallbackListener : public RewriterBase::Listener {
464+
CallbackListener(std::function<void(Operation *op)> onOperationErased,
465+
std::function<void(Block *block)> onBlockErased)
466+
: onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
467+
468+
void notifyBlockErased(Block *block) override { onBlockErased(block); }
469+
void notifyOperationErased(Operation *op) override {
470+
onOperationErased(op);
471+
}
472+
473+
std::function<void(Operation *op)> onOperationErased;
474+
std::function<void(Block *block)> onBlockErased;
475+
};
476+
461477
void runOnOperation() override {
462478
MLIRContext *context = &getContext();
463479
DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,23 @@ struct LinalgDetensorize
551567
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
552568
shouldConvertBranchOperand);
553569

554-
if (failed(
555-
applyFullConversion(getOperation(), target, std::move(patterns))))
570+
ConversionConfig config;
571+
CallbackListener listener(/*onOperationErased=*/
572+
[&](Operation *op) {
573+
opsToDetensor.erase(op);
574+
detensorableBranchOps.erase(op);
575+
},
576+
/*onBlockErased=*/
577+
[&](Block *block) {
578+
for (BlockArgument arg :
579+
block->getArguments()) {
580+
blockArgsToDetensor.erase(arg);
581+
}
582+
});
583+
584+
config.listener = &listener;
585+
if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
586+
config)))
556587
signalPassFailure();
557588

558589
RewritePatternSet canonPatterns(context);

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
730730
{tensor, lvlCoords, values, filled, added, count},
731731
EmitCInterface::On);
732732
Operation *parent = getTop(op);
733+
rewriter.setInsertionPointAfter(parent);
733734
rewriter.replaceOp(op, adaptor.getTensor());
734735
// Deallocate the buffers on exit of the loop nest.
735-
rewriter.setInsertionPointAfter(parent);
736736
memref::DeallocOp::create(rewriter, loc, values);
737737
memref::DeallocOp::create(rewriter, loc, filled);
738738
memref::DeallocOp::create(rewriter, loc, added);

0 commit comments

Comments
 (0)