Skip to content

[NVVM][NVPTX] Add support for tcgen05.mma #151949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

schwarzschild-radius
Copy link
Contributor

This commit adds support for tcgen05.mma instructions in NVPTX with tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple flag arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is documented in NVPTXUsage.rst file. For more details, please refer the PTX ISA

@llvmbot
Copy link
Member

llvmbot commented Aug 4, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-nvptx

Author: Pradeep Kumar (schwarzschild-radius)

Changes

This commit adds support for tcgen05.mma instructions in NVPTX with tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple flag arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is documented in NVPTXUsage.rst file. For more details, please refer the PTX ISA


Patch is 386.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151949.diff

13 Files Affected:

  • (modified) llvm/docs/NVPTXUsage.rst (+385-3)
  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+429-1)
  • (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+14)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+291)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+38-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+478-4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+1-1)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-block-scale-ptx88.ll (+526)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-block-scale.ll (+291)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-disable-output-lane.ll (+677)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-scale-d.ll (+412)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-ws.ll (+569)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma.ll (+601)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index d28eb6860c33a..1b61df2cf5254 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1945,6 +1945,388 @@ The last argument `i1 %unpack` is a compile-time constant which when set, indica
 For more information, refer to the
 `PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
 
+tcgen05.mma Intrinsics
+----------------------
+
+One of the key instructions introduced in the Blackwell architecture is the tcgen05.mma family, which carries out matrix multiply-accumulate operations using the 5th generation Tensor Core unit. The `tcgen05.mma` instruction supports a broad range of capabilities, including sparsity, block scaling, and weight-stationary convolutions. Accurately modeling these through intrinsics is highly complex, and the following table outlines the large number of intrinsics required to fully support the tcgen05.mma instruction set.
+
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| variant                            | Configuration                                                                                     | Total Variants |
++====================================+===================================================================================================+================+
+| tcgen05.mma.shared                 | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.tensor.ashift          | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d                | 2 (space) x 2 (sp) x 2 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d.tensor.ashift  | 2 (sp) x 2 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 16             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane    | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane... | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf4nvf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)               | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)                   | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf8f6f4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)               | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.ws                     | 2 (space) x 2 (sp) x 4 (kind) x 2 (zero_col_mask) x 4 (collector_usage_op) x 4 (collector_buffer) | 256            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| Total                              |                                                                                                   | 816            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+
+To reduce the number of possible intrinsic variations, we've modeled the tcgen05.mma instructions using flag operands. We've added range checks to these flags to prevent invalid values. We also expanded some flags back into intrinsic modifiers to avoid supporting invalid combinations of features.
+
+'``llvm.nvvm.tcgen05.mma.*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  declare void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; .sp variants
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; .scale_d variants
+  declare void @llvm.nvvm.tcgen05.mma.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+  ; sp.scale_d variants
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+Overview:
+"""""""""
+
+`nvvm.tcgen05.mma` is an asynchronous intrinsic which initiates an `MxNxK` matrix multiply and accumulate operation, `D = A * B + D` where the `A` matrix is `M x K`, the `B` matrix is `K x N`, and the `D` matrix is `M x N`. The operation of the form `D = A*B` is issued when the input predicate argument `%enable_inp_d` is false. The optional immediate argument `%scale_d_imm` can be specified to scale the input matrix `D` as follows: `D = A * B + D * (2 ^ - %scale_d_imm)`. The valid range of values for argument `%scale_d_imm` is `[0, 15]`. The 32-bit register operand idesc is the instruction descriptor as described in `Instruction descriptor <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor>`__
+
+`nvvm.tcgen05.mma` has single thread semantics, unlike the collective instructions `nvvm.mma.sync` or the PTX `wgmma.mma_async` instruction. So, a single thread issuing the `nvvm.tcgen05.mma` will result in the initiation of the whole matrix multiply and accumulate operation
+
+When `.sp` is specifed, the dimension of A matrix is `M x (K/2)` and requires specifiying an additional `%spmetadata` argument
+
+`.ashift` shifts the rows of the A matrix down by one row, except for the last row in the Tensor Memory. `.ashift` is only allowed with M = 128 or M = 256.
+
+The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`. It is illegal to specify either of `USE` or `FILL` for `%collector_usage_a_op_flag` along with `.ashift`
+
+For more information, refer to the
+`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__
+
+The following tables describes the possible values of the flag arguments
+
+`%kind_flag` flag:
+
+============= ==========
+  `kind_flag`   value
+============= ==========
+     F16          0
+     TF32         1
+     F8F6F4       2
+     I8           3
+============= ==========
+
+`%cta_group` flag:
+
+============= ==========
+ `cta_group`    value
+============= ==========
+     CG1          1
+     CG2          2
+============= ==========
+
+`%collector_usage_a_op_flag` flag:
+
+============================= ==========
+ `collector_usage_a_op_flag`    value
+============================= ==========
+     DISCARD                      0
+     LASTUSE                      1
+     USE                          2
+     FILL                         3
+============================= ==========
+
+'``llvm.nvvm.tcgen05.mma.block_scale*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  ; mxf8f6f4
+  declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; mxf4
+  declare void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; mxf4nvf4
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+Overview:
+"""""""""
+`nvvm.tcgen05.mma.block_scale` is an asynchronous intrinsic which initiates an `MxNxK` matrix multiply and accumulate operation, `D = (A * scale_a)  * (B * scale_a) + D` where the `A` matrix is `M x K`, the `B` matrix is `K x N`, and the `D` matrix is `M x N`. The matrices `A` and `B` are scaled with `%scale_A` and `%scale_B` matrices respectively before performing the matrix multiply and accumulate operation. The operation of the form `D = A*B` is issued when the input predicate argument `%enable_inp_d` is false. The 32-bit register operand idesc is the instruction descriptor as described in `Instruction descriptor <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor>`__
+
+`nvvm.tcgen05.mma.block_scale` has single thread semantics, unlike the collective instructions `nvvm.mma.sync` or the PTX `wgmma.mma_async` instruction. So, a single thread issuing the `nvvm.tcgen05.mma.block_scale` will result in the initiation of the whole matrix multiply and accumulate operation
+
+When `.sp` is specifed, the dimension of A matrix is `Mx(K/2)` and requires specifiying an additional `%spmetadata` argument
+
+The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`
+
+For more information, refer to the
+`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__
+
+The following tables describes the possible values of the flag arguments
+
+`%kind_flag` flag:
+
+============= ==========
+ `kind_flag`    value
+============= ==========
+  MXF8F6F4        0
+  MXF4            1
+  MXF4NVF4        2
+============= ==========
+
+`%cta_group` flag:
+
+============= ==========
+ `cta_group`    value
+============= ==========
+     CG1          1
+     CG2          2
+============= ==========
+
+`%collector_usage_a_op_flag` flag:
+
+========================...
[truncated]

Copy link

github-actions bot commented Aug 4, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@schwarzschild-radius schwarzschild-radius force-pushed the tcgen05_mma_nvptx_support branch from bfe642c to 8dd88b7 Compare August 4, 2025 12:25
This commit adds support for tcgen05.mma instructions in NVPTX which tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is present documented in NVPTXUsage.rst file
@schwarzschild-radius schwarzschild-radius force-pushed the tcgen05_mma_nvptx_support branch from 8dd88b7 to 8d4b4f5 Compare August 4, 2025 14:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants