Skip to content

[RISCV] Reuse lowerToScalableOp for more nodes. NFC #151911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2025

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Aug 4, 2025

A lot of fixed-length custom lowerings just involve inserting the operands into a scalable container and extracting the result out, and lowerToScalableOp already does this.

We just need to teach it to handle operands with different element types (but same vector element count), and we can reuse it for vselect/zext/sext/setcc/fcopysign.

A lot of fixed-length custom lowerings just involve inserting the operands into a scalable container and extracting the result out, and lowerToScalableOp already does this.

We just need to teach it to handle operands with different element types (but same vector element count), and we can reuse it for vselect/zext/sext/setcc/fcopysign.
@llvmbot
Copy link
Member

llvmbot commented Aug 4, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

A lot of fixed-length custom lowerings just involve inserting the operands into a scalable container and extracting the result out, and lowerToScalableOp already does this.

We just need to teach it to handle operands with different element types (but same vector element count), and we can reuse it for vselect/zext/sext/setcc/fcopysign.


Full diff: https://github.com/llvm/llvm-project/pull/151911.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+22-104)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (-5)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c0ada51ef4403..a9ad0d6f403a2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -7005,6 +7005,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
   OP_CASE(FDIV)
   OP_CASE(FNEG)
   OP_CASE(FABS)
+  OP_CASE(FCOPYSIGN)
   OP_CASE(FSQRT)
   OP_CASE(SMIN)
   OP_CASE(SMAX)
@@ -7072,6 +7073,15 @@ static unsigned getRISCVVLOp(SDValue Op) {
     if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
       return RISCVISD::VMXOR_VL;
     return RISCVISD::XOR_VL;
+  case ISD::ANY_EXTEND:
+  case ISD::ZERO_EXTEND:
+    return RISCVISD::VZEXT_VL;
+  case ISD::SIGN_EXTEND:
+    return RISCVISD::VSEXT_VL;
+  case ISD::SETCC:
+    return RISCVISD::SETCC_VL;
+  case ISD::VSELECT:
+    return RISCVISD::VMERGE_VL;
   case ISD::VP_SELECT:
   case ISD::VP_MERGE:
     return RISCVISD::VMERGE_VL;
@@ -7412,12 +7422,16 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     if (Op.getOperand(0).getValueType().isVector() &&
         Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
       return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ 1);
-    return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VZEXT_VL);
+    if (Op.getValueType().isScalableVector())
+      return Op;
+    return lowerToScalableOp(Op, DAG);
   case ISD::SIGN_EXTEND:
     if (Op.getOperand(0).getValueType().isVector() &&
         Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
       return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
-    return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
+    if (Op.getValueType().isScalableVector())
+      return Op;
+    return lowerToScalableOp(Op, DAG);
   case ISD::SPLAT_VECTOR_PARTS:
     return lowerSPLAT_VECTOR_PARTS(Op, DAG);
   case ISD::INSERT_VECTOR_ELT:
@@ -8159,7 +8173,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     if (isPromotedOpNeedingSplit(Op.getOperand(0), Subtarget))
       return SplitVectorOp(Op, DAG);
 
-    return lowerFixedLengthVectorSetccToRVV(Op, DAG);
+    return lowerToScalableOp(Op, DAG);
   }
   case ISD::ADD:
   case ISD::SUB:
@@ -8175,6 +8189,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::UREM:
   case ISD::BSWAP:
   case ISD::CTPOP:
+  case ISD::VSELECT:
     return lowerToScalableOp(Op, DAG);
   case ISD::SHL:
   case ISD::SRA:
@@ -8243,14 +8258,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       return lowerToScalableOp(Op, DAG);
     assert(Op.getOpcode() != ISD::CTTZ);
     return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
-  case ISD::VSELECT:
-    return lowerFixedLengthVectorSelectToRVV(Op, DAG);
   case ISD::FCOPYSIGN:
     if (Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16)
       return lowerFCOPYSIGN(Op, DAG, Subtarget);
     if (isPromotedOpNeedingSplit(Op, Subtarget))
       return SplitVectorOp(Op, DAG);
-    return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
+    return lowerToScalableOp(Op, DAG);
   case ISD::STRICT_FADD:
   case ISD::STRICT_FSUB:
   case ISD::STRICT_FMUL:
@@ -9687,33 +9700,6 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
   return convertFromScalableVector(VecVT, Select, DAG, Subtarget);
 }
 
-SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV(
-    SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const {
-  MVT ExtVT = Op.getSimpleValueType();
-  // Only custom-lower extensions from fixed-length vector types.
-  if (!ExtVT.isFixedLengthVector())
-    return Op;
-  MVT VT = Op.getOperand(0).getSimpleValueType();
-  // Grab the canonical container type for the extended type. Infer the smaller
-  // type from that to ensure the same number of vector elements, as we know
-  // the LMUL will be sufficient to hold the smaller type.
-  MVT ContainerExtVT = getContainerForFixedLengthVector(ExtVT);
-  // Get the extended container type manually to ensure the same number of
-  // vector elements between source and dest.
-  MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(),
-                                     ContainerExtVT.getVectorElementCount());
-
-  SDValue Op1 =
-      convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
-
-  SDLoc DL(Op);
-  auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-
-  SDValue Ext = DAG.getNode(ExtendOpc, DL, ContainerExtVT, Op1, Mask, VL);
-
-  return convertFromScalableVector(ExtVT, Ext, DAG, Subtarget);
-}
-
 // Custom-lower truncations from vectors to mask vectors by using a mask and a
 // setcc operation:
 //   (vXi1 = trunc vXiN vec) -> (vXi1 = setcc (and vec, 1), 0, ne)
@@ -12777,31 +12763,6 @@ SDValue RISCVTargetLowering::lowerVectorCompress(SDValue Op,
   return Res;
 }
 
-SDValue
-RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
-                                                      SelectionDAG &DAG) const {
-  MVT InVT = Op.getOperand(0).getSimpleValueType();
-  MVT ContainerVT = getContainerForFixedLengthVector(InVT);
-
-  MVT VT = Op.getSimpleValueType();
-
-  SDValue Op1 =
-      convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
-  SDValue Op2 =
-      convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
-
-  SDLoc DL(Op);
-  auto [Mask, VL] = getDefaultVLOps(VT.getVectorNumElements(), ContainerVT, DL,
-                                    DAG, Subtarget);
-  MVT MaskVT = getMaskTypeFor(ContainerVT);
-
-  SDValue Cmp =
-      DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
-                  {Op1, Op2, Op.getOperand(2), DAG.getUNDEF(MaskVT), Mask, VL});
-
-  return convertFromScalableVector(VT, Cmp, DAG, Subtarget);
-}
-
 SDValue RISCVTargetLowering::lowerVectorStrictFSetcc(SDValue Op,
                                                      SelectionDAG &DAG) const {
   unsigned Opc = Op.getOpcode();
@@ -12928,51 +12889,6 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
   return Max;
 }
 
-SDValue RISCVTargetLowering::lowerFixedLengthVectorFCOPYSIGNToRVV(
-    SDValue Op, SelectionDAG &DAG) const {
-  SDLoc DL(Op);
-  MVT VT = Op.getSimpleValueType();
-  SDValue Mag = Op.getOperand(0);
-  SDValue Sign = Op.getOperand(1);
-  assert(Mag.getValueType() == Sign.getValueType() &&
-         "Can only handle COPYSIGN with matching types.");
-
-  MVT ContainerVT = getContainerForFixedLengthVector(VT);
-  Mag = convertToScalableVector(ContainerVT, Mag, DAG, Subtarget);
-  Sign = convertToScalableVector(ContainerVT, Sign, DAG, Subtarget);
-
-  auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-
-  SDValue CopySign = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Mag,
-                                 Sign, DAG.getUNDEF(ContainerVT), Mask, VL);
-
-  return convertFromScalableVector(VT, CopySign, DAG, Subtarget);
-}
-
-SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
-    SDValue Op, SelectionDAG &DAG) const {
-  MVT VT = Op.getSimpleValueType();
-  MVT ContainerVT = getContainerForFixedLengthVector(VT);
-
-  MVT I1ContainerVT =
-      MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
-
-  SDValue CC =
-      convertToScalableVector(I1ContainerVT, Op.getOperand(0), DAG, Subtarget);
-  SDValue Op1 =
-      convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
-  SDValue Op2 =
-      convertToScalableVector(ContainerVT, Op.getOperand(2), DAG, Subtarget);
-
-  SDLoc DL(Op);
-  SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
-
-  SDValue Select = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, CC, Op1,
-                               Op2, DAG.getUNDEF(ContainerVT), VL);
-
-  return convertFromScalableVector(VT, Select, DAG, Subtarget);
-}
-
 SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
                                                SelectionDAG &DAG) const {
   const auto &TSInfo =
@@ -12999,7 +12915,9 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
     // "cast" fixed length vector to a scalable vector.
     assert(useRVVForFixedLengthVectorVT(V.getSimpleValueType()) &&
            "Only fixed length vectors are supported!");
-    Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
+    MVT VContainerVT = ContainerVT.changeVectorElementType(
+        V.getSimpleValueType().getVectorElementType());
+    Ops.push_back(convertToScalableVector(VContainerVT, V, DAG, Subtarget));
   }
 
   SDLoc DL(Op);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ca70c46988b4e..fa50e2105a708 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -534,9 +534,6 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerMaskedScatter(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
-  SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
-  SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
-                                            SelectionDAG &DAG) const;
   SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const;
@@ -551,8 +548,6 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPCttzElements(SDValue Op, SelectionDAG &DAG) const;
-  SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
-                                            unsigned ExtendOpc) const;
   SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerGET_FPENV(SDValue Op, SelectionDAG &DAG) const;

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lukel97 lukel97 merged commit 5a80274 into llvm:main Aug 4, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants