Skip to content

Commit 0f35244

Browse files
authored
[mlir][vector] shape_cast(constant) -> constant fold for non-splats (#145539)
The folder `shape_cast(splat constant) -> splat constant` was first introduced [here](3648065#diff-484cea976e0c96459027c951733bf2d22d34c5a0c0de6f577069870ef4588983R2600) (Nov 2020). In that commit there is a comment to _Only handle splat for now_. Based on that I assume the intention was to, at a later time, support a general `shape_cast(constant) -> constant` folder. That is what this PR does One minor downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants: ```mlir func.func @foo() -> (vector<4xi32>, vector<2x2xi32>) { %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1 %0 = vector.shape_cast %cst : vector<4xi32> to vector<2x2xi32> return %cst, %0 : vector<4xi32>, vector<2x2xi32> } ``` gets folded with this new folder to ```mlir func.func @foo() -> (vector<4xi32>, vector<2x2xi32>) { %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1 %cst_0 = arith.constant dense<[[1, 2], [3, 4]]> : vector<2x2xi32> # 'large' constant 2 return %cst, %cst_0 : vector<4xi32>, vector<2x2xi32> } ``` Notes on the above case: 1) This only effects the textual IR, the actual values share the same context storage (I've verified this by checking pointer values in the `DenseIntOrFPElementsAttrStorage` [constructor](https://github.com/llvm/llvm-project/blob/da5c442550a3823fff05c14300c1664d0fbf68c8/mlir/lib/IR/AttributeDetail.h#L59)) so no compile-time memory overhead to this folding. At the LLVM IR level the constant is shared, too. 2) This only happens when the pre-folded constant cannot be dead code eliminated (i.e. when it has 2+ uses) which I don't think is common.
1 parent c7f3437 commit 0f35244

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5916,14 +5916,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
59165916
}
59175917

59185918
// shape_cast(constant) -> constant
5919-
if (auto splatAttr =
5920-
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5921-
return splatAttr.reshape(getType());
5919+
if (auto denseAttr =
5920+
dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
5921+
return denseAttr.reshape(getType());
59225922

59235923
// shape_cast(poison) -> poison
5924-
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5924+
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
59255925
return ub::PoisonAttr::get(getContext());
5926-
}
59275926

59285927
return {};
59295928
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
13301330

13311331
// -----
13321332

1333-
// CHECK-LABEL: shape_cast_constant
1333+
// CHECK-LABEL: shape_cast_splat_constant
13341334
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
13351335
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
13361336
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
1337-
func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
1337+
func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
13381338
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
13391339
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
13401340
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
13441344

13451345
// -----
13461346

1347+
// Test of shape_cast's fold method:
1348+
// shape_cast(constant) -> constant.
1349+
//
1350+
// CHECK-LABEL: @shape_cast_dense_int_constant
1351+
// CHECK: %[[CST:.*]] = arith.constant
1352+
// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
1353+
// CHECK: return %[[CST]] : vector<2x3xi8>
1354+
func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
1355+
%cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
1356+
%0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
1357+
return %0 : vector<2x3xi8>
1358+
}
1359+
1360+
// -----
1361+
1362+
// Test of shape_cast fold's method:
1363+
// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
1364+
//
1365+
// CHECK-LABEL: @shape_cast_dense_float_constant
1366+
// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
1367+
// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
1368+
// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
1369+
func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
1370+
%cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
1371+
%0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
1372+
return %0, %cst : vector<2xf32>, vector<1x2xf32>
1373+
}
1374+
1375+
// -----
1376+
13471377
// CHECK-LABEL: shape_cast_poison
13481378
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
13491379
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>

0 commit comments

Comments
 (0)