diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index cf577eca5b9a6..9408019fe1870 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -202,17 +202,33 @@ struct MyConversionPattern : public ConversionPattern { #### Type Safety -The types of the remapped operands provided to a conversion pattern must be of a -type expected by the pattern. The expected types of a pattern are determined by -a provided [TypeConverter](#type-converter). If no type converter is provided, -the types of the remapped operands are expected to match the types of the -original operands. If a type converter is provided, the types of the remapped -operands are expected to be legal as determined by the converter. If the -remapped operand types are not of an expected type, and a materialization to the -expected type could not be performed, the pattern fails application before the -`matchAndRewrite` hook is invoked. This ensures that patterns do not have to -explicitly ensure type safety, or sanitize the types of the incoming remapped -operands. More information on type conversion is detailed in the +The types of the remapped operands provided to a conversion pattern (through +the adaptor or `ArrayRef` of operands) depend on type conversion rules. + +If the pattern was initialized with a [type converter](#type-converter), the +conversion driver passes values whose types match the legalized types of the +operands of the matched operation as per the type converter. To that end, the +conversion driver may insert target materializations to convert the most +recently mapped values to the expected legalized types. The driver tries to +reuse existing materializations on a best-effort basis, but this is not +guaranteed by the infrastructure. If the operand types of the matched op could +not be legalized, the pattern fails to apply before the `matchAndRewrite` hook +is invoked. + +If the pattern was initialized without a type converter, the conversion driver +passes the most recently mapped values to the pattern, excluding any +materializations. Materializations are intentionally excluded because their +presence may depend on other patterns. Passing materializationsm would make the +conversion infrastructure fragile and unpredictable. Moreover, there could be +multiple materializations to different types. (This can be the case when +multiple patterns are running with different type converters.) In such a case, +it would be unclear which materialization to pass. If a value with the same +type as the original operand is desired, users can directly take the respective +operand from the matched operation. + +The above rules ensure that patterns do not have to explicitly ensure type +safety, or sanitize the types of the incoming remapped operands. More +information on type conversion is detailed in the [dedicated section](#type-conversion) below. ## Type Conversion diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 6e1baaf23fcf7..4a9464ff265e0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> { "Test conversion patterns of only the specified dialects">, Option<"useDynamic", "dynamic", "bool", "false", "Use op conversion attributes to configure the conversion">, + Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "false", + "Experimental performance flag to disallow pattern rollback"> ]; } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f6437657c9a93..84b9035dc6358 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter { public: ~ConversionPatternRewriter() override; + /// Return the configuration of the current dialect conversion. + const ConversionConfig &getConfig() const; + /// Apply a signature conversion to given block. This replaces the block with /// a new block containing the updated signature. The operations of the given /// block are inlined into the newly-created block, which is returned. @@ -1228,18 +1231,18 @@ struct ConversionConfig { /// 2. Pattern produces IR (in-place modification or new IR) that is illegal /// and cannot be legalized by subsequent foldings / pattern applications. /// - /// If set to "false", the conversion driver will produce an LLVM fatal error - /// instead of rolling back IR modifications. Moreover, in case of a failed - /// conversion, the original IR is not restored. The resulting IR may be a - /// mix of original and rewritten IR. (Same as a failed greedy pattern - /// rewrite.) + /// Experimental: If set to "false", the conversion driver will produce an + /// LLVM fatal error instead of rolling back IR modifications. Moreover, in + /// case of a failed conversion, the original IR is not restored. The + /// resulting IR may be a mix of original and rewritten IR. (Same as a failed + /// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + /// with ASAN to detect invalid pattern API usage. /// - /// Note: This flag was added in preparation of the One-Shot Dialect - /// Conversion refactoring, which will remove the ability to roll back IR - /// modifications from the conversion driver. Use this flag to ensure that - /// your patterns do not trigger any IR rollbacks. For details, see + /// When pattern rollback is disabled, the conversion driver has to maintain + /// less internal state. This is more efficient, but not supported by all + /// lowering patterns. For details, see /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083. - bool allowPatternRollback = true; + bool allowPatternRollback = false; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index ed5d6d4a7fe40..cdb715064b0f7 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -31,7 +31,8 @@ namespace { class ConvertToLLVMPassInterface { public: ConvertToLLVMPassInterface(MLIRContext *context, - ArrayRef filterDialects); + ArrayRef filterDialects, + bool allowPatternRollback = true); virtual ~ConvertToLLVMPassInterface() = default; /// Get the dependent dialects used by `convert-to-llvm`. @@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface { MLIRContext *context; /// List of dialects names to use as filters. ArrayRef filterDialects; + /// An experimental flag to disallow pattern rollback. This is more efficient + /// but not supported by all lowering patterns. + bool allowPatternRollback; }; /// This DialectExtension can be attached to the context, which will invoke the @@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { /// Apply the conversion driver. LogicalResult transform(Operation *op, AnalysisManager manager) const final { - if (failed(applyPartialConversion(op, *target, *patterns))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, *target, *patterns, config))) return failure(); return success(); } @@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { patterns); // Apply the conversion. - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, target, std::move(patterns), config))) return failure(); return success(); } @@ -206,9 +214,11 @@ class ConvertToLLVMPass std::shared_ptr impl; // Choose the pass implementation. if (useDynamic) - impl = std::make_shared(context, filterDialects); + impl = std::make_shared(context, filterDialects, + allowPatternRollback); else - impl = std::make_shared(context, filterDialects); + impl = std::make_shared(context, filterDialects, + allowPatternRollback); if (failed(impl->initialize())) return failure(); this->impl = impl; @@ -228,8 +238,10 @@ class ConvertToLLVMPass //===----------------------------------------------------------------------===// ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( - MLIRContext *context, ArrayRef filterDialects) - : context(context), filterDialects(filterDialects) {} + MLIRContext *context, ArrayRef filterDialects, + bool allowPatternRollback) + : context(context), filterDialects(filterDialects), + allowPatternRollback(allowPatternRollback) {} void ConvertToLLVMPassInterface::getDependentDialects( DialectRegistry ®istry) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 830905495e759..d6f26fa200dbc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -458,6 +458,22 @@ struct LinalgDetensorize } }; + /// A listener that forwards notifyBlockErased and notifyOperationErased to + /// the given callbacks. + struct CallbackListener : public RewriterBase::Listener { + CallbackListener(std::function onOperationErased, + std::function onBlockErased) + : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {} + + void notifyBlockErased(Block *block) override { onBlockErased(block); } + void notifyOperationErased(Operation *op) override { + onOperationErased(op); + } + + std::function onOperationErased; + std::function onBlockErased; + }; + void runOnOperation() override { MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; @@ -551,8 +567,23 @@ struct LinalgDetensorize populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, shouldConvertBranchOperand); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + CallbackListener listener(/*onOperationErased=*/ + [&](Operation *op) { + opsToDetensor.erase(op); + detensorableBranchOps.erase(op); + }, + /*onBlockErased=*/ + [&](Block *block) { + for (BlockArgument arg : + block->getArguments()) { + blockArgsToDetensor.erase(arg); + } + }); + + config.listener = &listener; + if (failed(applyFullConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); RewritePatternSet canonPatterns(context); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 134aef3a6c719..0e88d31dae8e8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern { {tensor, lvlCoords, values, filled, added, count}, EmitCInterface::On); Operation *parent = getTop(op); + rewriter.setInsertionPointAfter(parent); rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. - rewriter.setInsertionPointAfter(parent); memref::DeallocOp::create(rewriter, loc, values); memref::DeallocOp::create(rewriter, loc, filled); memref::DeallocOp::create(rewriter, loc, added); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f23c6197accd5..8008958720e23 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -121,17 +121,8 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped values with the desired types in the - /// mapping. - /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. - ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + /// Lookup a value in the mapping. + ValueVector lookup(const ValueVector &from) const; template struct IsValueVector : std::is_same, ValueVector> {}; @@ -185,54 +176,40 @@ struct ConversionValueMapping { }; } // namespace -ValueVector -ConversionValueMapping::lookupOrDefault(Value from, - TypeRange desiredTypes) const { - // Try to find the deepest values that have the desired types. If there is no - // such mapping, simply return the deepest values. - ValueVector desiredValue; - ValueVector current{from}; - do { - // Store the current value if the types match. - if (TypeRange(ValueRange(current)) == desiredTypes) - desiredValue = current; - - // If possible, Replace each value with (one or multiple) mapped values. - ValueVector next; - for (Value v : current) { - auto it = mapping.find({v}); - if (it != mapping.end()) { - llvm::append_range(next, it->second); - } else { - next.push_back(v); - } - } - if (next != current) { - // If at least one value was replaced, continue the lookup from there. - current = std::move(next); - continue; - } - - // Otherwise: Check if there is a mapping for the entire vector. Such - // mappings are materializations. (N:M mapping are not supported for value - // replacements.) - // - // Note: From a correctness point of view, materializations do not have to - // be stored (and looked up) in the mapping. But for performance reasons, - // we choose to reuse existing IR (when possible) instead of creating it - // multiple times. - auto it = mapping.find(current); - if (it == mapping.end()) { - // No mapping found: The lookup stops here. - break; - } - current = it->second; - } while (true); - - // If the desired values were found use them, otherwise default to the leaf - // values. - // Note: If `desiredTypes` is empty, this function always returns `current`. - return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); +/// Marker attribute for pure type conversions. I.e., mappings whose only +/// purpose is to resolve a type mismatch. (In contrast, mappings that point to +/// the replacement values of a "replaceOp" call, etc., are not pure type +/// conversions.) +static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__"; + +/// Return the operation that defines all values in the vector. Return nullptr +/// if the values are not defined by the same operation. +static Operation *getCommonDefiningOp(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = values.front().getDefiningOp(); + for (Value v : llvm::drop_begin(values)) { + if (v.getDefiningOp() != op) + return nullptr; + } + return op; +} + +/// A vector of values is a pure type conversion if all values are defined by +/// the same operation and the operation has the `kPureTypeConversionMarker` +/// attribute. +static bool isPureTypeConversion(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = getCommonDefiningOp(values); + return op && op->hasAttr(kPureTypeConversionMarker); +} + +ValueVector ConversionValueMapping::lookup(const ValueVector &from) const { + auto it = mapping.find(from); + if (it == mapping.end()) { + // No mapping found: The lookup stops here. + return {}; + } + return it->second; } //===----------------------------------------------------------------------===// @@ -873,7 +850,7 @@ namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : context(ctx), config(config) {} + : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {} //===--------------------------------------------------------------------===// // State Management @@ -895,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// failure. template void appendRewrite(Args &&...args) { + assert(config.allowPatternRollback && "appending rewrites is not allowed"); rewrites.push_back( std::make_unique(*this, std::forward(args)...)); } @@ -921,16 +899,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasOpReplaced(Operation *op) const; /// Lookup the most recently mapped values with the desired types in the - /// mapping. + /// mapping, taking into account only replacements. Perform a best-effort + /// search for existing materializations with the desired types. /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. - ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + /// If `skipPureTypeConversions` is "true", materializations that are pure + /// type conversions are not considered. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}, + bool skipPureTypeConversions = false) const; /// Lookup the given value within the map, or return an empty vector if the /// value is not mapped. If it is mapped, this follows the same behavior @@ -993,11 +968,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// If `valuesToMap` is set to a non-null Value, then that value is mapped to /// the results of the unresolved materialization in the conversion value /// mapping. + /// + /// If `isPureTypeConversion` is "true", the materialization is created only + /// to resolve a type mismatch. That means it is not a regular value + /// replacement issued by the user. (Replacement values that are created + /// "out of thin air" appear like unresolved materializations because they are + /// unrealized_conversion_cast ops. However, they must be treated like + /// regular value replacements.) ValueRange buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp = nullptr); + UnrealizedConversionCastOp *castOp = nullptr, + bool isPureTypeConversion = true); /// Find a replacement value for the given SSA value in the conversion value /// mapping. The replacement value must have the same type as the given SSA @@ -1086,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ConversionValueMapping mapping; /// Ordered list of block operations (creations, splits, motions). + /// This vector is maintained only if `allowPatternRollback` is set to + /// "true". Otherwise, all IR rewrites are materialized immediately and no + /// bookkeeping is needed. SmallVector> rewrites; /// A set of operations that should no longer be considered for legalization. @@ -1109,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// by the current pattern. SetVector patternInsertedBlocks; + /// A list of unresolved materializations that were created by the current + /// pattern. + DenseSet patternMaterializations; + /// A mapping for looking up metadata of unresolved materializations. DenseMap unresolvedMaterializations; @@ -1124,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// A set of erased operations. This set is utilized only if + /// `allowPatternRollback` is set to "false". Conceptually, this set is + /// simialar to `replacedOps` (which is maintained when the flag is set to + /// "true"). However, erasing from a DenseSet is more efficient than erasing + /// from a SetVector. + DenseSet erasedOps; + + /// A set of erased blocks. This set is utilized only if + /// `allowPatternRollback` is set to "false". + DenseSet erasedBlocks; + + /// A rewriter that notifies the listener (if any) about all IR + /// modifications. This rewriter is utilized only if `allowPatternRollback` + /// is set to "false". If the flag is set to "true", the listener is notified + /// with a separate mechanism (e.g., in `IRRewrite::commit`). + IRRewriter notifyingRewriter; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1160,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); - if (!repl) - return; - +static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, + Value repl) { if (isa(repl)) { rewriter.replaceAllUsesWith(arg, repl); return; @@ -1181,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + if (!repl) + return; + performReplaceBlockArg(rewriter, arg, repl); +} + void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { @@ -1264,10 +1275,99 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management //===----------------------------------------------------------------------===// -ValueVector -ConversionPatternRewriterImpl::lookupOrDefault(Value from, - TypeRange desiredTypes) const { - return mapping.lookupOrDefault(from, desiredTypes); +ValueVector ConversionPatternRewriterImpl::lookupOrDefault( + Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const { + + // Helper function that looks up a single value. + auto lookup = [&](const ValueVector &values) -> ValueVector { + assert(!values.empty() && "expected non-empty value vector"); + + // If the pattern rollback is enabled, use the mapping to look up the + // values. + if (config.allowPatternRollback) + return mapping.lookup(values); + + // Otherwise, look up values by examining the IR. All replacements have + // already been materialized in IR. + Operation *op = getCommonDefiningOp(values); + if (!op) + return {}; + auto castOp = dyn_cast(op); + if (!castOp) + return {}; + if (!this->unresolvedMaterializations.contains(castOp)) + return {}; + if (castOp.getOutputs() != values) + return {}; + return castOp.getInputs(); + }; + + auto composedLookup = [&](const ValueVector &values) -> ValueVector { + // If possible, replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : values) { + ValueVector r = lookup({v}); + if (!r.empty()) { + llvm::append_range(next, r); + } else { + next.push_back(v); + } + } + if (next != values) { + // At least one value was replaced. + return next; + } + + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. + ValueVector r = lookup(values); + if (r.empty()) { + // No mapping found: The lookup stops here. + return {}; + } + return r; + }; + + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; + ValueVector current{from}; + ValueVector lastNonMaterialization{from}; + do { + // Store the current value if the types match. + bool match = TypeRange(ValueRange(current)) == desiredTypes; + if (skipPureTypeConversions) { + // Skip pure type conversions, if requested. + bool pureConversion = isPureTypeConversion(current); + match &= !pureConversion; + // Keep track of the last mapped value that was not a pure type + // conversion. + if (!pureConversion) + lastNonMaterialization = current; + } + if (match) + desiredValue = current; + + // Lookup next value in the mapping. + ValueVector next = composedLookup(current); + if (next.empty()) + break; + current = std::move(next); + } while (true); + + // If the desired values were found use them, otherwise default to the leaf + // values. (Skip pure type conversions, if requested.) + if (!desiredTypes.empty()) + return desiredValue; + if (skipPureTypeConversions) + return lastNonMaterialization; + return current; } ValueVector @@ -1300,15 +1400,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state, void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, StringRef patternName) { for (auto &rewrite : - llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) { - if (!config.allowPatternRollback && - !isa(rewrite)) { - // Unresolved materializations can always be rolled back (erased). - llvm::report_fatal_error("pattern '" + patternName + - "' rollback of IR modifications requested"); - } + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) rewrite->rollback(); - } rewrites.resize(numRewritesToKeep); } @@ -1324,10 +1417,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); if (!currentTypeConverter) { - // The current pattern does not have a type converter. I.e., it does not - // distinguish between legal and illegal types. For each operand, simply - // pass through the most recently mapped values. - remapped.push_back(lookupOrDefault(operand)); + // The current pattern does not have a type converter. Pass the most + // recently mapped values, excluding materializations. Materializations + // are intentionally excluded because their presence may depend on other + // patterns. Including materializations would make the lookup fragile + // and unpredictable. + remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{}, + /*skipPureTypeConversions=*/true)); continue; } @@ -1356,7 +1452,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = lookupOrDefault(operand); + repl = lookupOrDefault(operand, /*desiredTypes=*/{}, + /*skipPureTypeConversions=*/true); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1368,12 +1465,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation is ignored or was replaced. - return replacedOps.count(op) || ignoredOps.count(op); + return wasOpReplaced(op) || ignoredOps.count(op); } bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { // Check to see if this operation was replaced. - return replacedOps.count(op); + return replacedOps.count(op) || erasedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1457,7 +1554,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // a bit more efficient, so we try to do that when possible. bool fastPath = !config.listener; if (fastPath) { - appendRewrite(newBlock, block, newBlock->end()); + if (config.allowPatternRollback) + appendRewrite(newBlock, block, newBlock->end()); newBlock->getOperations().splice(newBlock->end(), block->getOperations()); } else { while (!block->empty()) @@ -1482,7 +1580,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*valuesToMap=*/{}, /*inputs=*/ValueRange(), - /*outputTypes=*/origArgType, /*originalType=*/Type(), converter) + /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, + /*castOp=*/nullptr, /*isPureTypeConversion=*/false) .front(); replaceUsesOfBlockArgument(origArg, mat, converter); continue; @@ -1504,7 +1603,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( replaceUsesOfBlockArgument(origArg, replArgs, converter); } - appendRewrite(/*origBlock=*/block, newBlock); + if (config.allowPatternRollback) + appendRewrite(/*origBlock=*/block, newBlock); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1523,7 +1623,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp) { + UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); assert(TypeRange(inputs) != outputTypes && @@ -1533,21 +1633,32 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // tracking the materialization like we do for other operations. OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); - auto convertOp = + UnrealizedConversionCastOp convertOp = UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); - if (!valuesToMap.empty()) - mapping.map(valuesToMap, convertOp.getResults()); + if (isPureTypeConversion) + convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); + + // Register the materialization. if (castOp) *castOp = convertOp; unresolvedMaterializations[convertOp] = UnresolvedMaterializationInfo(converter, kind, originalType); - appendRewrite(convertOp, - std::move(valuesToMap)); + if (config.allowPatternRollback) { + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); + appendRewrite(convertOp, + std::move(valuesToMap)); + } else { + patternMaterializations.insert(convertOp); + } return convertOp.getResults(); } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { + assert(config.allowPatternRollback && + "this code path is valid only in rollback mode"); + // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. @@ -1609,26 +1720,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(op->getParentOp()) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) && "attempting to insert into a block within a replaced/erased op"); + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyOperationInserted(op, previous); + if (wasDetached) { - // If the op was detached, it is most likely a newly created op. - // TODO: If the same op is inserted multiple times from a detached state, - // the rollback mechanism may erase the same op multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite(op); + // If the op was detached, it is most likely a newly created op. Add it the + // set of newly created ops, so that it will be legalized. If this op is + // not a newly created op, it will be legalized a second time, which is + // inefficient but harmless. patternNewOps.insert(op); + + if (config.allowPatternRollback) { + // TODO: If the same op is inserted multiple times from a detached + // state, the rollback mechanism may erase the same op multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite(op); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased operations that must be kept up to date. + erasedOps.erase(op); + } return; } // The op was moved from one place to another. - appendRewrite(op, previous); + if (config.allowPatternRollback) + appendRewrite(op, previous); +} + +/// Given that `fromRange` is about to be replaced with `toRange`, compute +/// replacement values with the types of `fromRange`. +static SmallVector +getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, + const SmallVector> &toRange, + const TypeConverter *converter) { + assert(!impl.config.allowPatternRollback && + "this code path is valid only in 'no rollback' mode"); + SmallVector repls; + for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) { + if (from.use_empty()) { + // The replaced value is dead. No replacement value is needed. + repls.push_back(Value()); + continue; + } + + if (to.empty()) { + // The replaced value is dropped. Materialize a replacement value "out of + // thin air". + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(from), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/ValueRange(), + /*outputTypes=*/from.getType(), /*originalType=*/Type(), + converter)[0]; + repls.push_back(srcMat); + continue; + } + + if (TypeRange(to) == TypeRange(from.getType())) { + // The replacement value already has the correct type. Use it directly. + repls.push_back(to[0]); + continue; + } + + // The replacement value has the wrong type. Build a source materialization + // to the original type. + // TODO: This is a bit inefficient. We should try to reuse existing + // materializations if possible. This would require an extension of the + // `lookupOrDefault` API. + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(to), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(), + /*originalType=*/Type(), converter)[0]; + repls.push_back(srcMat); + } + + return repls; } void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector> &&newValues) { - assert(newValues.size() == op->getNumResults()); + assert(newValues.size() == op->getNumResults() && + "incorrect number of replacement values"); + + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + SmallVector repls = getReplacementValues( + *this, op->getResults(), newValues, currentTypeConverter); + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + op->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + op->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Replace the op with the replacement values and notify the listener. + notifyingRewriter.replaceOp(op, repls); + return; + } + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Check if replaced op is an unresolved materialization, i.e., an @@ -1650,7 +1854,8 @@ void ConversionPatternRewriterImpl::replaceOp( MaterializationKind::Source, computeInsertPoint(result), result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputTypes=*/result.getType(), /*originalType=*/Type(), - currentTypeConverter); + currentTypeConverter, /*castOp=*/nullptr, + /*isPureTypeConversion=*/false); continue; } @@ -1667,11 +1872,46 @@ void ConversionPatternRewriterImpl::replaceOp( void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( BlockArgument from, ValueRange to, const TypeConverter *converter) { + if (!config.allowPatternRollback) { + SmallVector toConv = llvm::to_vector(to); + SmallVector repls = + getReplacementValues(*this, from, {toConv}, converter); + IRRewriter r(from.getContext()); + Value repl = repls.front(); + if (!repl) + return; + + performReplaceBlockArg(r, from, repl); + return; + } + appendRewrite(from.getOwner(), from, converter); mapping.map(from, to); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + block->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + block->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Erase the block and notify the listener. + notifyingRewriter.eraseBlock(block); + return; + } + assert(!wasOpReplaced(block->getParentOp()) && "attempting to erase a block within a replaced/erased op"); appendRewrite(block); @@ -1705,23 +1945,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(newParentOp) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) && "attempting to insert into a region within a replaced/erased op"); (void)newParentOp; + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyBlockInserted(block, previous, previousIt); + patternInsertedBlocks.insert(block); if (wasDetached) { // If the block was detached, it is most likely a newly created block. - // TODO: If the same block is inserted multiple times from a detached state, - // the rollback mechanism may erase the same block multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite(block); + if (config.allowPatternRollback) { + // TODO: If the same block is inserted multiple times from a detached + // state, the rollback mechanism may erase the same block multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite(block); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased blocks that must be kept up to date. + erasedBlocks.erase(block); + } return; } // The block was moved from one place to another. - appendRewrite(block, previous, previousIt); + if (config.allowPatternRollback) + appendRewrite(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1754,6 +2008,10 @@ ConversionPatternRewriter::ConversionPatternRewriter( ConversionPatternRewriter::~ConversionPatternRewriter() = default; +const ConversionConfig &ConversionPatternRewriter::getConfig() const { + return impl->config; +} + void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { assert(op && newOp && "expected non-null op"); replaceOp(op, newOp->getResults()); @@ -1897,7 +2155,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // a bit more efficient, so we try to do that when possible. bool fastPath = !impl->config.listener; - if (fastPath) + if (fastPath && impl->config.allowPatternRollback) impl->inlineBlockBefore(source, dest, before); // Replace all uses of block arguments. @@ -1923,6 +2181,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + // Pattern rollback is not allowed: no extra bookkeeping is needed. + PatternRewriter::startOpModification(op); + return; + } assert(!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"); #ifndef NDEBUG @@ -1932,20 +2195,27 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); - PatternRewriter::finalizeOpModification(op); impl->patternModifiedOps.insert(op); + if (!impl->config.allowPatternRollback) { + PatternRewriter::finalizeOpModification(op); + return; + } // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif } void ConversionPatternRewriter::cancelOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + PatternRewriter::cancelOpModification(op); + return; + } #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); @@ -2302,6 +2572,37 @@ OperationLegalizer::legalizeWithFold(Operation *op, return success(); } +/// Report a fatal error indicating that newly produced or modified IR could +/// not be legalized. +static void +reportNewIrLegalizationFatalError(const Pattern &pattern, + const SetVector &newOps, + const SetVector &modifiedOps, + const SetVector &insertedBlocks) { + StringRef detachedBlockStr = "(detached block)"; + std::string newOpNames = llvm::join( + llvm::map_range( + newOps, [](Operation *op) { return op->getName().getStringRef(); }), + ", "); + std::string modifiedOpNames = llvm::join( + llvm::map_range( + newOps, [](Operation *op) { return op->getName().getStringRef(); }), + ", "); + std::string insertedBlockNames = llvm::join( + llvm::map_range(insertedBlocks, + [&](Block *block) { + if (block->getParentOp()) + return block->getParentOp()->getName().getStringRef(); + return detachedBlockStr; + }), + ", "); + llvm::report_fatal_error( + "pattern '" + pattern.getDebugName() + + "' produced IR that could not be legalized. " + "new ops: {" + + newOpNames + "}, " + "modified ops: {" + modifiedOpNames + "}, " + + "inserted block into ops: {" + insertedBlockNames + "}"); +} + LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { @@ -2345,17 +2646,35 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (!rewriterImpl.config.allowPatternRollback) { - // Returning "failure" after modifying IR is not allowed. + // Erase all unresolved materializations. + for (auto op : rewriterImpl.patternMaterializations) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + rewriterImpl.patternMaterializations.clear(); +#if 0 + // Cheap pattern check that could have false positives. Can be enabled + // manually for debugging purposes. E.g., this check would report an API + // violation when an op is created and then erased in the same pattern. + if (!rewriterImpl.patternNewOps.empty() || + !rewriterImpl.patternModifiedOps.empty() || + !rewriterImpl.patternInsertedBlocks.empty()) { + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' rollback of IR modifications requested"); + } +#endif +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Expensive pattern check that can detect more API violations and has no + // fewer false positives than the cheap check. if (checkOp) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + "' returned failure but IR did change"); } - } #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2379,6 +2698,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // successfully applied. auto onSuccess = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + if (!rewriterImpl.config.allowPatternRollback) { + // Eagerly erase unused materializations. + for (auto op : rewriterImpl.patternMaterializations) { + if (op->use_empty()) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + } + rewriterImpl.patternMaterializations.clear(); + } SetVector newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); @@ -2389,8 +2718,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, appliedPatterns.erase(&pattern); if (failed(result)) { if (!rewriterImpl.config.allowPatternRollback) - llvm::report_fatal_error("pattern '" + pattern.getDebugName() + - "' produced IR that could not be legalized"); + reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps, + insertedBlocks); rewriterImpl.resetState(curState, pattern.getDebugName()); } if (config.listener) @@ -2469,6 +2798,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. for (Block *block : insertedBlocks) { + if (impl.erasedBlocks.contains(block)) + continue; + // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); if (!parentOp || parentOp == op || block->getNumArguments() == 0) @@ -2848,6 +3180,49 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, return failure(); } +static SmallVector +cseUnrealizedCasts(SmallVectorImpl &castOps) { + SmallVector result; + DominanceInfo domInfo; + DenseMap> hashedOps; + for (UnrealizedConversionCastOp castOp : castOps) { + unsigned hash = 0; + for (Type type : castOp.getResultTypes()) + hash ^= hash_value(type); + for (Value value : castOp.getInputs()) + hash ^= hash_value(value); + hashedOps[hash].push_back(castOp); + } + // TODO: This should run to a fixed point. + DenseSet erasedOps; + for (auto &it : hashedOps) { + SmallVector &ops = it.second; + if (ops.size() == 1) + continue; + UnrealizedConversionCastOp top = ops.front(); + for (UnrealizedConversionCastOp castOp : llvm::drop_begin(ops)) { + if (castOp.getInputs() != top.getInputs()) + continue; + if (castOp.getResultTypes() != top.getResultTypes()) + continue; + if (domInfo.dominates(castOp, top)) { + std::swap(top, castOp); + } + if (domInfo.properlyDominates(top, castOp)) { + castOp.replaceAllUsesWith(top); + castOp.erase(); + erasedOps.insert(castOp); + continue; + } + } + } + + for (UnrealizedConversionCastOp castOp : castOps) + if (!erasedOps.contains(castOp)) + result.push_back(castOp); + return result; +} + LogicalResult OperationConverter::convertOperations(ArrayRef ops) { assert(!ops.empty() && "expected at least one operation"); const ConversionTarget &target = opLegalizer.getTarget(); @@ -2901,6 +3276,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // patterns.) SmallVector remainingCastOps; reconcileUnrealizedCasts(allCastOps, &remainingCastOps); + remainingCastOps = cseUnrealizedCasts(remainingCastOps); + + // Drop markers. + for (UnrealizedConversionCastOp castOp : remainingCastOps) + castOp->removeAttr(kPureTypeConversionMarker); // Try to legalize all unresolved materializations. if (config.buildMaterializations) { diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir index 8f74976c59773..25a338df8d790 100644 --- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @main() { // This buffer is properly aligned. There should be no error. // CHECK-NOT: ^ memref is not aligned to 8 diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir index 26c731c921356..4c6a48d577a6c 100644 --- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir @@ -5,6 +5,14 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @store_dynamic(%memref: memref, %index: index) { %cst = arith.constant 1.0 : f32 memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref) -> f32 diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir index 8b6308e9c1939..1ac10306395ad 100644 --- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir @@ -1,11 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @cast_to_static_dim(%m: memref) -> memref<10xf32> { %0 = memref.cast %m : memref to memref<10xf32> return %0 : memref<10xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir index 95b9db2832cee..be9417baf93df 100644 --- a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + // Put memref.copy in a function, otherwise the memref.cast may fold. func.func @memcpy_helper(%src: memref, %dest: memref) { memref.copy %src, %dest : memref to memref diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir index 2e3f271743c93..ef4af62459738 100644 --- a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @main() { %c4 = arith.constant 4 : index %alloca = memref.alloca() : memref<1xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir index b87e5bdf0970c..2e42648297875 100644 --- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir @@ -1,12 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ -// RUN: -lower-affine \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @load(%memref: memref<1xf32>, %index: index) { memref.load %memref[%index] : memref<1xf32> return diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir index 12253fa3b5e83..dd000c6904bcb 100644 --- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir @@ -5,6 +5,14 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @store_dynamic(%memref: memref, %index: index) { %cst = arith.constant 1.0 : f32 memref.store %cst, %memref[%index] : memref diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index ec7e4085f2fa5..9fbe5bc60321e 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -1,12 +1,22 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ // RUN: -lower-affine \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @subview(%memref: memref<1xf32>, %offset: index) { memref.subview %memref[%offset] [1] [1] : memref<1xf32> to diff --git a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir index e4aab32d4a390..f37a6d6383c48 100644 --- a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func private @cast_to_static_dim(%t: tensor) -> tensor<10xf32> { %0 = tensor.cast %t : tensor to tensor<10xf32> return %0 : tensor<10xf32> diff --git a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir index c6d8f698b9433..e9e5c040c6488 100644 --- a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir @@ -1,10 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -one-shot-bufferize \ -// RUN: -buffer-deallocation-pipeline \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ // RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s func.func @main() { diff --git a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir index 8e3cab7be704d..73fcec4d7abcd 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @extract(%tensor: tensor<1xf32>, %index: index) { tensor.extract %tensor[%index] : tensor<1xf32> return diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir index 28f9be0fffe64..341a59e8b8102 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @extract_slice(%tensor: tensor<1xf32>, %offset: index) { tensor.extract_slice %tensor[%offset] [1] [1] : tensor<1xf32> to tensor<1xf32> return diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index e4406e60ffead..5630d1540e4d5 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() { %0 = "test.multiple_1_to_n_replacement"() : () -> (f16) "test.invalid"(%0) : (f16) -> () } + +// ----- + +// CHECK-LABEL: func @test_lookup_without_converter +// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16 +// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64 +// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> () +// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> () +func.func @test_lookup_without_converter() { + %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64) + "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> () + // Make sure that the second "replace_with_valid_consumer" lowering does not + // lookup the materialization that was created for the above op. + "test.replace_with_valid_consumer"(%0) : (i64) -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 2eaad552a7a3a..843bd30a51aff 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>, Arguments<(ins Variadic)>; def TestTypeProducerOp : TEST_Op<"type_producer">, Results<(outs AnyType)>; +def TestValidProducerOp : TEST_Op<"valid_producer">, + Results<(outs AnyType)>; +def TestValidConsumerOp : TEST_Op<"valid_consumer">, + Arguments<(ins AnyType)>; def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">, Results<(outs AnyType)>; def TestTypeConsumerOp : TEST_Op<"type_consumer">, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index eda618f5b09c6..52ca9c25ad2cd 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1198,6 +1198,47 @@ class TestEraseOp : public ConversionPattern { } }; +/// Pattern that replaces test.replace_with_valid_producer with +/// test.valid_producer and the specified type. +class TestReplaceWithValidProducer : public ConversionPattern { +public: + TestReplaceWithValidProducer(MLIRContext *ctx) + : ConversionPattern("test.replace_with_valid_producer", 1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto attr = op->getAttrOfType("type"); + if (!attr) + return failure(); + rewriter.replaceOpWithNewOp(op, attr.getValue()); + return success(); + } +}; + +/// Pattern that replaces test.replace_with_valid_consumer with +/// test.valid_consumer. Can be used with and without a type converter. +class TestReplaceWithValidConsumer : public ConversionPattern { +public: + TestReplaceWithValidConsumer(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.replace_with_valid_consumer", 1, + ctx) {} + TestReplaceWithValidConsumer(MLIRContext *ctx) + : ConversionPattern("test.replace_with_valid_consumer", 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // with_converter present: pattern must have been initialized with a type + // converter. + // with_converter absent: pattern must have been initialized without a type + // converter. + if (op->hasAttr("with_converter") != static_cast(getTypeConverter())) + return failure(); + rewriter.replaceOpWithNewOp(op, operands[0]); + return success(); + } +}; + /// This pattern matches a test.convert_block_args op. It either: /// a) Duplicates all block arguments, /// b) or: drops all block arguments and replaces each with 2x the first @@ -1314,6 +1355,7 @@ struct TestTypeConverter : public TypeConverter { TestTypeConverter() { addConversion(convertType); addSourceMaterialization(materializeCast); + addTargetMaterialization(materializeCast); } static LogicalResult convertType(Type t, SmallVectorImpl &results) { @@ -1389,10 +1431,12 @@ struct TestLegalizePatternDriver TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, TestUndoPropertiesModification, TestEraseOp, + TestReplaceWithValidProducer, TestReplaceWithValidConsumer, TestRepetitive1ToNConsumer>(&getContext()); patterns.add(&getContext(), converter); + TestBlockArgReplace, TestReplaceWithValidConsumer>( + &getContext(), converter); patterns.add(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1402,7 +1446,8 @@ struct TestLegalizePatternDriver ConversionTarget target(getContext()); target.addLegalOp(); target.addLegalOp(); + TerminatorOp, OneRegionOp, TestValidProducerOp, + TestValidConsumerOp>(); target.addLegalOp(OperationName("test.legal_op", &getContext())); target .addIllegalOp(); @@ -1454,6 +1499,7 @@ struct TestLegalizePatternDriver if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; DumpNotifications dumpNotifications; config.listener = &dumpNotifications; config.unlegalizedOps = &unlegalizedOps; @@ -1475,6 +1521,7 @@ struct TestLegalizePatternDriver }); ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; DumpNotifications dumpNotifications; config.listener = &dumpNotifications; if (failed(applyFullConversion(getOperation(), target, @@ -1490,6 +1537,7 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet legalizedOps; ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; config.legalizableOps = &legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, std::move(patterns), config))) @@ -1510,6 +1558,10 @@ struct TestLegalizePatternDriver clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"), clEnumValN(ConversionMode::Partial, "partial", "Perform a partial conversion"))}; + + Option allowPatternRollback{*this, "allow-pattern-rollback", + llvm::cl::desc("Allow pattern rollback"), + llvm::cl::init(true)}; }; } // namespace