Skip to content

Commit 433897d

Browse files
committed
[InstCombine][X86] simplifyX86immShift - convert variable in-range vector shift by immediate amounts to generic shifts (PR40391)
The slli/srli/srai 'immediate' vector shifts (although its not immediate anymore to match gcc) can be replaced with generic shifts if the shift amount is known to be in range.
1 parent d4d62fc commit 433897d

File tree

2 files changed

+108
-58
lines changed

2 files changed

+108
-58
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -296,77 +296,109 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
296296
InstCombiner::BuilderTy &Builder) {
297297
bool LogicalShift = false;
298298
bool ShiftLeft = false;
299+
bool IsImm = false;
299300

300301
switch (II.getIntrinsicID()) {
301302
default: llvm_unreachable("Unexpected intrinsic!");
302-
case Intrinsic::x86_sse2_psra_d:
303-
case Intrinsic::x86_sse2_psra_w:
304303
case Intrinsic::x86_sse2_psrai_d:
305304
case Intrinsic::x86_sse2_psrai_w:
306-
case Intrinsic::x86_avx2_psra_d:
307-
case Intrinsic::x86_avx2_psra_w:
308305
case Intrinsic::x86_avx2_psrai_d:
309306
case Intrinsic::x86_avx2_psrai_w:
310-
case Intrinsic::x86_avx512_psra_q_128:
311307
case Intrinsic::x86_avx512_psrai_q_128:
312-
case Intrinsic::x86_avx512_psra_q_256:
313308
case Intrinsic::x86_avx512_psrai_q_256:
314-
case Intrinsic::x86_avx512_psra_d_512:
315-
case Intrinsic::x86_avx512_psra_q_512:
316-
case Intrinsic::x86_avx512_psra_w_512:
317309
case Intrinsic::x86_avx512_psrai_d_512:
318310
case Intrinsic::x86_avx512_psrai_q_512:
319311
case Intrinsic::x86_avx512_psrai_w_512:
320-
LogicalShift = false; ShiftLeft = false;
312+
IsImm = true;
313+
LLVM_FALLTHROUGH;
314+
case Intrinsic::x86_sse2_psra_d:
315+
case Intrinsic::x86_sse2_psra_w:
316+
case Intrinsic::x86_avx2_psra_d:
317+
case Intrinsic::x86_avx2_psra_w:
318+
case Intrinsic::x86_avx512_psra_q_128:
319+
case Intrinsic::x86_avx512_psra_q_256:
320+
case Intrinsic::x86_avx512_psra_d_512:
321+
case Intrinsic::x86_avx512_psra_q_512:
322+
case Intrinsic::x86_avx512_psra_w_512:
323+
LogicalShift = false;
324+
ShiftLeft = false;
321325
break;
322-
case Intrinsic::x86_sse2_psrl_d:
323-
case Intrinsic::x86_sse2_psrl_q:
324-
case Intrinsic::x86_sse2_psrl_w:
325326
case Intrinsic::x86_sse2_psrli_d:
326327
case Intrinsic::x86_sse2_psrli_q:
327328
case Intrinsic::x86_sse2_psrli_w:
328-
case Intrinsic::x86_avx2_psrl_d:
329-
case Intrinsic::x86_avx2_psrl_q:
330-
case Intrinsic::x86_avx2_psrl_w:
331329
case Intrinsic::x86_avx2_psrli_d:
332330
case Intrinsic::x86_avx2_psrli_q:
333331
case Intrinsic::x86_avx2_psrli_w:
334-
case Intrinsic::x86_avx512_psrl_d_512:
335-
case Intrinsic::x86_avx512_psrl_q_512:
336-
case Intrinsic::x86_avx512_psrl_w_512:
337332
case Intrinsic::x86_avx512_psrli_d_512:
338333
case Intrinsic::x86_avx512_psrli_q_512:
339334
case Intrinsic::x86_avx512_psrli_w_512:
340-
LogicalShift = true; ShiftLeft = false;
335+
IsImm = true;
336+
LLVM_FALLTHROUGH;
337+
case Intrinsic::x86_sse2_psrl_d:
338+
case Intrinsic::x86_sse2_psrl_q:
339+
case Intrinsic::x86_sse2_psrl_w:
340+
case Intrinsic::x86_avx2_psrl_d:
341+
case Intrinsic::x86_avx2_psrl_q:
342+
case Intrinsic::x86_avx2_psrl_w:
343+
case Intrinsic::x86_avx512_psrl_d_512:
344+
case Intrinsic::x86_avx512_psrl_q_512:
345+
case Intrinsic::x86_avx512_psrl_w_512:
346+
LogicalShift = true;
347+
ShiftLeft = false;
341348
break;
342-
case Intrinsic::x86_sse2_psll_d:
343-
case Intrinsic::x86_sse2_psll_q:
344-
case Intrinsic::x86_sse2_psll_w:
345349
case Intrinsic::x86_sse2_pslli_d:
346350
case Intrinsic::x86_sse2_pslli_q:
347351
case Intrinsic::x86_sse2_pslli_w:
348-
case Intrinsic::x86_avx2_psll_d:
349-
case Intrinsic::x86_avx2_psll_q:
350-
case Intrinsic::x86_avx2_psll_w:
351352
case Intrinsic::x86_avx2_pslli_d:
352353
case Intrinsic::x86_avx2_pslli_q:
353354
case Intrinsic::x86_avx2_pslli_w:
354-
case Intrinsic::x86_avx512_psll_d_512:
355-
case Intrinsic::x86_avx512_psll_q_512:
356-
case Intrinsic::x86_avx512_psll_w_512:
357355
case Intrinsic::x86_avx512_pslli_d_512:
358356
case Intrinsic::x86_avx512_pslli_q_512:
359357
case Intrinsic::x86_avx512_pslli_w_512:
360-
LogicalShift = true; ShiftLeft = true;
358+
IsImm = true;
359+
LLVM_FALLTHROUGH;
360+
case Intrinsic::x86_sse2_psll_d:
361+
case Intrinsic::x86_sse2_psll_q:
362+
case Intrinsic::x86_sse2_psll_w:
363+
case Intrinsic::x86_avx2_psll_d:
364+
case Intrinsic::x86_avx2_psll_q:
365+
case Intrinsic::x86_avx2_psll_w:
366+
case Intrinsic::x86_avx512_psll_d_512:
367+
case Intrinsic::x86_avx512_psll_q_512:
368+
case Intrinsic::x86_avx512_psll_w_512:
369+
LogicalShift = true;
370+
ShiftLeft = true;
361371
break;
362372
}
363373
assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
364374

375+
auto Vec = II.getArgOperand(0);
376+
auto Amt = II.getArgOperand(1);
377+
auto VT = cast<VectorType>(Vec->getType());
378+
auto SVT = VT->getElementType();
379+
unsigned VWidth = VT->getNumElements();
380+
unsigned BitWidth = SVT->getPrimitiveSizeInBits();
381+
382+
// If the shift amount is guaranteed to be in-range we can replace it with a
383+
// generic shift.
384+
if (IsImm) {
385+
assert(Amt->getType()->isIntegerTy(32) &&
386+
"Unexpected shift-by-immediate type");
387+
KnownBits KnownAmtBits =
388+
llvm::computeKnownBits(Amt, II.getModule()->getDataLayout());
389+
if (KnownAmtBits.getMaxValue().ult(BitWidth)) {
390+
Amt = Builder.CreateZExtOrTrunc(Amt, SVT);
391+
Amt = Builder.CreateVectorSplat(VWidth, Amt);
392+
return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
393+
: Builder.CreateLShr(Vec, Amt))
394+
: Builder.CreateAShr(Vec, Amt));
395+
}
396+
}
397+
365398
// Simplify if count is constant.
366-
auto Arg1 = II.getArgOperand(1);
367-
auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1);
368-
auto CDV = dyn_cast<ConstantDataVector>(Arg1);
369-
auto CInt = dyn_cast<ConstantInt>(Arg1);
399+
auto CAZ = dyn_cast<ConstantAggregateZero>(Amt);
400+
auto CDV = dyn_cast<ConstantDataVector>(Amt);
401+
auto CInt = dyn_cast<ConstantInt>(Amt);
370402
if (!CAZ && !CDV && !CInt)
371403
return nullptr;
372404

@@ -390,12 +422,6 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
390422
else if (CInt)
391423
Count = CInt->getValue();
392424

393-
auto Vec = II.getArgOperand(0);
394-
auto VT = cast<VectorType>(Vec->getType());
395-
auto SVT = VT->getElementType();
396-
unsigned VWidth = VT->getNumElements();
397-
unsigned BitWidth = SVT->getPrimitiveSizeInBits();
398-
399425
// If shift-by-zero then just return the original value.
400426
if (Count.isNullValue())
401427
return Vec;

llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,9 +2680,12 @@ define <32 x i16> @avx512_psllv_w_512_undef(<32 x i16> %v) {
26802680

26812681
define <8 x i16> @sse2_psrai_w_128_masked(<8 x i16> %v, i32 %a) {
26822682
; CHECK-LABEL: @sse2_psrai_w_128_masked(
2683-
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15
2684-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i16> @llvm.x86.sse2.psrai.w(<8 x i16> [[V:%.*]], i32 [[TMP1]])
2685-
; CHECK-NEXT: ret <8 x i16> [[TMP2]]
2683+
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[A:%.*]] to i16
2684+
; CHECK-NEXT: [[TMP2:%.*]] = and i16 [[TMP1]], 15
2685+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x i16> undef, i16 [[TMP2]], i32 0
2686+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x i16> [[DOTSPLATINSERT]], <8 x i16> undef, <8 x i32> zeroinitializer
2687+
; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i16> [[V:%.*]], [[DOTSPLAT]]
2688+
; CHECK-NEXT: ret <8 x i16> [[TMP3]]
26862689
;
26872690
%1 = and i32 %a, 15
26882691
%2 = tail call <8 x i16> @llvm.x86.sse2.psrai.w(<8 x i16> %v, i32 %1)
@@ -2692,7 +2695,9 @@ define <8 x i16> @sse2_psrai_w_128_masked(<8 x i16> %v, i32 %a) {
26922695
define <8 x i32> @avx2_psrai_d_256_masked(<8 x i32> %v, i32 %a) {
26932696
; CHECK-LABEL: @avx2_psrai_d_256_masked(
26942697
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 31
2695-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psrai.d(<8 x i32> [[V:%.*]], i32 [[TMP1]])
2698+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x i32> undef, i32 [[TMP1]], i32 0
2699+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x i32> [[DOTSPLATINSERT]], <8 x i32> undef, <8 x i32> zeroinitializer
2700+
; CHECK-NEXT: [[TMP2:%.*]] = ashr <8 x i32> [[V:%.*]], [[DOTSPLAT]]
26962701
; CHECK-NEXT: ret <8 x i32> [[TMP2]]
26972702
;
26982703
%1 = and i32 %a, 31
@@ -2703,8 +2708,11 @@ define <8 x i32> @avx2_psrai_d_256_masked(<8 x i32> %v, i32 %a) {
27032708
define <8 x i64> @avx512_psrai_q_512_masked(<8 x i64> %v, i32 %a) {
27042709
; CHECK-LABEL: @avx512_psrai_q_512_masked(
27052710
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 63
2706-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i64> @llvm.x86.avx512.psrai.q.512(<8 x i64> [[V:%.*]], i32 [[TMP1]])
2707-
; CHECK-NEXT: ret <8 x i64> [[TMP2]]
2711+
; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64
2712+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x i64> undef, i64 [[TMP2]], i32 0
2713+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x i64> [[DOTSPLATINSERT]], <8 x i64> undef, <8 x i32> zeroinitializer
2714+
; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i64> [[V:%.*]], [[DOTSPLAT]]
2715+
; CHECK-NEXT: ret <8 x i64> [[TMP3]]
27082716
;
27092717
%1 = and i32 %a, 63
27102718
%2 = tail call <8 x i64> @llvm.x86.avx512.psrai.q.512(<8 x i64> %v, i32 %1)
@@ -2714,7 +2722,9 @@ define <8 x i64> @avx512_psrai_q_512_masked(<8 x i64> %v, i32 %a) {
27142722
define <4 x i32> @sse2_psrli_d_128_masked(<4 x i32> %v, i32 %a) {
27152723
; CHECK-LABEL: @sse2_psrli_d_128_masked(
27162724
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 31
2717-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.sse2.psrli.d(<4 x i32> [[V:%.*]], i32 [[TMP1]])
2725+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> undef, i32 [[TMP1]], i32 0
2726+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> undef, <4 x i32> zeroinitializer
2727+
; CHECK-NEXT: [[TMP2:%.*]] = lshr <4 x i32> [[V:%.*]], [[DOTSPLAT]]
27182728
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
27192729
;
27202730
%1 = and i32 %a, 31
@@ -2725,8 +2735,11 @@ define <4 x i32> @sse2_psrli_d_128_masked(<4 x i32> %v, i32 %a) {
27252735
define <4 x i64> @avx2_psrli_q_256_masked(<4 x i64> %v, i32 %a) {
27262736
; CHECK-LABEL: @avx2_psrli_q_256_masked(
27272737
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 63
2728-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i64> @llvm.x86.avx2.psrli.q(<4 x i64> [[V:%.*]], i32 [[TMP1]])
2729-
; CHECK-NEXT: ret <4 x i64> [[TMP2]]
2738+
; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64
2739+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i64> undef, i64 [[TMP2]], i32 0
2740+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x i64> [[DOTSPLATINSERT]], <4 x i64> undef, <4 x i32> zeroinitializer
2741+
; CHECK-NEXT: [[TMP3:%.*]] = lshr <4 x i64> [[V:%.*]], [[DOTSPLAT]]
2742+
; CHECK-NEXT: ret <4 x i64> [[TMP3]]
27302743
;
27312744
%1 = and i32 %a, 63
27322745
%2 = tail call <4 x i64> @llvm.x86.avx2.psrli.q(<4 x i64> %v, i32 %1)
@@ -2735,9 +2748,12 @@ define <4 x i64> @avx2_psrli_q_256_masked(<4 x i64> %v, i32 %a) {
27352748

27362749
define <32 x i16> @avx512_psrli_w_512_masked(<32 x i16> %v, i32 %a) {
27372750
; CHECK-LABEL: @avx512_psrli_w_512_masked(
2738-
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15
2739-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psrli.w.512(<32 x i16> [[V:%.*]], i32 [[TMP1]])
2740-
; CHECK-NEXT: ret <32 x i16> [[TMP2]]
2751+
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[A:%.*]] to i16
2752+
; CHECK-NEXT: [[TMP2:%.*]] = and i16 [[TMP1]], 15
2753+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <32 x i16> undef, i16 [[TMP2]], i32 0
2754+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <32 x i16> [[DOTSPLATINSERT]], <32 x i16> undef, <32 x i32> zeroinitializer
2755+
; CHECK-NEXT: [[TMP3:%.*]] = lshr <32 x i16> [[V:%.*]], [[DOTSPLAT]]
2756+
; CHECK-NEXT: ret <32 x i16> [[TMP3]]
27412757
;
27422758
%1 = and i32 %a, 15
27432759
%2 = tail call <32 x i16> @llvm.x86.avx512.psrli.w.512(<32 x i16> %v, i32 %1)
@@ -2747,8 +2763,11 @@ define <32 x i16> @avx512_psrli_w_512_masked(<32 x i16> %v, i32 %a) {
27472763
define <2 x i64> @sse2_pslli_q_128_masked(<2 x i64> %v, i32 %a) {
27482764
; CHECK-LABEL: @sse2_pslli_q_128_masked(
27492765
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 63
2750-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <2 x i64> @llvm.x86.sse2.pslli.q(<2 x i64> [[V:%.*]], i32 [[TMP1]])
2751-
; CHECK-NEXT: ret <2 x i64> [[TMP2]]
2766+
; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64
2767+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> undef, i64 [[TMP2]], i32 0
2768+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <2 x i64> [[DOTSPLATINSERT]], <2 x i64> undef, <2 x i32> zeroinitializer
2769+
; CHECK-NEXT: [[TMP3:%.*]] = shl <2 x i64> [[V:%.*]], [[DOTSPLAT]]
2770+
; CHECK-NEXT: ret <2 x i64> [[TMP3]]
27522771
;
27532772
%1 = and i32 %a, 63
27542773
%2 = tail call <2 x i64> @llvm.x86.sse2.pslli.q(<2 x i64> %v, i32 %1)
@@ -2757,9 +2776,12 @@ define <2 x i64> @sse2_pslli_q_128_masked(<2 x i64> %v, i32 %a) {
27572776

27582777
define <16 x i16> @avx2_pslli_w_256_masked(<16 x i16> %v, i32 %a) {
27592778
; CHECK-LABEL: @avx2_pslli_w_256_masked(
2760-
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15
2761-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <16 x i16> @llvm.x86.avx2.pslli.w(<16 x i16> [[V:%.*]], i32 [[TMP1]])
2762-
; CHECK-NEXT: ret <16 x i16> [[TMP2]]
2779+
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[A:%.*]] to i16
2780+
; CHECK-NEXT: [[TMP2:%.*]] = and i16 [[TMP1]], 15
2781+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <16 x i16> undef, i16 [[TMP2]], i32 0
2782+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <16 x i16> [[DOTSPLATINSERT]], <16 x i16> undef, <16 x i32> zeroinitializer
2783+
; CHECK-NEXT: [[TMP3:%.*]] = shl <16 x i16> [[V:%.*]], [[DOTSPLAT]]
2784+
; CHECK-NEXT: ret <16 x i16> [[TMP3]]
27632785
;
27642786
%1 = and i32 %a, 15
27652787
%2 = tail call <16 x i16> @llvm.x86.avx2.pslli.w(<16 x i16> %v, i32 %1)
@@ -2769,7 +2791,9 @@ define <16 x i16> @avx2_pslli_w_256_masked(<16 x i16> %v, i32 %a) {
27692791
define <16 x i32> @avx512_pslli_d_512_masked(<16 x i32> %v, i32 %a) {
27702792
; CHECK-LABEL: @avx512_pslli_d_512_masked(
27712793
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 31
2772-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <16 x i32> @llvm.x86.avx512.pslli.d.512(<16 x i32> [[V:%.*]], i32 [[TMP1]])
2794+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <16 x i32> undef, i32 [[TMP1]], i32 0
2795+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <16 x i32> [[DOTSPLATINSERT]], <16 x i32> undef, <16 x i32> zeroinitializer
2796+
; CHECK-NEXT: [[TMP2:%.*]] = shl <16 x i32> [[V:%.*]], [[DOTSPLAT]]
27732797
; CHECK-NEXT: ret <16 x i32> [[TMP2]]
27742798
;
27752799
%1 = and i32 %a, 31

0 commit comments

Comments
 (0)