-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[LoongArch] Optimize extractelement containing variable index #151475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-loongarch Author: ZhaoQi (zhaoqi5) ChangesFull diff: https://github.com/llvm/llvm-project/pull/151475.diff 4 Files Affected:
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index a5bf0e57e3053..4f534f1666eaa 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -2608,13 +2608,29 @@ SDValue LoongArchTargetLowering::lowerCONCAT_VECTORS(SDValue Op,
SDValue
LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
SelectionDAG &DAG) const {
- EVT VecTy = Op->getOperand(0)->getValueType(0);
+ MVT EltVT = Op.getSimpleValueType();
+ SDValue Vec = Op->getOperand(0);
+ EVT VecTy = Vec->getValueType(0);
SDValue Idx = Op->getOperand(1);
unsigned NumElts = VecTy.getVectorNumElements();
+ SDLoc DL(Op);
+
+ assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type");
if (isa<ConstantSDNode>(Idx) && Idx->getAsZExtVal() < NumElts)
return Op;
+ // TODO: Deal with other legal 256-bits vector types?
+ if (!isa<ConstantSDNode>(Idx) &&
+ (VecTy == MVT::v8i32 || VecTy == MVT::v8f32)) {
+ SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx);
+ SDValue SplatValue =
+ DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx);
+
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue,
+ DAG.getConstant(0, DL, Subtarget.getGRLenVT()));
+ }
+
return SDValue();
}
@@ -6632,6 +6648,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VREPLVEI)
NODE_NAME_CASE(VREPLGR2VR)
NODE_NAME_CASE(XVPERMI)
+ NODE_NAME_CASE(XVPERM)
NODE_NAME_CASE(VPICK_SEXT_ELT)
NODE_NAME_CASE(VPICK_ZEXT_ELT)
NODE_NAME_CASE(VREPLVE)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 6b49a98f3ae46..32a695825342e 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -141,6 +141,7 @@ enum NodeType : unsigned {
VREPLVEI,
VREPLGR2VR,
XVPERMI,
+ XVPERM,
// Extended vector element extraction
VPICK_SEXT_ELT,
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 5096a8fcda8eb..7f646ad0d6fdc 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -10,8 +10,12 @@
//
//===----------------------------------------------------------------------===//
+def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>,
+ SDTCisVec<2>, SDTCisInt<2>]>;
+
// Target nodes.
def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
+def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>;
def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
@@ -1835,6 +1839,12 @@ def : Pat<(loongarch_xvpermi v4i64:$xj, immZExt8: $ui8),
def : Pat<(loongarch_xvpermi v4f64:$xj, immZExt8: $ui8),
(XVPERMI_D v4f64:$xj, immZExt8: $ui8)>;
+// XVPERM_W
+def : Pat<(loongarch_xvperm v8i32:$xj, v8i32:$xk),
+ (XVPERM_W v8i32:$xj, v8i32:$xk)>;
+def : Pat<(loongarch_xvperm v8f32:$xj, v8i32:$xk),
+ (XVPERM_W v8f32:$xj, v8i32:$xk)>;
+
// XVREPLVE0_{W/D}
def : Pat<(lasxsplatf32 FPR32:$fj),
(XVREPLVE0_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32))>;
diff --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll
index 2e1618748688a..b191a9d08ab2d 100644
--- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll
+++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll
@@ -126,21 +126,11 @@ define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_8xi32_idx:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi.d $sp, $sp, -96
-; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT: addi.d $fp, $sp, 96
-; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
-; CHECK-NEXT: xvst $xr0, $sp, 32
-; CHECK-NEXT: addi.d $a0, $sp, 32
-; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2
-; CHECK-NEXT: ld.w $a0, $a0, 0
-; CHECK-NEXT: st.w $a0, $a1, 0
-; CHECK-NEXT: addi.d $sp, $fp, -96
-; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT: addi.d $sp, $sp, 96
+; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0
+; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0
+; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
+; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <8 x i32>, ptr %src
%e = extractelement <8 x i32> %v, i32 %idx
@@ -176,21 +166,11 @@ define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
; CHECK-LABEL: extract_8xfloat_idx:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi.d $sp, $sp, -96
-; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT: addi.d $fp, $sp, 96
-; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
; CHECK-NEXT: xvld $xr0, $a0, 0
-; CHECK-NEXT: xvst $xr0, $sp, 32
-; CHECK-NEXT: addi.d $a0, $sp, 32
-; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2
-; CHECK-NEXT: fld.s $fa0, $a0, 0
-; CHECK-NEXT: fst.s $fa0, $a1, 0
-; CHECK-NEXT: addi.d $sp, $fp, -96
-; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT: addi.d $sp, $sp, 96
+; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0
+; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0
+; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
+; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0
; CHECK-NEXT: ret
%v = load volatile <8 x float>, ptr %src
%e = extractelement <8 x float> %v, i32 %idx
|
Any idea for other 256-bits types? |
def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>, | ||
SDTCisVec<2>, SDTCisInt<2>]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>, | |
SDTCisVec<2>, SDTCisInt<2>]>; | |
def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>, | |
SDTCisVec<2>, SDTCisInt<2>]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will modify it. Thanks.
de080ca
to
80ccc27
Compare
if (isa<ConstantSDNode>(Idx) && Idx->getAsZExtVal() < NumElts) | ||
return Op; | ||
|
||
// TODO: Deal with other legal 256-bits vector types? | ||
if (!isa<ConstantSDNode>(Idx) && | ||
(VecTy == MVT::v8i32 || VecTy == MVT::v8f32)) { | ||
SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); | ||
SDValue SplatValue = | ||
DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx); | ||
|
||
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue, | ||
DAG.getConstant(0, DL, Subtarget.getGRLenVT())); | ||
if (!isa<ConstantSDNode>(Idx)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (isa<ConstantSDNode>(Idx)) {
if (Idx->getAsZExtVal() < NumElts)
return Op;
return SDValue();
}
switch(xxx) {
...
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Thanks.
It seems that Idx->getAsZExtVal() < NumElts
can also be deleted. I have tried that if the constant idx is greater than or equal to NumElts
, initial DAG will use undef as its idx.
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload | ||
; CHECK-NEXT: addi.d $sp, $sp, 96 | ||
; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 | ||
; CHECK-NEXT: bstrpick.d $a0, $a0, 31, 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combine into one bstrpick.d $a0, $a2, 31, 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first bstrpick.d
is used to perform zero-extend for the idx because of its i32
type. It is indeed useless in these cases. Maybe we can remove it by performing combine for extract_vector_elt
before type legalization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood.
No description provided.