Skip to content

Commit 2b27377

Browse files
authored
[MLIR][NVVM] Support stmatrix intrinsics (#148377)
Add support for the `@llvm.nvvm.stmatrix` intrinsic series. These correspond to PTX stmatrix operations, as documented in the [PTX ISA reference](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix).
1 parent 067f87c commit 2b27377

File tree

6 files changed

+143
-47
lines changed

6 files changed

+143
-47
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1990,32 +1990,43 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
19901990
let hasVerifier = 1;
19911991
}
19921992

1993-
def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
1994-
Arguments<(ins LLVM_PointerShared:$ptr,
1995-
Variadic<I32>:$sources,
1996-
MMALayoutAttr:$layout)> {
1993+
def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> {
1994+
let summary = "Matrix shape for ldmatrix and stmatrix";
1995+
let parameters = (ins "int":$m, "int":$n);
1996+
let assemblyFormat = "`<` struct(params) `>`";
1997+
}
1998+
1999+
def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
2000+
def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
2001+
def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
2002+
def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;
2003+
2004+
def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
2005+
[LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
2006+
LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
2007+
let genSpecializedAttr = 0;
2008+
let cppNamespace = "::mlir::NVVM";
2009+
}
2010+
def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elt_type"> {
2011+
let assemblyFormat = "`<` $value `>`";
2012+
}
2013+
2014+
def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
2015+
Arguments<(ins LLVM_PointerShared: $ptr, Variadic<I32>:$sources, MMALayoutAttr:$layout,
2016+
LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$eltType)> {
19972017
let summary = "cooperative matrix store";
19982018
let description = [{
19992019
Collectively store one or more matrices across all threads in a warp to the
20002020
___location indicated by the address operand $ptr in shared memory.
20012021

20022022
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
20032023
}];
2004-
2005-
let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
2006-
let extraClassDefinition = [{
2007-
std::string $cppClass::getPtx() {
2008-
int d = getSources().size();
2009-
std::string ptx = "stmatrix.sync.aligned";
2010-
ptx += ".x" + std::to_string(d);
2011-
if (getLayout() == NVVM::MMALayout::col)
2012-
ptx += ".trans";
2013-
if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};";
2014-
if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};";
2015-
if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
2016-
return ptx;
2017-
}
2024+
string llvmBuilder = [{
2025+
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
2026+
auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $eltType);
2027+
createIntrinsicCall(builder, intId, operands, operands[0]->getType());
20182028
}];
2029+
let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
20192030
let hasVerifier = 1;
20202031
}
20212032

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -813,15 +813,26 @@ LogicalResult NVVM::LdMatrixOp::verify() {
813813
}
814814

815815
LogicalResult NVVM::StMatrixOp::verify() {
816-
unsigned addressSpace =
817-
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
818-
if (addressSpace != NVVM::kSharedMemorySpace)
819-
return emitOpError("expected source pointer in memory space 3");
820-
821816
int numMatrix = getSources().size();
822817
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
823818
return emitOpError("expected num attribute to be 1, 2 or 4");
824819

820+
int m = getShape().getM(), n = getShape().getN();
821+
if (m == 8 && n == 8) {
822+
if (getEltType() != NVVM::LdStMatrixEltType::B16) {
823+
return emitOpError("expected element type to be B16 for 8x8 matrix");
824+
}
825+
} else if (m == 16 && n == 8) {
826+
if (getEltType() != NVVM::LdStMatrixEltType::B8) {
827+
return emitOpError("expected element type to be B8 for 16x8 matrix");
828+
}
829+
if (getLayout() != NVVM::MMALayout::col) {
830+
return emitOpError("expected layout to be col for 16x8 matrix");
831+
}
832+
} else {
833+
return emitOpError("expected shape to be 8x8 or 16x8");
834+
}
835+
825836
return success();
826837
}
827838

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,42 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
164164
}
165165
}
166166

167+
/// Return the intrinsic ID associated with stmatrix for the given paramters.
168+
static llvm::Intrinsic::ID
169+
getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
170+
NVVM::LdStMatrixShapeAttr shape,
171+
NVVM::LdStMatrixEltType eltType) {
172+
if (shape.getM() == 8 && shape.getN() == 8) {
173+
switch (num) {
174+
case 1:
175+
return (layout == NVVM::MMALayout::row)
176+
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
177+
: llvm::Intrinsic::
178+
nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
179+
case 2:
180+
return (layout == NVVM::MMALayout::row)
181+
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
182+
: llvm::Intrinsic::
183+
nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
184+
case 4:
185+
return (layout == NVVM::MMALayout::row)
186+
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
187+
: llvm::Intrinsic::
188+
nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
189+
}
190+
} else if (shape.getM() == 16 && shape.getN() == 8) {
191+
switch (num) {
192+
case 1:
193+
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
194+
case 2:
195+
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
196+
case 4:
197+
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
198+
}
199+
}
200+
llvm_unreachable("unknown stmatrix kind");
201+
}
202+
167203
/// Return the intrinsic ID associated with st.bulk for the given address type.
168204
static llvm::Intrinsic::ID
169205
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() {
580580

581581
// -----
582582

583-
// CHECK-LABEL: @stmatrix(
584-
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
585-
// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
586-
// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
587-
// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
588-
// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
589-
llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
590-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
591-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
592-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
593-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
594-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
595-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
596-
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
597-
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
598-
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
599-
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
600-
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
601-
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
602-
llvm.return
603-
}
604-
605-
// -----
606-
607583
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
608584
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
609585
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,42 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<
312312
nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1>
313313
llvm.return
314314
}
315+
316+
// -----
317+
318+
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
319+
// expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
320+
nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32
321+
llvm.return
322+
}
323+
324+
// -----
325+
326+
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
327+
// expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}}
328+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
329+
llvm.return
330+
}
331+
332+
// -----
333+
334+
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
335+
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}}
336+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
337+
llvm.return
338+
}
339+
// -----
340+
341+
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
342+
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}}
343+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
344+
llvm.return
345+
}
346+
347+
// -----
348+
349+
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
350+
// expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}}
351+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
352+
llvm.return
353+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
573573
llvm.return
574574
}
575575

576+
// CHECK-LABEL: @st_matrix
577+
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
578+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
579+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
580+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
581+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
582+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
583+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
584+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
585+
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32
586+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
587+
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32
588+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
589+
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32
590+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
591+
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
592+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
593+
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
594+
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
595+
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32, i32, i32
596+
llvm.return
597+
}
598+
576599
// This function has the "kernel" attribute attached and should appear in the
577600
// NVVM annotations after conversion.
578601
llvm.func @kernel_func() attributes {nvvm.kernel} {

0 commit comments

Comments
 (0)