Skip to content

[MLIR][NVVM] Support stmatrix intrinsics #148377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1990,32 +1990,43 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}

def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
Arguments<(ins LLVM_PointerShared:$ptr,
Variadic<I32>:$sources,
MMALayoutAttr:$layout)> {
def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> {
let summary = "Matrix shape for ldmatrix and stmatrix";
let parameters = (ins "int":$m, "int":$n);
let assemblyFormat = "`<` struct(params) `>`";
}

def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;

def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
[LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elt_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
Arguments<(ins LLVM_PointerShared: $ptr, Variadic<I32>:$sources, MMALayoutAttr:$layout,
LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$eltType)> {
let summary = "cooperative matrix store";
let description = [{
Collectively store one or more matrices across all threads in a warp to the
___location indicated by the address operand $ptr in shared memory.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
}];

let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int d = getSources().size();
std::string ptx = "stmatrix.sync.aligned";
ptx += ".x" + std::to_string(d);
if (getLayout() == NVVM::MMALayout::col)
ptx += ".trans";
if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};";
if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};";
if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
return ptx;
}
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $eltType);
createIntrinsicCall(builder, intId, operands, operands[0]->getType());
}];
let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
let hasVerifier = 1;
}

Expand Down
21 changes: 16 additions & 5 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,15 +819,26 @@ LogicalResult NVVM::LdMatrixOp::verify() {
}

LogicalResult NVVM::StMatrixOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != NVVM::kSharedMemorySpace)
return emitOpError("expected source pointer in memory space 3");

int numMatrix = getSources().size();
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
return emitOpError("expected num attribute to be 1, 2 or 4");

int m = getShape().getM(), n = getShape().getN();
if (m == 8 && n == 8) {
if (getEltType() != NVVM::LdStMatrixEltType::B16) {
return emitOpError("expected element type to be B16 for 8x8 matrix");
}
} else if (m == 16 && n == 8) {
if (getEltType() != NVVM::LdStMatrixEltType::B8) {
return emitOpError("expected element type to be B8 for 16x8 matrix");
}
if (getLayout() != NVVM::MMALayout::col) {
return emitOpError("expected layout to be col for 16x8 matrix");
}
} else {
return emitOpError("expected shape to be 8x8 or 16x8");
}

return success();
}

Expand Down
36 changes: 36 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,42 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
}
}

/// Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID
getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
NVVM::LdStMatrixShapeAttr shape,
NVVM::LdStMatrixEltType eltType) {
if (shape.getM() == 8 && shape.getN() == 8) {
switch (num) {
case 1:
return (layout == NVVM::MMALayout::row)
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
: llvm::Intrinsic::
nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
case 2:
return (layout == NVVM::MMALayout::row)
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
: llvm::Intrinsic::
nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
case 4:
return (layout == NVVM::MMALayout::row)
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
: llvm::Intrinsic::
nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
}
} else if (shape.getM() == 16 && shape.getN() == 8) {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
case 2:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
case 4:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
}
}
llvm_unreachable("unknown stmatrix kind");
}

/// Return the intrinsic ID associated with st.bulk for the given address type.
static llvm::Intrinsic::ID
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
Expand Down
24 changes: 0 additions & 24 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() {

// -----

// CHECK-LABEL: @stmatrix(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
llvm.return
}

// -----

// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,42 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<
nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1>
llvm.return
}

// -----

llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32
llvm.return
}

// -----

llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}}
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
llvm.return
}

// -----

llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}}
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
llvm.return
}
// -----

llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}}
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
llvm.return
}

// -----

llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}}
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
llvm.return
}
23 changes: 23 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
llvm.return
}

// CHECK-LABEL: @st_matrix
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32, i32, i32
llvm.return
}

// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {
Expand Down