@@ -121,17 +121,17 @@ struct ConversionValueMapping {
121
121
// / false positives.
122
122
bool isMappedTo (Value value) const { return mappedTo.contains (value); }
123
123
124
- // / Lookup the most recently mapped values with the desired types in the
125
- // / mapping.
124
+ // / Lookup a value in the mapping. If `skipPureTypeConversions` is "true",
125
+ // / pure type conversions are not considered. Return an empty vector if no
126
+ // / mapping was found.
126
127
// /
127
- // / Special cases:
128
- // / - If the desired type range is empty, simply return the most recently
129
- // / mapped values.
130
- // / - If there is no mapping to the desired types, also return the most
131
- // / recently mapped values.
132
- // / - If there is no mapping for the given values at all, return the given
133
- // / value.
134
- ValueVector lookupOrDefault (Value from, TypeRange desiredTypes = {}) const ;
128
+ // / Note: This mapping data structure supports N:M mappings. This function
129
+ // / first tries to look up mappings for each input value individually (and
130
+ // / then composes the results). If such a lookup is unsuccessful, the entire
131
+ // / vector is looked up together. If the lookup is still unsuccessful, an
132
+ // / empty vector is returned.
133
+ ValueVector lookup (const ValueVector &from,
134
+ bool skipPureTypeConversions = false ) const ;
135
135
136
136
template <typename T>
137
137
struct IsValueVector : std::is_same<std::decay_t <T>, ValueVector> {};
@@ -185,54 +185,55 @@ struct ConversionValueMapping {
185
185
};
186
186
} // namespace
187
187
188
- ValueVector
189
- ConversionValueMapping::lookupOrDefault (Value from,
190
- TypeRange desiredTypes) const {
191
- // Try to find the deepest values that have the desired types. If there is no
192
- // such mapping, simply return the deepest values.
193
- ValueVector desiredValue;
194
- ValueVector current{from};
195
- do {
196
- // Store the current value if the types match.
197
- if (TypeRange (ValueRange (current)) == desiredTypes)
198
- desiredValue = current;
199
-
200
- // If possible, Replace each value with (one or multiple) mapped values.
201
- ValueVector next;
202
- for (Value v : current) {
203
- auto it = mapping.find ({v});
204
- if (it != mapping.end ()) {
205
- llvm::append_range (next, it->second );
206
- } else {
207
- next.push_back (v);
208
- }
209
- }
210
- if (next != current) {
211
- // If at least one value was replaced, continue the lookup from there.
212
- current = std::move (next);
213
- continue ;
214
- }
188
+ // / Marker attribute for pure type conversions. I.e., mappings whose only
189
+ // / purpose is to resolve a type mismatch. (In contrast, mappings that point to
190
+ // / the replacement values of a "replaceOp" call, etc., are not pure type
191
+ // / conversions.)
192
+ static const StringRef kPureTypeConversionMarker = " __pure_type_conversion__" ;
193
+
194
+ // / A vector of values is a pure type conversion if all values are defined by
195
+ // / the same operation and the operation has the `kPureTypeConversionMarker`
196
+ // / attribute.
197
+ static bool isPureTypeConversion (const ValueVector &values) {
198
+ assert (!values.empty () && " expected non-empty value vector" );
199
+ Operation *op = values.front ().getDefiningOp ();
200
+ for (Value v : llvm::drop_begin (values))
201
+ if (v.getDefiningOp () != op)
202
+ return false ;
203
+ return op && op->hasAttr (kPureTypeConversionMarker );
204
+ }
215
205
216
- // Otherwise: Check if there is a mapping for the entire vector. Such
217
- // mappings are materializations. (N:M mapping are not supported for value
218
- // replacements.)
219
- //
220
- // Note: From a correctness point of view, materializations do not have to
221
- // be stored (and looked up) in the mapping. But for performance reasons,
222
- // we choose to reuse existing IR (when possible) instead of creating it
223
- // multiple times.
224
- auto it = mapping.find (current);
225
- if (it == mapping.end ()) {
226
- // No mapping found: The lookup stops here.
227
- break ;
206
+ ValueVector ConversionValueMapping::lookup (const ValueVector &from,
207
+ bool skipPureTypeConversions) const {
208
+ // If possible, replace each value with (one or multiple) mapped values.
209
+ ValueVector next;
210
+ for (Value v : from) {
211
+ auto it = mapping.find ({v});
212
+ if (it != mapping.end ()) {
213
+ llvm::append_range (next, it->second );
214
+ } else {
215
+ next.push_back (v);
228
216
}
229
- current = it->second ;
230
- } while (true );
231
-
232
- // If the desired values were found use them, otherwise default to the leaf
233
- // values.
234
- // Note: If `desiredTypes` is empty, this function always returns `current`.
235
- return !desiredValue.empty () ? std::move (desiredValue) : std::move (current);
217
+ }
218
+ if (next != from) {
219
+ // At least one value was replaced.
220
+ return next;
221
+ }
222
+
223
+ // Otherwise: Check if there is a mapping for the entire vector. Such
224
+ // mappings are materializations. (N:M mapping are not supported for value
225
+ // replacements.)
226
+ //
227
+ // Note: From a correctness point of view, materializations do not have to
228
+ // be stored (and looked up) in the mapping. But for performance reasons,
229
+ // we choose to reuse existing IR (when possible) instead of creating it
230
+ // multiple times.
231
+ auto it = mapping.find (from);
232
+ if (it == mapping.end ()) {
233
+ // No mapping found: The lookup stops here.
234
+ return {};
235
+ }
236
+ return it->second ;
236
237
}
237
238
238
239
// ===----------------------------------------------------------------------===//
@@ -930,7 +931,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
930
931
// / recently mapped values.
931
932
// / - If there is no mapping for the given values at all, return the given
932
933
// / value.
933
- ValueVector lookupOrDefault (Value from, TypeRange desiredTypes = {}) const ;
934
+ // /
935
+ // / If `skipPureTypeConversions` is "true", materializations that are pure
936
+ // / type conversions are not considered.
937
+ ValueVector lookupOrDefault (Value from, TypeRange desiredTypes = {},
938
+ bool skipPureTypeConversions = false ) const ;
934
939
935
940
// / Lookup the given value within the map, or return an empty vector if the
936
941
// / value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +998,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
993
998
// / If `valuesToMap` is set to a non-null Value, then that value is mapped to
994
999
// / the results of the unresolved materialization in the conversion value
995
1000
// / mapping.
1001
+ // /
1002
+ // / If `isPureTypeConversion` is "true", the materialization is created only
1003
+ // / to resolve a type mismatch. That means it is not a regular value
1004
+ // / replacement issued by the user. (Replacement values that are created
1005
+ // / "out of thin air" appear like unresolved materializations because they are
1006
+ // / unrealized_conversion_cast ops. However, they must be treated like
1007
+ // / regular value replacements.)
996
1008
ValueRange buildUnresolvedMaterialization (
997
1009
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
998
1010
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
999
1011
Type originalType, const TypeConverter *converter,
1000
- UnrealizedConversionCastOp *castOp = nullptr );
1012
+ UnrealizedConversionCastOp *castOp = nullptr ,
1013
+ bool isPureTypeConversion = true );
1001
1014
1002
1015
// / Find a replacement value for the given SSA value in the conversion value
1003
1016
// / mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1277,42 @@ void ConversionPatternRewriterImpl::applyRewrites() {
1264
1277
// State Management
1265
1278
// ===----------------------------------------------------------------------===//
1266
1279
1267
- ValueVector
1268
- ConversionPatternRewriterImpl::lookupOrDefault (Value from,
1269
- TypeRange desiredTypes) const {
1270
- return mapping.lookupOrDefault (from, desiredTypes);
1280
+ ValueVector ConversionPatternRewriterImpl::lookupOrDefault (
1281
+ Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1282
+ // Try to find the deepest values that have the desired types. If there is no
1283
+ // such mapping, simply return the deepest values.
1284
+ ValueVector desiredValue;
1285
+ ValueVector current{from};
1286
+ ValueVector lastNonMaterialization{from};
1287
+ do {
1288
+ // Store the current value if the types match.
1289
+ bool match = TypeRange (ValueRange (current)) == desiredTypes;
1290
+ if (skipPureTypeConversions) {
1291
+ // Skip pure type conversions, if requested.
1292
+ bool pureConversion = isPureTypeConversion (current);
1293
+ match &= !pureConversion;
1294
+ // Keep track of the last mapped value that was not a pure type
1295
+ // conversion.
1296
+ if (!pureConversion)
1297
+ lastNonMaterialization = current;
1298
+ }
1299
+ if (match)
1300
+ desiredValue = current;
1301
+
1302
+ // Lookup next value in the mapping.
1303
+ ValueVector next = mapping.lookup (current, skipPureTypeConversions);
1304
+ if (next.empty ())
1305
+ break ;
1306
+ current = std::move (next);
1307
+ } while (true );
1308
+
1309
+ // If the desired values were found use them, otherwise default to the leaf
1310
+ // values. (Skip pure type conversions, if requested.)
1311
+ if (!desiredTypes.empty ())
1312
+ return desiredValue;
1313
+ if (skipPureTypeConversions)
1314
+ return lastNonMaterialization;
1315
+ return current;
1271
1316
}
1272
1317
1273
1318
ValueVector
@@ -1324,10 +1369,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
1324
1369
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc ();
1325
1370
1326
1371
if (!currentTypeConverter) {
1327
- // The current pattern does not have a type converter. I.e., it does not
1328
- // distinguish between legal and illegal types. For each operand, simply
1329
- // pass through the most recently mapped values.
1330
- remapped.push_back (lookupOrDefault (operand));
1372
+ // The current pattern does not have a type converter. Pass the most
1373
+ // recently mapped values, excluding materializations. Materializations
1374
+ // are intentionally excluded because their presence may depend on other
1375
+ // patterns. Including materializations would make the lookup fragile
1376
+ // and unpredictable.
1377
+ remapped.push_back (lookupOrDefault (operand, /* desiredTypes=*/ {},
1378
+ /* skipPureTypeConversions=*/ true ));
1331
1379
continue ;
1332
1380
}
1333
1381
@@ -1356,7 +1404,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
1356
1404
}
1357
1405
1358
1406
// Create a materialization for the most recently mapped values.
1359
- repl = lookupOrDefault (operand);
1407
+ repl = lookupOrDefault (operand, /* desiredTypes=*/ {},
1408
+ /* skipPureTypeConversions=*/ true );
1360
1409
ValueRange castValues = buildUnresolvedMaterialization (
1361
1410
MaterializationKind::Target, computeInsertPoint (repl), operandLoc,
1362
1411
/* valuesToMap=*/ repl, /* inputs=*/ repl, /* outputTypes=*/ legalTypes,
@@ -1482,7 +1531,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1482
1531
OpBuilder::InsertPoint (newBlock, newBlock->begin ()),
1483
1532
origArg.getLoc (),
1484
1533
/* valuesToMap=*/ {}, /* inputs=*/ ValueRange (),
1485
- /* outputTypes=*/ origArgType, /* originalType=*/ Type (), converter)
1534
+ /* outputTypes=*/ origArgType, /* originalType=*/ Type (), converter,
1535
+ /* castOp=*/ nullptr , /* isPureTypeConversion=*/ false )
1486
1536
.front ();
1487
1537
replaceUsesOfBlockArgument (origArg, mat, converter);
1488
1538
continue ;
@@ -1523,7 +1573,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1523
1573
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1524
1574
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1525
1575
Type originalType, const TypeConverter *converter,
1526
- UnrealizedConversionCastOp *castOp) {
1576
+ UnrealizedConversionCastOp *castOp, bool isPureTypeConversion ) {
1527
1577
assert ((!originalType || kind == MaterializationKind::Target) &&
1528
1578
" original type is valid only for target materializations" );
1529
1579
assert (TypeRange (inputs) != outputTypes &&
@@ -1535,6 +1585,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1535
1585
builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
1536
1586
auto convertOp =
1537
1587
UnrealizedConversionCastOp::create (builder, loc, outputTypes, inputs);
1588
+ if (isPureTypeConversion)
1589
+ convertOp->setAttr (kPureTypeConversionMarker , builder.getUnitAttr ());
1538
1590
if (!valuesToMap.empty ())
1539
1591
mapping.map (valuesToMap, convertOp.getResults ());
1540
1592
if (castOp)
@@ -1650,7 +1702,8 @@ void ConversionPatternRewriterImpl::replaceOp(
1650
1702
MaterializationKind::Source, computeInsertPoint (result),
1651
1703
result.getLoc (), /* valuesToMap=*/ {result}, /* inputs=*/ ValueRange (),
1652
1704
/* outputTypes=*/ result.getType (), /* originalType=*/ Type (),
1653
- currentTypeConverter);
1705
+ currentTypeConverter, /* castOp=*/ nullptr ,
1706
+ /* isPureTypeConversion=*/ false );
1654
1707
continue ;
1655
1708
}
1656
1709
@@ -2902,6 +2955,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2902
2955
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2903
2956
reconcileUnrealizedCasts (allCastOps, &remainingCastOps);
2904
2957
2958
+ // Drop markers.
2959
+ for (UnrealizedConversionCastOp castOp : remainingCastOps)
2960
+ castOp->removeAttr (kPureTypeConversionMarker );
2961
+
2905
2962
// Try to legalize all unresolved materializations.
2906
2963
if (config.buildMaterializations ) {
2907
2964
IRRewriter rewriter (rewriterImpl.context , config.listener );
0 commit comments