Skip to content

Commit 1e504be

Browse files
authored
[MLIR] Specify new padOp's output type in DropPadUnitDims (#150706)
Previously when dropping unit dim from a pad with mixed dynamic/static input/output shapes, the resulting shape would take on the Type of the input, resulting in invalid IR. Also did some minor cleanup to the formatting of the `drop_unit_dim_corresponding_to_dynamic_dim` test to make it match the rest of the file. --------- Signed-off-by: dan <[email protected]>
1 parent cdf75df commit 1e504be

File tree

2 files changed

+48
-22
lines changed

2 files changed

+48
-22
lines changed

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
637637
}
638638

639639
ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
640+
ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
640641
int64_t padRank = sourceShape.size();
641642

642643
auto isStaticZero = [](OpFoldResult f) {
@@ -647,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
647648
allowedUnitDims.end());
648649
llvm::SmallDenseSet<unsigned> unitDims;
649650
SmallVector<int64_t> newShape;
651+
SmallVector<int64_t> newResultShape;
650652
SmallVector<OpFoldResult> newLowPad;
651653
SmallVector<OpFoldResult> newHighPad;
652-
for (const auto [dim, size, low, high] :
653-
zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
654-
padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
654+
for (const auto [dim, size, outSize, low, high] : zip_equal(
655+
llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
656+
resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
655657
if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
656658
isStaticZero(high)) {
657659
unitDims.insert(dim);
658660
} else {
659661
newShape.push_back(size);
662+
newResultShape.push_back(outSize);
660663
newLowPad.push_back(low);
661664
newHighPad.push_back(high);
662665
}
@@ -686,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
686689
collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
687690
reassociationMap, options.rankReductionStrategy);
688691

689-
auto newPadOp = tensor::PadOp::create(
690-
rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
692+
auto newResultType = RankedTensorType::get(
693+
newResultShape, padOp.getResultType().getElementType());
694+
auto newPadOp = rewriter.create<tensor::PadOp>(
695+
padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
691696
newHighPad, paddingVal, padOp.getNofold());
692697

693698
Value dest = padOp.getResult();

mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
10761076

10771077
// -----
10781078

1079+
func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> {
1080+
%c0 = arith.constant 0 : index
1081+
%c1 = arith.constant 1 : index
1082+
%cst = arith.constant 0.000000e+00 : f32
1083+
%padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] {
1084+
^bb0(%arg1: index, %arg2: index):
1085+
tensor.yield %cst : f32
1086+
} : tensor<1x?xf32> to tensor<1x16xf32>
1087+
return %padded : tensor<1x16xf32>
1088+
}
1089+
// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic
1090+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1091+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32>
1092+
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] {
1093+
// CHECK: ^bb0(%[[IDX:.*]]: index):
1094+
// CHECK: tensor.yield %[[CST]] : f32
1095+
// CHECK: } : tensor<?xf32> to tensor<16xf32>
1096+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32>
1097+
// CHECK: return %[[EXPANDED]] : tensor<1x16xf32>
1098+
1099+
// -----
1100+
1101+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
1102+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
1103+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
1104+
module {
1105+
func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
1106+
%cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
1107+
%0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
1108+
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
1109+
^bb0(%in: f32, %in_0: f32, %out: f32):
1110+
%2 = arith.mulf %in, %in_0 : f32
1111+
%3 = arith.addf %out, %2 : f32
1112+
linalg.yield %3 : f32
1113+
} -> tensor<?x1x61x1xf32>
1114+
return %1 : tensor<?x1x61x1xf32>
1115+
}
1116+
}
10791117
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
10801118
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
10811119

@@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
10971135
// CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32>
10981136
// CHECK: }
10991137

1100-
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
1101-
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
1102-
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
1103-
module {
1104-
func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
1105-
%cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
1106-
%0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
1107-
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
1108-
^bb0(%in: f32, %in_0: f32, %out: f32):
1109-
%2 = arith.mulf %in, %in_0 : f32
1110-
%3 = arith.addf %out, %2 : f32
1111-
linalg.yield %3 : f32
1112-
} -> tensor<?x1x61x1xf32>
1113-
return %1 : tensor<?x1x61x1xf32>
1114-
}
1115-
}
1116-
11171138
// -----
11181139

11191140
func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {

0 commit comments

Comments
 (0)