Skip to content

[AMDGPU] gfx1250 v_wmma_ld_scale instructions #152010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyBitOp3,
ImmTyMatrixAFMT,
ImmTyMatrixBFMT,
ImmTyMatrixAScale,
ImmTyMatrixBScale,
ImmTyMatrixAScaleFmt,
ImmTyMatrixBScaleFmt,
ImmTyMatrixAReuse,
ImmTyMatrixBReuse,
ImmTyScaleSel,
Expand Down Expand Up @@ -428,6 +432,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
bool isMatrixAScale() const { return isImmTy(ImmTyMatrixAScale); }
bool isMatrixBScale() const { return isImmTy(ImmTyMatrixBScale); }
bool isMatrixAScaleFmt() const { return isImmTy(ImmTyMatrixAScaleFmt); }
bool isMatrixBScaleFmt() const { return isImmTy(ImmTyMatrixBScaleFmt); }
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
Expand Down Expand Up @@ -1183,6 +1191,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyBitOp3: OS << "BitOp3"; break;
case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
case ImmTyMatrixAScale: OS << "ImmTyMatrixAScale"; break;
case ImmTyMatrixBScale: OS << "ImmTyMatrixBScale"; break;
case ImmTyMatrixAScaleFmt: OS << "ImmTyMatrixAScaleFmt"; break;
case ImmTyMatrixBScaleFmt: OS << "ImmTyMatrixBScaleFmt"; break;
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
case ImmTyScaleSel: OS << "ScaleSel" ; break;
Expand Down Expand Up @@ -1728,6 +1740,14 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
AMDGPUOperand::ImmTy Type);
ParseStatus parseMatrixAFMT(OperandVector &Operands);
ParseStatus parseMatrixBFMT(OperandVector &Operands);
ParseStatus tryParseMatrixScale(OperandVector &Operands, StringRef Name,
AMDGPUOperand::ImmTy Type);
ParseStatus parseMatrixAScale(OperandVector &Operands);
ParseStatus parseMatrixBScale(OperandVector &Operands);
ParseStatus tryParseMatrixScaleFmt(OperandVector &Operands, StringRef Name,
AMDGPUOperand::ImmTy Type);
ParseStatus parseMatrixAScaleFmt(OperandVector &Operands);
ParseStatus parseMatrixBScaleFmt(OperandVector &Operands);

ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
Expand Down Expand Up @@ -7356,6 +7376,42 @@ ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
AMDGPUOperand::ImmTyMatrixBFMT);
}

ParseStatus AMDGPUAsmParser::tryParseMatrixScale(OperandVector &Operands,
StringRef Name,
AMDGPUOperand::ImmTy Type) {
return parseStringOrIntWithPrefix(
Operands, Name, {"MATRIX_SCALE_ROW0", "MATRIX_SCALE_ROW1"}, Type);
}

ParseStatus AMDGPUAsmParser::parseMatrixAScale(OperandVector &Operands) {
return tryParseMatrixScale(Operands, "matrix_a_scale",
AMDGPUOperand::ImmTyMatrixAScale);
}

ParseStatus AMDGPUAsmParser::parseMatrixBScale(OperandVector &Operands) {
return tryParseMatrixScale(Operands, "matrix_b_scale",
AMDGPUOperand::ImmTyMatrixBScale);
}

ParseStatus AMDGPUAsmParser::tryParseMatrixScaleFmt(OperandVector &Operands,
StringRef Name,
AMDGPUOperand::ImmTy Type) {
return parseStringOrIntWithPrefix(
Operands, Name,
{"MATRIX_SCALE_FMT_E8", "MATRIX_SCALE_FMT_E5M3", "MATRIX_SCALE_FMT_E4M3"},
Type);
}

ParseStatus AMDGPUAsmParser::parseMatrixAScaleFmt(OperandVector &Operands) {
return tryParseMatrixScaleFmt(Operands, "matrix_a_scale_fmt",
AMDGPUOperand::ImmTyMatrixAScaleFmt);
}

ParseStatus AMDGPUAsmParser::parseMatrixBScaleFmt(OperandVector &Operands) {
return tryParseMatrixScaleFmt(Operands, "matrix_b_scale_fmt",
AMDGPUOperand::ImmTyMatrixBScaleFmt);
}

// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
// values to live in a joint format operand in the MCInst encoding.
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
Expand Down Expand Up @@ -9489,6 +9545,34 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
AMDGPUOperand::ImmTyMatrixBFMT, 0);
}

int MatrixAScaleIdx =
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale);
if (MatrixAScaleIdx != -1) {
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAScale, 0);
}

int MatrixBScaleIdx =
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale);
if (MatrixBScaleIdx != -1) {
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixBScale, 0);
}

int MatrixAScaleFmtIdx =
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_scale_fmt);
if (MatrixAScaleFmtIdx != -1) {
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAScaleFmt, 0);
}

int MatrixBScaleFmtIdx =
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_scale_fmt);
if (MatrixBScaleFmtIdx != -1) {
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixBScaleFmt, 0);
}

if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAReuse, 0);
Expand Down
69 changes: 69 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,75 @@ void AMDGPUInstPrinter::printMatrixBFMT(const MCInst *MI, unsigned OpNo,
printMatrixFMT(MI, OpNo, STI, O, 'b');
}

void AMDGPUInstPrinter::printMatrixScale(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O, char AorB) {
auto Imm = MI->getOperand(OpNo).getImm() & 1;
if (Imm == 0)
return;

O << " matrix_" << AorB << "_scale:";
switch (Imm) {
default:
O << Imm;
break;
case WMMA::MatrixScale::MATRIX_SCALE_ROW0:
O << "MATRIX_SCALE_ROW0";
break;
case WMMA::MatrixScale::MATRIX_SCALE_ROW1:
O << "MATRIX_SCALE_ROW1";
break;
}
}

void AMDGPUInstPrinter::printMatrixAScale(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
printMatrixScale(MI, OpNo, STI, O, 'a');
}

void AMDGPUInstPrinter::printMatrixBScale(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
printMatrixScale(MI, OpNo, STI, O, 'b');
}

void AMDGPUInstPrinter::printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O, char AorB) {
auto Imm = MI->getOperand(OpNo).getImm() & 3;
if (Imm == 0)
return;

O << " matrix_" << AorB << "_scale_fmt:";
switch (Imm) {
default:
O << Imm;
break;
case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E8:
O << "MATRIX_SCALE_FMT_E8";
break;
case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E5M3:
O << "MATRIX_SCALE_FMT_E5M3";
break;
case WMMA::MatrixScaleFmt::MATRIX_SCALE_FMT_E4M3:
O << "MATRIX_SCALE_FMT_E4M3";
break;
}
}

void AMDGPUInstPrinter::printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
printMatrixScaleFmt(MI, OpNo, STI, O, 'a');
}

void AMDGPUInstPrinter::printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
printMatrixScaleFmt(MI, OpNo, STI, O, 'b');
}

void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum,
const MCSubtargetInfo &STI,
raw_ostream &O) {
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,19 @@ class AMDGPUInstPrinter : public MCInstPrinter {
const MCSubtargetInfo &STI, raw_ostream &O);
void printMatrixBFMT(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printMatrixScale(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O, char AorB);
void printMatrixAScale(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printMatrixBScale(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printMatrixScaleFmt(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O,
char AorB);
void printMatrixAScaleFmt(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printMatrixBScaleFmt(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printInterpSlot(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printInterpAttr(const MCInst *MI, unsigned OpNo,
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AMDGPU/SIDefines.h
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,17 @@ enum MatrixFMT : unsigned {
MATRIX_FMT_BF6 = 3,
MATRIX_FMT_FP4 = 4
};

enum MatrixScale : unsigned {
MATRIX_SCALE_ROW0 = 0,
MATRIX_SCALE_ROW1 = 1,
};

enum MatrixScaleFmt : unsigned {
MATRIX_SCALE_FMT_E8 = 0,
MATRIX_SCALE_FMT_E5M3 = 1,
MATRIX_SCALE_FMT_E4M3 = 2
};
} // namespace WMMA

namespace VOP3PEncoding {
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,12 @@ def bitop3_0 : DefaultOperand<BitOp3, 0>;
def MatrixAFMT : CustomOperand<i32, 1, "MatrixAFMT">;
def MatrixBFMT : CustomOperand<i32, 1, "MatrixBFMT">;

def MatrixAScale : CustomOperand<i32, 1, "MatrixAScale">;
def MatrixBScale : CustomOperand<i32, 1, "MatrixBScale">;

def MatrixAScaleFmt : CustomOperand<i32, 1, "MatrixAScaleFmt">;
def MatrixBScaleFmt : CustomOperand<i32, 1, "MatrixBScaleFmt">;

def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">;
def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">;

Expand Down Expand Up @@ -2680,6 +2686,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
field bit HasNeg = HasModifiers;
field bit HasMatrixReuse = 0;
field bit HasMatrixFMT = 0;
field bit HasMatrixScale = 0;
field bit HasMatrixReuse = 0;

field bit HasSrc0Mods = HasModifiers;
field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0);
Expand Down
Loading