Skip to content

Commit dd0737b

Browse files
authored
[AMDGPU] gfx1250 v_wmma_ld_scale instructions (#152010)
1 parent 215e6be commit dd0737b

File tree

10 files changed

+554
-49
lines changed

10 files changed

+554
-49
lines changed

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
178178
ImmTyBitOp3,
179179
ImmTyMatrixAFMT,
180180
ImmTyMatrixBFMT,
181+
ImmTyMatrixAScale,
182+
ImmTyMatrixBScale,
183+
ImmTyMatrixAScaleFmt,
184+
ImmTyMatrixBScaleFmt,
181185
ImmTyMatrixAReuse,
182186
ImmTyMatrixBReuse,
183187
ImmTyScaleSel,
@@ -428,6 +432,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
428432
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
429433
bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
430434
bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
435+
bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
436+
bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
437+
bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
438+
bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
431439
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
432440
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
433441
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1183,6 +1191,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
11831191
case ImmTyBitOp3: OS << "BitOp3"; break;
11841192
case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
11851193
case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
1194+
case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
1195+
case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
1196+
case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
1197+
case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
11861198
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
11871199
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
11881200
case ImmTyScaleSel: OS << "ScaleSel" ; break;
@@ -1728,6 +1740,14 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
17281740
AMDGPUOperand::ImmTy Type);
17291741
ParseStatus parseMatrixAFMT(OperandVector &Operands);
17301742
ParseStatus parseMatrixBFMT(OperandVector &Operands);
1743+
ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
1744+
AMDGPUOperand::ImmTy Type);
1745+
ParseStatus parseMatrixAScale(OperandVector &Operands);
1746+
ParseStatus parseMatrixBScale(OperandVector &Operands);
1747+
ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
1748+
AMDGPUOperand::ImmTy Type);
1749+
ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
1750+
ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);
17311751

17321752
ParseStatus parseDfmtNfmt(int64_t &Format);
17331753
ParseStatus parseUfmt(int64_t &Format);
@@ -7356,6 +7376,42 @@ ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
73567376
AMDGPUOperand::ImmTyMatrixBFMT);
73577377
}
73587378

7379+
ParseStatus AMDGPUAsmParser::tryParseMatrixScale(OperandVector &Operands,
7380+
StringRef Name,
7381+
AMDGPUOperand::ImmTy Type) {
7382+
return parseStringOrIntWithPrefix(
7383+
Operands, Name, {"MATRIX_SCALE_ROW0", "MATRIX_SCALE_ROW1"}, Type);
7384+
}
7385+
7386+
ParseStatus AMDGPUAsmParser::parseMatrixAScale(OperandVector &Operands) {
7387+
return tryParseMatrixScale(Operands, "matrix_a_scale",
7388+
AMDGPUOperand::ImmTyMatrixAScale);
7389+
}
7390+
7391+
ParseStatus AMDGPUAsmParser::parseMatrixBScale(OperandVector &Operands) {
7392+
return tryParseMatrixScale(Operands, "matrix_b_scale",
7393+
AMDGPUOperand::ImmTyMatrixBScale);
7394+
}
7395+
7396+
ParseStatus AMDGPUAsmParser::tryParseMatrixScaleFmt(OperandVector &Operands,
7397+
StringRef Name,
7398+
AMDGPUOperand::ImmTy Type) {
7399+
return parseStringOrIntWithPrefix(
7400+
Operands, Name,
7401+
{"MATRIX_SCALE_FMT_E8", "MATRIX_SCALE_FMT_E5M3", "MATRIX_SCALE_FMT_E4M3"},
7402+
Type);
7403+
}
7404+
7405+
ParseStatus AMDGPUAsmParser::parseMatrixAScaleFmt(OperandVector &Operands) {
7406+
return tryParseMatrixScaleFmt(Operands, "matrix_a_scale_fmt",
7407+
AMDGPUOperand::ImmTyMatrixAScaleFmt);
7408+
}
7409+
7410+
ParseStatus AMDGPUAsmParser::parseMatrixBScaleFmt(OperandVector &Operands) {
7411+
return tryParseMatrixScaleFmt(Operands, "matrix_b_scale_fmt",
7412+
AMDGPUOperand::ImmTyMatrixBScaleFmt);
7413+
}
7414+
73597415
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
73607416
// values to live in a joint format operand in the MCInst encoding.
73617417
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9489,6 +9545,34 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
94899545
AMDGPUOperand::ImmTyMatrixBFMT, 0);
94909546
}
94919547

9548+
int MatrixAScaleIdx =
9549+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale);
9550+
if (MatrixAScaleIdx != -1) {
9551+
addOptionalImmOperand(Inst, Operands, OptIdx,
9552+
AMDGPUOperand::ImmTyMatrixAScale, 0);
9553+
}
9554+
9555+
int MatrixBScaleIdx =
9556+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale);
9557+
if (MatrixBScaleIdx != -1) {
9558+
addOptionalImmOperand(Inst, Operands, OptIdx,
9559+
AMDGPUOperand::ImmTyMatrixBScale, 0);
9560+
}
9561+
9562+
int MatrixAScaleFmtIdx =
9563+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale_fmt);
9564+
if (MatrixAScaleFmtIdx != -1) {
9565+
addOptionalImmOperand(Inst, Operands, OptIdx,
9566+
AMDGPUOperand::ImmTyMatrixAScaleFmt, 0);
9567+
}
9568+
9569+
int MatrixBScaleFmtIdx =
9570+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale_fmt);
9571+
if (MatrixBScaleFmtIdx != -1) {
9572+
addOptionalImmOperand(Inst, Operands, OptIdx,
9573+
AMDGPUOperand::ImmTyMatrixBScaleFmt, 0);
9574+
}
9575+
94929576
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
94939577
addOptionalImmOperand(Inst, Operands, OptIdx,
94949578
AMDGPUOperand::ImmTyMatrixAReuse, 0);

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,75 @@ void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo,
13931393
printMatrixFMT(MI, OpNo, STI, O, 'b');
13941394
}
13951395

1396+
void AMDGPUInstPrinter::printMatrixScale(const MCInst *MI, unsigned OpNo,
1397+
const MCSubtargetInfo &STI,
1398+
raw_ostream &O, char AorB) {
1399+
auto Imm = MI->getOperand(OpNo).getImm() & 1;
1400+
if (Imm == 0)
1401+
return;
1402+
1403+
O << " matrix_" << AorB << "_scale:";
1404+
switch (Imm) {
1405+
default:
1406+
O << Imm;
1407+
break;
1408+
case WMMA::MatrixScale::MATRIX_SCALE_ROW0:
1409+
O << "MATRIX_SCALE_ROW0";
1410+
break;
1411+
case WMMA::MatrixScale::MATRIX_SCALE_ROW1:
1412+
O << "MATRIX_SCALE_ROW1";
1413+
break;
1414+
}
1415+
}
1416+
1417+
void AMDGPUInstPrinter::printMatrixAScale(const MCInst *MI, unsigned OpNo,
1418+
const MCSubtargetInfo &STI,
1419+
raw_ostream &O) {
1420+
printMatrixScale(MI, OpNo, STI, O, 'a');
1421+
}
1422+
1423+
void AMDGPUInstPrinter::printMatrixBScale(const MCInst *MI, unsigned OpNo,
1424+
const MCSubtargetInfo &STI,
1425+
raw_ostream &O) {
1426+
printMatrixScale(MI, OpNo, STI, O, 'b');
1427+
}
1428+
1429+
void AMDGPUInstPrinter::printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
1430+
const MCSubtargetInfo &STI,
1431+
raw_ostream &O, char AorB) {
1432+
auto Imm = MI->getOperand(OpNo).getImm() & 3;
1433+
if (Imm == 0)
1434+
return;
1435+
1436+
O << " matrix_" << AorB << "_scale_fmt:";
1437+
switch (Imm) {
1438+
default:
1439+
O << Imm;
1440+
break;
1441+
case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E8:
1442+
O << "MATRIX_SCALE_FMT_E8";
1443+
break;
1444+
case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E5M3:
1445+
O << "MATRIX_SCALE_FMT_E5M3";
1446+
break;
1447+
case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E4M3:
1448+
O << "MATRIX_SCALE_FMT_E4M3";
1449+
break;
1450+
}
1451+
}
1452+
1453+
void AMDGPUInstPrinter::printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
1454+
const MCSubtargetInfo &STI,
1455+
raw_ostream &O) {
1456+
printMatrixScaleFmt(MI, OpNo, STI, O, 'a');
1457+
}
1458+
1459+
void AMDGPUInstPrinter::printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
1460+
const MCSubtargetInfo &STI,
1461+
raw_ostream &O) {
1462+
printMatrixScaleFmt(MI, OpNo, STI, O, 'b');
1463+
}
1464+
13961465
void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum,
13971466
const MCSubtargetInfo &STI,
13981467
raw_ostream &O) {

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,19 @@ class AMDGPUInstPrinter : public MCInstPrinter {
140140
const MCSubtargetInfo &STI, raw_ostream &O);
141141
void printMatrixBFMT(const MCInst *MI, unsigned OpNo,
142142
const MCSubtargetInfo &STI, raw_ostream &O);
143+
void printMatrixScale(const MCInst *MI, unsigned OpNo,
144+
const MCSubtargetInfo &STI, raw_ostream &O, char AorB);
145+
void printMatrixAScale(const MCInst *MI, unsigned OpNo,
146+
const MCSubtargetInfo &STI, raw_ostream &O);
147+
void printMatrixBScale(const MCInst *MI, unsigned OpNo,
148+
const MCSubtargetInfo &STI, raw_ostream &O);
149+
void printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
150+
const MCSubtargetInfo &STI, raw_ostream &O,
151+
char AorB);
152+
void printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
153+
const MCSubtargetInfo &STI, raw_ostream &O);
154+
void printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
155+
const MCSubtargetInfo &STI, raw_ostream &O);
143156
void printInterpSlot(const MCInst *MI, unsigned OpNo,
144157
const MCSubtargetInfo &STI, raw_ostream &O);
145158
void printInterpAttr(const MCInst *MI, unsigned OpNo,

llvm/lib/Target/AMDGPU/SIDefines.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,17 @@ enum MatrixFMT : unsigned {
10181018
MATRIX_FMT_BF6 = 3,
10191019
MATRIX_FMT_FP4 = 4
10201020
};
1021+
1022+
enum MatrixScale : unsigned {
1023+
MATRIX_SCALE_ROW0 = 0,
1024+
MATRIX_SCALE_ROW1 = 1,
1025+
};
1026+
1027+
enum MatrixScaleFmt : unsigned {
1028+
MATRIX_SCALE_FMT_E8 = 0,
1029+
MATRIX_SCALE_FMT_E5M3 = 1,
1030+
MATRIX_SCALE_FMT_E4M3 = 2
1031+
};
10211032
} // namespace WMMA
10221033

10231034
namespace VOP3PEncoding {

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,12 @@ def bitop3_0 : DefaultOperand<BitOp3, 0>;
13101310
def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">;
13111311
def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">;
13121312

1313+
def MatrixAScale : CustomOperand<i32, 1, "MatrixAScale">;
1314+
def MatrixBScale : CustomOperand<i32, 1, "MatrixBScale">;
1315+
1316+
def MatrixAScaleFmt : CustomOperand<i32, 1, "MatrixAScaleFmt">;
1317+
def MatrixBScaleFmt : CustomOperand<i32, 1, "MatrixBScaleFmt">;
1318+
13131319
def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">;
13141320
def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">;
13151321

@@ -2680,6 +2686,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
26802686
field bit HasNeg = HasModifiers;
26812687
field bit HasMatrixReuse = 0;
26822688
field bit HasMatrixFMT = 0;
2689+
field bit HasMatrixScale = 0;
2690+
field bit HasMatrixReuse = 0;
26832691

26842692
field bit HasSrc0Mods = HasModifiers;
26852693
field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0);

0 commit comments

Comments
 (0)