Skip to content

[AMDGPU] gfx1250 v_wmma_ld_scale instructions #152010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

rampitec
Copy link
Collaborator

@rampitec rampitec commented Aug 4, 2025

No description provided.

@rampitec rampitec requested a review from changpeng August 4, 2025 18:05
Copy link
Collaborator Author

rampitec commented Aug 4, 2025

This stack of pull requests is managed by Graphite. Learn more about stacking.

@rampitec rampitec marked this pull request as ready for review August 4, 2025 18:05
@llvmbot llvmbot added backend:AMDGPU mc Machine (object) code labels Aug 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 4, 2025

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mc

Author: Stanislav Mekhanoshin (rampitec)

Changes

Patch is 51.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152010.diff

10 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+84)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+69)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+13)
  • (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+11)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+8)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+78-42)
  • (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+21-7)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+170)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+10)
  • (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+90)
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a83caa0db8a69..d33765db9cc7d 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -178,6 +178,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyBitOp3,
     ImmTyMatrixAFMT,
     ImmTyMatrixBFMT,
+    ImmTyMatrixAScale,
+    ImmTyMatrixBScale,
+    ImmTyMatrixAScaleFmt,
+    ImmTyMatrixBScaleFmt,
     ImmTyMatrixAReuse,
     ImmTyMatrixBReuse,
     ImmTyScaleSel,
@@ -428,6 +432,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
   bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
   bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
+  bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
+  bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
+  bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
+  bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
   bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
   bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1183,6 +1191,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyBitOp3: OS << "BitOp3"; break;
     case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
     case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
+    case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
+    case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
+    case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
+    case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
     case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
     case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
     case ImmTyScaleSel: OS << "ScaleSel" ; break;
@@ -1728,6 +1740,14 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                                 AMDGPUOperand::ImmTy Type);
   ParseStatus parseMatrixAFMT(OperandVector &Operands);
   ParseStatus parseMatrixBFMT(OperandVector &Operands);
+  ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
+                                  AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAScale(OperandVector &Operands);
+  ParseStatus parseMatrixBScale(OperandVector &Operands);
+  ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
+                                     AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
+  ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -7356,6 +7376,42 @@ ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
                            AMDGPUOperand::ImmTyMatrixBFMT);
 }
 
+ParseStatus AMDGPUAsmParser::tryParseMatrixScale(OperandVector &Operands,
+                                                 StringRef Name,
+                                                 AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(
+      Operands, Name, {"MATRIX_SCALE_ROW0", "MATRIX_SCALE_ROW1"}, Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScale(OperandVector &Operands) {
+  return tryParseMatrixScale(Operands, "matrix_a_scale",
+                             AMDGPUOperand::ImmTyMatrixAScale);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScale(OperandVector &Operands) {
+  return tryParseMatrixScale(Operands, "matrix_b_scale",
+                             AMDGPUOperand::ImmTyMatrixBScale);
+}
+
+ParseStatus AMDGPUAsmParser::tryParseMatrixScaleFmt(OperandVector &Operands,
+                                                    StringRef Name,
+                                                    AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(
+      Operands, Name,
+      {"MATRIX_SCALE_FMT_E8", "MATRIX_SCALE_FMT_E5M3", "MATRIX_SCALE_FMT_E4M3"},
+      Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScaleFmt(OperandVector &Operands) {
+  return tryParseMatrixScaleFmt(Operands, "matrix_a_scale_fmt",
+                                AMDGPUOperand::ImmTyMatrixAScaleFmt);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScaleFmt(OperandVector &Operands) {
+  return tryParseMatrixScaleFmt(Operands, "matrix_b_scale_fmt",
+                                AMDGPUOperand::ImmTyMatrixBScaleFmt);
+}
+
 // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
 // values to live in a joint format operand in the MCInst encoding.
 ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9489,6 +9545,34 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
                           AMDGPUOperand::ImmTyMatrixBFMT, 0);
   }
 
+  int MatrixAScaleIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale);
+  if (MatrixAScaleIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAScale, 0);
+  }
+
+  int MatrixBScaleIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale);
+  if (MatrixBScaleIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBScale, 0);
+  }
+
+  int MatrixAScaleFmtIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale_fmt);
+  if (MatrixAScaleFmtIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAScaleFmt, 0);
+  }
+
+  int MatrixBScaleFmtIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale_fmt);
+  if (MatrixBScaleFmtIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBScaleFmt, 0);
+  }
+
   if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
     addOptionalImmOperand(Inst, Operands, OptIdx,
                           AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index 42c4d8b8a9717..ee8683a549a80 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -1393,6 +1393,75 @@ void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo,
   printMatrixFMT(MI, OpNo, STI, O, 'b');
 }
 
+void AMDGPUInstPrinter::printMatrixScale(const MCInst *MI, unsigned OpNo,
+                                         const MCSubtargetInfo &STI,
+                                         raw_ostream &O, char AorB) {
+  auto Imm = MI->getOperand(OpNo).getImm() & 1;
+  if (Imm == 0)
+    return;
+
+  O << " matrix_" << AorB << "_scale:";
+  switch (Imm) {
+  default:
+    O << Imm;
+    break;
+  case WMMA::MatrixScale::MATRIX_SCALE_ROW0:
+    O << "MATRIX_SCALE_ROW0";
+    break;
+  case WMMA::MatrixScale::MATRIX_SCALE_ROW1:
+    O << "MATRIX_SCALE_ROW1";
+    break;
+  }
+}
+
+void AMDGPUInstPrinter::printMatrixAScale(const MCInst *MI, unsigned OpNo,
+                                          const MCSubtargetInfo &STI,
+                                          raw_ostream &O) {
+  printMatrixScale(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScale(const MCInst *MI, unsigned OpNo,
+                                          const MCSubtargetInfo &STI,
+                                          raw_ostream &O) {
+  printMatrixScale(MI, OpNo, STI, O, 'b');
+}
+
+void AMDGPUInstPrinter::printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+                                            const MCSubtargetInfo &STI,
+                                            raw_ostream &O, char AorB) {
+  auto Imm = MI->getOperand(OpNo).getImm() & 3;
+  if (Imm == 0)
+    return;
+
+  O << " matrix_" << AorB << "_scale_fmt:";
+  switch (Imm) {
+  default:
+    O << Imm;
+    break;
+  case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E8:
+    O << "MATRIX_SCALE_FMT_E8";
+    break;
+  case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E5M3:
+    O << "MATRIX_SCALE_FMT_E5M3";
+    break;
+  case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E4M3:
+    O << "MATRIX_SCALE_FMT_E4M3";
+    break;
+  }
+}
+
+void AMDGPUInstPrinter::printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+                                             const MCSubtargetInfo &STI,
+                                             raw_ostream &O) {
+  printMatrixScaleFmt(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+                                             const MCSubtargetInfo &STI,
+                                             raw_ostream &O) {
+  printMatrixScaleFmt(MI, OpNo, STI, O, 'b');
+}
+
 void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum,
                                         const MCSubtargetInfo &STI,
                                         raw_ostream &O) {
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
index f6739b14926e1..be32061c64537 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
@@ -140,6 +140,19 @@ class AMDGPUInstPrinter : public MCInstPrinter {
                        const MCSubtargetInfo &STI, raw_ostream &O);
   void printMatrixBFMT(const MCInst *MI, unsigned OpNo,
                        const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixScale(const MCInst *MI, unsigned OpNo,
+                        const MCSubtargetInfo &STI, raw_ostream &O, char AorB);
+  void printMatrixAScale(const MCInst *MI, unsigned OpNo,
+                         const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixBScale(const MCInst *MI, unsigned OpNo,
+                         const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+                           const MCSubtargetInfo &STI, raw_ostream &O,
+                           char AorB);
+  void printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+                            const MCSubtargetInfo &STI, raw_ostream &O);
+  void printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+                            const MCSubtargetInfo &STI, raw_ostream &O);
   void printInterpSlot(const MCInst *MI, unsigned OpNo,
                        const MCSubtargetInfo &STI, raw_ostream &O);
   void printInterpAttr(const MCInst *MI, unsigned OpNo,
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index c56414519a6fe..deadb7aed0f69 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -1018,6 +1018,17 @@ enum MatrixFMT : unsigned {
   MATRIX_FMT_BF6 = 3,
   MATRIX_FMT_FP4 = 4
 };
+
+enum MatrixScale : unsigned {
+  MATRIX_SCALE_ROW0 = 0,
+  MATRIX_SCALE_ROW1 = 1,
+};
+
+enum MatrixScaleFmt : unsigned {
+  MATRIX_SCALE_FMT_E8 = 0,
+  MATRIX_SCALE_FMT_E5M3 = 1,
+  MATRIX_SCALE_FMT_E4M3 = 2
+};
 } // namespace WMMA
 
 namespace VOP3PEncoding {
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 4698a5805ee0c..50914a5ef231f 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1310,6 +1310,12 @@ def bitop3_0 : DefaultOperand<BitOp3, 0>;
 def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">;
 def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">;
 
+def MatrixAScale : CustomOperand<i32, 1, "MatrixAScale">;
+def MatrixBScale : CustomOperand<i32, 1, "MatrixBScale">;
+
+def MatrixAScaleFmt : CustomOperand<i32, 1, "MatrixAScaleFmt">;
+def MatrixBScaleFmt : CustomOperand<i32, 1, "MatrixBScaleFmt">;
+
 def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">;
 def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">;
 
@@ -2680,6 +2686,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
   field bit HasNeg = HasModifiers;
   field bit HasMatrixReuse = 0;
   field bit HasMatrixFMT = 0;
+  field bit HasMatrixScale = 0;
+  field bit HasMatrixReuse = 0;
 
   field bit HasSrc0Mods = HasModifiers;
   field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0);
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 95fcd4ac1c101..457c0eed4f047 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -1407,9 +1407,9 @@ let WaveSizePredicate = isWave64 in {
 }
 
 class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
-                             bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
-                             bit _HasMatrixFMT = 0, bit _HasMatrixReuse = 0,
-                             bit _IsF4 = 0>
+                        bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
+                        bit _HasMatrixFMT = 0, bit _HasMatrixScale = 0,
+                        bit _Scale16 = 0, bit _HasMatrixReuse = 0, bit _IsF4 = 0>
     : VOP3P_Profile<VOPProfile<ArgTy>> {
   bit IsIU = _IsIU;
   bit NoABMods = !or(_IsFP8BF8XF32, _IsF4); // No IMOD support for A and B
@@ -1417,6 +1417,8 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
 
   int IndexType = _IndexType;
   let HasMatrixFMT = _HasMatrixFMT;
+  let HasMatrixScale = _HasMatrixScale;
+  bit Scale16 = _Scale16;
   let HasMatrixReuse = _HasMatrixReuse;
 
   bit HasIModOp = _Has_ImodOp;
@@ -1455,6 +1457,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                                                             IsC_F16:  "_f16",
                                                             IsC_BF16: "_bf16",
                                                             1: "_b32")));
+  ValueType ScaleTy = !if(Scale16, i64, i32);
 
   // For f16 and bf16 matrices A and B, each element can be modified by
   // fneg(neg_lo,neg_hi = 1). For f32 and f64, neg_lo[0:1] is allowed, but
@@ -1516,6 +1519,13 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                        !eq(IndexType, 32): (ins IndexKey32bit:$index_key_32bit));
   dag MatrixFMT = !if(HasMatrixFMT, (ins MatrixAFMT:$matrix_a_fmt, MatrixBFMT:$matrix_b_fmt),
                                    (ins));
+  dag MatrixScaleSrc = !if(HasMatrixScale,
+                           !if(Scale16, (ins VCSrc_b64:$scale_src0, VCSrc_b64:$scale_src1),
+                                        (ins VCSrc_b32:$scale_src0, VCSrc_b32:$scale_src1)),
+                           (ins));
+  dag MatrixScale = !if(HasMatrixScale, (ins MatrixAScale:$matrix_a_scale, MatrixBScale:$matrix_b_scale,
+                                             MatrixAScaleFmt:$matrix_a_scale_fmt, MatrixBScaleFmt:$matrix_b_scale_fmt),
+                                        (ins));
   dag MatrixReuse = !if(HasMatrixReuse, (ins MatrixAReuse:$matrix_a_reuse, MatrixBReuse:$matrix_b_reuse), (ins));
   dag Clamp = !if(HasClamp, (ins Clamp0:$clamp), (ins));
   dag Neg = !cond(!and(NegLoAny, NegHiAny)             : (ins neg_lo0:$neg_lo, neg_hi0:$neg_hi),
@@ -1529,7 +1539,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                                                  (ins VRegSrc_64:$src2),
                                                  (ins VRegSrc_32:$src2)),
                                             IndexKey)),
-                      MatrixFMT, MatrixReuse, Clamp, Neg);
+                      MatrixScaleSrc, MatrixFMT, MatrixScale, MatrixReuse, Clamp, Neg);
 
   // asm
 
@@ -1538,13 +1548,15 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
                              !eq(IndexType, 16) : "$index_key_16bit",
                              !eq(IndexType, 32) : "$index_key_32bit");
   string MatrxFMTAsm = !if(HasMatrixFMT, "$matrix_a_fmt$matrix_b_fmt", "");
+  string MatrixScaleSrcAsm = !if(HasMatrixScale, ", $scale_src0, $scale_src1", "");
+  string MatrixScaleAsm = !if(HasMatrixScale, "$matrix_a_scale$matrix_b_scale$matrix_a_scale_fmt$matrix_b_scale_fmt", "");
   string MatrixReuseAsm = !if(HasMatrixReuse, "$matrix_a_reuse$matrix_b_reuse", "");
   string ClampAsm = !if(HasClamp, "$clamp", "");
   string NegAsm = !cond(!and(NegLoAny, NegHiAny)             : "$neg_lo$neg_hi",
                         !and(NegLoAny, !not(NegHiAny))       : "$neg_lo",
                         !and(!not(NegLoAny), !not(NegHiAny)) : "");
 
-  let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrxFMTAsm#MatrixReuseAsm#NegAsm#ClampAsm;
+  let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrixScaleSrcAsm#MatrxFMTAsm#MatrixScaleAsm#MatrixReuseAsm#NegAsm#ClampAsm;
 
   // isel patterns
   bit IsAB_BF16_IMod0 = !and(IsAB_BF16, !not(HasIModOp));
@@ -1606,20 +1618,27 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
   dag MatrixFMTOutPat = !if(HasMatrixFMT, (ins i32:$matrix_a_fmt, i32:$matrix_b_fmt), (ins));
   dag Src2InlineInPat = !con(!if(IsC_IMod1, (ins (VOP3PModsNegAbs i32:$src2_modifiers)), (ins)), (ins (Src2VT (WMMAVISrc Src2VT:$src2))));
   dag Src2InlineOutPat = !con(!if(IsIUXF32, (ins), !if(IsC_IMod1,  (ins i32:$src2_modifiers), (ins (i32 8)))), (ins Src2VT:$src2));
+  dag MatrixScaleInPat = !if(HasMatrixScale, (ins timm:$matrix_a_scale, timm:$matrix_a_scale_fmt, ScaleTy:$scale_src0,
+                                                  timm:$matrix_b_scale, timm:$matrix_b_scale_fmt, ScaleTy:$scale_src1),
+                                             (ins));
 
   dag MatrixReuseInPat = !if(HasMatrixReuse, (ins timm:$matrix_a_reuse, timm:$matrix_b_reuse), (ins));
+  dag MatrixScaleOutSrcPat = !if(HasMatrixScale, (ins ScaleTy:$scale_src0, ScaleTy:$scale_src1), (ins));
+  dag MatrixScaleOutModPat = !if(HasMatrixScale, (ins i32:$matrix_a_scale, i32:$matrix_b_scale, i32:$matrix_a_scale_fmt, i32:$matrix_b_scale_fmt), (ins));
   dag MatrixReuseOutModPat = !if(HasMatrixReuse, (ins i1:$matrix_a_reuse, i1:$matrix_b_reuse), (ins));
 
-  dag WmmaInPat  = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixReuseInPat, ClampPat);
-  dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+  dag WmmaInPat  = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+  dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixScaleOutSrcPat, MatrixFMTOutPat,
+                        MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
 
   dag SwmmacInPat  = !con(Src0InPat, Src1InPat, (ins Src2VT:$srcTiedDef), IndexInPat, MatrixReuseInPat, ClampPat);
   dag SwmmacOutPat = !con(Src0OutPat, Src1OutPat, (ins Src2VT:$srcTiedDef), IndexOutPat, MatrixReuseOutModPat, ClampPat);
 
   // wmma pattern where src2 is inline imm uses _threeaddr pseudo,
   // can't use _twoaddr since it would violate src2 tied to vdst constraint.
-  dag WmmaInlineInPat  = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixReuseInPat, ClampPat);
-  dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+  dag WmmaInlineInPat  = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+  dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixScaleOutSrcPat,
+                              MatrixFMTOutPat, MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
 }
 
 def WMMAInstInfoTable : GenericTable {
@@ -1728,39 +1747,51 @@ def F32_FP8BF8_SWMMAC_w64 : VOP3PWMMA_Profile<[v4f32,   i32, v2i32, v4f32], 1,
 // *** IU4X32_SWMMAC_w64 lanes 0-31 will have 8xi4 remaining lanes are ignored
 //                       for matrix A, index is i16; Matrix B uses all lanes
 
-def F32_F32_WMMA_w32             : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_BF16X32_WMMA_w32         : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_F16X32_WMMA_w32          : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F16_F16X32_WMMA_w32        ...
[truncated]

@rampitec rampitec merged commit dd0737b into main Aug 4, 2025
14 checks passed
@rampitec rampitec deleted the users/rampitec/08-04-_amdgpu_gfx1250_v_wmma_ld_scale_instructions branch August 4, 2025 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants