Skip to content

[DAG] Fold trunc(abdu(x,y)) and trunc(abds(x,y)) if they have sufficient leading zero/sign bits #151471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ namespace {
SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
const SDLoc &DL);
SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
SDValue foldAbdToNarrowType(EVT VT, SDNode *N, const SDLoc &DL);
SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
SDValue False, ISD::CondCode CC, const SDLoc &DL);
Expand Down Expand Up @@ -3925,6 +3926,47 @@ SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
return SDValue();
}

// trunc (ABDU/S A, B)) → ABDU/S (trunc A), (trunc B)
SDValue DAGCombiner::foldAbdToNarrowType(EVT VT, SDNode *N, const SDLoc &DL) {
SDValue Op = N->getOperand(0);

unsigned Opcode = Op.getOpcode();
if (Opcode != ISD::ABDU && Opcode != ISD::ABDS)
return SDValue();

SDValue Operand0 = Op.getOperand(0);
SDValue Operand1 = Op.getOperand(1);

// Early exit if either operand is zero.
if (ISD::isBuildVectorAllZeros(Operand0.getNode()) ||
ISD::isBuildVectorAllZeros(Operand1.getNode()))
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Doesn't visitABD handle this eventually anyhow?


EVT SrcVT = Op.getValueType();
EVT TruncVT = N->getValueType(0);
unsigned NumSrcBits = SrcVT.getScalarSizeInBits();
unsigned NumTruncBits = TruncVT.getScalarSizeInBits();
unsigned NeededBits = NumSrcBits - NumTruncBits;

bool CanFold = false;

if (Opcode == ISD::ABDU) {
KnownBits Known = DAG.computeKnownBits(Op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you testing the ABD result instead of the operands like the alive2 tests? Is the fold still correct?

CanFold = Known.countMinLeadingZeros() >= NeededBits;
} else {
unsigned SignBits = DAG.ComputeNumSignBits(Op);
CanFold = SignBits >= NeededBits;
}

if (CanFold) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just because we can fold doesn't mean we should - look at the switch statement at the bottom of visitTRUNCATE - it has a series of legality/profitability checks for different opcodes so that we don't always fold trunc(abd(x,y)) -> abd(trunc(x),trunc(y)) if it'd be more costly.

SDValue NewOp0 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Operand0);
SDValue NewOp1 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Operand1);
return DAG.getNode(Opcode, DL, TruncVT, NewOp0, NewOp1);
}

return SDValue();
}

// Refinement of DAG/Type Legalisation (promotion) when CTLZ is used for
// counting leading ones. Broadly, it replaces the substraction with a left
// shift.
Expand Down Expand Up @@ -16275,6 +16317,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
return NewVSel;

// fold trunc (ABDU/S A, B)) → ABDU/S (trunc A), (trunc B)
if (SDValue V = foldAbdToNarrowType(VT, N, SDLoc(N)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to handle these in the switch statement below - which is what its there for

return V;

// Narrow a suitable binary operation with a non-opaque constant operand by
// moving it ahead of the truncate. This is limited to pre-legalization
// because targets may prefer a wider type during later combines and invert
Expand Down
36 changes: 12 additions & 24 deletions llvm/test/CodeGen/AArch64/abd-combine.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ define <8 x i16> @abdu_base(<8 x i16> %src1, <8 x i16> %src2) {
define <8 x i16> @abdu_const(<8 x i16> %src1) {
; CHECK-LABEL: abdu_const:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v1.4s, #1
; CHECK-NEXT: ushll v2.4s, v0.4h, #0
; CHECK-NEXT: ushll2 v0.4s, v0.8h, #0
; CHECK-NEXT: uabd v0.4s, v0.4s, v1.4s
; CHECK-NEXT: uabd v1.4s, v2.4s, v1.4s
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
; CHECK-NEXT: movi v1.4h, #1
; CHECK-NEXT: mov v1.d[1], v1.d[0]
; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h
; CHECK-NEXT: ret
%zextsrc1 = zext <8 x i16> %src1 to <8 x i32>
%sub = sub <8 x i32> %zextsrc1, <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
Expand All @@ -34,12 +31,9 @@ define <8 x i16> @abdu_const(<8 x i16> %src1) {
define <8 x i16> @abdu_const_lhs(<8 x i16> %src1) {
; CHECK-LABEL: abdu_const_lhs:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v1.4s, #1
; CHECK-NEXT: ushll v2.4s, v0.4h, #0
; CHECK-NEXT: ushll2 v0.4s, v0.8h, #0
; CHECK-NEXT: uabd v0.4s, v0.4s, v1.4s
; CHECK-NEXT: uabd v1.4s, v2.4s, v1.4s
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
; CHECK-NEXT: movi v1.4h, #1
; CHECK-NEXT: mov v1.d[1], v1.d[0]
; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h
; CHECK-NEXT: ret
%zextsrc1 = zext <8 x i16> %src1 to <8 x i32>
%sub = sub <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, %zextsrc1
Expand Down Expand Up @@ -318,12 +312,9 @@ define <8 x i16> @abds_base(<8 x i16> %src1, <8 x i16> %src2) {
define <8 x i16> @abds_const(<8 x i16> %src1) {
; CHECK-LABEL: abds_const:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v1.4s, #1
; CHECK-NEXT: sshll v2.4s, v0.4h, #0
; CHECK-NEXT: sshll2 v0.4s, v0.8h, #0
; CHECK-NEXT: sabd v0.4s, v0.4s, v1.4s
; CHECK-NEXT: sabd v1.4s, v2.4s, v1.4s
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
; CHECK-NEXT: movi v1.4h, #1
; CHECK-NEXT: mov v1.d[1], v1.d[0]
; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h
; CHECK-NEXT: ret
%zextsrc1 = sext <8 x i16> %src1 to <8 x i32>
%sub = sub <8 x i32> %zextsrc1, <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
Expand All @@ -335,12 +326,9 @@ define <8 x i16> @abds_const(<8 x i16> %src1) {
define <8 x i16> @abds_const_lhs(<8 x i16> %src1) {
; CHECK-LABEL: abds_const_lhs:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v1.4s, #1
; CHECK-NEXT: sshll v2.4s, v0.4h, #0
; CHECK-NEXT: sshll2 v0.4s, v0.8h, #0
; CHECK-NEXT: sabd v0.4s, v0.4s, v1.4s
; CHECK-NEXT: sabd v1.4s, v2.4s, v1.4s
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
; CHECK-NEXT: movi v1.4h, #1
; CHECK-NEXT: mov v1.d[1], v1.d[0]
; CHECK-NEXT: sabd v0.8h, v0.8h, v1.8h
; CHECK-NEXT: ret
%zextsrc1 = sext <8 x i16> %src1 to <8 x i32>
%sub = sub <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, %zextsrc1
Expand Down
66 changes: 66 additions & 0 deletions llvm/test/CodeGen/AArch64/arm64-neon-aba-abd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,69 @@ define <4 x i32> @knownbits_sabd_and_mul_mask(<4 x i32> %a0, <4 x i32> %a1) {
%6 = shufflevector <4 x i32> %5, <4 x i32> undef, <4 x i32> <i32 0, i32 0, i32 3, i32 3>
ret <4 x i32> %6
}

define <4 x i16> @trunc_abdu_foldable(<4 x i16> %a, <4 x i16> %b) {
; CHECK-SD-LABEL: trunc_abdu_foldable:
; CHECK-SD: // %bb.0:
; CHECK-SD-NEXT: uabd v0.4h, v0.4h, v1.4h
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: trunc_abdu_foldable:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
; CHECK-GI-NEXT: uabd v0.4s, v0.4s, v1.4s
; CHECK-GI-NEXT: xtn v0.4h, v0.4s
; CHECK-GI-NEXT: ret
%ext_a = zext <4 x i16> %a to <4 x i32>
%ext_b = zext <4 x i16> %b to <4 x i32>
%abd = call <4 x i32> @llvm.aarch64.neon.uabd.v4i32(<4 x i32> %ext_a, <4 x i32> %ext_b)
%trunc = trunc <4 x i32> %abd to <4 x i16>
ret <4 x i16> %trunc
}

define <4 x i16> @trunc_abds_foldable(<4 x i16> %a, <4 x i16> %b) {
; CHECK-SD-LABEL: trunc_abds_foldable:
; CHECK-SD: // %bb.0:
; CHECK-SD-NEXT: sabd v0.4h, v0.4h, v1.4h
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: trunc_abds_foldable:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: sshll v0.4s, v0.4h, #0
; CHECK-GI-NEXT: sshll v1.4s, v1.4h, #0
; CHECK-GI-NEXT: sabd v0.4s, v0.4s, v1.4s
; CHECK-GI-NEXT: xtn v0.4h, v0.4s
; CHECK-GI-NEXT: ret
%a32 = sext <4 x i16> %a to <4 x i32>
%b32 = sext <4 x i16> %b to <4 x i32>
%abd32 = call <4 x i32> @llvm.aarch64.neon.sabd.v4i32(<4 x i32> %a32, <4 x i32> %b32)
%res16 = trunc <4 x i32> %abd32 to <4 x i16>
ret <4 x i16> %res16
}

define <4 x i16> @trunc_abdu_not_foldable(<4 x i16> %a, <4 x i32> %b) {
; CHECK-LABEL: trunc_abdu_not_foldable:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v0.4s, v0.4h, #0
; CHECK-NEXT: uabd v0.4s, v0.4s, v1.4s
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: ret
%ext_a = zext <4 x i16> %a to <4 x i32>
%abd = call <4 x i32> @llvm.aarch64.neon.uabd.v4i32(<4 x i32> %ext_a, <4 x i32> %b)
%trunc = trunc <4 x i32> %abd to <4 x i16>
ret <4 x i16> %trunc
}

define <4 x i16> @truncate_abds_testcase1(<4 x i16> %a, <4 x i32> %b) {
; CHECK-LABEL: truncate_abds_testcase1:
; CHECK: // %bb.0:
; CHECK-NEXT: sshll v0.4s, v0.4h, #0
; CHECK-NEXT: sabd v0.4s, v0.4s, v1.4s
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: ret
%a32 = sext <4 x i16> %a to <4 x i32>
%abd32 = call <4 x i32> @llvm.aarch64.neon.sabd.v4i32(<4 x i32> %a32, <4 x i32> %b)
%res16 = trunc <4 x i32> %abd32 to <4 x i16>
ret <4 x i16> %res16
}