From 5aed821f5dab515a59905342ec09b83bc6df336d Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Sat, 12 Jul 2025 23:17:41 +0800 Subject: [PATCH 01/12] [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix --- llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 +++++++++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 29 +++- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 ++++++- llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py | 14 ++ llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma.py | 125 ++++++++++++++++++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 39 +++--- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 43 ++++++ .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 24 ---- mlir/test/Target/LLVMIR/nvvmir.mlir | 23 ++++ 12 files changed, 365 insertions(+), 54 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 0375f29ad8906..aad21fd4cba1c 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -331,6 +331,11 @@ class WMMA_REGS { !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2), !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4), + // stmatrix b8 -> s32 @ m16n8 + !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1), + !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2), + !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4), + ); } @@ -403,6 +408,17 @@ class LDMATRIX_NAME { !subst("llvm.", "int_", intr)); } +class STMATRIX_NAME { + string intr = "llvm.nvvm.stmatrix.sync.aligned" + # "." # Frag.geom + # "." # Frag.frag + # !if(Trans, ".trans", "") + # "." # Frag.ptx_elt_type + ; + string record = !subst(".", "_", + !subst("llvm.", "int_", intr)); +} + // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -443,6 +459,16 @@ class LDMATRIX_OPS Geom, list Frags, list Types> { list ops = !foreach(x, ret, x.gft); } +class STMATRIX_OPS Geom, list Frags, list Types> { + list ret = + !foldl([], Geom, t1, geom, !listconcat(t1, + !foldl([], Frags, t2, frag, !listconcat(t2, + !foldl([], Types, t3, type, !listconcat(t3, + [WMMA_REGS])))))); + // Debugging aid for readable representation of the list above. + list ops = !foreach(x, ret, x.gft); +} + // Creates list of valid combinations of fragments. This is the main list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -537,9 +563,18 @@ class NVVM_MMA_OPS { list ldmatrix_geom_m8n16_ops = LDMATRIX_OPS< ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret; + list stmatrix_b16_ops = STMATRIX_OPS< + ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; + + list stmatrix_b8_ops = STMATRIX_OPS< + ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret; + list all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops, ldmatrix_geom_m16n16_ops, ldmatrix_geom_m8n16_ops); + + list all_stmatrix_ops = !listconcat(stmatrix_b16_ops, + stmatrix_b8_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED { ); } +// Returns true if the fragment is valid for stmatrix ops is supported; +// false otherwise. +class NVVM_STMATRIX_SUPPORTED { + string g = frag.geom; + string t = frag.ptx_elt_type; + + bit ret = !cond( + !and(!eq(g, "m8n8"), !eq(t, "b16")): true, + !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true, + true: false + ); +} + class SHFL_INFO { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in { } } +// STMATRIX +class NVVM_STMATRIX + : Intrinsic<[], + !listconcat([llvm_anyptr_ty], Frag.regs), + [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, + WriteOnly>, NoCapture>], + STMATRIX_NAME.intr>; + +foreach transposed = [0, 1] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in { + if NVVM_STMATRIX_SUPPORTED.ret then { + def STMATRIX_NAME.record + : NVVM_STMATRIX; + } + } +} + // MAPA let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture>] in { def int_nvvm_mapa diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 3d010e04824c5..d94be492b0c02 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( return true; } + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(4); + return true; + } + + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v4i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(16); + return true; + } + case Intrinsic::nvvm_atomic_add_gen_f_cta: case Intrinsic::nvvm_atomic_add_gen_f_sys: case Intrinsic::nvvm_atomic_add_gen_i_cta: diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 93827be5c2811..1e24bf8ab99e1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4597,7 +4597,14 @@ class WMMA_REGINFO !and(!eq(op, "ldmatrix"), !eq(ptx_elt_type, "b8x16.b4x16_p64"), - !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); + !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], + + !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), + !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], + + !and(!eq(op, "stmatrix"), + !eq(ptx_elt_type, "b8"), + !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -4878,6 +4885,40 @@ defset list LDMATRIXs = { } // transposed } // defset +// +// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 +// +class STMATRIX + : WMMA_INSTR.record, [!con((ins ADDR:$dst), Frag.Ins)]>, + Requires { + // Build PatFrag that only matches particular address space. + dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); + PatFrag IntrFrag = PatFrag; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF.ret; + let OutOperandList = (outs); + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + let AsmString = "stmatrix.sync.aligned." + # Frag.geom + # "." # Frag.frag + # !if(Transposed, ".trans", "") + # Space + # "." # Frag.ptx_elt_type + # " [$dst], " # Frag.regstring # ";"; +} + +// Create all stmatrix variants +defset list STMATRIXs = { + foreach transposed = [false, true] in {foreach space = [".shared", ""] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in + if NVVM_STMATRIX_SUPPORTED.ret then + def : STMATRIX, transposed, space>; + } // space + } // transposed +} // defset + // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. @@ -4888,7 +4929,7 @@ class MMA_PAT Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT; multiclass MAPA { diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py new file mode 100644 index 0000000000000..8f502065345c1 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py @@ -0,0 +1,14 @@ +# Check all variants of instructions supported by PTX78 on SM90 +# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll +# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \ +# RUN: --check-prefixes=PTX78STMATRIX-DAG +# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ +# RUN: | FileCheck %t-ptx78-sm_90.ll +# RUN: %if ptxas-12.7 %{ \ +# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ +# RUN: | %ptxas-verify -arch=sm_90 \ +# RUN: %} + +import wmma + +wmma.main() diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py index 6ad0a2a5865c4..5c14a54601ed9 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM100a # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_100a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py index 7d9953484da7d..a77f9adddff9c 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM101a # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_101a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py index 7bddf0b6fbb78..8126e64d6cc85 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM120a # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_120a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 2ee489670e9e4..3888e9b6b1b8d 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -10,6 +10,7 @@ from itertools import product from string import Template + class MMAType: def __init__(self, ptx_type): self.ptx_type = ptx_type @@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type): "m8n16:x1:b8x16.b4x16_p64": 1, "m8n16:x2:b8x16.b4x16_p64": 2, "m8n16:x4:b8x16.b4x16_p64": 4, + # stmatrix + "m8n8:x1:b16": 1, + "m8n8:x2:b16": 2, + "m8n8:x4:b16": 4, + "m16n8:x1:b8": 1, + "m16n8:x2:b8": 2, + "m16n8:x4:b8": 4, }.get( "%s:%s:%s" % (geom, frag, ptx_elt_type), { @@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types): ] +def make_stmatrix_ops(geoms, frags, types): + return [ + MMAFrag(geom, frag, ptx_type) + for (geom, frag, ptx_type) in product(geoms, frags, types) + ] + + def get_wmma_ops(): return ( make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) @@ -315,6 +330,12 @@ def get_ldmatrix_ops(): ) +def get_stmatrix_ops(): + return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops( + ["m16n8"], ["x1", "x2", "x4"], ["b8"] + ) + + def is_wmma_geom_supported(geom): # geometries for FP and ints. if geom in ["m8n32k16", "m32n8k16"]: @@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom): assert False # Unexpected geometry. +def is_stmatrix_geom_supported(geom): + if geom in ["m8n8"]: + return ptx_version >= 78 and gpu_arch >= 90 + elif geom in ["m16n8"]: + return ptx_version >= 86 and gpu_arch >= 100 and aa + assert False # Unexpected geometry. + + def is_ldmatrix_trans_supported(geom, trans): if geom in ["m8n8"]: return True @@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans): return trans == "" assert False # Unexpected geometry. + +def is_stmatrix_trans_supported(geom, trans): + if geom in ["m8n8"]: + return True + elif geom in ["m16n8"]: + return trans == ".trans" + assert False # Unexpected geometry. + + def is_type_supported(ptx_type): if ptx_type in ["s8", "u8", "s32"]: return ptx_version >= 63 and gpu_arch >= 72 @@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans): return frag.frag in ["x1", "x2", "x4"] +def is_stmatrix_variant_supported(frag, trans): + if not ( + is_type_supported(frag.mma_type.ptx_type) + and is_stmatrix_geom_supported(frag.geom) + and is_stmatrix_trans_supported(frag.geom, trans) + ): + return False + return frag.frag in ["x1", "x2", "x4"] + + def make_wmma_slice_ty(frag): return [frag.mma_type.llvm_type] * frag.nregs @@ -716,6 +764,61 @@ def gen_ldmatrix_tests(): return generated_items +def gen_stmatrix_tests(): + stmatrix_template = """ +declare void @${intrinsic}(i8 ${as}* %dst, ${args}); + +; CHECK-LABEL: .func {{.*}}test_${function}( +define void @test_${function}(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}] +; CHECK: {${check_args}} + call void @${intrinsic}(i8${as}* %dst, ${args}); + ret void +} + +; CHECK-LABEL: .func{{.*}}test_${function}_o( +define void @test_${function}_o(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128], +; CHECK: {${check_args}} + %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128; + call void @${intrinsic}(i8 ${as}* %dst1, ${args}); + ret void +} +""" + intrinsic_template = ( + "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" + ) + instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" + ) + generated_items = [] + + for frag, space, trans in product(get_stmatrix_ops(), + ["", ".shared"], + ["", ".trans"], + ): + if not is_stmatrix_variant_supported(frag, trans): + continue + + params = { + "frag": frag.frag, + "space": space,"trans": trans, + "itype": frag.mma_type.ptx_type, + "pspace": get_pspace(space), + "as": "addrspace(%d)" % get_aspace(space), + "geom": frag.geom, + } + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["args"] = make_wmma_slice_args(frag) + test_params["check_args"] = check_pattern(frag) + + print(Template(stmatrix_template).substitute(test_params)) + generated_items.append((test_params["intrinsic"], test_params["instruction"])) + + return generated_items def mma_signature(op): if op.a.mma_type.ptx_type == "f16": @@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items): ; NOALTFLOAT-NOT: .{{bf16|tf32}} ; NODOUBLE-NOT: .f64 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned +; NOSTMATRIX-NOT: stmatrix.sync.aligned ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items): ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 + +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 + ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 @@ -1039,6 +1163,7 @@ def gen_tests(): items = gen_wmma_load_tests() items += gen_wmma_store_tests() items += gen_ldmatrix_tests() + items += gen_stmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() gen_check_unsupported_ops(items) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 45a8904375e2b..8de5932aaf2e3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,10 +1990,22 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } -def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, - Arguments<(ins LLVM_PointerShared:$ptr, - Variadic:$sources, - MMALayoutAttr:$layout)> { +def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">; +def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">; +def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">; +def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">; + +def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix", + [LdStMatrixShapeM8N8, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def LdStMatrixShapeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, + Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, LdStMatrixShapeAttr:$shape)> { let summary = "cooperative matrix store"; let description = [{ Collectively store one or more matrices across all threads in a warp to the @@ -2001,21 +2013,12 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) }]; - - let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - int d = getSources().size(); - std::string ptx = "stmatrix.sync.aligned"; - ptx += ".x" + std::to_string(d); - if (getLayout() == NVVM::MMALayout::col) - ptx += ".trans"; - if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};"; - if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};"; - if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; - return ptx; - } + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape); + createIntrinsicCall(builder, intId, operands, operands[0]->getType()); }]; + let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; let hasVerifier = 1; } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index eecca64c4bf81..d03242f402ec5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -163,6 +163,49 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, } } +/// Return the intrinsic ID associated with stmatrix for the given paramters. +static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, + int32_t num, + NVVM::LdStMatrixShape shape) { + if (shape == NVVM::LdStMatrixShape::M8N8) { + if (layout == NVVM::MMALayout::row) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16; + default: + llvm_unreachable("unsupported number of matrix"); + } + } else { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; + default: + llvm_unreachable("unsupported number of matrix"); + } + } + } else { + // for 16x8 matrices, .trans is mandatory + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; + default: + llvm_unreachable("unsupported number of matrix"); + } + } +} + /// Return the intrinsic ID associated with st.bulk for the given address type. static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) { diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 8d720ce62a91b..580b09d70c480 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() { // ----- -// CHECK-LABEL: @stmatrix( -// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, -// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32) -llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) { -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 - llvm.return -} - -// ----- - // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l" diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..3be35faf091e2 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: @st_matrix +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} { From 653ae854e5d88f62b7e2e2353f8bb385251294eb Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 14 Jul 2025 10:40:47 +0800 Subject: [PATCH 02/12] Remove changes on NVPTX --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 29 +---- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 +------ llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py | 14 --- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma.py | 125 ------------------- 7 files changed, 12 insertions(+), 213 deletions(-) delete mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index d94be492b0c02..3d010e04824c5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3952,10 +3952,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: { + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -3978,30 +3975,6 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( return true; } - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: { - Info.opc = ISD::INTRINSIC_VOID; - Info.memVT = MVT::i32; - Info.ptrVal = I.getArgOperand(0); - Info.offset = 0; - Info.flags = MachineMemOperand::MOStore; - Info.align = Align(4); - return true; - } - - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: { - Info.opc = ISD::INTRINSIC_VOID; - Info.memVT = MVT::v4i32; - Info.ptrVal = I.getArgOperand(0); - Info.offset = 0; - Info.flags = MachineMemOperand::MOStore; - Info.align = Align(16); - return true; - } - case Intrinsic::nvvm_atomic_add_gen_f_cta: case Intrinsic::nvvm_atomic_add_gen_f_sys: case Intrinsic::nvvm_atomic_add_gen_i_cta: diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 1e24bf8ab99e1..93827be5c2811 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4597,14 +4597,7 @@ class WMMA_REGINFO !and(!eq(op, "ldmatrix"), !eq(ptx_elt_type, "b8x16.b4x16_p64"), - !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], - - !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), - !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], - - !and(!eq(op, "stmatrix"), - !eq(ptx_elt_type, "b8"), - !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); + !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -4885,40 +4878,6 @@ defset list LDMATRIXs = { } // transposed } // defset -// -// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 -// -class STMATRIX - : WMMA_INSTR.record, [!con((ins ADDR:$dst), Frag.Ins)]>, - Requires { - // Build PatFrag that only matches particular address space. - dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); - PatFrag IntrFrag = PatFrag; - // Build AS-constrained pattern. - let IntrinsicPattern = BuildPatternPF.ret; - let OutOperandList = (outs); - let InOperandList = !con(Args, (ins MmaCode:$ptx)); - let AsmString = "stmatrix.sync.aligned." - # Frag.geom - # "." # Frag.frag - # !if(Transposed, ".trans", "") - # Space - # "." # Frag.ptx_elt_type - # " [$dst], " # Frag.regstring # ";"; -} - -// Create all stmatrix variants -defset list STMATRIXs = { - foreach transposed = [false, true] in {foreach space = [".shared", ""] in { - foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in - if NVVM_STMATRIX_SUPPORTED.ret then - def : STMATRIX, transposed, space>; - } // space - } // transposed -} // defset - // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. @@ -4929,7 +4888,7 @@ class MMA_PAT Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in def : MMA_PAT; multiclass MAPA { diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py deleted file mode 100644 index 8f502065345c1..0000000000000 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py +++ /dev/null @@ -1,14 +0,0 @@ -# Check all variants of instructions supported by PTX78 on SM90 -# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll -# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \ -# RUN: --check-prefixes=PTX78STMATRIX-DAG -# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ -# RUN: | FileCheck %t-ptx78-sm_90.ll -# RUN: %if ptxas-12.7 %{ \ -# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ -# RUN: | %ptxas-verify -arch=sm_90 \ -# RUN: %} - -import wmma - -wmma.main() diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py index 5c14a54601ed9..6ad0a2a5865c4 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py @@ -1,7 +1,9 @@ # Check all variants of instructions supported by PTX86 on SM100a # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ +# RUN: --check-prefixes=PTX86LDMATRIX-DAG # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_100a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py index a77f9adddff9c..7d9953484da7d 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py @@ -1,7 +1,9 @@ # Check all variants of instructions supported by PTX86 on SM101a # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ +# RUN: --check-prefixes=PTX86LDMATRIX-DAG # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_101a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py index 8126e64d6cc85..7bddf0b6fbb78 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py @@ -1,7 +1,9 @@ # Check all variants of instructions supported by PTX86 on SM120a # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ +# RUN: --check-prefixes=PTX86LDMATRIX-DAG # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_120a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 3888e9b6b1b8d..2ee489670e9e4 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -10,7 +10,6 @@ from itertools import product from string import Template - class MMAType: def __init__(self, ptx_type): self.ptx_type = ptx_type @@ -177,13 +176,6 @@ def __init__(self, geom, frag, ptx_elt_type): "m8n16:x1:b8x16.b4x16_p64": 1, "m8n16:x2:b8x16.b4x16_p64": 2, "m8n16:x4:b8x16.b4x16_p64": 4, - # stmatrix - "m8n8:x1:b16": 1, - "m8n8:x2:b16": 2, - "m8n8:x4:b16": 4, - "m16n8:x1:b8": 1, - "m16n8:x2:b8": 2, - "m16n8:x4:b8": 4, }.get( "%s:%s:%s" % (geom, frag, ptx_elt_type), { @@ -249,13 +241,6 @@ def make_ldmatrix_ops(geoms, frags, types): ] -def make_stmatrix_ops(geoms, frags, types): - return [ - MMAFrag(geom, frag, ptx_type) - for (geom, frag, ptx_type) in product(geoms, frags, types) - ] - - def get_wmma_ops(): return ( make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) @@ -330,12 +315,6 @@ def get_ldmatrix_ops(): ) -def get_stmatrix_ops(): - return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops( - ["m16n8"], ["x1", "x2", "x4"], ["b8"] - ) - - def is_wmma_geom_supported(geom): # geometries for FP and ints. if geom in ["m8n32k16", "m32n8k16"]: @@ -381,14 +360,6 @@ def is_ldmatrix_geom_supported(geom): assert False # Unexpected geometry. -def is_stmatrix_geom_supported(geom): - if geom in ["m8n8"]: - return ptx_version >= 78 and gpu_arch >= 90 - elif geom in ["m16n8"]: - return ptx_version >= 86 and gpu_arch >= 100 and aa - assert False # Unexpected geometry. - - def is_ldmatrix_trans_supported(geom, trans): if geom in ["m8n8"]: return True @@ -398,15 +369,6 @@ def is_ldmatrix_trans_supported(geom, trans): return trans == "" assert False # Unexpected geometry. - -def is_stmatrix_trans_supported(geom, trans): - if geom in ["m8n8"]: - return True - elif geom in ["m16n8"]: - return trans == ".trans" - assert False # Unexpected geometry. - - def is_type_supported(ptx_type): if ptx_type in ["s8", "u8", "s32"]: return ptx_version >= 63 and gpu_arch >= 72 @@ -501,16 +463,6 @@ def is_ldmatrix_variant_supported(frag, trans): return frag.frag in ["x1", "x2", "x4"] -def is_stmatrix_variant_supported(frag, trans): - if not ( - is_type_supported(frag.mma_type.ptx_type) - and is_stmatrix_geom_supported(frag.geom) - and is_stmatrix_trans_supported(frag.geom, trans) - ): - return False - return frag.frag in ["x1", "x2", "x4"] - - def make_wmma_slice_ty(frag): return [frag.mma_type.llvm_type] * frag.nregs @@ -764,61 +716,6 @@ def gen_ldmatrix_tests(): return generated_items -def gen_stmatrix_tests(): - stmatrix_template = """ -declare void @${intrinsic}(i8 ${as}* %dst, ${args}); - -; CHECK-LABEL: .func {{.*}}test_${function}( -define void @test_${function}(i8 ${as}* %dst, ${args}) { -; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}] -; CHECK: {${check_args}} - call void @${intrinsic}(i8${as}* %dst, ${args}); - ret void -} - -; CHECK-LABEL: .func{{.*}}test_${function}_o( -define void @test_${function}_o(i8 ${as}* %dst, ${args}) { -; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128], -; CHECK: {${check_args}} - %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128; - call void @${intrinsic}(i8 ${as}* %dst1, ${args}); - ret void -} -""" - intrinsic_template = ( - "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" - ) - instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" - ) - generated_items = [] - - for frag, space, trans in product(get_stmatrix_ops(), - ["", ".shared"], - ["", ".trans"], - ): - if not is_stmatrix_variant_supported(frag, trans): - continue - - params = { - "frag": frag.frag, - "space": space,"trans": trans, - "itype": frag.mma_type.ptx_type, - "pspace": get_pspace(space), - "as": "addrspace(%d)" % get_aspace(space), - "geom": frag.geom, - } - - test_params = params - test_params["intrinsic"] = Template(intrinsic_template).substitute(params) - test_params["function"] = test_params["intrinsic"].replace(".", "_") - test_params["instruction"] = Template(instruction_template).substitute(params) - test_params["args"] = make_wmma_slice_args(frag) - test_params["check_args"] = check_pattern(frag) - - print(Template(stmatrix_template).substitute(test_params)) - generated_items.append((test_params["intrinsic"], test_params["instruction"])) - - return generated_items def mma_signature(op): if op.a.mma_type.ptx_type == "f16": @@ -996,7 +893,6 @@ def gen_check_unsupported_ops(items): ; NOALTFLOAT-NOT: .{{bf16|tf32}} ; NODOUBLE-NOT: .f64 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned -; NOSTMATRIX-NOT: stmatrix.sync.aligned ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -1098,26 +994,6 @@ def gen_check_unsupported_ops(items): ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 - -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 - ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 @@ -1163,7 +1039,6 @@ def gen_tests(): items = gen_wmma_load_tests() items += gen_wmma_store_tests() items += gen_ldmatrix_tests() - items += gen_stmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() gen_check_unsupported_ops(items) From 56db71e89825f6d727f14e7bdced49019fb63380 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 16 Jul 2025 10:20:09 +0800 Subject: [PATCH 03/12] Move the changes of IntrinsicsNVVM.td to another PR --- llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 -------------------------- 1 file changed, 65 deletions(-) diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index aad21fd4cba1c..0375f29ad8906 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -331,11 +331,6 @@ class WMMA_REGS { !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2), !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4), - // stmatrix b8 -> s32 @ m16n8 - !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1), - !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2), - !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4), - ); } @@ -408,17 +403,6 @@ class LDMATRIX_NAME { !subst("llvm.", "int_", intr)); } -class STMATRIX_NAME { - string intr = "llvm.nvvm.stmatrix.sync.aligned" - # "." # Frag.geom - # "." # Frag.frag - # !if(Trans, ".trans", "") - # "." # Frag.ptx_elt_type - ; - string record = !subst(".", "_", - !subst("llvm.", "int_", intr)); -} - // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -459,16 +443,6 @@ class LDMATRIX_OPS Geom, list Frags, list Types> { list ops = !foreach(x, ret, x.gft); } -class STMATRIX_OPS Geom, list Frags, list Types> { - list ret = - !foldl([], Geom, t1, geom, !listconcat(t1, - !foldl([], Frags, t2, frag, !listconcat(t2, - !foldl([], Types, t3, type, !listconcat(t3, - [WMMA_REGS])))))); - // Debugging aid for readable representation of the list above. - list ops = !foreach(x, ret, x.gft); -} - // Creates list of valid combinations of fragments. This is the main list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -563,18 +537,9 @@ class NVVM_MMA_OPS { list ldmatrix_geom_m8n16_ops = LDMATRIX_OPS< ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret; - list stmatrix_b16_ops = STMATRIX_OPS< - ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; - - list stmatrix_b8_ops = STMATRIX_OPS< - ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret; - list all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops, ldmatrix_geom_m16n16_ops, ldmatrix_geom_m8n16_ops); - - list all_stmatrix_ops = !listconcat(stmatrix_b16_ops, - stmatrix_b8_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -715,19 +680,6 @@ class NVVM_LDMATRIX_SUPPORTED { ); } -// Returns true if the fragment is valid for stmatrix ops is supported; -// false otherwise. -class NVVM_STMATRIX_SUPPORTED { - string g = frag.geom; - string t = frag.ptx_elt_type; - - bit ret = !cond( - !and(!eq(g, "m8n8"), !eq(t, "b16")): true, - !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true, - true: false - ); -} - class SHFL_INFO { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -2017,23 +1969,6 @@ foreach transposed = [0, 1] in { } } -// STMATRIX -class NVVM_STMATRIX - : Intrinsic<[], - !listconcat([llvm_anyptr_ty], Frag.regs), - [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, - WriteOnly>, NoCapture>], - STMATRIX_NAME.intr>; - -foreach transposed = [0, 1] in { - foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in { - if NVVM_STMATRIX_SUPPORTED.ret then { - def STMATRIX_NAME.record - : NVVM_STMATRIX; - } - } -} - // MAPA let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture>] in { def int_nvvm_mapa From e5a277a3ad4735cc67ecd82ccf4597dbddde355f Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 16 Jul 2025 14:22:56 +0800 Subject: [PATCH 04/12] Modify the arguments of the stmatrix op --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 26 ++++--- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 73 ++++++++++--------- mlir/test/Target/LLVMIR/nvvmir.mlir | 18 ++--- 3 files changed, 63 insertions(+), 54 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8de5932aaf2e3..af9c29274dc1d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,22 +1990,30 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } -def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">; -def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">; -def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">; -def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">; +def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> { + let summary = "Matrix shape for ldmatrix and stmatrix"; + let parameters = (ins "int":$m, "int":$n); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">; +def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">; +def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">; +def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">; -def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix", - [LdStMatrixShapeM8N8, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> { +def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix", + [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8, + LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } -def LdStMatrixShapeAttr : EnumAttr { +def LdStMatrixEltTypeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, - Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, LdStMatrixShapeAttr:$shape)> { + Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, + LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$elttype)> { let summary = "cooperative matrix store"; let description = [{ Collectively store one or more matrices across all threads in a warp to the @@ -2015,7 +2023,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, }]; string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape); + auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $elttype); createIntrinsicCall(builder, intId, operands, operands[0]->getType()); }]; let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index d03242f402ec5..3491f658529c2 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -164,46 +164,47 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, } /// Return the intrinsic ID associated with stmatrix for the given paramters. -static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, - int32_t num, - NVVM::LdStMatrixShape shape) { - if (shape == NVVM::LdStMatrixShape::M8N8) { - if (layout == NVVM::MMALayout::row) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16; - default: - llvm_unreachable("unsupported number of matrix"); - } - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; - default: - llvm_unreachable("unsupported number of matrix"); +static llvm::Intrinsic::ID +getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { + if (shape.getM() == 8 && shape.getN() == 8) { + if (eltType == NVVM::LdStMatrixEltType::B16) { + if (layout == NVVM::MMALayout::row) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16; + } + } else { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; + } } } - } else { - // for 16x8 matrices, .trans is mandatory - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; - default: - llvm_unreachable("unsupported number of matrix"); + } else if (shape.getM() == 16 && shape.getN() == 8) { + if (eltType == NVVM::LdStMatrixEltType::B8) { + if (layout == NVVM::MMALayout::col) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; + } + } } } + llvm_unreachable("unknown stmatrix kind"); } /// Return the intrinsic ID associated with st.bulk for the given address type. diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 3be35faf091e2..ad3e67b039d8f 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -576,23 +576,23 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK-LABEL: @st_matrix llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 llvm.return } From 9e2a6d2dcd280f3cd57c65ac90df239e585fc983 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 21 Jul 2025 15:20:06 +0800 Subject: [PATCH 05/12] Add verifier checks --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 16 ++++++++++ mlir/test/Dialect/LLVMIR/invalid.mlir | 36 ++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 6e29b129e8835..e46c06bb9bd9c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -828,6 +828,22 @@ LogicalResult NVVM::StMatrixOp::verify() { if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) return emitOpError("expected num attribute to be 1, 2 or 4"); + int m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (getElttype() != NVVM::LdStMatrixEltType::B16) { + return emitOpError("expected element type to be B16 for 8x8 matrix"); + } + } else if (m == 16 && n == 8) { + if (getElttype() != NVVM::LdStMatrixEltType::B8) { + return emitOpError("expected element type to be B8 for 16x8 matrix"); + } + if (getLayout() != NVVM::MMALayout::col) { + return emitOpError("expected layout to be col for 16x8 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8 or 16x8"); + } + return success(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 7f2c8c72e5cf9..f4643b87cf107 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1144,8 +1144,44 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.return } +llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32 + llvm.return +} + // ----- +llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + llvm.return +} +// ----- + +llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- llvm.func @caller() { // expected-error @below {{expected function call to produce a value}} llvm.call @callee() : () -> () From 31f80368c4d8f0df8815b142cf6527bdc80fbb67 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 21 Jul 2025 15:37:49 +0800 Subject: [PATCH 06/12] Fix the typo --- mlir/test/Dialect/LLVMIR/invalid.mlir | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index f4643b87cf107..c73c3ac115642 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1144,7 +1144,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.return } -llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32 llvm.return @@ -1152,7 +1152,7 @@ llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: // ----- -llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return @@ -1160,14 +1160,14 @@ llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: // ----- -llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return } // ----- -llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return @@ -1175,7 +1175,7 @@ llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: // ----- -llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return From 73aefcf0923fdf1c26b182ac1b96f52bfd2ba549 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 21 Jul 2025 16:12:32 +0800 Subject: [PATCH 07/12] Follow the convention of eltType in the naming --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 4 ++-- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 4 ++-- mlir/test/Dialect/LLVMIR/invalid.mlir | 10 +++++----- mlir/test/Target/LLVMIR/nvvmir.mlir | 18 +++++++++--------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index af9c29274dc1d..cda96049996f1 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2013,7 +2013,7 @@ def LdStMatrixEltTypeAttr : EnumAttr, Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, - LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$elttype)> { + LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$eltType)> { let summary = "cooperative matrix store"; let description = [{ Collectively store one or more matrices across all threads in a warp to the @@ -2023,7 +2023,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, }]; string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $elttype); + auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $eltType); createIntrinsicCall(builder, intId, operands, operands[0]->getType()); }]; let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e46c06bb9bd9c..b7be91bf7979a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -830,11 +830,11 @@ LogicalResult NVVM::StMatrixOp::verify() { int m = getShape().getM(), n = getShape().getN(); if (m == 8 && n == 8) { - if (getElttype() != NVVM::LdStMatrixEltType::B16) { + if (getEltType() != NVVM::LdStMatrixEltType::B16) { return emitOpError("expected element type to be B16 for 8x8 matrix"); } } else if (m == 16 && n == 8) { - if (getElttype() != NVVM::LdStMatrixEltType::B8) { + if (getEltType() != NVVM::LdStMatrixEltType::B8) { return emitOpError("expected element type to be B8 for 16x8 matrix"); } if (getLayout() != NVVM::MMALayout::col) { diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index c73c3ac115642..7355f27b2ace1 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1146,7 +1146,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} - nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32 llvm.return } @@ -1154,7 +1154,7 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32 llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return } @@ -1162,14 +1162,14 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32 llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return } // ----- llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return } @@ -1177,7 +1177,7 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32 llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index ad3e67b039d8f..8946321f64828 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -576,23 +576,23 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK-LABEL: @st_matrix llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 llvm.return } From f964ba8a82e57683091dd347e4c4e90c33f65823 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 10:16:59 +0800 Subject: [PATCH 08/12] Change "ld_st_matrix_elttype" to "ld_st_matrix_elt_type" --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 2 +- mlir/test/Dialect/LLVMIR/invalid.mlir | 10 +++++----- mlir/test/Target/LLVMIR/nvvmir.mlir | 18 +++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index cda96049996f1..78d90bbac6124 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2007,7 +2007,7 @@ def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmat let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } -def LdStMatrixEltTypeAttr : EnumAttr { +def LdStMatrixEltTypeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 7355f27b2ace1..a391e16e49b46 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1146,7 +1146,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} - nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32 llvm.return } @@ -1154,7 +1154,7 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32 llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 llvm.return } @@ -1162,14 +1162,14 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32 llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 llvm.return } // ----- llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 llvm.return } @@ -1177,7 +1177,7 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32 llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 8946321f64828..5c2cfa4683104 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -576,23 +576,23 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK-LABEL: @st_matrix llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32, i32 llvm.return } From 461c7d9900dd54f9cca046c17eb392a3e25bb6b5 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 10:23:24 +0800 Subject: [PATCH 09/12] Simplifier the structure of `getStMatrixIntrinsicId` --- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 54 ++++++++----------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 214378883c473..90462d16c874e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -170,39 +170,31 @@ getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType) { if (shape.getM() == 8 && shape.getN() == 8) { - if (eltType == NVVM::LdStMatrixEltType::B16) { - if (layout == NVVM::MMALayout::row) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16; - } - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; - } - } + switch (num) { + case 1: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; } } else if (shape.getM() == 16 && shape.getN() == 8) { - if (eltType == NVVM::LdStMatrixEltType::B8) { - if (layout == NVVM::MMALayout::col) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; - } - } + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; } } llvm_unreachable("unknown stmatrix kind"); From 6a69cba933069d155d9cf93f5f539856ebd6e511 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 10:28:17 +0800 Subject: [PATCH 10/12] Move the negative test to nvvmir-invalid.mlir --- mlir/test/Dialect/LLVMIR/invalid.mlir | 37 ------------------- mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 39 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index a391e16e49b46..6c43b1a7611c9 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1144,43 +1144,6 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.return } -llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { - // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} - nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32 - llvm.return -} - -// ----- - -llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { - // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 - llvm.return -} - -// ----- - -llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { - // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 - llvm.return -} -// ----- - -llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { - // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 - llvm.return -} - -// ----- - -llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { - // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 - llvm.return -} - // ----- llvm.func @caller() { // expected-error @below {{expected function call to produce a value}} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 8c4f0aafd36a7..85478cc160064 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -312,3 +312,42 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr< nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1> llvm.return } + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32, i32, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : !llvm.ptr<3>, i32 + llvm.return +} From 39f881e374e40a27868eec521f95981a657725bb Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 10:35:27 +0800 Subject: [PATCH 11/12] Keep the pointer as shared instead of any --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 2 +- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 78d90bbac6124..30df3b739e5ca 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2012,7 +2012,7 @@ def LdStMatrixEltTypeAttr : EnumAttr, - Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, + Arguments<(ins LLVM_PointerShared: $ptr, Variadic:$sources, MMALayoutAttr:$layout, LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$eltType)> { let summary = "cooperative matrix store"; let description = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index b7be91bf7979a..7e46057db1b65 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -819,11 +819,6 @@ LogicalResult NVVM::LdMatrixOp::verify() { } LogicalResult NVVM::StMatrixOp::verify() { - unsigned addressSpace = - llvm::cast(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - int numMatrix = getSources().size(); if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) return emitOpError("expected num attribute to be 1, 2 or 4"); From a54cedc2f505a62b91a73fbbf43c6a32a185508a Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Thu, 31 Jul 2025 14:11:20 +0800 Subject: [PATCH 12/12] Undo unintentional changes --- mlir/test/Dialect/LLVMIR/invalid.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 6c43b1a7611c9..7f2c8c72e5cf9 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1145,6 +1145,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { } // ----- + llvm.func @caller() { // expected-error @below {{expected function call to produce a value}} llvm.call @callee() : () -> ()