-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[AMDGPU] gfx1250 v_wmma_ld_scale instructions #152010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
rampitec
merged 1 commit into
main
from
users/rampitec/08-04-_amdgpu_gfx1250_v_wmma_ld_scale_instructions
Aug 4, 2025
Merged
[AMDGPU] gfx1250 v_wmma_ld_scale instructions #152010
rampitec
merged 1 commit into
main
from
users/rampitec/08-04-_amdgpu_gfx1250_v_wmma_ld_scale_instructions
Aug 4, 2025
+554
−49
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-mc Author: Stanislav Mekhanoshin (rampitec) ChangesPatch is 51.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152010.diff 10 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a83caa0db8a69..d33765db9cc7d 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -178,6 +178,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyBitOp3,
ImmTyMatrixAFMT,
ImmTyMatrixBFMT,
+ ImmTyMatrixAScale,
+ ImmTyMatrixBScale,
+ ImmTyMatrixAScaleFmt,
+ ImmTyMatrixBScaleFmt,
ImmTyMatrixAReuse,
ImmTyMatrixBReuse,
ImmTyScaleSel,
@@ -428,6 +432,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
+ bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
+ bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
+ bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
+ bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1183,6 +1191,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyBitOp3: OS << "BitOp3"; break;
case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
+ case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
+ case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
+ case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
+ case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
case ImmTyScaleSel: OS << "ScaleSel" ; break;
@@ -1728,6 +1740,14 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
AMDGPUOperand::ImmTy Type);
ParseStatus parseMatrixAFMT(OperandVector &Operands);
ParseStatus parseMatrixBFMT(OperandVector &Operands);
+ ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAScale(OperandVector &Operands);
+ ParseStatus parseMatrixBScale(OperandVector &Operands);
+ ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
+ ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);
ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
@@ -7356,6 +7376,42 @@ ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
AMDGPUOperand::ImmTyMatrixBFMT);
}
+ParseStatus AMDGPUAsmParser::tryParseMatrixScale(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(
+ Operands, Name, {"MATRIX_SCALE_ROW0", "MATRIX_SCALE_ROW1"}, Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScale(OperandVector &Operands) {
+ return tryParseMatrixScale(Operands, "matrix_a_scale",
+ AMDGPUOperand::ImmTyMatrixAScale);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScale(OperandVector &Operands) {
+ return tryParseMatrixScale(Operands, "matrix_b_scale",
+ AMDGPUOperand::ImmTyMatrixBScale);
+}
+
+ParseStatus AMDGPUAsmParser::tryParseMatrixScaleFmt(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(
+ Operands, Name,
+ {"MATRIX_SCALE_FMT_E8", "MATRIX_SCALE_FMT_E5M3", "MATRIX_SCALE_FMT_E4M3"},
+ Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAScaleFmt(OperandVector &Operands) {
+ return tryParseMatrixScaleFmt(Operands, "matrix_a_scale_fmt",
+ AMDGPUOperand::ImmTyMatrixAScaleFmt);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBScaleFmt(OperandVector &Operands) {
+ return tryParseMatrixScaleFmt(Operands, "matrix_b_scale_fmt",
+ AMDGPUOperand::ImmTyMatrixBScaleFmt);
+}
+
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
// values to live in a joint format operand in the MCInst encoding.
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9489,6 +9545,34 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
AMDGPUOperand::ImmTyMatrixBFMT, 0);
}
+ int MatrixAScaleIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale);
+ if (MatrixAScaleIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAScale, 0);
+ }
+
+ int MatrixBScaleIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale);
+ if (MatrixBScaleIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBScale, 0);
+ }
+
+ int MatrixAScaleFmtIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale_fmt);
+ if (MatrixAScaleFmtIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAScaleFmt, 0);
+ }
+
+ int MatrixBScaleFmtIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale_fmt);
+ if (MatrixBScaleFmtIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBScaleFmt, 0);
+ }
+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index 42c4d8b8a9717..ee8683a549a80 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -1393,6 +1393,75 @@ void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo,
printMatrixFMT(MI, OpNo, STI, O, 'b');
}
+void AMDGPUInstPrinter::printMatrixScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O, char AorB) {
+ auto Imm = MI->getOperand(OpNo).getImm() & 1;
+ if (Imm == 0)
+ return;
+
+ O << " matrix_" << AorB << "_scale:";
+ switch (Imm) {
+ default:
+ O << Imm;
+ break;
+ case WMMA::MatrixScale::MATRIX_SCALE_ROW0:
+ O << "MATRIX_SCALE_ROW0";
+ break;
+ case WMMA::MatrixScale::MATRIX_SCALE_ROW1:
+ O << "MATRIX_SCALE_ROW1";
+ break;
+ }
+}
+
+void AMDGPUInstPrinter::printMatrixAScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScale(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScale(MI, OpNo, STI, O, 'b');
+}
+
+void AMDGPUInstPrinter::printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O, char AorB) {
+ auto Imm = MI->getOperand(OpNo).getImm() & 3;
+ if (Imm == 0)
+ return;
+
+ O << " matrix_" << AorB << "_scale_fmt:";
+ switch (Imm) {
+ default:
+ O << Imm;
+ break;
+ case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E8:
+ O << "MATRIX_SCALE_FMT_E8";
+ break;
+ case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E5M3:
+ O << "MATRIX_SCALE_FMT_E5M3";
+ break;
+ case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E4M3:
+ O << "MATRIX_SCALE_FMT_E4M3";
+ break;
+ }
+}
+
+void AMDGPUInstPrinter::printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScaleFmt(MI, OpNo, STI, O, 'a');
+}
+
+void AMDGPUInstPrinter::printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI,
+ raw_ostream &O) {
+ printMatrixScaleFmt(MI, OpNo, STI, O, 'b');
+}
+
void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum,
const MCSubtargetInfo &STI,
raw_ostream &O) {
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
index f6739b14926e1..be32061c64537 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
@@ -140,6 +140,19 @@ class AMDGPUInstPrinter : public MCInstPrinter {
const MCSubtargetInfo &STI, raw_ostream &O);
void printMatrixBFMT(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O, char AorB);
+ void printMatrixAScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixBScale(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O,
+ char AorB);
+ void printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
+ void printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
+ const MCSubtargetInfo &STI, raw_ostream &O);
void printInterpSlot(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printInterpAttr(const MCInst *MI, unsigned OpNo,
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index c56414519a6fe..deadb7aed0f69 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -1018,6 +1018,17 @@ enum MatrixFMT : unsigned {
MATRIX_FMT_BF6 = 3,
MATRIX_FMT_FP4 = 4
};
+
+enum MatrixScale : unsigned {
+ MATRIX_SCALE_ROW0 = 0,
+ MATRIX_SCALE_ROW1 = 1,
+};
+
+enum MatrixScaleFmt : unsigned {
+ MATRIX_SCALE_FMT_E8 = 0,
+ MATRIX_SCALE_FMT_E5M3 = 1,
+ MATRIX_SCALE_FMT_E4M3 = 2
+};
} // namespace WMMA
namespace VOP3PEncoding {
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 4698a5805ee0c..50914a5ef231f 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1310,6 +1310,12 @@ def bitop3_0 : DefaultOperand<BitOp3, 0>;
def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">;
def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">;
+def MatrixAScale : CustomOperand<i32, 1, "MatrixAScale">;
+def MatrixBScale : CustomOperand<i32, 1, "MatrixBScale">;
+
+def MatrixAScaleFmt : CustomOperand<i32, 1, "MatrixAScaleFmt">;
+def MatrixBScaleFmt : CustomOperand<i32, 1, "MatrixBScaleFmt">;
+
def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">;
def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">;
@@ -2680,6 +2686,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
field bit HasNeg = HasModifiers;
field bit HasMatrixReuse = 0;
field bit HasMatrixFMT = 0;
+ field bit HasMatrixScale = 0;
+ field bit HasMatrixReuse = 0;
field bit HasSrc0Mods = HasModifiers;
field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0);
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 95fcd4ac1c101..457c0eed4f047 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -1407,9 +1407,9 @@ let WaveSizePredicate = isWave64 in {
}
class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
- bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
- bit _HasMatrixFMT = 0, bit _HasMatrixReuse = 0,
- bit _IsF4 = 0>
+ bit _IsIU, bit _IsFP8BF8XF32, bit _Has_ImodOp = 0,
+ bit _HasMatrixFMT = 0, bit _HasMatrixScale = 0,
+ bit _Scale16 = 0, bit _HasMatrixReuse = 0, bit _IsF4 = 0>
: VOP3P_Profile<VOPProfile<ArgTy>> {
bit IsIU = _IsIU;
bit NoABMods = !or(_IsFP8BF8XF32, _IsF4); // No IMOD support for A and B
@@ -1417,6 +1417,8 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
int IndexType = _IndexType;
let HasMatrixFMT = _HasMatrixFMT;
+ let HasMatrixScale = _HasMatrixScale;
+ bit Scale16 = _Scale16;
let HasMatrixReuse = _HasMatrixReuse;
bit HasIModOp = _Has_ImodOp;
@@ -1455,6 +1457,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
IsC_F16: "_f16",
IsC_BF16: "_bf16",
1: "_b32")));
+ ValueType ScaleTy = !if(Scale16, i64, i32);
// For f16 and bf16 matrices A and B, each element can be modified by
// fneg(neg_lo,neg_hi = 1). For f32 and f64, neg_lo[0:1] is allowed, but
@@ -1516,6 +1519,13 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
!eq(IndexType, 32): (ins IndexKey32bit:$index_key_32bit));
dag MatrixFMT = !if(HasMatrixFMT, (ins MatrixAFMT:$matrix_a_fmt, MatrixBFMT:$matrix_b_fmt),
(ins));
+ dag MatrixScaleSrc = !if(HasMatrixScale,
+ !if(Scale16, (ins VCSrc_b64:$scale_src0, VCSrc_b64:$scale_src1),
+ (ins VCSrc_b32:$scale_src0, VCSrc_b32:$scale_src1)),
+ (ins));
+ dag MatrixScale = !if(HasMatrixScale, (ins MatrixAScale:$matrix_a_scale, MatrixBScale:$matrix_b_scale,
+ MatrixAScaleFmt:$matrix_a_scale_fmt, MatrixBScaleFmt:$matrix_b_scale_fmt),
+ (ins));
dag MatrixReuse = !if(HasMatrixReuse, (ins MatrixAReuse:$matrix_a_reuse, MatrixBReuse:$matrix_b_reuse), (ins));
dag Clamp = !if(HasClamp, (ins Clamp0:$clamp), (ins));
dag Neg = !cond(!and(NegLoAny, NegHiAny) : (ins neg_lo0:$neg_lo, neg_hi0:$neg_hi),
@@ -1529,7 +1539,7 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
(ins VRegSrc_64:$src2),
(ins VRegSrc_32:$src2)),
IndexKey)),
- MatrixFMT, MatrixReuse, Clamp, Neg);
+ MatrixScaleSrc, MatrixFMT, MatrixScale, MatrixReuse, Clamp, Neg);
// asm
@@ -1538,13 +1548,15 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
!eq(IndexType, 16) : "$index_key_16bit",
!eq(IndexType, 32) : "$index_key_32bit");
string MatrxFMTAsm = !if(HasMatrixFMT, "$matrix_a_fmt$matrix_b_fmt", "");
+ string MatrixScaleSrcAsm = !if(HasMatrixScale, ", $scale_src0, $scale_src1", "");
+ string MatrixScaleAsm = !if(HasMatrixScale, "$matrix_a_scale$matrix_b_scale$matrix_a_scale_fmt$matrix_b_scale_fmt", "");
string MatrixReuseAsm = !if(HasMatrixReuse, "$matrix_a_reuse$matrix_b_reuse", "");
string ClampAsm = !if(HasClamp, "$clamp", "");
string NegAsm = !cond(!and(NegLoAny, NegHiAny) : "$neg_lo$neg_hi",
!and(NegLoAny, !not(NegHiAny)) : "$neg_lo",
!and(!not(NegLoAny), !not(NegHiAny)) : "");
- let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrxFMTAsm#MatrixReuseAsm#NegAsm#ClampAsm;
+ let AsmVOP3P = "$vdst, $src0, $src1, $src2"#IndexKeyAsm#MatrixScaleSrcAsm#MatrxFMTAsm#MatrixScaleAsm#MatrixReuseAsm#NegAsm#ClampAsm;
// isel patterns
bit IsAB_BF16_IMod0 = !and(IsAB_BF16, !not(HasIModOp));
@@ -1606,20 +1618,27 @@ class VOP3PWMMA_Profile<list<ValueType> ArgTy, bit _IsSWMMAC, int _IndexType,
dag MatrixFMTOutPat = !if(HasMatrixFMT, (ins i32:$matrix_a_fmt, i32:$matrix_b_fmt), (ins));
dag Src2InlineInPat = !con(!if(IsC_IMod1, (ins (VOP3PModsNegAbs i32:$src2_modifiers)), (ins)), (ins (Src2VT (WMMAVISrc Src2VT:$src2))));
dag Src2InlineOutPat = !con(!if(IsIUXF32, (ins), !if(IsC_IMod1, (ins i32:$src2_modifiers), (ins (i32 8)))), (ins Src2VT:$src2));
+ dag MatrixScaleInPat = !if(HasMatrixScale, (ins timm:$matrix_a_scale, timm:$matrix_a_scale_fmt, ScaleTy:$scale_src0,
+ timm:$matrix_b_scale, timm:$matrix_b_scale_fmt, ScaleTy:$scale_src1),
+ (ins));
dag MatrixReuseInPat = !if(HasMatrixReuse, (ins timm:$matrix_a_reuse, timm:$matrix_b_reuse), (ins));
+ dag MatrixScaleOutSrcPat = !if(HasMatrixScale, (ins ScaleTy:$scale_src0, ScaleTy:$scale_src1), (ins));
+ dag MatrixScaleOutModPat = !if(HasMatrixScale, (ins i32:$matrix_a_scale, i32:$matrix_b_scale, i32:$matrix_a_scale_fmt, i32:$matrix_b_scale_fmt), (ins));
dag MatrixReuseOutModPat = !if(HasMatrixReuse, (ins i1:$matrix_a_reuse, i1:$matrix_b_reuse), (ins));
- dag WmmaInPat = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixReuseInPat, ClampPat);
- dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+ dag WmmaInPat = !con(Src0InPat, Src1InPat, Src2InPatWmma, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+ dag WmmaOutPat = !con(Src0OutPat, Src1OutPat, Src2OutPatWmma, MatrixScaleOutSrcPat, MatrixFMTOutPat,
+ MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
dag SwmmacInPat = !con(Src0InPat, Src1InPat, (ins Src2VT:$srcTiedDef), IndexInPat, MatrixReuseInPat, ClampPat);
dag SwmmacOutPat = !con(Src0OutPat, Src1OutPat, (ins Src2VT:$srcTiedDef), IndexOutPat, MatrixReuseOutModPat, ClampPat);
// wmma pattern where src2 is inline imm uses _threeaddr pseudo,
// can't use _twoaddr since it would violate src2 tied to vdst constraint.
- dag WmmaInlineInPat = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixReuseInPat, ClampPat);
- dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixFMTOutPat, MatrixReuseOutModPat, ClampPat);
+ dag WmmaInlineInPat = !con(Src0InPat, Src1InPat, Src2InlineInPat, MatrixScaleInPat, MatrixReuseInPat, ClampPat);
+ dag WmmaInlineOutPat = !con(Src0OutPat, Src1OutPat, Src2InlineOutPat, MatrixScaleOutSrcPat,
+ MatrixFMTOutPat, MatrixScaleOutModPat, MatrixReuseOutModPat, ClampPat);
}
def WMMAInstInfoTable : GenericTable {
@@ -1728,39 +1747,51 @@ def F32_FP8BF8_SWMMAC_w64 : VOP3PWMMA_Profile<[v4f32, i32, v2i32, v4f32], 1,
// *** IU4X32_SWMMAC_w64 lanes 0-31 will have 8xi4 remaining lanes are ignored
// for matrix A, index is i16; Matrix B uses all lanes
-def F32_F32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v2f32, v2f32, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_BF16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16bf16, v16bf16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F32_F16X32_WMMA_w32 : VOP3PWMMA_Profile<[v8f32, v16f16, v16f16, v8f32], 0, 0, 0, 0, 1, 0, 1>;
-def F16_F16X32_WMMA_w32 ...
[truncated]
|
changpeng
approved these changes
Aug 4, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.