Skip to content

Commit 5a80274

Browse files
authored
[RISCV] Reuse lowerToScalableOp for more nodes. NFC (#151911)
A lot of fixed-length custom lowerings just involve inserting the operands into a scalable container and extracting the result out, and lowerToScalableOp already does this. We just need to teach it to handle operands with different element types (but same vector element count), and we can reuse it for vselect/zext/sext/setcc/fcopysign.
1 parent afbabb1 commit 5a80274

File tree

2 files changed

+22
-109
lines changed

2 files changed

+22
-109
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 22 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -7012,6 +7012,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
70127012
OP_CASE(FDIV)
70137013
OP_CASE(FNEG)
70147014
OP_CASE(FABS)
7015+
OP_CASE(FCOPYSIGN)
70157016
OP_CASE(FSQRT)
70167017
OP_CASE(SMIN)
70177018
OP_CASE(SMAX)
@@ -7079,6 +7080,15 @@ static unsigned getRISCVVLOp(SDValue Op) {
70797080
if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
70807081
return RISCVISD::VMXOR_VL;
70817082
return RISCVISD::XOR_VL;
7083+
case ISD::ANY_EXTEND:
7084+
case ISD::ZERO_EXTEND:
7085+
return RISCVISD::VZEXT_VL;
7086+
case ISD::SIGN_EXTEND:
7087+
return RISCVISD::VSEXT_VL;
7088+
case ISD::SETCC:
7089+
return RISCVISD::SETCC_VL;
7090+
case ISD::VSELECT:
7091+
return RISCVISD::VMERGE_VL;
70827092
case ISD::VP_SELECT:
70837093
case ISD::VP_MERGE:
70847094
return RISCVISD::VMERGE_VL;
@@ -7419,12 +7429,16 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
74197429
if (Op.getOperand(0).getValueType().isVector() &&
74207430
Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
74217431
return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ 1);
7422-
return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VZEXT_VL);
7432+
if (Op.getValueType().isScalableVector())
7433+
return Op;
7434+
return lowerToScalableOp(Op, DAG);
74237435
case ISD::SIGN_EXTEND:
74247436
if (Op.getOperand(0).getValueType().isVector() &&
74257437
Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
74267438
return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
7427-
return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
7439+
if (Op.getValueType().isScalableVector())
7440+
return Op;
7441+
return lowerToScalableOp(Op, DAG);
74287442
case ISD::SPLAT_VECTOR_PARTS:
74297443
return lowerSPLAT_VECTOR_PARTS(Op, DAG);
74307444
case ISD::INSERT_VECTOR_ELT:
@@ -8166,7 +8180,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
81668180
if (isPromotedOpNeedingSplit(Op.getOperand(0), Subtarget))
81678181
return SplitVectorOp(Op, DAG);
81688182

8169-
return lowerFixedLengthVectorSetccToRVV(Op, DAG);
8183+
return lowerToScalableOp(Op, DAG);
81708184
}
81718185
case ISD::ADD:
81728186
case ISD::SUB:
@@ -8182,6 +8196,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
81828196
case ISD::UREM:
81838197
case ISD::BSWAP:
81848198
case ISD::CTPOP:
8199+
case ISD::VSELECT:
81858200
return lowerToScalableOp(Op, DAG);
81868201
case ISD::SHL:
81878202
case ISD::SRA:
@@ -8250,14 +8265,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
82508265
return lowerToScalableOp(Op, DAG);
82518266
assert(Op.getOpcode() != ISD::CTTZ);
82528267
return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
8253-
case ISD::VSELECT:
8254-
return lowerFixedLengthVectorSelectToRVV(Op, DAG);
82558268
case ISD::FCOPYSIGN:
82568269
if (Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16)
82578270
return lowerFCOPYSIGN(Op, DAG, Subtarget);
82588271
if (isPromotedOpNeedingSplit(Op, Subtarget))
82598272
return SplitVectorOp(Op, DAG);
8260-
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
8273+
return lowerToScalableOp(Op, DAG);
82618274
case ISD::STRICT_FADD:
82628275
case ISD::STRICT_FSUB:
82638276
case ISD::STRICT_FMUL:
@@ -9694,33 +9707,6 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
96949707
return convertFromScalableVector(VecVT, Select, DAG, Subtarget);
96959708
}
96969709

9697-
SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV(
9698-
SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const {
9699-
MVT ExtVT = Op.getSimpleValueType();
9700-
// Only custom-lower extensions from fixed-length vector types.
9701-
if (!ExtVT.isFixedLengthVector())
9702-
return Op;
9703-
MVT VT = Op.getOperand(0).getSimpleValueType();
9704-
// Grab the canonical container type for the extended type. Infer the smaller
9705-
// type from that to ensure the same number of vector elements, as we know
9706-
// the LMUL will be sufficient to hold the smaller type.
9707-
MVT ContainerExtVT = getContainerForFixedLengthVector(ExtVT);
9708-
// Get the extended container type manually to ensure the same number of
9709-
// vector elements between source and dest.
9710-
MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(),
9711-
ContainerExtVT.getVectorElementCount());
9712-
9713-
SDValue Op1 =
9714-
convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
9715-
9716-
SDLoc DL(Op);
9717-
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
9718-
9719-
SDValue Ext = DAG.getNode(ExtendOpc, DL, ContainerExtVT, Op1, Mask, VL);
9720-
9721-
return convertFromScalableVector(ExtVT, Ext, DAG, Subtarget);
9722-
}
9723-
97249710
// Custom-lower truncations from vectors to mask vectors by using a mask and a
97259711
// setcc operation:
97269712
// (vXi1 = trunc vXiN vec) -> (vXi1 = setcc (and vec, 1), 0, ne)
@@ -12834,31 +12820,6 @@ SDValue RISCVTargetLowering::lowerVectorCompress(SDValue Op,
1283412820
return Res;
1283512821
}
1283612822

12837-
SDValue
12838-
RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
12839-
SelectionDAG &DAG) const {
12840-
MVT InVT = Op.getOperand(0).getSimpleValueType();
12841-
MVT ContainerVT = getContainerForFixedLengthVector(InVT);
12842-
12843-
MVT VT = Op.getSimpleValueType();
12844-
12845-
SDValue Op1 =
12846-
convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
12847-
SDValue Op2 =
12848-
convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
12849-
12850-
SDLoc DL(Op);
12851-
auto [Mask, VL] = getDefaultVLOps(VT.getVectorNumElements(), ContainerVT, DL,
12852-
DAG, Subtarget);
12853-
MVT MaskVT = getMaskTypeFor(ContainerVT);
12854-
12855-
SDValue Cmp =
12856-
DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
12857-
{Op1, Op2, Op.getOperand(2), DAG.getUNDEF(MaskVT), Mask, VL});
12858-
12859-
return convertFromScalableVector(VT, Cmp, DAG, Subtarget);
12860-
}
12861-
1286212823
SDValue RISCVTargetLowering::lowerVectorStrictFSetcc(SDValue Op,
1286312824
SelectionDAG &DAG) const {
1286412825
unsigned Opc = Op.getOpcode();
@@ -12985,51 +12946,6 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
1298512946
return Max;
1298612947
}
1298712948

12988-
SDValue RISCVTargetLowering::lowerFixedLengthVectorFCOPYSIGNToRVV(
12989-
SDValue Op, SelectionDAG &DAG) const {
12990-
SDLoc DL(Op);
12991-
MVT VT = Op.getSimpleValueType();
12992-
SDValue Mag = Op.getOperand(0);
12993-
SDValue Sign = Op.getOperand(1);
12994-
assert(Mag.getValueType() == Sign.getValueType() &&
12995-
"Can only handle COPYSIGN with matching types.");
12996-
12997-
MVT ContainerVT = getContainerForFixedLengthVector(VT);
12998-
Mag = convertToScalableVector(ContainerVT, Mag, DAG, Subtarget);
12999-
Sign = convertToScalableVector(ContainerVT, Sign, DAG, Subtarget);
13000-
13001-
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
13002-
13003-
SDValue CopySign = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Mag,
13004-
Sign, DAG.getUNDEF(ContainerVT), Mask, VL);
13005-
13006-
return convertFromScalableVector(VT, CopySign, DAG, Subtarget);
13007-
}
13008-
13009-
SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
13010-
SDValue Op, SelectionDAG &DAG) const {
13011-
MVT VT = Op.getSimpleValueType();
13012-
MVT ContainerVT = getContainerForFixedLengthVector(VT);
13013-
13014-
MVT I1ContainerVT =
13015-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
13016-
13017-
SDValue CC =
13018-
convertToScalableVector(I1ContainerVT, Op.getOperand(0), DAG, Subtarget);
13019-
SDValue Op1 =
13020-
convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
13021-
SDValue Op2 =
13022-
convertToScalableVector(ContainerVT, Op.getOperand(2), DAG, Subtarget);
13023-
13024-
SDLoc DL(Op);
13025-
SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
13026-
13027-
SDValue Select = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, CC, Op1,
13028-
Op2, DAG.getUNDEF(ContainerVT), VL);
13029-
13030-
return convertFromScalableVector(VT, Select, DAG, Subtarget);
13031-
}
13032-
1303312949
SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
1303412950
SelectionDAG &DAG) const {
1303512951
const auto &TSInfo =
@@ -13056,7 +12972,9 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
1305612972
// "cast" fixed length vector to a scalable vector.
1305712973
assert(useRVVForFixedLengthVectorVT(V.getSimpleValueType()) &&
1305812974
"Only fixed length vectors are supported!");
13059-
Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
12975+
MVT VContainerVT = ContainerVT.changeVectorElementType(
12976+
V.getSimpleValueType().getVectorElementType());
12977+
Ops.push_back(convertToScalableVector(VContainerVT, V, DAG, Subtarget));
1306012978
}
1306112979

1306212980
SDLoc DL(Op);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,6 @@ class RISCVTargetLowering : public TargetLowering {
534534
SDValue lowerMaskedScatter(SDValue Op, SelectionDAG &DAG) const;
535535
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
536536
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
537-
SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
538-
SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
539-
SelectionDAG &DAG) const;
540537
SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
541538
SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
542539
SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const;
@@ -551,8 +548,6 @@ class RISCVTargetLowering : public TargetLowering {
551548
SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
552549
SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;
553550
SDValue lowerVPCttzElements(SDValue Op, SelectionDAG &DAG) const;
554-
SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
555-
unsigned ExtendOpc) const;
556551
SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
557552
SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
558553
SDValue lowerGET_FPENV(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)