-
Notifications
You must be signed in to change notification settings - Fork 14.7k
Open
Labels
Description
The DecomposePackUnpack
pattern can create invalid IR when creating a rank-reducing tensor.extract_slice
op.
Example input IR to recreate this issue:
module {
func.func @example(%arg0: tensor<1008xf32>) -> tensor<1x1001xf32> {
%expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [1, 126, 1, 1, 8] : tensor<1008xf32> into tensor<1x126x1x1x8xf32>
%0 = tensor.empty() : tensor<1x1x1x1001xf32>
%unpack = linalg.unpack %expanded outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [3] inner_tiles = [8] into %0 : tensor<1x126x1x1x8xf32> -> tensor<1x1x1x1001xf32>
%collapsed = tensor.collapse_shape %unpack [[0, 1, 2], [3]] : tensor<1x1x1x1001xf32> into tensor<1x1001xf32>
return %collapsed : tensor<1x1001xf32>
}
}
This is the output IR that is generated after running this pattern on the example input:
error.mlir:5:15: error: expected mixed offsets rank to match mixed sizes rank (2 vs 1) so the rank of the result type is well-formed.
%unpack = linalg.unpack %expanded outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [3] inner_tiles = [8] into %0 : tensor<1x126x1x1x8xf32> -> tensor<1x1x1x1001xf32>
^
error.mlir:5:15: note: see current operation: %3 = "tensor.extract_slice"(%2) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0>, static_sizes = array<i64: 1001>, static_strides = array<i64: 1, 1>}> : (tensor<126x8xf32>) -> tensor<1001xf32>
// -----// IR Dump After LowerPackUnpack Failed (lower-pack-unpack) //----- //
"func.func"() <{function_type = (tensor<1008xf32>) -> tensor<1x1001xf32>, sym_name = "example"}> ({
^bb0(%arg0: tensor<1008xf32>):
%0 = "tensor.expand_shape"(%arg0) <{reassociation = [[0, 1, 2, 3, 4]], static_output_shape = array<i64: 1, 126, 1, 1, 8>}> : (tensor<1008xf32>) -> tensor<1x126x1x1x8xf32>
%1 = "tensor.empty"() : () -> tensor<1x1x1x1001xf32>
%2 = "tensor.extract_slice"(%0) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 126, 1, 1, 8>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<1x126x1x1x8xf32>) -> tensor<126x8xf32>
%3 = "tensor.extract_slice"(%2) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0>, static_sizes = array<i64: 1001>, static_strides = array<i64: 1, 1>}> : (tensor<126x8xf32>) -> tensor<1001xf32>
%4 = "tensor.insert_slice"(%3, %1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1001>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<1001xf32>, tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32>
%5 = "tensor.collapse_shape"(%4) <{reassociation = [[0, 1, 2], [3]]}> : (tensor<1x1x1x1001xf32>) -> tensor<1x1001xf32>
"func.return"(%5) : (tensor<1x1001xf32>) -> ()
}) : () -> ()
This LowerPackUnpack
pass from the above failure only includes the aforementioned pattern put into a RewritePatternSet
via populateDecomposePackUnpackPatterns
and applied via applyPatternsGreedily
to the FuncOp body with the default configuration.