Skip to content

Commit 8ae4dee

Browse files
[NVPTX] Lower stmatrix intrinsics to PTX (#148561)
Lower stmatrix intrinsics defined in #148377 to PTX. See [PTX Doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix). --------- Co-authored-by: peterbell10 <[email protected]>
1 parent 1b8defd commit 8ae4dee

File tree

8 files changed

+284
-12
lines changed

8 files changed

+284
-12
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
331331
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
332332
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
333333

334+
// stmatrix b8 -> s32 @ m16n8
335+
!eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
336+
!eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
337+
!eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
338+
334339
);
335340
}
336341

@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
403408
!subst("llvm.", "int_", intr));
404409
}
405410

411+
class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
412+
string intr = "llvm.nvvm.stmatrix.sync.aligned"
413+
# "." # Frag.geom
414+
# "." # Frag.frag
415+
# !if(Trans, ".trans", "")
416+
# "." # Frag.ptx_elt_type
417+
;
418+
string record = !subst(".", "_",
419+
!subst("llvm.", "int_", intr));
420+
}
421+
406422
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
407423
// Geom: list of supported geometries.
408424
// TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
443459
list<string> ops = !foreach(x, ret, x.gft);
444460
}
445461

462+
class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
463+
list<WMMA_REGS> ret =
464+
!foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
465+
!foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
466+
!foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
467+
[WMMA_REGS<geom, frag, type>]))))));
468+
// Debugging aid for readable representation of the list above.
469+
list<string> ops = !foreach(x, ret, x.gft);
470+
}
471+
446472
// Creates list of valid combinations of fragments. This is the main list that
447473
// drives generation of corresponding intrinsics and instructions.
448474
class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
537563
list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
538564
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
539565

566+
list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
567+
["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
568+
569+
list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
570+
["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
571+
540572
list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
541573
ldmatrix_geom_m16n16_ops,
542574
ldmatrix_geom_m8n16_ops);
575+
576+
list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
577+
stmatrix_b8_ops);
543578
}
544579

545580
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
680715
);
681716
}
682717

718+
// Returns true if the fragment is valid for stmatrix ops is supported;
719+
// false otherwise.
720+
class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
721+
string g = frag.geom;
722+
string t = frag.ptx_elt_type;
723+
724+
bit ret = !cond(
725+
!and(!eq(g, "m8n8"), !eq(t, "b16")): true,
726+
!and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
727+
true: false
728+
);
729+
}
730+
683731
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
684732
string Suffix = !if(sync, "sync_", "")
685733
# mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
19692017
}
19702018
}
19712019

2020+
// STMATRIX
2021+
class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
2022+
: Intrinsic<[],
2023+
!listconcat([llvm_anyptr_ty], Frag.regs),
2024+
[IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
2025+
WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
2026+
STMATRIX_NAME<Frag, Transposed>.intr>;
2027+
2028+
foreach transposed = [0, 1] in {
2029+
foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
2030+
if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
2031+
def STMATRIX_NAME<frag, transposed>.record
2032+
: NVVM_STMATRIX<frag, transposed>;
2033+
}
2034+
}
2035+
}
2036+
19722037
// MAPA
19732038
let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
19742039
def int_nvvm_mapa

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4006,7 +4006,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
40064006
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
40074007
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
40084008
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4009-
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
4009+
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4010+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4011+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4012+
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
40104013
Info.opc = ISD::INTRINSIC_VOID;
40114014
Info.memVT = MVT::v2i32;
40124015
Info.ptrVal = I.getArgOperand(0);
@@ -4029,6 +4032,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
40294032
return true;
40304033
}
40314034

4035+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4036+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4037+
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4038+
Info.opc = ISD::INTRINSIC_VOID;
4039+
Info.memVT = MVT::i32;
4040+
Info.ptrVal = I.getArgOperand(0);
4041+
Info.offset = 0;
4042+
Info.flags = MachineMemOperand::MOStore;
4043+
Info.align = Align(4);
4044+
return true;
4045+
}
4046+
4047+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4048+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4049+
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4050+
Info.opc = ISD::INTRINSIC_VOID;
4051+
Info.memVT = MVT::v4i32;
4052+
Info.ptrVal = I.getArgOperand(0);
4053+
Info.offset = 0;
4054+
Info.flags = MachineMemOperand::MOStore;
4055+
Info.align = Align(16);
4056+
return true;
4057+
}
4058+
40324059
case Intrinsic::nvvm_atomic_add_gen_f_cta:
40334060
case Intrinsic::nvvm_atomic_add_gen_f_sys:
40344061
case Intrinsic::nvvm_atomic_add_gen_i_cta:

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4758,7 +4758,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
47584758

47594759
!and(!eq(op, "ldmatrix"),
47604760
!eq(ptx_elt_type, "b8x16.b4x16_p64"),
4761-
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
4761+
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
4762+
4763+
!and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
4764+
!eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
4765+
4766+
!and(!eq(op, "stmatrix"),
4767+
!eq(ptx_elt_type, "b8"),
4768+
!eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
47624769

47634770
// template DAGs for instruction inputs/output.
47644771
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -5039,6 +5046,42 @@ defset list<WMMA_INSTR> LDMATRIXs = {
50395046
} // transposed
50405047
} // defset
50415048

5049+
//
5050+
// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
5051+
//
5052+
class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
5053+
: WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
5054+
Requires<Frag.Predicates> {
5055+
// Build PatFrag that only matches particular address space.
5056+
dag PFOperands = !con((ops node:$dst),
5057+
!dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
5058+
PatFrag IntrFrag = PatFrag<PFOperands,
5059+
!foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
5060+
!cond(!eq(Space, ".shared"): AS_match.shared,
5061+
true: AS_match.generic)>;
5062+
// Build AS-constrained pattern.
5063+
let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
5064+
let OutOperandList = (outs);
5065+
let InOperandList = !con(Args, (ins MmaCode:$ptx));
5066+
let AsmString = "stmatrix.sync.aligned."
5067+
# Frag.geom
5068+
# "." # Frag.frag
5069+
# !if(Transposed, ".trans", "")
5070+
# Space
5071+
# "." # Frag.ptx_elt_type
5072+
# " [$dst], " # Frag.regstring # ";";
5073+
}
5074+
5075+
// Create all stmatrix variants
5076+
defset list<WMMA_INSTR> STMATRIXs = {
5077+
foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
5078+
foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
5079+
if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
5080+
def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
5081+
} // space
5082+
} // transposed
5083+
} // defset
5084+
50425085
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
50435086
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
50445087
// the instruction record.
@@ -5049,7 +5092,7 @@ class MMA_PAT<WMMA_INSTR wi>
50495092
Requires<wi.Predicates>;
50505093

50515094
// Build intrinsic->instruction patterns for all MMA instructions.
5052-
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
5095+
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
50535096
def : MMA_PAT<mma>;
50545097

50555098
multiclass MAPA<string suffix, Intrinsic Intr> {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Check all variants of instructions supported by PTX78 on SM90
2+
# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
3+
# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
4+
# RUN: --check-prefixes=PTX78STMATRIX-DAG
5+
# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
6+
# RUN: | FileCheck %t-ptx78-sm_90.ll
7+
# RUN: %if ptxas-12.7 %{ \
8+
# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
9+
# RUN: | %ptxas-verify -arch=sm_90 \
10+
# RUN: %}
11+
12+
import wmma
13+
14+
wmma.main()

llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Check all variants of instructions supported by PTX86 on SM100a
22
# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
33
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
4-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5-
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
6-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
75
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
86
# RUN: | FileCheck %t-ptx86-sm_100a.ll
97
# RUN: %if ptxas-12.7 %{ \

llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Check all variants of instructions supported by PTX86 on SM101a
22
# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
33
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
4-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5-
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
6-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
75
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
86
# RUN: | FileCheck %t-ptx86-sm_101a.ll
97
# RUN: %if ptxas-12.7 %{ \

llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Check all variants of instructions supported by PTX86 on SM120a
22
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
33
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
4-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5-
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
6-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
75
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
86
# RUN: | FileCheck %t-ptx86-sm_120a.ll
97
# RUN: %if ptxas-12.7 %{ \

0 commit comments

Comments
 (0)