-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[RISCV] Reuse lowerToScalableOp for more nodes. NFC #151911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
A lot of fixed-length custom lowerings just involve inserting the operands into a scalable container and extracting the result out, and lowerToScalableOp already does this. We just need to teach it to handle operands with different element types (but same vector element count), and we can reuse it for vselect/zext/sext/setcc/fcopysign.
@llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) ChangesA lot of fixed-length custom lowerings just involve inserting the operands into a scalable container and extracting the result out, and lowerToScalableOp already does this. We just need to teach it to handle operands with different element types (but same vector element count), and we can reuse it for vselect/zext/sext/setcc/fcopysign. Full diff: https://github.com/llvm/llvm-project/pull/151911.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c0ada51ef4403..a9ad0d6f403a2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -7005,6 +7005,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
OP_CASE(FDIV)
OP_CASE(FNEG)
OP_CASE(FABS)
+ OP_CASE(FCOPYSIGN)
OP_CASE(FSQRT)
OP_CASE(SMIN)
OP_CASE(SMAX)
@@ -7072,6 +7073,15 @@ static unsigned getRISCVVLOp(SDValue Op) {
if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
return RISCVISD::VMXOR_VL;
return RISCVISD::XOR_VL;
+ case ISD::ANY_EXTEND:
+ case ISD::ZERO_EXTEND:
+ return RISCVISD::VZEXT_VL;
+ case ISD::SIGN_EXTEND:
+ return RISCVISD::VSEXT_VL;
+ case ISD::SETCC:
+ return RISCVISD::SETCC_VL;
+ case ISD::VSELECT:
+ return RISCVISD::VMERGE_VL;
case ISD::VP_SELECT:
case ISD::VP_MERGE:
return RISCVISD::VMERGE_VL;
@@ -7412,12 +7422,16 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
if (Op.getOperand(0).getValueType().isVector() &&
Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ 1);
- return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VZEXT_VL);
+ if (Op.getValueType().isScalableVector())
+ return Op;
+ return lowerToScalableOp(Op, DAG);
case ISD::SIGN_EXTEND:
if (Op.getOperand(0).getValueType().isVector() &&
Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
- return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
+ if (Op.getValueType().isScalableVector())
+ return Op;
+ return lowerToScalableOp(Op, DAG);
case ISD::SPLAT_VECTOR_PARTS:
return lowerSPLAT_VECTOR_PARTS(Op, DAG);
case ISD::INSERT_VECTOR_ELT:
@@ -8159,7 +8173,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
if (isPromotedOpNeedingSplit(Op.getOperand(0), Subtarget))
return SplitVectorOp(Op, DAG);
- return lowerFixedLengthVectorSetccToRVV(Op, DAG);
+ return lowerToScalableOp(Op, DAG);
}
case ISD::ADD:
case ISD::SUB:
@@ -8175,6 +8189,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::UREM:
case ISD::BSWAP:
case ISD::CTPOP:
+ case ISD::VSELECT:
return lowerToScalableOp(Op, DAG);
case ISD::SHL:
case ISD::SRA:
@@ -8243,14 +8258,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerToScalableOp(Op, DAG);
assert(Op.getOpcode() != ISD::CTTZ);
return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
- case ISD::VSELECT:
- return lowerFixedLengthVectorSelectToRVV(Op, DAG);
case ISD::FCOPYSIGN:
if (Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16)
return lowerFCOPYSIGN(Op, DAG, Subtarget);
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVectorOp(Op, DAG);
- return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
+ return lowerToScalableOp(Op, DAG);
case ISD::STRICT_FADD:
case ISD::STRICT_FSUB:
case ISD::STRICT_FMUL:
@@ -9687,33 +9700,6 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
return convertFromScalableVector(VecVT, Select, DAG, Subtarget);
}
-SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV(
- SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const {
- MVT ExtVT = Op.getSimpleValueType();
- // Only custom-lower extensions from fixed-length vector types.
- if (!ExtVT.isFixedLengthVector())
- return Op;
- MVT VT = Op.getOperand(0).getSimpleValueType();
- // Grab the canonical container type for the extended type. Infer the smaller
- // type from that to ensure the same number of vector elements, as we know
- // the LMUL will be sufficient to hold the smaller type.
- MVT ContainerExtVT = getContainerForFixedLengthVector(ExtVT);
- // Get the extended container type manually to ensure the same number of
- // vector elements between source and dest.
- MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(),
- ContainerExtVT.getVectorElementCount());
-
- SDValue Op1 =
- convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
-
- SDLoc DL(Op);
- auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-
- SDValue Ext = DAG.getNode(ExtendOpc, DL, ContainerExtVT, Op1, Mask, VL);
-
- return convertFromScalableVector(ExtVT, Ext, DAG, Subtarget);
-}
-
// Custom-lower truncations from vectors to mask vectors by using a mask and a
// setcc operation:
// (vXi1 = trunc vXiN vec) -> (vXi1 = setcc (and vec, 1), 0, ne)
@@ -12777,31 +12763,6 @@ SDValue RISCVTargetLowering::lowerVectorCompress(SDValue Op,
return Res;
}
-SDValue
-RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
- SelectionDAG &DAG) const {
- MVT InVT = Op.getOperand(0).getSimpleValueType();
- MVT ContainerVT = getContainerForFixedLengthVector(InVT);
-
- MVT VT = Op.getSimpleValueType();
-
- SDValue Op1 =
- convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
- SDValue Op2 =
- convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
-
- SDLoc DL(Op);
- auto [Mask, VL] = getDefaultVLOps(VT.getVectorNumElements(), ContainerVT, DL,
- DAG, Subtarget);
- MVT MaskVT = getMaskTypeFor(ContainerVT);
-
- SDValue Cmp =
- DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
- {Op1, Op2, Op.getOperand(2), DAG.getUNDEF(MaskVT), Mask, VL});
-
- return convertFromScalableVector(VT, Cmp, DAG, Subtarget);
-}
-
SDValue RISCVTargetLowering::lowerVectorStrictFSetcc(SDValue Op,
SelectionDAG &DAG) const {
unsigned Opc = Op.getOpcode();
@@ -12928,51 +12889,6 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
return Max;
}
-SDValue RISCVTargetLowering::lowerFixedLengthVectorFCOPYSIGNToRVV(
- SDValue Op, SelectionDAG &DAG) const {
- SDLoc DL(Op);
- MVT VT = Op.getSimpleValueType();
- SDValue Mag = Op.getOperand(0);
- SDValue Sign = Op.getOperand(1);
- assert(Mag.getValueType() == Sign.getValueType() &&
- "Can only handle COPYSIGN with matching types.");
-
- MVT ContainerVT = getContainerForFixedLengthVector(VT);
- Mag = convertToScalableVector(ContainerVT, Mag, DAG, Subtarget);
- Sign = convertToScalableVector(ContainerVT, Sign, DAG, Subtarget);
-
- auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-
- SDValue CopySign = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Mag,
- Sign, DAG.getUNDEF(ContainerVT), Mask, VL);
-
- return convertFromScalableVector(VT, CopySign, DAG, Subtarget);
-}
-
-SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
- SDValue Op, SelectionDAG &DAG) const {
- MVT VT = Op.getSimpleValueType();
- MVT ContainerVT = getContainerForFixedLengthVector(VT);
-
- MVT I1ContainerVT =
- MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
-
- SDValue CC =
- convertToScalableVector(I1ContainerVT, Op.getOperand(0), DAG, Subtarget);
- SDValue Op1 =
- convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
- SDValue Op2 =
- convertToScalableVector(ContainerVT, Op.getOperand(2), DAG, Subtarget);
-
- SDLoc DL(Op);
- SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
-
- SDValue Select = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, CC, Op1,
- Op2, DAG.getUNDEF(ContainerVT), VL);
-
- return convertFromScalableVector(VT, Select, DAG, Subtarget);
-}
-
SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
SelectionDAG &DAG) const {
const auto &TSInfo =
@@ -12999,7 +12915,9 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
// "cast" fixed length vector to a scalable vector.
assert(useRVVForFixedLengthVectorVT(V.getSimpleValueType()) &&
"Only fixed length vectors are supported!");
- Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
+ MVT VContainerVT = ContainerVT.changeVectorElementType(
+ V.getSimpleValueType().getVectorElementType());
+ Ops.push_back(convertToScalableVector(VContainerVT, V, DAG, Subtarget));
}
SDLoc DL(Op);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ca70c46988b4e..fa50e2105a708 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -534,9 +534,6 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerMaskedScatter(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
- SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
- SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
- SelectionDAG &DAG) const;
SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const;
@@ -551,8 +548,6 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPCttzElements(SDValue Op, SelectionDAG &DAG) const;
- SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
- unsigned ExtendOpc) const;
SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerGET_FPENV(SDValue Op, SelectionDAG &DAG) const;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
A lot of fixed-length custom lowerings just involve inserting the operands into a scalable container and extracting the result out, and lowerToScalableOp already does this.
We just need to teach it to handle operands with different element types (but same vector element count), and we can reuse it for vselect/zext/sext/setcc/fcopysign.