diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 45a8904375e2b..30df3b739e5ca 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,10 +1990,30 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } -def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, - Arguments<(ins LLVM_PointerShared:$ptr, - Variadic:$sources, - MMALayoutAttr:$layout)> { +def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> { + let summary = "Matrix shape for ldmatrix and stmatrix"; + let parameters = (ins "int":$m, "int":$n); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">; +def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">; +def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">; +def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">; + +def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix", + [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8, + LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def LdStMatrixEltTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, + Arguments<(ins LLVM_PointerShared: $ptr, Variadic:$sources, MMALayoutAttr:$layout, + LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$eltType)> { let summary = "cooperative matrix store"; let description = [{ Collectively store one or more matrices across all threads in a warp to the @@ -2001,21 +2021,12 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) }]; - - let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - int d = getSources().size(); - std::string ptx = "stmatrix.sync.aligned"; - ptx += ".x" + std::to_string(d); - if (getLayout() == NVVM::MMALayout::col) - ptx += ".trans"; - if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};"; - if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};"; - if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; - return ptx; - } + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $eltType); + createIntrinsicCall(builder, intId, operands, operands[0]->getType()); }]; + let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 6e29b129e8835..7e46057db1b65 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -819,15 +819,26 @@ LogicalResult NVVM::LdMatrixOp::verify() { } LogicalResult NVVM::StMatrixOp::verify() { - unsigned addressSpace = - llvm::cast(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - int numMatrix = getSources().size(); if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) return emitOpError("expected num attribute to be 1, 2 or 4"); + int m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (getEltType() != NVVM::LdStMatrixEltType::B16) { + return emitOpError("expected element type to be B16 for 8x8 matrix"); + } + } else if (m == 16 && n == 8) { + if (getEltType() != NVVM::LdStMatrixEltType::B8) { + return emitOpError("expected element type to be B8 for 16x8 matrix"); + } + if (getLayout() != NVVM::MMALayout::col) { + return emitOpError("expected layout to be col for 16x8 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8 or 16x8"); + } + return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index b3577c6702389..90462d16c874e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -164,6 +164,42 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, } } +/// Return the intrinsic ID associated with stmatrix for the given paramters. +static llvm::Intrinsic::ID +getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { + if (shape.getM() == 8 && shape.getN() == 8) { + switch (num) { + case 1: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; + } + } else if (shape.getM() == 16 && shape.getN() == 8) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; + } + } + llvm_unreachable("unknown stmatrix kind"); +} + /// Return the intrinsic ID associated with st.bulk for the given address type. static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) { diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 8d720ce62a91b..580b09d70c480 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() { // ----- -// CHECK-LABEL: @stmatrix( -// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, -// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32) -llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) { -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 - llvm.return -} - -// ----- - // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l" diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 8c4f0aafd36a7..85478cc160064 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -312,3 +312,42 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr< nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1> llvm.return } + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..5c2cfa4683104 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: @st_matrix +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} {