Skip to content

Commit 1fad57e

Browse files
make lookup unambiguous
1 parent e44a14e commit 1fad57e

File tree

5 files changed

+222
-83
lines changed

5 files changed

+222
-83
lines changed

mlir/docs/DialectConversion.md

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,33 @@ struct MyConversionPattern : public ConversionPattern {
202202
203203
#### Type Safety
204204
205-
The types of the remapped operands provided to a conversion pattern must be of a
206-
type expected by the pattern. The expected types of a pattern are determined by
207-
a provided [TypeConverter](#type-converter). If no type converter is provided,
208-
the types of the remapped operands are expected to match the types of the
209-
original operands. If a type converter is provided, the types of the remapped
210-
operands are expected to be legal as determined by the converter. If the
211-
remapped operand types are not of an expected type, and a materialization to the
212-
expected type could not be performed, the pattern fails application before the
213-
`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
214-
explicitly ensure type safety, or sanitize the types of the incoming remapped
215-
operands. More information on type conversion is detailed in the
205+
The types of the remapped operands provided to a conversion pattern (through
206+
the adaptor or `ArrayRef` of operands) depend on type conversion rules.
207+
208+
If the pattern was initialized with a [type converter](#type-converter), the
209+
conversion driver passes values whose types match the legalized types of the
210+
operands of the matched operation as per the type converter. To that end, the
211+
conversion driver may insert target materializations to convert the most
212+
recently mapped values to the expected legalized types. The driver tries to
213+
reuse existing materializations on a best-effort basis, but this is not
214+
guaranteed by the infrastructure. If the operand types of the matched op could
215+
not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
216+
is invoked.
217+
218+
If the pattern was initialized without a type converter, the conversion driver
219+
passes the most recently mapped values to the pattern, excluding any
220+
materializations. Materializations are intentionally excluded because their
221+
presence may depend on other patterns. Passing materializationsm would make the
222+
conversion infrastructure fragile and unpredictable. Moreover, there could be
223+
multiple materializations to different types. (This can be the case when
224+
multiple patterns are running with different type converters.) In such a case,
225+
it would be unclear which materialization to pass. If a value with the same
226+
type as the original operand is desired, users can directly take the respective
227+
operand from the matched operation.
228+
229+
The above rules ensure that patterns do not have to explicitly ensure type
230+
safety, or sanitize the types of the incoming remapped operands. More
231+
information on type conversion is detailed in the
216232
[dedicated section](#type-conversion) below.
217233
218234
## Type Conversion

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 127 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,17 @@ struct ConversionValueMapping {
121121
/// false positives.
122122
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
123123

124-
/// Lookup the most recently mapped values with the desired types in the
125-
/// mapping.
124+
/// Lookup a value in the mapping. If `skipPureTypeConversions` is "true",
125+
/// pure type conversions are not considered. Return an empty vector if no
126+
/// mapping was found.
126127
///
127-
/// Special cases:
128-
/// - If the desired type range is empty, simply return the most recently
129-
/// mapped values.
130-
/// - If there is no mapping to the desired types, also return the most
131-
/// recently mapped values.
132-
/// - If there is no mapping for the given values at all, return the given
133-
/// value.
134-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
128+
/// Note: This mapping data structure supports N:M mappings. This function
129+
/// first tries to look up mappings for each input value individually (and
130+
/// then composes the results). If such a lookup is unsuccessful, the entire
131+
/// vector is looked up together. If the lookup is still unsuccessful, an
132+
/// empty vector is returned.
133+
ValueVector lookup(const ValueVector &from,
134+
bool skipPureTypeConversions = false) const;
135135

136136
template <typename T>
137137
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -185,54 +185,55 @@ struct ConversionValueMapping {
185185
};
186186
} // namespace
187187

188-
ValueVector
189-
ConversionValueMapping::lookupOrDefault(Value from,
190-
TypeRange desiredTypes) const {
191-
// Try to find the deepest values that have the desired types. If there is no
192-
// such mapping, simply return the deepest values.
193-
ValueVector desiredValue;
194-
ValueVector current{from};
195-
do {
196-
// Store the current value if the types match.
197-
if (TypeRange(ValueRange(current)) == desiredTypes)
198-
desiredValue = current;
199-
200-
// If possible, Replace each value with (one or multiple) mapped values.
201-
ValueVector next;
202-
for (Value v : current) {
203-
auto it = mapping.find({v});
204-
if (it != mapping.end()) {
205-
llvm::append_range(next, it->second);
206-
} else {
207-
next.push_back(v);
208-
}
209-
}
210-
if (next != current) {
211-
// If at least one value was replaced, continue the lookup from there.
212-
current = std::move(next);
213-
continue;
214-
}
188+
/// Marker attribute for pure type conversions. I.e., mappings whose only
189+
/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
190+
/// the replacement values of a "replaceOp" call, etc., are not pure type
191+
/// conversions.)
192+
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
193+
194+
/// A vector of values is a pure type conversion if all values are defined by
195+
/// the same operation and the operation has the `kPureTypeConversionMarker`
196+
/// attribute.
197+
static bool isPureTypeConversion(const ValueVector &values) {
198+
assert(!values.empty() && "expected non-empty value vector");
199+
Operation *op = values.front().getDefiningOp();
200+
for (Value v : llvm::drop_begin(values))
201+
if (v.getDefiningOp() != op)
202+
return false;
203+
return op && op->hasAttr(kPureTypeConversionMarker);
204+
}
215205

216-
// Otherwise: Check if there is a mapping for the entire vector. Such
217-
// mappings are materializations. (N:M mapping are not supported for value
218-
// replacements.)
219-
//
220-
// Note: From a correctness point of view, materializations do not have to
221-
// be stored (and looked up) in the mapping. But for performance reasons,
222-
// we choose to reuse existing IR (when possible) instead of creating it
223-
// multiple times.
224-
auto it = mapping.find(current);
225-
if (it == mapping.end()) {
226-
// No mapping found: The lookup stops here.
227-
break;
206+
ValueVector ConversionValueMapping::lookup(const ValueVector &from,
207+
bool skipPureTypeConversions) const {
208+
// If possible, replace each value with (one or multiple) mapped values.
209+
ValueVector next;
210+
for (Value v : from) {
211+
auto it = mapping.find({v});
212+
if (it != mapping.end()) {
213+
llvm::append_range(next, it->second);
214+
} else {
215+
next.push_back(v);
228216
}
229-
current = it->second;
230-
} while (true);
231-
232-
// If the desired values were found use them, otherwise default to the leaf
233-
// values.
234-
// Note: If `desiredTypes` is empty, this function always returns `current`.
235-
return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
217+
}
218+
if (next != from) {
219+
// At least one value was replaced.
220+
return next;
221+
}
222+
223+
// Otherwise: Check if there is a mapping for the entire vector. Such
224+
// mappings are materializations. (N:M mapping are not supported for value
225+
// replacements.)
226+
//
227+
// Note: From a correctness point of view, materializations do not have to
228+
// be stored (and looked up) in the mapping. But for performance reasons,
229+
// we choose to reuse existing IR (when possible) instead of creating it
230+
// multiple times.
231+
auto it = mapping.find(from);
232+
if (it == mapping.end()) {
233+
// No mapping found: The lookup stops here.
234+
return {};
235+
}
236+
return it->second;
236237
}
237238

238239
//===----------------------------------------------------------------------===//
@@ -930,7 +931,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
930931
/// recently mapped values.
931932
/// - If there is no mapping for the given values at all, return the given
932933
/// value.
933-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
934+
///
935+
/// If `skipPureTypeConversions` is "true", materializations that are pure
936+
/// type conversions are not considered.
937+
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
938+
bool skipPureTypeConversions = false) const;
934939

935940
/// Lookup the given value within the map, or return an empty vector if the
936941
/// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +998,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
993998
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
994999
/// the results of the unresolved materialization in the conversion value
9951000
/// mapping.
1001+
///
1002+
/// If `isPureTypeConversion` is "true", the materialization is created only
1003+
/// to resolve a type mismatch. That means it is not a regular value
1004+
/// replacement issued by the user. (Replacement values that are created
1005+
/// "out of thin air" appear like unresolved materializations because they are
1006+
/// unrealized_conversion_cast ops. However, they must be treated like
1007+
/// regular value replacements.)
9961008
ValueRange buildUnresolvedMaterialization(
9971009
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
9981010
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
9991011
Type originalType, const TypeConverter *converter,
1000-
UnrealizedConversionCastOp *castOp = nullptr);
1012+
UnrealizedConversionCastOp *castOp = nullptr,
1013+
bool isPureTypeConversion = true);
10011014

10021015
/// Find a replacement value for the given SSA value in the conversion value
10031016
/// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1277,42 @@ void ConversionPatternRewriterImpl::applyRewrites() {
12641277
// State Management
12651278
//===----------------------------------------------------------------------===//
12661279

1267-
ValueVector
1268-
ConversionPatternRewriterImpl::lookupOrDefault(Value from,
1269-
TypeRange desiredTypes) const {
1270-
return mapping.lookupOrDefault(from, desiredTypes);
1280+
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
1281+
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1282+
// Try to find the deepest values that have the desired types. If there is no
1283+
// such mapping, simply return the deepest values.
1284+
ValueVector desiredValue;
1285+
ValueVector current{from};
1286+
ValueVector lastNonMaterialization{from};
1287+
do {
1288+
// Store the current value if the types match.
1289+
bool match = TypeRange(ValueRange(current)) == desiredTypes;
1290+
if (skipPureTypeConversions) {
1291+
// Skip pure type conversions, if requested.
1292+
bool pureConversion = isPureTypeConversion(current);
1293+
match &= !pureConversion;
1294+
// Keep track of the last mapped value that was not a pure type
1295+
// conversion.
1296+
if (!pureConversion)
1297+
lastNonMaterialization = current;
1298+
}
1299+
if (match)
1300+
desiredValue = current;
1301+
1302+
// Lookup next value in the mapping.
1303+
ValueVector next = mapping.lookup(current, skipPureTypeConversions);
1304+
if (next.empty())
1305+
break;
1306+
current = std::move(next);
1307+
} while (true);
1308+
1309+
// If the desired values were found use them, otherwise default to the leaf
1310+
// values. (Skip pure type conversions, if requested.)
1311+
if (!desiredTypes.empty())
1312+
return desiredValue;
1313+
if (skipPureTypeConversions)
1314+
return lastNonMaterialization;
1315+
return current;
12711316
}
12721317

12731318
ValueVector
@@ -1324,10 +1369,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13241369
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
13251370

13261371
if (!currentTypeConverter) {
1327-
// The current pattern does not have a type converter. I.e., it does not
1328-
// distinguish between legal and illegal types. For each operand, simply
1329-
// pass through the most recently mapped values.
1330-
remapped.push_back(lookupOrDefault(operand));
1372+
// The current pattern does not have a type converter. Pass the most
1373+
// recently mapped values, excluding materializations. Materializations
1374+
// are intentionally excluded because their presence may depend on other
1375+
// patterns. Including materializations would make the lookup fragile
1376+
// and unpredictable.
1377+
remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1378+
/*skipPureTypeConversions=*/true));
13311379
continue;
13321380
}
13331381

@@ -1356,7 +1404,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13561404
}
13571405

13581406
// Create a materialization for the most recently mapped values.
1359-
repl = lookupOrDefault(operand);
1407+
repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1408+
/*skipPureTypeConversions=*/true);
13601409
ValueRange castValues = buildUnresolvedMaterialization(
13611410
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
13621411
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1531,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14821531
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
14831532
origArg.getLoc(),
14841533
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1485-
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1534+
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1535+
/*castOp=*/nullptr, /*isPureTypeConversion=*/false)
14861536
.front();
14871537
replaceUsesOfBlockArgument(origArg, mat, converter);
14881538
continue;
@@ -1523,7 +1573,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15231573
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
15241574
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
15251575
Type originalType, const TypeConverter *converter,
1526-
UnrealizedConversionCastOp *castOp) {
1576+
UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
15271577
assert((!originalType || kind == MaterializationKind::Target) &&
15281578
"original type is valid only for target materializations");
15291579
assert(TypeRange(inputs) != outputTypes &&
@@ -1535,6 +1585,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15351585
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
15361586
auto convertOp =
15371587
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1588+
if (isPureTypeConversion)
1589+
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
15381590
if (!valuesToMap.empty())
15391591
mapping.map(valuesToMap, convertOp.getResults());
15401592
if (castOp)
@@ -1650,7 +1702,8 @@ void ConversionPatternRewriterImpl::replaceOp(
16501702
MaterializationKind::Source, computeInsertPoint(result),
16511703
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
16521704
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
1653-
currentTypeConverter);
1705+
currentTypeConverter, /*castOp=*/nullptr,
1706+
/*isPureTypeConversion=*/false);
16541707
continue;
16551708
}
16561709

@@ -2902,6 +2955,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
29022955
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
29032956
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
29042957

2958+
// Drop markers.
2959+
for (UnrealizedConversionCastOp castOp : remainingCastOps)
2960+
castOp->removeAttr(kPureTypeConversionMarker);
2961+
29052962
// Try to legalize all unresolved materializations.
29062963
if (config.buildMaterializations) {
29072964
IRRewriter rewriter(rewriterImpl.context, config.listener);

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
415415
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
416416
"test.invalid"(%0) : (f16) -> ()
417417
}
418+
419+
// -----
420+
421+
// CHECK-LABEL: func @test_lookup_without_converter
422+
// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16
423+
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
424+
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
425+
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
426+
func.func @test_lookup_without_converter() {
427+
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
428+
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
429+
// Make sure that the second "replace_with_valid_consumer" lowering does not
430+
// lookup the materialization that was created for the above op.
431+
"test.replace_with_valid_consumer"(%0) : (i64) -> ()
432+
// expected-remark@+1 {{op 'func.return' is not legalizable}}
433+
return
434+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
21042104
Arguments<(ins Variadic<AnyType>)>;
21052105
def TestTypeProducerOp : TEST_Op<"type_producer">,
21062106
Results<(outs AnyType)>;
2107+
def TestValidProducerOp : TEST_Op<"valid_producer">,
2108+
Results<(outs AnyType)>;
2109+
def TestValidConsumerOp : TEST_Op<"valid_consumer">,
2110+
Arguments<(ins AnyType)>;
21072111
def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
21082112
Results<(outs AnyType)>;
21092113
def TestTypeConsumerOp : TEST_Op<"type_consumer">,

0 commit comments

Comments
 (0)