diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index 8eec91562ecfe..ee1ca4538554b 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -391,16 +391,6 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, } } -void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum, - raw_ostream &O) { - auto &Op = MI->getOperand(OpNum); - assert(Op.isImm() && "Invalid operand"); - if (Op.getImm() != 0) { - O << "+"; - printOperand(MI, OpNum, O); - } -} - void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O) { int64_t Imm = MI->getOperand(OpNum).getImm(); diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h index c3ff3469150e4..92155b01464e8 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h @@ -46,7 +46,6 @@ class NVPTXInstPrinter : public MCInstPrinter { StringRef Modifier = {}); void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O, StringRef Modifier = {}); - void printOffseti32imm(const MCInst *MI, int OpNum, raw_ostream &O); void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O); void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O); void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O); diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp index cd404819cb837..a3496090def3c 100644 --- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp @@ -56,15 +56,12 @@ static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI, case NVPTX::LD_i16: case NVPTX::LD_i32: case NVPTX::LD_i64: - case NVPTX::LD_i8: case NVPTX::LDV_i16_v2: case NVPTX::LDV_i16_v4: case NVPTX::LDV_i32_v2: case NVPTX::LDV_i32_v4: case NVPTX::LDV_i64_v2: - case NVPTX::LDV_i64_v4: - case NVPTX::LDV_i8_v2: - case NVPTX::LDV_i8_v4: { + case NVPTX::LDV_i64_v4: { LoadInsts.push_back(&U); return true; } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 95abcded46485..6068035b2ee47 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1003,14 +1003,10 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) { // Helper function template to reduce amount of boilerplate code for // opcode selection. static std::optional -pickOpcodeForVT(MVT::SimpleValueType VT, std::optional Opcode_i8, - std::optional Opcode_i16, +pickOpcodeForVT(MVT::SimpleValueType VT, std::optional Opcode_i16, std::optional Opcode_i32, std::optional Opcode_i64) { switch (VT) { - case MVT::i1: - case MVT::i8: - return Opcode_i8; case MVT::f16: case MVT::i16: case MVT::bf16: @@ -1078,8 +1074,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { Chain}; const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; - const std::optional Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_i8, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64); + const std::optional Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64); if (!Opcode) return false; @@ -1164,17 +1160,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { default: llvm_unreachable("Unexpected opcode"); case NVPTXISD::LoadV2: - Opcode = - pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v2, NVPTX::LDV_i16_v2, - NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2); + Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v2, + NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2); break; case NVPTXISD::LoadV4: - Opcode = - pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v4, NVPTX::LDV_i16_v4, - NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4); + Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v4, + NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4); break; case NVPTXISD::LoadV8: - Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i8 */}, {/* no v8i16 */}, + Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i16 */}, NVPTX::LDV_i32_v8, {/* no v8i64 */}); break; } @@ -1230,22 +1224,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { default: llvm_unreachable("Unexpected opcode"); case ISD::LOAD: - Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i8, - NVPTX::LD_GLOBAL_NC_i16, NVPTX::LD_GLOBAL_NC_i32, - NVPTX::LD_GLOBAL_NC_i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16, + NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64); break; case NVPTXISD::LoadV2: - Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_GLOBAL_NC_v2i8, NVPTX::LD_GLOBAL_NC_v2i16, - NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64); + Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16, + NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64); break; case NVPTXISD::LoadV4: - Opcode = pickOpcodeForVT( - TargetVT, NVPTX::LD_GLOBAL_NC_v4i8, NVPTX::LD_GLOBAL_NC_v4i16, - NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64); + Opcode = + pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v4i16, + NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64); break; case NVPTXISD::LoadV8: - Opcode = pickOpcodeForVT(TargetVT, {/* no v8i8 */}, {/* no v8i16 */}, + Opcode = pickOpcodeForVT(TargetVT, {/* no v8i16 */}, NVPTX::LD_GLOBAL_NC_v8i32, {/* no v8i64 */}); break; } @@ -1276,8 +1269,9 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) { break; } - const MVT::SimpleValueType SelectVT = - MVT::getIntegerVT(LD->getMemoryVT().getSizeInBits() / NumElts).SimpleTy; + SDLoc DL(N); + const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts; + const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; // If this is an LDU intrinsic, the address is the third operand. If its an // LDU SD node (from custom vector handling), then its the second operand @@ -1286,32 +1280,28 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) { SDValue Base, Offset; SelectADDR(Addr, Base, Offset); - SDValue Ops[] = {Base, Offset, LD->getChain()}; + SDValue Ops[] = {getI32Imm(FromTypeWidth, DL), Base, Offset, LD->getChain()}; std::optional Opcode; switch (N->getOpcode()) { default: llvm_unreachable("Unexpected opcode"); case ISD::INTRINSIC_W_CHAIN: - Opcode = - pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_i8, NVPTX::LDU_GLOBAL_i16, - NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_i16, + NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64); break; case NVPTXISD::LDUV2: - Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v2i8, - NVPTX::LDU_GLOBAL_v2i16, NVPTX::LDU_GLOBAL_v2i32, - NVPTX::LDU_GLOBAL_v2i64); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v2i16, + NVPTX::LDU_GLOBAL_v2i32, NVPTX::LDU_GLOBAL_v2i64); break; case NVPTXISD::LDUV4: - Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v4i8, - NVPTX::LDU_GLOBAL_v4i16, NVPTX::LDU_GLOBAL_v4i32, - {/* no v4i64 */}); + Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v4i16, + NVPTX::LDU_GLOBAL_v4i32, {/* no v4i64 */}); break; } if (!Opcode) return false; - SDLoc DL(N); SDNode *NVPTXLDU = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops); ReplaceNode(LD, NVPTXLDU); @@ -1362,8 +1352,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { Chain}; const std::optional Opcode = - pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i8, - NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64); + pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i16, + NVPTX::ST_i32, NVPTX::ST_i64); if (!Opcode) return false; @@ -1423,16 +1413,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { default: return false; case NVPTXISD::StoreV2: - Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v2, NVPTX::STV_i16_v2, - NVPTX::STV_i32_v2, NVPTX::STV_i64_v2); + Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v2, NVPTX::STV_i32_v2, + NVPTX::STV_i64_v2); break; case NVPTXISD::StoreV4: - Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v4, NVPTX::STV_i16_v4, - NVPTX::STV_i32_v4, NVPTX::STV_i64_v4); + Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, + NVPTX::STV_i64_v4); break; case NVPTXISD::StoreV8: - Opcode = pickOpcodeForVT(EltVT, {/* no v8i8 */}, {/* no v8i16 */}, - NVPTX::STV_i32_v8, {/* no v8i64 */}); + Opcode = pickOpcodeForVT(EltVT, {/* no v8i16 */}, NVPTX::STV_i32_v8, + {/* no v8i64 */}); break; } @@ -1687,10 +1677,11 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) { auto API = APF.bitcastToAPInt(); API = API.concat(API); auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32); - return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32i, DL, VT, Const), 0); + return SDValue(CurDAG->getMachineNode(NVPTX::MOV_B32_i, DL, VT, Const), + 0); } auto Const = CurDAG->getTargetConstantFP(APF, DL, VT); - return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16i, DL, VT, Const), 0); + return SDValue(CurDAG->getMachineNode(NVPTX::MOV_BF16_i, DL, VT, Const), 0); }; switch (N->getOpcode()) { diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td index 86dcb4a9384f1..719be0300940e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td @@ -11,15 +11,9 @@ // //===----------------------------------------------------------------------===// -// Vector instruction type enum -class VecInstTypeEnum val> { - bits<4> Value=val; -} -def VecNOP : VecInstTypeEnum<0>; - // Generic NVPTX Format -class NVPTXInst pattern> +class NVPTXInst pattern = []> : Instruction { field bits<14> Inst; @@ -30,7 +24,6 @@ class NVPTXInst pattern> let Pattern = pattern; // TSFlagFields - bits<4> VecInstType = VecNOP.Value; bit IsLoad = false; bit IsStore = false; @@ -45,7 +38,6 @@ class NVPTXInst pattern> // 2**(2-1) = 2. bits<2> IsSuld = 0; - let TSFlags{3...0} = VecInstType; let TSFlags{4} = IsLoad; let TSFlags{5} = IsStore; let TSFlags{6} = IsTex; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index e218ef17bb09b..34fe467c94563 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -35,23 +35,23 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB, const TargetRegisterClass *DestRC = MRI.getRegClass(DestReg); const TargetRegisterClass *SrcRC = MRI.getRegClass(SrcReg); - if (RegInfo.getRegSizeInBits(*DestRC) != RegInfo.getRegSizeInBits(*SrcRC)) + if (DestRC != SrcRC) report_fatal_error("Copy one register into another with a different width"); unsigned Op; - if (DestRC == &NVPTX::B1RegClass) { - Op = NVPTX::IMOV1r; - } else if (DestRC == &NVPTX::B16RegClass) { - Op = NVPTX::MOV16r; - } else if (DestRC == &NVPTX::B32RegClass) { - Op = NVPTX::IMOV32r; - } else if (DestRC == &NVPTX::B64RegClass) { - Op = NVPTX::IMOV64r; - } else if (DestRC == &NVPTX::B128RegClass) { - Op = NVPTX::IMOV128r; - } else { + if (DestRC == &NVPTX::B1RegClass) + Op = NVPTX::MOV_B1_r; + else if (DestRC == &NVPTX::B16RegClass) + Op = NVPTX::MOV_B16_r; + else if (DestRC == &NVPTX::B32RegClass) + Op = NVPTX::MOV_B32_r; + else if (DestRC == &NVPTX::B64RegClass) + Op = NVPTX::MOV_B64_r; + else if (DestRC == &NVPTX::B128RegClass) + Op = NVPTX::MOV_B128_r; + else llvm_unreachable("Bad register copy"); - } + BuildMI(MBB, I, DL, get(Op), DestReg) .addReg(SrcReg, getKillRegState(KillSrc)); } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 6000b40694763..d8047d31ff6f0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -15,19 +15,8 @@ include "NVPTXInstrFormats.td" let OperandType = "OPERAND_IMMEDIATE" in { def f16imm : Operand; def bf16imm : Operand; - } -// List of vector specific properties -def isVecLD : VecInstTypeEnum<1>; -def isVecST : VecInstTypeEnum<2>; -def isVecBuild : VecInstTypeEnum<3>; -def isVecShuffle : VecInstTypeEnum<4>; -def isVecExtract : VecInstTypeEnum<5>; -def isVecInsert : VecInstTypeEnum<6>; -def isVecDest : VecInstTypeEnum<7>; -def isVecOther : VecInstTypeEnum<15>; - //===----------------------------------------------------------------------===// // NVPTX Operand Definitions. //===----------------------------------------------------------------------===// @@ -484,46 +473,28 @@ let hasSideEffects = false in { // takes a CvtMode immediate that defines the conversion mode to use. It can // be CvtNONE to omit a conversion mode. multiclass CVT_FROM_ALL Preds = []> { - def _s8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s8">, - Requires; - def _u8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u8">, - Requires; - def _s16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s16">, - Requires; - def _u16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u16">, - Requires; - def _s32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s32">, - Requires; - def _u32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u32">, - Requires; - def _s64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s64">, - Requires; - def _u64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u64">, - Requires; + foreach sign = ["s", "u"] in { + def _ # sign # "8" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "8">, + Requires; + def _ # sign # "16" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "16">, + Requires; + def _ # sign # "32" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B32:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "32">, + Requires; + def _ # sign # "64" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B64:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "64">, + Requires; + } def _f16 : BasicFlagsNVPTXInst<(outs RC:$dst), (ins B16:$src), (ins CvtMode:$mode), @@ -554,14 +525,12 @@ let hasSideEffects = false in { } // Generate cvts from all types to all types. - defm CVT_s8 : CVT_FROM_ALL<"s8", B16>; - defm CVT_u8 : CVT_FROM_ALL<"u8", B16>; - defm CVT_s16 : CVT_FROM_ALL<"s16", B16>; - defm CVT_u16 : CVT_FROM_ALL<"u16", B16>; - defm CVT_s32 : CVT_FROM_ALL<"s32", B32>; - defm CVT_u32 : CVT_FROM_ALL<"u32", B32>; - defm CVT_s64 : CVT_FROM_ALL<"s64", B64>; - defm CVT_u64 : CVT_FROM_ALL<"u64", B64>; + foreach sign = ["s", "u"] in { + defm CVT_ # sign # "8" : CVT_FROM_ALL; + defm CVT_ # sign # "16" : CVT_FROM_ALL; + defm CVT_ # sign # "32" : CVT_FROM_ALL; + defm CVT_ # sign # "64" : CVT_FROM_ALL; + } defm CVT_f16 : CVT_FROM_ALL<"f16", B16>; defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>; defm CVT_f32 : CVT_FROM_ALL<"f32", B32>; @@ -569,18 +538,12 @@ let hasSideEffects = false in { // These cvts are different from those above: The source and dest registers // are of the same type. - def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "cvt.s16.s8">; - def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s8">; - def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s16">; - def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s8">; - def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s16">; - def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s32">; + def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "cvt.s16.s8">; + def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s8">; + def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s16">; + def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s8">; + def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s16">; + def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s32">; multiclass CVT_FROM_FLOAT_V2_SM80 { def _f32 : @@ -782,7 +745,7 @@ defm SUB : I3<"sub.s", sub, commutative = false>; def ADD16x2 : I16x2<"add.s", add>; -// in32 and int64 addition and subtraction with carry-out. +// int32 and int64 addition and subtraction with carry-out. defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>; defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>; @@ -803,17 +766,17 @@ defm UDIV : I3<"div.u", udiv, commutative = false>; defm SREM : I3<"rem.s", srem, commutative = false>; defm UREM : I3<"rem.u", urem, commutative = false>; -// Integer absolute value. NumBits should be one minus the bit width of RC. -// This idiom implements the algorithm at -// http://graphics.stanford.edu/~seander/bithacks.html#IntegerAbs. -multiclass ABS { - def : BasicNVPTXInst<(outs RC:$dst), (ins RC:$a), - "abs" # SizeName, - [(set T:$dst, (abs T:$a))]>; +foreach t = [I16RT, I32RT, I64RT] in { + def ABS_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "abs.s" # t.Size, + [(set t.Ty:$dst, (abs t.Ty:$a))]>; + + def NEG_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "neg.s" # t.Size, + [(set t.Ty:$dst, (ineg t.Ty:$src))]>; } -defm ABS_16 : ABS; -defm ABS_32 : ABS; -defm ABS_64 : ABS; // Integer min/max. defm SMAX : I3<"max.s", smax, commutative = true>; @@ -830,116 +793,63 @@ def UMIN16x2 : I16x2<"min.u", umin>; // // Wide multiplication // -def MULWIDES64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">; -def MULWIDES64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">; - -def MULWIDEU64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">; -def MULWIDEU64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">; - -def MULWIDES32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">; -def MULWIDES32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">; - -def MULWIDEU32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">; -def MULWIDEU32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">; def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>; -def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>; -def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>; +def smul_wide : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>; +def umul_wide : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>; -// Matchers for signed, unsigned mul.wide ISD nodes. -let Predicates = [hasOptEnabled] in { - def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>; - def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>; - def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>; - def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>; +multiclass MULWIDEInst { + def suffix # _rr : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.RC:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, small_t.Ty:$b))]>; + def suffix # _ri : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.Imm:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, imm:$b))]>; } +defm MUL_WIDE : MULWIDEInst<"s32", smul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"u32", umul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"s16", smul_wide, I32RT, I16RT>; +defm MUL_WIDE : MULWIDEInst<"u16", umul_wide, I32RT, I16RT>; + // // Integer multiply-add // -def mul_oneuse : OneUse2; - -multiclass MAD { - def rrr: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>; - - def rir: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>; - def rri: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>; - def rii: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>; -} - -let Predicates = [hasOptEnabled] in { -defm MAD16 : MAD<"mad.lo.s16", i16, B16, i16imm>; -defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>; -defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>; -} - -multiclass MAD_WIDE { +multiclass MADInst { def rrr: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2 small_t.Ty:$a, small_t.Ty:$b), big_t.Ty:$c))]>; def rri: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2 small_t.Ty:$a, small_t.Ty:$b), imm:$c))]>; def rir: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2 small_t.Ty:$a, imm:$b), big_t.Ty:$c))]>; def rii: - BasicNVPTXInst<(outs BigT.RC:$dst), - (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c), - "mad.wide." # PtxSuffix, - [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2 small_t.Ty:$a, imm:$b), imm:$c))]>; } -def mul_wide_unsigned_oneuse : OneUse2; -def mul_wide_signed_oneuse : OneUse2; - let Predicates = [hasOptEnabled] in { -defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>; -defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>; -defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>; -defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>; -} + defm MAD_LO_S16 : MADInst<"lo.s16", mul, I16RT, I16RT>; + defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>; + defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>; -foreach t = [I16RT, I32RT, I64RT] in { - def NEG_S # t.Size : - BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), - "neg.s" # t.Size, - [(set t.Ty:$dst, (ineg t.Ty:$src))]>; + defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>; + defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>; + defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>; + defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>; } //----------------------------------- @@ -1050,8 +960,7 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b), def FRCP32_approx_r : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>; @@ -1060,14 +969,12 @@ def FRCP32_approx_r : // def FDIV32_approx_rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>; def FDIV32_approx_ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>; // @@ -1090,14 +997,12 @@ def : Pat<(fdiv_full f32imm_1, f32:$b), // def FDIV32rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>; def FDIV32ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>; // @@ -1111,8 +1016,7 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b), def FRCP32r_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>; // @@ -1120,14 +1024,12 @@ def FRCP32r_prec : // def FDIV32rr_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>; def FDIV32ri_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>; @@ -1206,10 +1108,8 @@ def TANH_APPROX_f32 : // Template for three-arg bitwise operations. Takes three args, Creates .b16, // .b32, .b64, and .pred (predicate registers -- i.e., i1) versions of OpcStr. multiclass BITWISE { - defm b1 : I3Inst; - defm b16 : I3Inst; - defm b32 : I3Inst; - defm b64 : I3Inst; + foreach t = [I1RT, I16RT, I32RT, I64RT] in + defm _ # t.PtxType : I3Inst; } defm OR : BITWISE<"or", or>; @@ -1217,48 +1117,40 @@ defm AND : BITWISE<"and", and>; defm XOR : BITWISE<"xor", xor>; // PTX does not support mul on predicates, convert to and instructions -def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>; -def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>; +def : Pat<(mul i1:$a, i1:$b), (AND_predrr $a, $b)>; +def : Pat<(mul i1:$a, imm:$b), (AND_predri $a, imm:$b)>; foreach op = [add, sub] in { - def : Pat<(op i1:$a, i1:$b), (XORb1rr $a, $b)>; - def : Pat<(op i1:$a, imm:$b), (XORb1ri $a, imm:$b)>; + def : Pat<(op i1:$a, i1:$b), (XOR_predrr $a, $b)>; + def : Pat<(op i1:$a, imm:$b), (XOR_predri $a, imm:$b)>; } // These transformations were once reliably performed by instcombine, but thanks // to poison semantics they are no longer safe for LLVM IR, perform them here // instead. -def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>; -def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>; +def : Pat<(select i1:$a, i1:$b, 0), (AND_predrr $a, $b)>; +def : Pat<(select i1:$a, 1, i1:$b), (OR_predrr $a, $b)>; // Lower logical v2i16/v4i8 ops as bitwise ops on b32. foreach vt = [v2i16, v4i8] in { - def : Pat<(or vt:$a, vt:$b), (ORb32rr $a, $b)>; - def : Pat<(xor vt:$a, vt:$b), (XORb32rr $a, $b)>; - def : Pat<(and vt:$a, vt:$b), (ANDb32rr $a, $b)>; + def : Pat<(or vt:$a, vt:$b), (OR_b32rr $a, $b)>; + def : Pat<(xor vt:$a, vt:$b), (XOR_b32rr $a, $b)>; + def : Pat<(and vt:$a, vt:$b), (AND_b32rr $a, $b)>; // The constants get legalized into a bitcast from i32, so that's what we need // to match here. def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ORb32ri $a, imm:$b)>; + (OR_b32ri $a, imm:$b)>; def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))), - (XORb32ri $a, imm:$b)>; + (XOR_b32ri $a, imm:$b)>; def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ANDb32ri $a, imm:$b)>; -} - -def NOT1 : BasicNVPTXInst<(outs B1:$dst), (ins B1:$src), - "not.pred", - [(set i1:$dst, (not i1:$src))]>; -def NOT16 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "not.b16", - [(set i16:$dst, (not i16:$src))]>; -def NOT32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "not.b32", - [(set i32:$dst, (not i32:$src))]>; -def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "not.b64", - [(set i64:$dst, (not i64:$src))]>; + (AND_b32ri $a, imm:$b)>; +} + +foreach t = [I1RT, I16RT, I32RT, I64RT] in + def NOT_ # t.PtxType : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "not." # t.PtxType, + [(set t.Ty:$dst, (not t.Ty:$src))]>; // Template for left/right shifts. Takes three operands, // [dest (reg), src (reg), shift (reg or imm)]. @@ -1266,34 +1158,22 @@ def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), // // This template also defines a 32-bit shift (imm, imm) instruction. multiclass SHIFT { - def i64rr : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B32:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, i32:$b))]>; - def i64ri : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, i32imm:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, (i32 imm:$b)))]>; - def i32rr : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, i32:$b))]>; - def i32ri : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, (i32 imm:$b)))]>; - def i32ii : - BasicNVPTXInst<(outs B32:$dst), (ins i32imm:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode (i32 imm:$a), (i32 imm:$b)))]>; - def i16rr : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B32:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, i32:$b))]>; - def i16ri : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, i32imm:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, (i32 imm:$b)))]>; + let hasSideEffects = false in { + foreach t = [I64RT, I32RT, I16RT] in { + def t.Size # _rr : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, B32:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, i32:$b))]>; + def t.Size # _ri : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, (i32 imm:$b)))]>; + def t.Size # _ii : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode (t.Ty imm:$a), (i32 imm:$b)))]>; + } + } } defm SHL : SHIFT<"shl.b", shl>; @@ -1301,14 +1181,11 @@ defm SRA : SHIFT<"shr.s", sra>; defm SRL : SHIFT<"shr.u", srl>; // Bit-reverse -def BREV32 : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a), - "brev.b32", - [(set i32:$dst, (bitreverse i32:$a))]>; -def BREV64 : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a), - "brev.b64", - [(set i64:$dst, (bitreverse i64:$a))]>; +foreach t = [I64RT, I32RT] in + def BREV_ # t.PtxType : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "brev." # t.PtxType, + [(set t.Ty:$dst, (bitreverse t.Ty:$a))]>; // @@ -1562,10 +1439,7 @@ def SETP_bf16x2rr : def addr : ComplexPattern; -def ADDR_base : Operand { - let PrintMethod = "printOperand"; -} - +def ADDR_base : Operand; def ADDR : Operand { let PrintMethod = "printMemOperand"; let MIOperandInfo = (ops ADDR_base, i32imm); @@ -1579,10 +1453,6 @@ def MmaCode : Operand { let PrintMethod = "printMmaCode"; } -def Offseti32imm : Operand { - let PrintMethod = "printOffseti32imm"; -} - // Get pointer to local stack. let hasSideEffects = false in { def MOV_DEPOT_ADDR : NVPTXInst<(outs B32:$d), (ins i32imm:$num), @@ -1594,33 +1464,31 @@ let hasSideEffects = false in { // copyPhysreg is hard-coded in NVPTXInstrInfo.cpp let hasSideEffects = false, isAsCheapAsAMove = true in { - // Class for register-to-register moves - class MOVr : - BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), - "mov." # OpStr>; - - // Class for immediate-to-register moves - class MOVi : - BasicNVPTXInst<(outs RC:$dst), (ins IMMType:$src), - "mov." # OpStr, - [(set VT:$dst, ImmNode:$src)]>; -} + let isMoveReg = true in + class MOVr : + BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), "mov." # OpStr>; -def IMOV1r : MOVr; -def MOV16r : MOVr; -def IMOV32r : MOVr; -def IMOV64r : MOVr; -def IMOV128r : MOVr; + let isMoveImm = true in + class MOVi : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.Imm:$src), + "mov." # suffix, + [(set t.Ty:$dst, t.ImmNode:$src)]>; +} +def MOV_B1_r : MOVr; +def MOV_B16_r : MOVr; +def MOV_B32_r : MOVr; +def MOV_B64_r : MOVr; +def MOV_B128_r : MOVr; -def IMOV1i : MOVi; -def IMOV16i : MOVi; -def IMOV32i : MOVi; -def IMOV64i : MOVi; -def FMOV16i : MOVi; -def BFMOV16i : MOVi; -def FMOV32i : MOVi; -def FMOV64i : MOVi; +def MOV_B1_i : MOVi; +def MOV_B16_i : MOVi; +def MOV_B32_i : MOVi; +def MOV_B64_i : MOVi; +def MOV_F16_i : MOVi; +def MOV_BF16_i : MOVi; +def MOV_F32_i : MOVi; +def MOV_F64_i : MOVi; def to_tglobaladdr : SDNodeXFormgetTargetFrameIndex(N->getIndex(), N->getValueType(0)); }]>; -def : Pat<(i32 globaladdr:$dst), (IMOV32i (to_tglobaladdr $dst))>; -def : Pat<(i64 globaladdr:$dst), (IMOV64i (to_tglobaladdr $dst))>; +def : Pat<(i32 globaladdr:$dst), (MOV_B32_i (to_tglobaladdr $dst))>; +def : Pat<(i64 globaladdr:$dst), (MOV_B64_i (to_tglobaladdr $dst))>; -def : Pat<(i32 externalsym:$dst), (IMOV32i (to_texternsym $dst))>; -def : Pat<(i64 externalsym:$dst), (IMOV64i (to_texternsym $dst))>; +def : Pat<(i32 externalsym:$dst), (MOV_B32_i (to_texternsym $dst))>; +def : Pat<(i64 externalsym:$dst), (MOV_B64_i (to_texternsym $dst))>; //---- Copy Frame Index ---- def LEA_ADDRi : NVPTXInst<(outs B32:$dst), (ins ADDR:$addr), @@ -1831,7 +1699,6 @@ class LD "\t$dst, [$addr];", []>; let mayLoad=1, hasSideEffects=0 in { - def LD_i8 : LD; def LD_i16 : LD; def LD_i32 : LD; def LD_i64 : LD; @@ -1847,7 +1714,6 @@ class ST " \t[$addr], $src;", []>; let mayStore=1, hasSideEffects=0 in { - def ST_i8 : ST; def ST_i16 : ST; def ST_i32 : ST; def ST_i64 : ST; @@ -1880,7 +1746,6 @@ multiclass LD_VEC { "[$addr];", []>; } let mayLoad=1, hasSideEffects=0 in { - defm LDV_i8 : LD_VEC; defm LDV_i16 : LD_VEC; defm LDV_i32 : LD_VEC; defm LDV_i64 : LD_VEC; @@ -1914,7 +1779,6 @@ multiclass ST_VEC { } let mayStore=1, hasSideEffects=0 in { - defm STV_i8 : ST_VEC; defm STV_i16 : ST_VEC; defm STV_i32 : ST_VEC; defm STV_i64 : ST_VEC; @@ -2084,14 +1948,14 @@ def : Pat<(i64 (anyext i32:$a)), (CVT_u64_u32 $a, CvtNONE)>; // truncate i64 def : Pat<(i32 (trunc i64:$a)), (CVT_u32_u64 $a, CvtNONE)>; def : Pat<(i16 (trunc i64:$a)), (CVT_u16_u64 $a, CvtNONE)>; -def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (ANDb64ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (AND_b64ri $a, 1), 0, CmpNE)>; // truncate i32 def : Pat<(i16 (trunc i32:$a)), (CVT_u16_u32 $a, CvtNONE)>; -def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (ANDb32ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (AND_b32ri $a, 1), 0, CmpNE)>; // truncate i16 -def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (ANDb16ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (AND_b16ri $a, 1), 0, CmpNE)>; // sext_inreg def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>; @@ -2335,32 +2199,20 @@ defm : CVT_ROUND; //----------------------------------- let isTerminator=1 in { - let isReturn=1, isBarrier=1 in + let isReturn=1, isBarrier=1 in def Return : BasicNVPTXInst<(outs), (ins), "ret", [(retglue)]>; - let isBranch=1 in - def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), + let isBranch=1 in { + def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), "@$a bra \t$target;", [(brcond i1:$a, bb:$target)]>; - let isBranch=1 in - def CBranchOther : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), - "@!$a bra \t$target;", []>; - let isBranch=1, isBarrier=1 in + let isBarrier=1 in def GOTO : BasicNVPTXInst<(outs), (ins brtarget:$target), - "bra.uni", [(br bb:$target)]>; + "bra.uni", [(br bb:$target)]>; + } } -def : Pat<(brcond i32:$a, bb:$target), - (CBranch (SETP_i32ri $a, 0, CmpNE), bb:$target)>; - -// SelectionDAGBuilder::visitSWitchCase() will invert the condition of a -// conditional branch if the target block is the next block so that the code -// can fall through to the target block. The inversion is done by 'xor -// condition, 1', which will be translated to (setne condition, -1). Since ptx -// supports '@!pred bra target', we should use it. -def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target), - (CBranchOther $a, bb:$target)>; // trap instruction def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 0a00220d94289..d33719236b172 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -243,63 +243,82 @@ foreach sync = [false, true] in { } // vote.{all,any,uni,ballot} -multiclass VOTE { - def : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred), - "vote." # mode, - [(set regclass:$dest, (IntOp i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; -} +let Predicates = [hasPTX<60>, hasSM<30>] in { + multiclass VOTE { + def : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred), + "vote." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i1:$pred))]>; + } -defm VOTE_ALL : VOTE; -defm VOTE_ANY : VOTE; -defm VOTE_UNI : VOTE; -defm VOTE_BALLOT : VOTE; + defm VOTE_ALL : VOTE<"all", I1RT, int_nvvm_vote_all>; + defm VOTE_ANY : VOTE<"any", I1RT, int_nvvm_vote_any>; + defm VOTE_UNI : VOTE<"uni", I1RT, int_nvvm_vote_uni>; + defm VOTE_BALLOT : VOTE<"ballot", I32RT, int_nvvm_vote_ballot>; + + // vote.sync.{all,any,uni,ballot} + multiclass VOTE_SYNC { + def i : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, i32imm:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op imm:$mask, i1:$pred))]>; + def r : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, B32:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i32:$mask, i1:$pred))]>; + } -// vote.sync.{all,any,uni,ballot} -multiclass VOTE_SYNC { - def i : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, i32imm:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp imm:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; - def r : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, B32:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp i32:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; + defm VOTE_SYNC_ALL : VOTE_SYNC<"all", I1RT, int_nvvm_vote_all_sync>; + defm VOTE_SYNC_ANY : VOTE_SYNC<"any", I1RT, int_nvvm_vote_any_sync>; + defm VOTE_SYNC_UNI : VOTE_SYNC<"uni", I1RT, int_nvvm_vote_uni_sync>; + defm VOTE_SYNC_BALLOT : VOTE_SYNC<"ballot", I32RT, int_nvvm_vote_ballot_sync>; } - -defm VOTE_SYNC_ALL : VOTE_SYNC; -defm VOTE_SYNC_ANY : VOTE_SYNC; -defm VOTE_SYNC_UNI : VOTE_SYNC; -defm VOTE_SYNC_BALLOT : VOTE_SYNC; - // elect.sync +let Predicates = [hasPTX<80>, hasSM<90>] in { def INT_ELECT_SYNC_I : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins i32imm:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>; def INT_ELECT_SYNC_R : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins B32:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>; +} + +let Predicates = [hasPTX<60>, hasSM<70>] in { + multiclass MATCH_ANY_SYNC { + def ii : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, t.Ty:$value))]>; + } -multiclass MATCH_ANY_SYNC { - def ii : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; + defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC; + defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC; + + multiclass MATCH_ALLP_SYNC { + def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, t.Ty:$value))]>; + } + defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC; + defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC; } // activemask.b32 @@ -308,39 +327,6 @@ def ACTIVEMASK : BasicNVPTXInst<(outs B32:$dest), (ins), [(set i32:$dest, (int_nvvm_activemask))]>, Requires<[hasPTX<62>, hasSM<30>]>; -defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC; -defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC; - -multiclass MATCH_ALLP_SYNC { - def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; -} -defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC; -defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC; - multiclass REDUX_SYNC { def : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src, B32:$mask), "redux.sync." # BinOp # "." # PTXType, @@ -381,24 +367,20 @@ defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">; //----------------------------------- // Explicit Memory Fence Functions //----------------------------------- -class MEMBAR : - BasicNVPTXInst<(outs), (ins), - StrOp, [(IntOP)]>; +class NullaryInst : + BasicNVPTXInst<(outs), (ins), StrOp, [(IntOP)]>; -def INT_MEMBAR_CTA : MEMBAR<"membar.cta", int_nvvm_membar_cta>; -def INT_MEMBAR_GL : MEMBAR<"membar.gl", int_nvvm_membar_gl>; -def INT_MEMBAR_SYS : MEMBAR<"membar.sys", int_nvvm_membar_sys>; +def INT_MEMBAR_CTA : NullaryInst<"membar.cta", int_nvvm_membar_cta>; +def INT_MEMBAR_GL : NullaryInst<"membar.gl", int_nvvm_membar_gl>; +def INT_MEMBAR_SYS : NullaryInst<"membar.sys", int_nvvm_membar_sys>; def INT_FENCE_SC_CLUSTER: - MEMBAR<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, + NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, Requires<[hasPTX<78>, hasSM<90>]>; // Proxy fence (uni-directional) -// fence.proxy.tensormap.release variants - class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE : - BasicNVPTXInst<(outs), (ins), - "fence.proxy.tensormap::generic.release." # Scope, [(Intr)]>, + NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>, Requires<[hasPTX<83>, hasSM<90>]>; def INT_FENCE_PROXY_TENSORMAP_GENERIC_RELEASE_CTA: @@ -488,35 +470,31 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 : CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16, int_nvvm_cp_async_cg_shared_global_16_s>; -def CP_ASYNC_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.commit_group", [(int_nvvm_cp_async_commit_group)]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + def CP_ASYNC_COMMIT_GROUP : + NullaryInst<"cp.async.commit_group", int_nvvm_cp_async_commit_group>; -def CP_ASYNC_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", - [(int_nvvm_cp_async_wait_group timm:$n)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", + [(int_nvvm_cp_async_wait_group timm:$n)]>; -def CP_ASYNC_WAIT_ALL : - BasicNVPTXInst<(outs), (ins), "cp.async.wait_all", - [(int_nvvm_cp_async_wait_all)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_ALL : + NullaryInst<"cp.async.wait_all", int_nvvm_cp_async_wait_all>; +} -// cp.async.bulk variants of the commit/wait group -def CP_ASYNC_BULK_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.bulk.commit_group", - [(int_nvvm_cp_async_bulk_commit_group)]>, - Requires<[hasPTX<80>, hasSM<90>]>; +let Predicates = [hasPTX<80>, hasSM<90>] in { + // cp.async.bulk variants of the commit/wait group + def CP_ASYNC_BULK_COMMIT_GROUP : + NullaryInst<"cp.async.bulk.commit_group", int_nvvm_cp_async_bulk_commit_group>; -def CP_ASYNC_BULK_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", - [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", + [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>; -def CP_ASYNC_BULK_WAIT_GROUP_READ : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", - [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP_READ : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", + [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>; +} //------------------------------ // TMA Async Bulk Copy Functions @@ -974,33 +952,30 @@ defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4", //Prefetch and Prefetchu -class PREFETCH_INTRS : - BasicNVPTXInst<(outs), (ins ADDR:$addr), - InstName, - [(!cast(!strconcat("int_nvvm_", - !subst(".", "_", InstName))) addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; - +let Predicates = [hasPTX<80>, hasSM<90>] in { + class PREFETCH_INTRS : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + InstName, + [(!cast(!strconcat("int_nvvm_", + !subst(".", "_", InstName))) addr:$addr)]>; -def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; -def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; -def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; -def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; -def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; -def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; + def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; + def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; + def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; + def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; + def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; + def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; -def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_normal", - [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_normal", + [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>; -def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_last", - [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_last", + [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>; - -def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; + def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; +} //Applypriority intrinsics class APPLYPRIORITY_L2_INTRS : @@ -1031,99 +1006,82 @@ def DISCARD_GLOBAL_L2 : DISCARD_L2_INTRS<"global">; // MBarrier Functions //----------------------------------- -multiclass MBARRIER_INIT { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), - "mbarrier.init" # AddrSpace # ".b64", - [(Intrin addr:$addr, i32:$count)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; -defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", - int_nvvm_mbarrier_init_shared>; - -multiclass MBARRIER_INVAL { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "mbarrier.inval" # AddrSpace # ".b64", - [(Intrin addr:$addr)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; -defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", - int_nvvm_mbarrier_inval_shared>; - -multiclass MBARRIER_ARRIVE { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; -defm MBARRIER_ARRIVE_SHARED : - MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; - -multiclass MBARRIER_ARRIVE_NOCOMPLETE { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_NOCOMPLETE : - MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; -defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; - -multiclass MBARRIER_ARRIVE_DROP { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive_drop" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP : - MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; -defm MBARRIER_ARRIVE_DROP_SHARED : - MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; - -multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", - int_nvvm_mbarrier_arrive_drop_noComplete_shared>; - -multiclass MBARRIER_TEST_WAIT { - def "" : BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), - "mbarrier.test_wait" # AddrSpace # ".b64", - [(set i1:$res, (Intrin addr:$addr, i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + class MBARRIER_INIT : + BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), + "mbarrier.init" # AddrSpace # ".b64", + [(Intrin addr:$addr, i32:$count)]>; + + def MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; + def MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", + int_nvvm_mbarrier_init_shared>; + + class MBARRIER_INVAL : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + "mbarrier.inval" # AddrSpace # ".b64", + [(Intrin addr:$addr)]>; + + def MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; + def MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", + int_nvvm_mbarrier_inval_shared>; + + class MBARRIER_ARRIVE : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; + def MBARRIER_ARRIVE_SHARED : + MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; + + class MBARRIER_ARRIVE_NOCOMPLETE : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_NOCOMPLETE : + MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; + def MBARRIER_ARRIVE_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; + + class MBARRIER_ARRIVE_DROP : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive_drop" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE_DROP : + MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; + def MBARRIER_ARRIVE_DROP_SHARED : + MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; + + class MBARRIER_ARRIVE_DROP_NOCOMPLETE : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_DROP_NOCOMPLETE : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; + def MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", + int_nvvm_mbarrier_arrive_drop_noComplete_shared>; + + class MBARRIER_TEST_WAIT : + BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), + "mbarrier.test_wait" # AddrSpace # ".b64", + [(set i1:$res, (Intrin addr:$addr, i64:$state))]>; + + def MBARRIER_TEST_WAIT : + MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; + def MBARRIER_TEST_WAIT_SHARED : + MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; + + def MBARRIER_PENDING_COUNT : + BasicNVPTXInst<(outs B32:$res), (ins B64:$state), + "mbarrier.pending_count.b64", + [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>; } - -defm MBARRIER_TEST_WAIT : - MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; -defm MBARRIER_TEST_WAIT_SHARED : - MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; - -class MBARRIER_PENDING_COUNT : - BasicNVPTXInst<(outs B32:$res), (ins B64:$state), - "mbarrier.pending_count.b64", - [(set i32:$res, (Intrin i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; - -def MBARRIER_PENDING_COUNT : - MBARRIER_PENDING_COUNT; - //----------------------------------- // Math Functions //----------------------------------- @@ -1449,15 +1407,11 @@ defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>; def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>; -def COPYSIGN_F : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$src0, B32:$src1), - "copysign.f32", - [(set f32:$dst, (fcopysign_nvptx f32:$src1, f32:$src0))]>; - -def COPYSIGN_D : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$src0, B64:$src1), - "copysign.f64", - [(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>; +foreach t = [F32RT, F64RT] in + def COPYSIGN_ # t : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src0, t.RC:$src1), + "copysign." # t.PtxType, + [(set t.Ty:$dst, (fcopysign_nvptx t.Ty:$src1, t.Ty:$src0))]>; // // Neg bf16, bf16x2 @@ -2255,38 +2209,35 @@ defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">; // Scalar -class LDU_G - : NVPTXInst<(outs regclass:$result), (ins ADDR:$src), - "ldu.global." # TyStr # " \t$result, [$src];", []>; +class LDU_G + : NVPTXInst<(outs regclass:$result), (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.b$fromWidth \t$result, [$src];", []>; -def LDU_GLOBAL_i8 : LDU_G<"b8", B16>; -def LDU_GLOBAL_i16 : LDU_G<"b16", B16>; -def LDU_GLOBAL_i32 : LDU_G<"b32", B32>; -def LDU_GLOBAL_i64 : LDU_G<"b64", B64>; +def LDU_GLOBAL_i16 : LDU_G; +def LDU_GLOBAL_i32 : LDU_G; +def LDU_GLOBAL_i64 : LDU_G; // vector // Elementized vector ldu -class VLDU_G_ELE_V2 +class VLDU_G_ELE_V2 : NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins ADDR:$src), - "ldu.global.v2." # TyStr # " \t{{$dst1, $dst2}}, [$src];", []>; + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v2.b$fromWidth \t{{$dst1, $dst2}}, [$src];", []>; -class VLDU_G_ELE_V4 - : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins ADDR:$src), - "ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; +class VLDU_G_ELE_V4 + : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v4.b$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; -def LDU_GLOBAL_v2i8 : VLDU_G_ELE_V2<"b8", B16>; -def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<"b16", B16>; -def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<"b32", B32>; -def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<"b64", B64>; +def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2; +def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2; +def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2; -def LDU_GLOBAL_v4i8 : VLDU_G_ELE_V4<"b8", B16>; -def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<"b16", B16>; -def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<"b32", B32>; +def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4; +def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4; //----------------------------------- @@ -2327,12 +2278,10 @@ class VLDG_G_ELE_V8 : "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>; // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads. -def LD_GLOBAL_NC_v2i8 : VLDG_G_ELE_V2; def LD_GLOBAL_NC_v2i16 : VLDG_G_ELE_V2; def LD_GLOBAL_NC_v2i32 : VLDG_G_ELE_V2; def LD_GLOBAL_NC_v2i64 : VLDG_G_ELE_V2; -def LD_GLOBAL_NC_v4i8 : VLDG_G_ELE_V4; def LD_GLOBAL_NC_v4i16 : VLDG_G_ELE_V4; def LD_GLOBAL_NC_v4i32 : VLDG_G_ELE_V4; @@ -2342,19 +2291,19 @@ def LD_GLOBAL_NC_v8i32 : VLDG_G_ELE_V8; multiclass NG_TO_G Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta." # Str # ".u32", []>, Requires; + "cvta." # Str # ".u32">, Requires; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta." # Str # ".u64", []>, Requires; + "cvta." # Str # ".u64">, Requires; } multiclass G_TO_NG Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta.to." # Str # ".u32", []>, Requires; + "cvta.to." # Str # ".u32">, Requires; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta.to." # Str # ".u64", []>, Requires; + "cvta.to." # Str # ".u64">, Requires; } foreach space = ["local", "shared", "global", "const", "param"] in { @@ -4614,9 +4563,9 @@ def INT_PTX_SREG_LANEMASK_GT : PTX_READ_SREG_R32<"lanemask_gt", int_nvvm_read_ptx_sreg_lanemask_gt>; let hasSideEffects = 1 in { -def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; -def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; -def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; + def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; + def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; + def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; } def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>; @@ -5096,37 +5045,36 @@ foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT; multiclass MAPA { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, i32:$b))]>; + def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, imm:$b))]>; + def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, i32:$b))]>; + def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, imm:$b))]>; + } } + defm mapa : MAPA<"", int_nvvm_mapa>; defm mapa_shared_cluster : MAPA<".shared::cluster", int_nvvm_mapa_shared_cluster>; multiclass GETCTARANK { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), - "getctarank" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), - "getctarank" # suffix # ".u64", - [(set i32:$d, (Intr i64:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), + "getctarank" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a))]>; + def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), + "getctarank" # suffix # ".u64", + [(set i32:$d, (Intr i64:$a))]>; + } } defm getctarank : GETCTARANK<"", int_nvvm_getctarank>; @@ -5165,29 +5113,25 @@ def INT_NVVM_WGMMA_WAIT_GROUP_SYNC_ALIGNED : BasicNVPTXInst<(outs), (ins i64imm: [(int_nvvm_wgmma_wait_group_sync_aligned timm:$n)]>, Requires<[hasSM90a, hasPTX<80>]>; } // isConvergent = true -def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.launch_dependents", - [(int_nvvm_griddepcontrol_launch_dependents)]>, - Requires<[hasSM<90>, hasPTX<78>]>; - -def GRIDDEPCONTROL_WAIT : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.wait", - [(int_nvvm_griddepcontrol_wait)]>, - Requires<[hasSM<90>, hasPTX<78>]>; +let Predicates = [hasSM<90>, hasPTX<78>] in { + def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.launch_dependents", + [(int_nvvm_griddepcontrol_launch_dependents)]>; + def GRIDDEPCONTROL_WAIT : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.wait", + [(int_nvvm_griddepcontrol_wait)]>; +} def INT_EXIT : BasicNVPTXInst<(outs), (ins), "exit", [(int_nvvm_exit)]>; // Tcgen05 intrinsics -let isConvergent = true in { +let isConvergent = true, Predicates = [hasTcgen05Instructions] in { multiclass TCGEN05_ALLOC_INTR { def "" : BasicNVPTXInst<(outs), (ins ADDR:$dst, B32:$ncols), "tcgen05.alloc.cta_group::" # num # ".sync.aligned" # AS # ".b32", - [(Intr addr:$dst, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$dst, B32:$ncols)]>; } defm TCGEN05_ALLOC_CG1 : TCGEN05_ALLOC_INTR<"", "1", int_nvvm_tcgen05_alloc_cg1>; @@ -5200,8 +5144,7 @@ multiclass TCGEN05_DEALLOC_INTR { def "" : BasicNVPTXInst<(outs), (ins B32:$tmem_addr, B32:$ncols), "tcgen05.dealloc.cta_group::" # num # ".sync.aligned.b32", - [(Intr B32:$tmem_addr, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr B32:$tmem_addr, B32:$ncols)]>; } defm TCGEN05_DEALLOC_CG1: TCGEN05_DEALLOC_INTR<"1", int_nvvm_tcgen05_dealloc_cg1>; defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2>; @@ -5209,19 +5152,13 @@ defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2 multiclass TCGEN05_RELINQ_PERMIT_INTR { def "" : BasicNVPTXInst<(outs), (ins), "tcgen05.relinquish_alloc_permit.cta_group::" # num # ".sync.aligned", - [(Intr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr)]>; } defm TCGEN05_RELINQ_CG1: TCGEN05_RELINQ_PERMIT_INTR<"1", int_nvvm_tcgen05_relinq_alloc_permit_cg1>; defm TCGEN05_RELINQ_CG2: TCGEN05_RELINQ_PERMIT_INTR<"2", int_nvvm_tcgen05_relinq_alloc_permit_cg2>; -def tcgen05_wait_ld: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::ld.sync.aligned", - [(int_nvvm_tcgen05_wait_ld)]>, - Requires<[hasTcgen05Instructions]>; - -def tcgen05_wait_st: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::st.sync.aligned", - [(int_nvvm_tcgen05_wait_st)]>, - Requires<[hasTcgen05Instructions]>; +def tcgen05_wait_ld: NullaryInst<"tcgen05.wait::ld.sync.aligned", int_nvvm_tcgen05_wait_ld>; +def tcgen05_wait_st: NullaryInst<"tcgen05.wait::st.sync.aligned", int_nvvm_tcgen05_wait_st>; multiclass TCGEN05_COMMIT_INTR { defvar prefix = "tcgen05.commit.cta_group::" # num #".mbarrier::arrive::one.shared::cluster"; @@ -5232,12 +5169,10 @@ multiclass TCGEN05_COMMIT_INTR { def "" : BasicNVPTXInst<(outs), (ins ADDR:$mbar), prefix # ".b64", - [(Intr addr:$mbar)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$mbar)]>; def _MC : BasicNVPTXInst<(outs), (ins ADDR:$mbar, B16:$mc), prefix # ".multicast::cluster.b64", - [(IntrMC addr:$mbar, B16:$mc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrMC addr:$mbar, B16:$mc)]>; } defm TCGEN05_COMMIT_CG1 : TCGEN05_COMMIT_INTR<"", "1">; @@ -5249,8 +5184,7 @@ multiclass TCGEN05_SHIFT_INTR { def "" : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr), "tcgen05.shift.cta_group::" # num # ".down", - [(Intr addr:$tmem_addr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$tmem_addr)]>; } defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>; defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>; @@ -5270,13 +5204,11 @@ multiclass TCGEN05_CP_INTR { def _cg1 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm, - [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>; def _cg2 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm, - [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>; } foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { @@ -5289,17 +5221,13 @@ foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { } } // isConvergent -let hasSideEffects = 1 in { +let hasSideEffects = 1, Predicates = [hasTcgen05Instructions] in { -def tcgen05_fence_before_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::before_thread_sync", - [(int_nvvm_tcgen05_fence_before_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_before_thread_sync: NullaryInst< + "tcgen05.fence::before_thread_sync", int_nvvm_tcgen05_fence_before_thread_sync>; -def tcgen05_fence_after_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::after_thread_sync", - [(int_nvvm_tcgen05_fence_after_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_after_thread_sync: NullaryInst< + "tcgen05.fence::after_thread_sync", int_nvvm_tcgen05_fence_after_thread_sync>; } // hasSideEffects @@ -5392,17 +5320,17 @@ foreach shape = ["16x64b", "16x128b", "16x256b", "32x32b", "16x32bx2"] in { // Bulk store instructions def st_bulk_imm : TImmLeaf; -def INT_NVVM_ST_BULK_GENERIC : - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk", - [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; +let Predicates = [hasSM<100>, hasPTX<86>] in { + def INT_NVVM_ST_BULK_GENERIC : + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk", + [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; -def INT_NVVM_ST_BULK_SHARED_CTA: - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk.shared::cta", - [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; + def INT_NVVM_ST_BULK_SHARED_CTA: + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk.shared::cta", + [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; +} // // clusterlaunchcontorl Instructions diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td index d40886a56d6a4..2e81ab122d1df 100644 --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -38,14 +38,6 @@ foreach i = 0...4 in { def R#i : NVPTXReg<"%r"#i>; // 32-bit def RL#i : NVPTXReg<"%rd"#i>; // 64-bit def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit - def H#i : NVPTXReg<"%h"#i>; // 16-bit float - def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float - - // Arguments - def ia#i : NVPTXReg<"%ia"#i>; - def la#i : NVPTXReg<"%la"#i>; - def fa#i : NVPTXReg<"%fa"#i>; - def da#i : NVPTXReg<"%da"#i>; } foreach i = 0...31 in {