Skip to content

[LoongArch] Optimize extractelement containing variable index #151475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 125 additions & 4 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::BITCAST);
}

// Set DAG combine for 'LASX' feature.

if (Subtarget.hasExtLASX())
setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);

// Compute derived properties from the register classes.
computeRegisterProperties(Subtarget.getRegisterInfo());

Expand Down Expand Up @@ -2608,14 +2613,91 @@ SDValue LoongArchTargetLowering::lowerCONCAT_VECTORS(SDValue Op,
SDValue
LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
SelectionDAG &DAG) const {
EVT VecTy = Op->getOperand(0)->getValueType(0);
MVT EltVT = Op.getSimpleValueType();
SDValue Vec = Op->getOperand(0);
EVT VecTy = Vec->getValueType(0);
SDValue Idx = Op->getOperand(1);
unsigned NumElts = VecTy.getVectorNumElements();
SDLoc DL(Op);
MVT GRLenVT = Subtarget.getGRLenVT();

assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type");

if (isa<ConstantSDNode>(Idx) && Idx->getAsZExtVal() < NumElts)
if (isa<ConstantSDNode>(Idx))
return Op;

return SDValue();
switch (VecTy.getSimpleVT().SimpleTy) {
default:
llvm_unreachable("Unexpected type");
case MVT::v32i8:
case MVT::v16i16: {
// Consider the source vector as v8i32 type.
SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec);

// Compute the adjusted index and use it to broadcast the vector.
// The original desired i8/i16 element is now replicated in each
// i32 lane of the splatted vector.
SDValue NewIdx = DAG.getNode(
LoongArchISD::BSTRPICK, DL, GRLenVT, Idx,
DAG.getConstant(31, DL, GRLenVT),
DAG.getConstant(((VecTy == MVT::v32i8) ? 2 : 1), DL, GRLenVT));
SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, NewIdx);
SDValue SplatValue =
DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdx);
SDValue SplatVec = DAG.getBitcast(VecTy, SplatValue);

// Compute the local index of the original i8/i16 element within the
// i32 element and then use it to broadcast the vector. Each elements
// of the vector will be the desired element.
SDValue LocalIdx = DAG.getNode(
ISD::AND, DL, GRLenVT, Idx,
DAG.getConstant(((VecTy == MVT::v32i8) ? 3 : 1), DL, GRLenVT));
SDValue ExtractVec =
DAG.getNode(LoongArchISD::VREPLVE, DL, VecTy, SplatVec, LocalIdx);

return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec,
DAG.getConstant(0, DL, GRLenVT));
}
case MVT::v8i32:
case MVT::v8f32: {
SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx);
SDValue SplatValue =
DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx);

return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue,
DAG.getConstant(0, DL, GRLenVT));
}
case MVT::v4i64:
case MVT::v4f64: {
// Consider the source vector as v8i32 type.
SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec);

// Split the original element index into low and high parts:
// Lo = Idx * 2, Hi = Idx * 2 + 1.
SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx);
SDValue SplatIdxLo = DAG.getNode(LoongArchISD::VSLLI, DL, MVT::v8i32,
SplatIdx, DAG.getConstant(1, DL, GRLenVT));
SDValue SplatIdxHi =
DAG.getNode(ISD::ADD, DL, MVT::v8i32, SplatIdxLo,
DAG.getSplatBuildVector(MVT::v8i32, DL,
DAG.getConstant(1, DL, GRLenVT)));

// Use the broadcasted index to broadcast the low and high parts of the
// vector separately.
SDValue SplatVecLo =
DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxLo);
SDValue SplatVecHi =
DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxHi);

// Combine the low and high i32 parts to reconstruct the original i64/f64
// element.
SDValue SplatValue = DAG.getNode(LoongArchISD::VILVL, DL, MVT::v8i32,
SplatVecHi, SplatVecLo);
SDValue ExtractVec = DAG.getBitcast(VecTy, SplatValue);

return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec,
DAG.getConstant(0, DL, GRLenVT));
}
}
}

SDValue
Expand Down Expand Up @@ -5830,6 +5912,42 @@ performSPLIT_PAIR_F64Combine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue
performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
if (!DCI.isBeforeLegalize())
return SDValue();

MVT EltVT = N->getSimpleValueType(0);
SDValue Vec = N->getOperand(0);
EVT VecTy = Vec->getValueType(0);
SDValue Idx = N->getOperand(1);
unsigned IdxOp = Idx.getOpcode();
SDLoc DL(N);

if (!VecTy.is256BitVector() || isa<ConstantSDNode>(Idx))
return SDValue();

// Combine:
// t2 = truncate t1
// t3 = {zero/sign/any}_extend t2
// t4 = extract_vector_elt t0, t3
// to:
// t4 = extract_vector_elt t0, t1
if (IdxOp == ISD::ZERO_EXTEND || IdxOp == ISD::SIGN_EXTEND ||
IdxOp == ISD::ANY_EXTEND) {
SDValue IdxOrig = Idx.getOperand(0);
if (!(IdxOrig.getOpcode() == ISD::TRUNCATE))
return SDValue();

return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
IdxOrig.getOperand(0));
}

return SDValue();
}

SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand Down Expand Up @@ -5859,6 +5977,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
return performVMSKLTZCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::SPLIT_PAIR_F64:
return performSPLIT_PAIR_F64Combine(N, DAG, DCI, Subtarget);
case ISD::EXTRACT_VECTOR_ELT:
return performEXTRACT_VECTOR_ELTCombine(N, DAG, DCI, Subtarget);
}
return SDValue();
}
Expand Down Expand Up @@ -6632,6 +6752,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VREPLVEI)
NODE_NAME_CASE(VREPLGR2VR)
NODE_NAME_CASE(XVPERMI)
NODE_NAME_CASE(XVPERM)
NODE_NAME_CASE(VPICK_SEXT_ELT)
NODE_NAME_CASE(VPICK_ZEXT_ELT)
NODE_NAME_CASE(VREPLVE)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ enum NodeType : unsigned {
VREPLVEI,
VREPLGR2VR,
XVPERMI,
XVPERM,

// Extended vector element extraction
VPICK_SEXT_ELT,
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
//
//===----------------------------------------------------------------------===//

def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>,
SDTCisVec<2>, SDTCisInt<2>]>;

// Target nodes.
def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>;
def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
Expand Down Expand Up @@ -1835,6 +1839,12 @@ def : Pat<(loongarch_xvpermi v4i64:$xj, immZExt8: $ui8),
def : Pat<(loongarch_xvpermi v4f64:$xj, immZExt8: $ui8),
(XVPERMI_D v4f64:$xj, immZExt8: $ui8)>;

// XVPERM_W
def : Pat<(loongarch_xvperm v8i32:$xj, v8i32:$xk),
(XVPERM_W v8i32:$xj, v8i32:$xk)>;
def : Pat<(loongarch_xvperm v8f32:$xj, v8i32:$xk),
(XVPERM_W v8f32:$xj, v8i32:$xk)>;

// XVREPLVE0_{W/D}
def : Pat<(lasxsplatf32 FPR32:$fj),
(XVREPLVE0_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32))>;
Expand Down
116 changes: 32 additions & 84 deletions llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,13 @@ define void @extract_4xdouble(ptr %src, ptr %dst) nounwind {
define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_32xi8_idx:
; CHECK: # %bb.0:
; CHECK-NEXT: addi.d $sp, $sp, -96
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
; CHECK-NEXT: addi.d $fp, $sp, 96
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
; CHECK-NEXT: xvst $xr0, $sp, 32
; CHECK-NEXT: addi.d $a0, $sp, 32
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 0
; CHECK-NEXT: ld.b $a0, $a0, 0
; CHECK-NEXT: st.b $a0, $a1, 0
; CHECK-NEXT: addi.d $sp, $fp, -96
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
; CHECK-NEXT: addi.d $sp, $sp, 96
; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 2
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
; CHECK-NEXT: andi $a0, $a2, 3
; CHECK-NEXT: xvreplve.b $xr0, $xr0, $a0
; CHECK-NEXT: xvstelm.b $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <32 x i8>, ptr %src
%e = extractelement <32 x i8> %v, i32 %idx
Expand All @@ -101,21 +93,13 @@ define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_16xi16_idx:
; CHECK: # %bb.0:
; CHECK-NEXT: addi.d $sp, $sp, -96
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
; CHECK-NEXT: addi.d $fp, $sp, 96
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
; CHECK-NEXT: xvst $xr0, $sp, 32
; CHECK-NEXT: addi.d $a0, $sp, 32
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 1
; CHECK-NEXT: ld.h $a0, $a0, 0
; CHECK-NEXT: st.h $a0, $a1, 0
; CHECK-NEXT: addi.d $sp, $fp, -96
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
; CHECK-NEXT: addi.d $sp, $sp, 96
; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 1
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
; CHECK-NEXT: andi $a0, $a2, 1
; CHECK-NEXT: xvreplve.h $xr0, $xr0, $a0
; CHECK-NEXT: xvstelm.h $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <16 x i16>, ptr %src
%e = extractelement <16 x i16> %v, i32 %idx
Expand All @@ -126,21 +110,10 @@ define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_8xi32_idx:
; CHECK: # %bb.0:
; CHECK-NEXT: addi.d $sp, $sp, -96
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
; CHECK-NEXT: addi.d $fp, $sp, 96
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
; CHECK-NEXT: xvst $xr0, $sp, 32
; CHECK-NEXT: addi.d $a0, $sp, 32
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2
; CHECK-NEXT: ld.w $a0, $a0, 0
; CHECK-NEXT: st.w $a0, $a1, 0
; CHECK-NEXT: addi.d $sp, $fp, -96
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
; CHECK-NEXT: addi.d $sp, $sp, 96
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <8 x i32>, ptr %src
%e = extractelement <8 x i32> %v, i32 %idx
Expand All @@ -151,21 +124,14 @@ define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_4xi64_idx:
; CHECK: # %bb.0:
; CHECK-NEXT: addi.d $sp, $sp, -96
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
; CHECK-NEXT: addi.d $fp, $sp, 96
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
; CHECK-NEXT: xvst $xr0, $sp, 32
; CHECK-NEXT: addi.d $a0, $sp, 32
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3
; CHECK-NEXT: ld.d $a0, $a0, 0
; CHECK-NEXT: st.d $a0, $a1, 0
; CHECK-NEXT: addi.d $sp, $fp, -96
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
; CHECK-NEXT: addi.d $sp, $sp, 96
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2
; CHECK-NEXT: xvslli.w $xr1, $xr1, 1
; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1
; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
; CHECK-NEXT: xvilvl.w $xr0, $xr0, $xr2
; CHECK-NEXT: xvstelm.d $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <4 x i64>, ptr %src
%e = extractelement <4 x i64> %v, i32 %idx
Expand All @@ -176,21 +142,10 @@ define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_8xfloat_idx:
; CHECK: # %bb.0:
; CHECK-NEXT: addi.d $sp, $sp, -96
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
; CHECK-NEXT: addi.d $fp, $sp, 96
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
; CHECK-NEXT: xvst $xr0, $sp, 32
; CHECK-NEXT: addi.d $a0, $sp, 32
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2
; CHECK-NEXT: fld.s $fa0, $a0, 0
; CHECK-NEXT: fst.s $fa0, $a1, 0
; CHECK-NEXT: addi.d $sp, $fp, -96
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
; CHECK-NEXT: addi.d $sp, $sp, 96
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <8 x float>, ptr %src
%e = extractelement <8 x float> %v, i32 %idx
Expand All @@ -201,21 +156,14 @@ define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
define void @extract_4xdouble_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_4xdouble_idx:
; CHECK: # %bb.0:
; CHECK-NEXT: addi.d $sp, $sp, -96
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
; CHECK-NEXT: addi.d $fp, $sp, 96
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
; CHECK-NEXT: xvst $xr0, $sp, 32
; CHECK-NEXT: addi.d $a0, $sp, 32
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3
; CHECK-NEXT: fld.d $fa0, $a0, 0
; CHECK-NEXT: fst.d $fa0, $a1, 0
; CHECK-NEXT: addi.d $sp, $fp, -96
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
; CHECK-NEXT: addi.d $sp, $sp, 96
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2
; CHECK-NEXT: xvslli.w $xr1, $xr1, 1
; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1
; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
; CHECK-NEXT: xvilvl.w $xr0, $xr0, $xr2
; CHECK-NEXT: xvstelm.d $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <4 x double>, ptr %src
%e = extractelement <4 x double> %v, i32 %idx
Expand Down