-
Notifications
You must be signed in to change notification settings - Fork 14.7k
Open
Labels
Description
I'm trying to lower a simple matmul to SVE as show in the test here I'm using the same transform code plus some more to minimize the flags passed to mlir-opt, this is the code:
module {
func.func @bare_matmul(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
%cst = arith.constant 0.000000e+00 : f32
%c128_i32 = arith.constant 128 : i32
%0 = arith.muli %arg9, %c128_i32 : i32
%1 = arith.index_cast %0 : i32 to index
%2 = arith.muli %arg10, %c128_i32 : i32
%3 = arith.index_cast %2 : i32 to index
%4 = arith.index_cast %arg5 : i32 to index
%5 = arith.muli %1, %4 : index
%6 = arith.addi %5, %3 : index
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%6], sizes: [128, 128], strides: [%4, 1] : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>>
%alloc = memref.alloc() : memref<128x128xf32>
memref.copy %reinterpret_cast, %alloc : memref<128x128xf32, strided<[?, 1], offset: ?>> to memref<128x128xf32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<128x128xf32> to tensor<128x128xf32>
%8 = arith.index_cast %arg4 : i32 to index
%9 = arith.muli %1, %8 : index
%10 = arith.addi %9, %3 : index
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%10], sizes: [128, 128], strides: [%8, 1] : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>>
%alloc_1 = memref.alloc() : memref<128x128xf32>
memref.copy %reinterpret_cast_0, %alloc_1 : memref<128x128xf32, strided<[?, 1], offset: ?>> to memref<128x128xf32>
%11 = bufferization.to_tensor %alloc_1 restrict writable : memref<128x128xf32> to tensor<128x128xf32>
%12 = tensor.empty() : tensor<128x128xf32>
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<128x128xf32>) -> tensor<128x128xf32>
%14 = linalg.matmul ins(%7, %11 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%13 : tensor<128x128xf32>) -> tensor<128x128xf32>
%reinterpret_cast_2 = memref.reinterpret_cast %arg2 to offset: [%10], sizes: [128, 128], strides: [%8, 1] : memref<*xf32> to memref<128x128xf32, strided<[?, 1], offset: ?>>
bufferization.materialize_in_destination %14 in writable %reinterpret_cast_2 : (tensor<128x128xf32>, memref<128x128xf32, strided<[?, 1], offset: ?>>) -> ()
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__tile_and_vectorize(%arg0: !transform.op<"func.func"> {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.op<"func.func">) -> !transform.any_op
%tiled_linalg_op, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.structured.vectorize %tiled_linalg_op vector_sizes [2, [4], 1] : !transform.any_op
transform.apply_patterns to %arg0 {
transform.apply_patterns.vector.reduction_to_contract
transform.apply_patterns.vector.transfer_permutation_patterns
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.sink_ops
} : !transform.op<"func.func">
transform.apply_patterns to %arg0 {
transform.apply_patterns.vector.lower_contraction
transform.apply_patterns.vector.lower_outerproduct
} : !transform.op<"func.func">
transform.yield
}
transform.named_sequence @opt(%arg0: !transform.op<"func.func"> {transform.consumed}) {
transform.apply_cse to %arg0 : !transform.op<"func.func">
%0 = transform.apply_registered_pass "canonicalize" to %arg0 : (!transform.op<"func.func">) -> !transform.op<"func.func">
%1 = transform.apply_registered_pass "convert-vector-to-scf" to %0 : (!transform.op<"func.func">) -> !transform.op<"func.func">
%2 = transform.apply_registered_pass "convert-linalg-to-loops" to %1 : (!transform.op<"func.func">) -> !transform.op<"func.func">
%3 = transform.apply_registered_pass "arm-sve-legalize-vector-storage" to %2 : (!transform.op<"func.func">) -> !transform.op<"func.func">
%4 = transform.apply_registered_pass "convert-vector-to-llvm" with options = {"enable-arm-sve" = true} to %3 : (!transform.op<"func.func">) -> !transform.op<"func.func">
transform.yield
}
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
%0 = transform.bufferization.one_shot_bufferize %arg0 {bufferize_function_boundaries = true} : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["func.func"]} in %0 : (!transform.any_op) -> !transform.op<"func.func">
transform.foreach %1 : !transform.op<"func.func"> {
^bb0(%arg1: !transform.op<"func.func">):
transform.include @__tile_and_vectorize failures(propagate) (%arg1) : (!transform.op<"func.func">) -> ()
transform.include @opt failures(propagate) (%arg1) : (!transform.op<"func.func">) -> ()
}
transform.yield
}
}
However, when running mlir-opt
like this:
mlir-opt file.mlir --transform-interpreter --test-transform-dialect-erase-schedule --test-lower-to-llvm
I get an error regarding the control flow dialect, the error comes from the pass convert-cf-to-llvm
that the test-lower-to-llvm
flag uses and the error is:
/tmp/dumps/file.mlir:25:11: error: failed to legalize operation 'cf.br' that was explicitly marked illegal
%14 = linalg.matmul ins(%7, %11 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%13 : tensor<128x128xf32>) -> tensor<128x128xf32>
^
/tmp/dumps/file.mlir:25:11: note: see current operation: "cf.br"(%10, %31)[^bb1] : (index, tensor<128x128xf32>) -> ()
I'm using latest llvm version, exactly commit: 5ae83b0ccd28