diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 734191447d67f..b708b6db2607e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16343,6 +16343,38 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { DAG, DL); } break; + case ISD::ABDU: + case ISD::ABDS: + // (trunc (abdu/abds a, b)) → (abdu/abds (trunc a), (trunc b)) + if (!LegalOperations || N0.hasOneUse()) { + EVT SrcVT = N0.getValueType(); + EVT TruncVT = VT; + unsigned SrcBits = SrcVT.getScalarSizeInBits(); + unsigned TruncBits = TruncVT.getScalarSizeInBits(); + unsigned NeededBits = SrcBits - TruncBits; + + SDValue A = N0.getOperand(0); + SDValue B = N0.getOperand(1); + bool CanFold = false; + + if (N0.getOpcode() == ISD::ABDU) { + KnownBits KnownA = DAG.computeKnownBits(A); + KnownBits KnownB = DAG.computeKnownBits(B); + CanFold = KnownA.countMinLeadingZeros() >= NeededBits && + KnownB.countMinLeadingZeros() >= NeededBits; + } else { + unsigned SignBitsA = DAG.ComputeNumSignBits(A); + unsigned SignBitsB = DAG.ComputeNumSignBits(B); + CanFold = SignBitsA >= NeededBits && SignBitsB >= NeededBits; + } + + if (CanFold && TLI.isOperationLegal(N0.getOpcode(), VT)) { + SDValue NewA = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, A); + SDValue NewB = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, B); + return DAG.getNode(N0.getOpcode(), DL, TruncVT, NewA, NewB); + } + } + break; } return SDValue(); diff --git a/llvm/test/CodeGen/AArch64/abd-combine.ll b/llvm/test/CodeGen/AArch64/abd-combine.ll index d0257890d2c43..843a459beecf8 100644 --- a/llvm/test/CodeGen/AArch64/abd-combine.ll +++ b/llvm/test/CodeGen/AArch64/abd-combine.ll @@ -17,12 +17,9 @@ define <8 x i16> @abdu_base(<8 x i16> %src1, <8 x i16> %src2) { define <8 x i16> @abdu_const(<8 x i16> %src1) { ; CHECK-LABEL: abdu_const: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.4s, #1 -; CHECK-NEXT: ushll v2.4s, v0.4h, #0 -; CHECK-NEXT: ushll2 v0.4s, v0.8h, #0 -; CHECK-NEXT: uabd v0.4s, v0.4s, v1.4s -; CHECK-NEXT: uabd v1.4s, v2.4s, v1.4s -; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h +; CHECK-NEXT: movi v1.4h, #1 +; CHECK-NEXT: mov v1.d[1], v1.d[0] +; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h ; CHECK-NEXT: ret %zextsrc1 = zext <8 x i16> %src1 to <8 x i32> %sub = sub <8 x i32> %zextsrc1, @@ -34,12 +31,9 @@ define <8 x i16> @abdu_const(<8 x i16> %src1) { define <8 x i16> @abdu_const_lhs(<8 x i16> %src1) { ; CHECK-LABEL: abdu_const_lhs: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.4s, #1 -; CHECK-NEXT: ushll v2.4s, v0.4h, #0 -; CHECK-NEXT: ushll2 v0.4s, v0.8h, #0 -; CHECK-NEXT: uabd v0.4s, v0.4s, v1.4s -; CHECK-NEXT: uabd v1.4s, v2.4s, v1.4s -; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h +; CHECK-NEXT: movi v1.4h, #1 +; CHECK-NEXT: mov v1.d[1], v1.d[0] +; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h ; CHECK-NEXT: ret %zextsrc1 = zext <8 x i16> %src1 to <8 x i32> %sub = sub <8 x i32> , %zextsrc1 @@ -51,6 +45,10 @@ define <8 x i16> @abdu_const_lhs(<8 x i16> %src1) { define <8 x i16> @abdu_const_zero(<8 x i16> %src1) { ; CHECK-LABEL: abdu_const_zero: ; CHECK: // %bb.0: +; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8 +; CHECK-NEXT: abs v0.4h, v0.4h +; CHECK-NEXT: abs v1.4h, v1.4h +; CHECK-NEXT: mov v0.d[1], v1.d[0] ; CHECK-NEXT: ret %zextsrc1 = zext <8 x i16> %src1 to <8 x i32> %sub = sub <8 x i32> , %zextsrc1 @@ -318,12 +316,9 @@ define <8 x i16> @abds_base(<8 x i16> %src1, <8 x i16> %src2) { define <8 x i16> @abds_const(<8 x i16> %src1) { ; CHECK-LABEL: abds_const: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.4s, #1 -; CHECK-NEXT: sshll v2.4s, v0.4h, #0 -; CHECK-NEXT: sshll2 v0.4s, v0.8h, #0 -; CHECK-NEXT: sabd v0.4s, v0.4s, v1.4s -; CHECK-NEXT: sabd v1.4s, v2.4s, v1.4s -; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h +; CHECK-NEXT: movi v1.4h, #1 +; CHECK-NEXT: mov v1.d[1], v1.d[0] +; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h ; CHECK-NEXT: ret %zextsrc1 = sext <8 x i16> %src1 to <8 x i32> %sub = sub <8 x i32> %zextsrc1, @@ -335,12 +330,9 @@ define <8 x i16> @abds_const(<8 x i16> %src1) { define <8 x i16> @abds_const_lhs(<8 x i16> %src1) { ; CHECK-LABEL: abds_const_lhs: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.4s, #1 -; CHECK-NEXT: sshll v2.4s, v0.4h, #0 -; CHECK-NEXT: sshll2 v0.4s, v0.8h, #0 -; CHECK-NEXT: sabd v0.4s, v0.4s, v1.4s -; CHECK-NEXT: sabd v1.4s, v2.4s, v1.4s -; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h +; CHECK-NEXT: movi v1.4h, #1 +; CHECK-NEXT: mov v1.d[1], v1.d[0] +; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h ; CHECK-NEXT: ret %zextsrc1 = sext <8 x i16> %src1 to <8 x i32> %sub = sub <8 x i32> , %zextsrc1 @@ -352,11 +344,10 @@ define <8 x i16> @abds_const_lhs(<8 x i16> %src1) { define <8 x i16> @abds_const_zero(<8 x i16> %src1) { ; CHECK-LABEL: abds_const_zero: ; CHECK: // %bb.0: -; CHECK-NEXT: sshll v1.4s, v0.4h, #0 -; CHECK-NEXT: sshll2 v0.4s, v0.8h, #0 -; CHECK-NEXT: abs v0.4s, v0.4s -; CHECK-NEXT: abs v1.4s, v1.4s -; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h +; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8 +; CHECK-NEXT: abs v0.4h, v0.4h +; CHECK-NEXT: abs v1.4h, v1.4h +; CHECK-NEXT: mov v0.d[1], v1.d[0] ; CHECK-NEXT: ret %zextsrc1 = sext <8 x i16> %src1 to <8 x i32> %sub = sub <8 x i32> , %zextsrc1 diff --git a/llvm/test/CodeGen/AArch64/arm64-neon-aba-abd.ll b/llvm/test/CodeGen/AArch64/arm64-neon-aba-abd.ll index 6c7ddd916abdf..ccd1917ae3d85 100644 --- a/llvm/test/CodeGen/AArch64/arm64-neon-aba-abd.ll +++ b/llvm/test/CodeGen/AArch64/arm64-neon-aba-abd.ll @@ -575,3 +575,69 @@ define <4 x i32> @knownbits_sabd_and_mul_mask(<4 x i32> %a0, <4 x i32> %a1) { %6 = shufflevector <4 x i32> %5, <4 x i32> undef, <4 x i32> ret <4 x i32> %6 } + +define <4 x i16> @trunc_abdu_foldable(<4 x i16> %a, <4 x i16> %b) { +; CHECK-SD-LABEL: trunc_abdu_foldable: +; CHECK-SD: // %bb.0: +; CHECK-SD-NEXT: uabd v0.4h, v0.4h, v1.4h +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: trunc_abdu_foldable: +; CHECK-GI: // %bb.0: +; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0 +; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0 +; CHECK-GI-NEXT: uabd v0.4s, v0.4s, v1.4s +; CHECK-GI-NEXT: xtn v0.4h, v0.4s +; CHECK-GI-NEXT: ret + %ext_a = zext <4 x i16> %a to <4 x i32> + %ext_b = zext <4 x i16> %b to <4 x i32> + %abd = call <4 x i32> @llvm.aarch64.neon.uabd.v4i32(<4 x i32> %ext_a, <4 x i32> %ext_b) + %trunc = trunc <4 x i32> %abd to <4 x i16> + ret <4 x i16> %trunc +} + +define <4 x i16> @trunc_abds_foldable(<4 x i16> %a, <4 x i16> %b) { +; CHECK-SD-LABEL: trunc_abds_foldable: +; CHECK-SD: // %bb.0: +; CHECK-SD-NEXT: sabd v0.4h, v0.4h, v1.4h +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: trunc_abds_foldable: +; CHECK-GI: // %bb.0: +; CHECK-GI-NEXT: sshll v0.4s, v0.4h, #0 +; CHECK-GI-NEXT: sshll v1.4s, v1.4h, #0 +; CHECK-GI-NEXT: sabd v0.4s, v0.4s, v1.4s +; CHECK-GI-NEXT: xtn v0.4h, v0.4s +; CHECK-GI-NEXT: ret + %a32 = sext <4 x i16> %a to <4 x i32> + %b32 = sext <4 x i16> %b to <4 x i32> + %abd32 = call <4 x i32> @llvm.aarch64.neon.sabd.v4i32(<4 x i32> %a32, <4 x i32> %b32) + %res16 = trunc <4 x i32> %abd32 to <4 x i16> + ret <4 x i16> %res16 +} + +define <4 x i16> @trunc_abdu_not_foldable(<4 x i16> %a, <4 x i32> %b) { +; CHECK-LABEL: trunc_abdu_not_foldable: +; CHECK: // %bb.0: +; CHECK-NEXT: ushll v0.4s, v0.4h, #0 +; CHECK-NEXT: uabd v0.4s, v0.4s, v1.4s +; CHECK-NEXT: xtn v0.4h, v0.4s +; CHECK-NEXT: ret + %ext_a = zext <4 x i16> %a to <4 x i32> + %abd = call <4 x i32> @llvm.aarch64.neon.uabd.v4i32(<4 x i32> %ext_a, <4 x i32> %b) + %trunc = trunc <4 x i32> %abd to <4 x i16> + ret <4 x i16> %trunc +} + +define <4 x i16> @truncate_abds_testcase1(<4 x i16> %a, <4 x i32> %b) { +; CHECK-LABEL: truncate_abds_testcase1: +; CHECK: // %bb.0: +; CHECK-NEXT: sshll v0.4s, v0.4h, #0 +; CHECK-NEXT: sabd v0.4s, v0.4s, v1.4s +; CHECK-NEXT: xtn v0.4h, v0.4s +; CHECK-NEXT: ret + %a32 = sext <4 x i16> %a to <4 x i32> + %abd32 = call <4 x i32> @llvm.aarch64.neon.sabd.v4i32(<4 x i32> %a32, <4 x i32> %b) + %res16 = trunc <4 x i32> %abd32 to <4 x i16> + ret <4 x i16> %res16 +}