diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index a5bf0e57e3053..b77168c65b532 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -412,6 +412,11 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::BITCAST); } + // Set DAG combine for 'LASX' feature. + + if (Subtarget.hasExtLASX()) + setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + // Compute derived properties from the register classes. computeRegisterProperties(Subtarget.getRegisterInfo()); @@ -2608,14 +2613,91 @@ SDValue LoongArchTargetLowering::lowerCONCAT_VECTORS(SDValue Op, SDValue LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { - EVT VecTy = Op->getOperand(0)->getValueType(0); + MVT EltVT = Op.getSimpleValueType(); + SDValue Vec = Op->getOperand(0); + EVT VecTy = Vec->getValueType(0); SDValue Idx = Op->getOperand(1); - unsigned NumElts = VecTy.getVectorNumElements(); + SDLoc DL(Op); + MVT GRLenVT = Subtarget.getGRLenVT(); + + assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type"); - if (isa(Idx) && Idx->getAsZExtVal() < NumElts) + if (isa(Idx)) return Op; - return SDValue(); + switch (VecTy.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unexpected type"); + case MVT::v32i8: + case MVT::v16i16: { + // Consider the source vector as v8i32 type. + SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + + // Compute the adjusted index and use it to broadcast the vector. + // The original desired i8/i16 element is now replicated in each + // i32 lane of the splatted vector. + SDValue NewIdx = DAG.getNode( + LoongArchISD::BSTRPICK, DL, GRLenVT, Idx, + DAG.getConstant(31, DL, GRLenVT), + DAG.getConstant(((VecTy == MVT::v32i8) ? 2 : 1), DL, GRLenVT)); + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, NewIdx); + SDValue SplatValue = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdx); + SDValue SplatVec = DAG.getBitcast(VecTy, SplatValue); + + // Compute the local index of the original i8/i16 element within the + // i32 element and then use it to broadcast the vector. Each elements + // of the vector will be the desired element. + SDValue LocalIdx = DAG.getNode( + ISD::AND, DL, GRLenVT, Idx, + DAG.getConstant(((VecTy == MVT::v32i8) ? 3 : 1), DL, GRLenVT)); + SDValue ExtractVec = + DAG.getNode(LoongArchISD::VREPLVE, DL, VecTy, SplatVec, LocalIdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, + DAG.getConstant(0, DL, GRLenVT)); + } + case MVT::v8i32: + case MVT::v8f32: { + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); + SDValue SplatValue = + DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue, + DAG.getConstant(0, DL, GRLenVT)); + } + case MVT::v4i64: + case MVT::v4f64: { + // Consider the source vector as v8i32 type. + SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + + // Split the original element index into low and high parts: + // Lo = Idx * 2, Hi = Idx * 2 + 1. + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); + SDValue SplatIdxLo = DAG.getNode(LoongArchISD::VSLLI, DL, MVT::v8i32, + SplatIdx, DAG.getConstant(1, DL, GRLenVT)); + SDValue SplatIdxHi = + DAG.getNode(ISD::ADD, DL, MVT::v8i32, SplatIdxLo, + DAG.getSplatBuildVector(MVT::v8i32, DL, + DAG.getConstant(1, DL, GRLenVT))); + + // Use the broadcasted index to broadcast the low and high parts of the + // vector separately. + SDValue SplatVecLo = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxLo); + SDValue SplatVecHi = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxHi); + + // Combine the low and high i32 parts to reconstruct the original i64/f64 + // element. + SDValue SplatValue = DAG.getNode(LoongArchISD::VILVL, DL, MVT::v8i32, + SplatVecHi, SplatVecLo); + SDValue ExtractVec = DAG.getBitcast(VecTy, SplatValue); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, + DAG.getConstant(0, DL, GRLenVT)); + } + } } SDValue @@ -5830,6 +5912,42 @@ performSPLIT_PAIR_F64Combine(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue +performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const LoongArchSubtarget &Subtarget) { + if (!DCI.isBeforeLegalize()) + return SDValue(); + + MVT EltVT = N->getSimpleValueType(0); + SDValue Vec = N->getOperand(0); + EVT VecTy = Vec->getValueType(0); + SDValue Idx = N->getOperand(1); + unsigned IdxOp = Idx.getOpcode(); + SDLoc DL(N); + + if (!VecTy.is256BitVector() || isa(Idx)) + return SDValue(); + + // Combine: + // t2 = truncate t1 + // t3 = {zero/sign/any}_extend t2 + // t4 = extract_vector_elt t0, t3 + // to: + // t4 = extract_vector_elt t0, t1 + if (IdxOp == ISD::ZERO_EXTEND || IdxOp == ISD::SIGN_EXTEND || + IdxOp == ISD::ANY_EXTEND) { + SDValue IdxOrig = Idx.getOperand(0); + if (!(IdxOrig.getOpcode() == ISD::TRUNCATE)) + return SDValue(); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec, + IdxOrig.getOperand(0)); + } + + return SDValue(); +} + SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -5859,6 +5977,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N, return performVMSKLTZCombine(N, DAG, DCI, Subtarget); case LoongArchISD::SPLIT_PAIR_F64: return performSPLIT_PAIR_F64Combine(N, DAG, DCI, Subtarget); + case ISD::EXTRACT_VECTOR_ELT: + return performEXTRACT_VECTOR_ELTCombine(N, DAG, DCI, Subtarget); } return SDValue(); } @@ -6632,6 +6752,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VREPLVEI) NODE_NAME_CASE(VREPLGR2VR) NODE_NAME_CASE(XVPERMI) + NODE_NAME_CASE(XVPERM) NODE_NAME_CASE(VPICK_SEXT_ELT) NODE_NAME_CASE(VPICK_ZEXT_ELT) NODE_NAME_CASE(VREPLVE) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h index 6b49a98f3ae46..32a695825342e 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h @@ -141,6 +141,7 @@ enum NodeType : unsigned { VREPLVEI, VREPLGR2VR, XVPERMI, + XVPERM, // Extended vector element extraction VPICK_SEXT_ELT, diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index 5096a8fcda8eb..b790ac3e73ec7 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -10,8 +10,12 @@ // //===----------------------------------------------------------------------===// +def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>, + SDTCisVec<2>, SDTCisInt<2>]>; + // Target nodes. def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>; +def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>; def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>; def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>; def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>; @@ -1835,6 +1839,12 @@ def : Pat<(loongarch_xvpermi v4i64:$xj, immZExt8: $ui8), def : Pat<(loongarch_xvpermi v4f64:$xj, immZExt8: $ui8), (XVPERMI_D v4f64:$xj, immZExt8: $ui8)>; +// XVPERM_W +def : Pat<(loongarch_xvperm v8i32:$xj, v8i32:$xk), + (XVPERM_W v8i32:$xj, v8i32:$xk)>; +def : Pat<(loongarch_xvperm v8f32:$xj, v8i32:$xk), + (XVPERM_W v8f32:$xj, v8i32:$xk)>; + // XVREPLVE0_{W/D} def : Pat<(lasxsplatf32 FPR32:$fj), (XVREPLVE0_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32))>; diff --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll index 2e1618748688a..72542df328250 100644 --- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll +++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll @@ -76,21 +76,13 @@ define void @extract_4xdouble(ptr %src, ptr %dst) nounwind { define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_32xi8_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 0 -; CHECK-NEXT: ld.b $a0, $a0, 0 -; CHECK-NEXT: st.b $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 2 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: andi $a0, $a2, 3 +; CHECK-NEXT: xvreplve.b $xr0, $xr0, $a0 +; CHECK-NEXT: xvstelm.b $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <32 x i8>, ptr %src %e = extractelement <32 x i8> %v, i32 %idx @@ -101,21 +93,13 @@ define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_16xi16_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 1 -; CHECK-NEXT: ld.h $a0, $a0, 0 -; CHECK-NEXT: st.h $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 1 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: andi $a0, $a2, 1 +; CHECK-NEXT: xvreplve.h $xr0, $xr0, $a0 +; CHECK-NEXT: xvstelm.h $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <16 x i16>, ptr %src %e = extractelement <16 x i16> %v, i32 %idx @@ -126,21 +110,10 @@ define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_8xi32_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2 -; CHECK-NEXT: ld.w $a0, $a0, 0 -; CHECK-NEXT: st.w $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <8 x i32>, ptr %src %e = extractelement <8 x i32> %v, i32 %idx @@ -151,21 +124,14 @@ define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_4xi64_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3 -; CHECK-NEXT: ld.d $a0, $a0, 0 -; CHECK-NEXT: st.d $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 +; CHECK-NEXT: xvslli.w $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1 +; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvilvl.w $xr0, $xr0, $xr2 +; CHECK-NEXT: xvstelm.d $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <4 x i64>, ptr %src %e = extractelement <4 x i64> %v, i32 %idx @@ -176,21 +142,10 @@ define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_8xfloat_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2 -; CHECK-NEXT: fld.s $fa0, $a0, 0 -; CHECK-NEXT: fst.s $fa0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <8 x float>, ptr %src %e = extractelement <8 x float> %v, i32 %idx @@ -201,21 +156,14 @@ define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_4xdouble_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_4xdouble_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3 -; CHECK-NEXT: fld.d $fa0, $a0, 0 -; CHECK-NEXT: fst.d $fa0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 +; CHECK-NEXT: xvslli.w $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1 +; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvilvl.w $xr0, $xr0, $xr2 +; CHECK-NEXT: xvstelm.d $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <4 x double>, ptr %src %e = extractelement <4 x double> %v, i32 %idx