Skip to content

[mlir] [linalg] DecomposePackUnpack pattern can create invalid output IR when creating a rank-reducing extract_slice. #152037

@nprisament

Description

@nprisament

The DecomposePackUnpack pattern can create invalid IR when creating a rank-reducing tensor.extract_slice op.

Example input IR to recreate this issue:

module {
  func.func @example(%arg0: tensor<1008xf32>) -> tensor<1x1001xf32> {
    %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [1, 126, 1, 1, 8] : tensor<1008xf32> into tensor<1x126x1x1x8xf32>
    %0 = tensor.empty() : tensor<1x1x1x1001xf32>
    %unpack = linalg.unpack %expanded outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [3] inner_tiles = [8] into %0 : tensor<1x126x1x1x8xf32> -> tensor<1x1x1x1001xf32>
    %collapsed = tensor.collapse_shape %unpack [[0, 1, 2], [3]] : tensor<1x1x1x1001xf32> into tensor<1x1001xf32>
    return %collapsed : tensor<1x1001xf32>
  }
}

This is the output IR that is generated after running this pattern on the example input:

error.mlir:5:15: error: expected mixed offsets rank to match mixed sizes rank (2 vs 1) so the rank of the result type is well-formed.
    %unpack = linalg.unpack %expanded outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [3] inner_tiles = [8] into %0 : tensor<1x126x1x1x8xf32> -> tensor<1x1x1x1001xf32>
              ^
error.mlir:5:15: note: see current operation: %3 = "tensor.extract_slice"(%2) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0>, static_sizes = array<i64: 1001>, static_strides = array<i64: 1, 1>}> : (tensor<126x8xf32>) -> tensor<1001xf32>
// -----// IR Dump After LowerPackUnpack Failed (lower-pack-unpack) //----- //
"func.func"() <{function_type = (tensor<1008xf32>) -> tensor<1x1001xf32>, sym_name = "example"}> ({
^bb0(%arg0: tensor<1008xf32>):
  %0 = "tensor.expand_shape"(%arg0) <{reassociation = [[0, 1, 2, 3, 4]], static_output_shape = array<i64: 1, 126, 1, 1, 8>}> : (tensor<1008xf32>) -> tensor<1x126x1x1x8xf32>
  %1 = "tensor.empty"() : () -> tensor<1x1x1x1001xf32>
  %2 = "tensor.extract_slice"(%0) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 126, 1, 1, 8>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<1x126x1x1x8xf32>) -> tensor<126x8xf32>
  %3 = "tensor.extract_slice"(%2) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0>, static_sizes = array<i64: 1001>, static_strides = array<i64: 1, 1>}> : (tensor<126x8xf32>) -> tensor<1001xf32>
  %4 = "tensor.insert_slice"(%3, %1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1001>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<1001xf32>, tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32>
  %5 = "tensor.collapse_shape"(%4) <{reassociation = [[0, 1, 2], [3]]}> : (tensor<1x1x1x1001xf32>) -> tensor<1x1001xf32>
  "func.return"(%5) : (tensor<1x1001xf32>) -> ()
}) : () -> ()

This LowerPackUnpack pass from the above failure only includes the aforementioned pattern put into a RewritePatternSet via populateDecomposePackUnpackPatterns and applied via applyPatternsGreedily to the FuncOp body with the default configuration.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions