Skip to content

Commit 944cc5e

Browse files
committed
[SelectionDAGBuilder][CGP][X86] Move some of SDB's gather/scatter uniform base handling to CGP.
I've always found the "findValue" a little odd and inconsistent with other things in SDB. This simplfifies the code in SDB to just handle a splat constant address or a 2 operand GEP in the same BB. This removes the need for "findValue" since the operands to the GEP are guaranteed to be available. The splat constant handling is new, but was needed to avoid regressions due to constant folding combining GEPs created in CGP. CGP is now responsible for canonicalizing gather/scatters into this form. The pattern I'm using for scalarizing, a scalar GEP followed by a GEP with an all zeroes index, seems to be subject to constant folding that the insertelement+shufflevector was not. Differential Revision: https://reviews.llvm.org/D76947
1 parent 741d3c2 commit 944cc5e

File tree

7 files changed

+261
-73
lines changed

7 files changed

+261
-73
lines changed

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ class TypePromotionTransaction;
368368
bool optimizeInst(Instruction *I, bool &ModifiedDT);
369369
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
370370
Type *AccessTy, unsigned AddrSpace);
371+
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
371372
bool optimizeInlineAsmInst(CallInst *CS);
372373
bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
373374
bool optimizeExt(Instruction *&I);
@@ -2041,7 +2042,12 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, bool &ModifiedDT) {
20412042
II->eraseFromParent();
20422043
return true;
20432044
}
2045+
break;
20442046
}
2047+
case Intrinsic::masked_gather:
2048+
return optimizeGatherScatterInst(II, II->getArgOperand(0));
2049+
case Intrinsic::masked_scatter:
2050+
return optimizeGatherScatterInst(II, II->getArgOperand(1));
20452051
}
20462052

20472053
SmallVector<Value *, 2> PtrOps;
@@ -5182,6 +5188,119 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
51825188
return true;
51835189
}
51845190

5191+
/// Rewrite GEP input to gather/scatter to enable SelectionDAGBuilder to find
5192+
/// a uniform base to use for ISD::MGATHER/MSCATTER. SelectionDAGBuilder can
5193+
/// only handle a 2 operand GEP in the same basic block or a splat constant
5194+
/// vector. The 2 operands to the GEP must have a scalar pointer and a vector
5195+
/// index.
5196+
///
5197+
/// If the existing GEP has a vector base pointer that is splat, we can look
5198+
/// through the splat to find the scalar pointer. If we can't find a scalar
5199+
/// pointer there's nothing we can do.
5200+
///
5201+
/// If we have a GEP with more than 2 indices where the middle indices are all
5202+
/// zeroes, we can replace it with 2 GEPs where the second has 2 operands.
5203+
///
5204+
/// If the final index isn't a vector or is a splat, we can emit a scalar GEP
5205+
/// followed by a GEP with an all zeroes vector index. This will enable
5206+
/// SelectionDAGBuilder to use a the scalar GEP as the uniform base and have a
5207+
/// zero index.
5208+
bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
5209+
Value *Ptr) {
5210+
const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
5211+
if (!GEP || !GEP->hasIndices())
5212+
return false;
5213+
5214+
// If the GEP and the gather/scatter aren't in the same BB, don't optimize.
5215+
// FIXME: We should support this by sinking the GEP.
5216+
if (MemoryInst->getParent() != GEP->getParent())
5217+
return false;
5218+
5219+
SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end());
5220+
5221+
bool RewriteGEP = false;
5222+
5223+
if (Ops[0]->getType()->isVectorTy()) {
5224+
Ops[0] = const_cast<Value *>(getSplatValue(Ops[0]));
5225+
if (!Ops[0])
5226+
return false;
5227+
RewriteGEP = true;
5228+
}
5229+
5230+
unsigned FinalIndex = Ops.size() - 1;
5231+
5232+
// Ensure all but the last index is 0.
5233+
// FIXME: This isn't strictly required. All that's required is that they are
5234+
// all scalars or splats.
5235+
for (unsigned i = 1; i < FinalIndex; ++i) {
5236+
auto *C = dyn_cast<Constant>(Ops[i]);
5237+
if (!C)
5238+
return false;
5239+
if (isa<VectorType>(C->getType()))
5240+
C = C->getSplatValue();
5241+
auto *CI = dyn_cast_or_null<ConstantInt>(C);
5242+
if (!CI || !CI->isZero())
5243+
return false;
5244+
// Scalarize the index if needed.
5245+
Ops[i] = CI;
5246+
}
5247+
5248+
// Try to scalarize the final index.
5249+
if (Ops[FinalIndex]->getType()->isVectorTy()) {
5250+
if (Value *V = const_cast<Value *>(getSplatValue(Ops[FinalIndex]))) {
5251+
auto *C = dyn_cast<ConstantInt>(V);
5252+
// Don't scalarize all zeros vector.
5253+
if (!C || !C->isZero()) {
5254+
Ops[FinalIndex] = V;
5255+
RewriteGEP = true;
5256+
}
5257+
}
5258+
}
5259+
5260+
// If we made any changes or the we have extra operands, we need to generate
5261+
// new instructions.
5262+
if (!RewriteGEP && Ops.size() == 2)
5263+
return false;
5264+
5265+
unsigned NumElts = Ptr->getType()->getVectorNumElements();
5266+
5267+
IRBuilder<> Builder(MemoryInst);
5268+
5269+
Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
5270+
5271+
Value *NewAddr;
5272+
5273+
// If the final index isn't a vector, emit a scalar GEP containing all ops
5274+
// and a vector GEP with all zeroes final index.
5275+
if (!Ops[FinalIndex]->getType()->isVectorTy()) {
5276+
NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
5277+
Type *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
5278+
NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
5279+
} else {
5280+
Value *Base = Ops[0];
5281+
Value *Index = Ops[FinalIndex];
5282+
5283+
// Create a scalar GEP if there are more than 2 operands.
5284+
if (Ops.size() != 2) {
5285+
// Replace the last index with 0.
5286+
Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy);
5287+
Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front());
5288+
}
5289+
5290+
// Now create the GEP with scalar pointer and vector index.
5291+
NewAddr = Builder.CreateGEP(Base, Index);
5292+
}
5293+
5294+
MemoryInst->replaceUsesOfWith(Ptr, NewAddr);
5295+
5296+
// If we have no uses, recursively delete the value and all dead instructions
5297+
// using it.
5298+
if (Ptr->use_empty())
5299+
RecursivelyDeleteTriviallyDeadInstructions(Ptr, TLInfo);
5300+
5301+
return true;
5302+
}
5303+
51855304
/// If there are any memory operands, use OptimizeMemoryInst to sink their
51865305
/// address computing into the block when possible / profitable.
51875306
bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,12 +1435,6 @@ SDValue SelectionDAGBuilder::getValue(const Value *V) {
14351435
return Val;
14361436
}
14371437

1438-
// Return true if SDValue exists for the given Value
1439-
bool SelectionDAGBuilder::findValue(const Value *V) const {
1440-
return (NodeMap.find(V) != NodeMap.end()) ||
1441-
(FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end());
1442-
}
1443-
14441438
/// getNonRegisterValue - Return an SDValue for the given Value, but
14451439
/// don't look in FuncInfo.ValueMap for a virtual register.
14461440
SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) {
@@ -4254,70 +4248,49 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
42544248
// In all other cases the function returns 'false'.
42554249
static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,
42564250
ISD::MemIndexType &IndexType, SDValue &Scale,
4257-
SelectionDAGBuilder *SDB) {
4251+
SelectionDAGBuilder *SDB, const BasicBlock *CurBB) {
42584252
SelectionDAG& DAG = SDB->DAG;
4259-
LLVMContext &Context = *DAG.getContext();
4253+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4254+
const DataLayout &DL = DAG.getDataLayout();
42604255

42614256
assert(Ptr->getType()->isVectorTy() && "Uexpected pointer type");
4262-
const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
4263-
if (!GEP)
4264-
return false;
42654257

4266-
const Value *BasePtr = GEP->getPointerOperand();
4267-
if (BasePtr->getType()->isVectorTy()) {
4268-
BasePtr = getSplatValue(BasePtr);
4269-
if (!BasePtr)
4258+
// Handle splat constant pointer.
4259+
if (auto *C = dyn_cast<Constant>(Ptr)) {
4260+
C = C->getSplatValue();
4261+
if (!C)
42704262
return false;
4271-
}
42724263

4273-
unsigned FinalIndex = GEP->getNumOperands() - 1;
4274-
Value *IndexVal = GEP->getOperand(FinalIndex);
4275-
gep_type_iterator GTI = gep_type_begin(*GEP);
4264+
Base = SDB->getValue(C);
42764265

4277-
// Ensure all the other indices are 0.
4278-
for (unsigned i = 1; i < FinalIndex; ++i, ++GTI) {
4279-
auto *C = dyn_cast<Constant>(GEP->getOperand(i));
4280-
if (!C)
4281-
return false;
4282-
if (isa<VectorType>(C->getType()))
4283-
C = C->getSplatValue();
4284-
auto *CI = dyn_cast_or_null<ConstantInt>(C);
4285-
if (!CI || !CI->isZero())
4286-
return false;
4266+
unsigned NumElts = Ptr->getType()->getVectorNumElements();
4267+
EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts);
4268+
Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT);
4269+
IndexType = ISD::SIGNED_SCALED;
4270+
Scale = DAG.getTargetConstant(1, SDB->getCurSDLoc(), TLI.getPointerTy(DL));
4271+
return true;
42874272
}
42884273

4289-
// The operands of the GEP may be defined in another basic block.
4290-
// In this case we'll not find nodes for the operands.
4291-
if (!SDB->findValue(BasePtr))
4274+
const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
4275+
if (!GEP || GEP->getParent() != CurBB)
42924276
return false;
4293-
Constant *C = dyn_cast<Constant>(IndexVal);
4294-
if (!C && !SDB->findValue(IndexVal))
4277+
4278+
if (GEP->getNumOperands() != 2)
42954279
return false;
42964280

4297-
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4298-
const DataLayout &DL = DAG.getDataLayout();
4299-
StructType *STy = GTI.getStructTypeOrNull();
4281+
const Value *BasePtr = GEP->getPointerOperand();
4282+
const Value *IndexVal = GEP->getOperand(GEP->getNumOperands() - 1);
4283+
4284+
// Make sure the base is scalar and the index is a vector.
4285+
if (BasePtr->getType()->isVectorTy() || !IndexVal->getType()->isVectorTy())
4286+
return false;
43004287

4301-
if (STy) {
4302-
const StructLayout *SL = DL.getStructLayout(STy);
4303-
unsigned Field = cast<Constant>(IndexVal)->getUniqueInteger().getZExtValue();
4304-
Scale = DAG.getTargetConstant(1, SDB->getCurSDLoc(), TLI.getPointerTy(DL));
4305-
Index = DAG.getConstant(SL->getElementOffset(Field),
4306-
SDB->getCurSDLoc(), TLI.getPointerTy(DL));
4307-
} else {
4308-
Scale = DAG.getTargetConstant(
4309-
DL.getTypeAllocSize(GEP->getResultElementType()),
4310-
SDB->getCurSDLoc(), TLI.getPointerTy(DL));
4311-
Index = SDB->getValue(IndexVal);
4312-
}
43134288
Base = SDB->getValue(BasePtr);
4289+
Index = SDB->getValue(IndexVal);
43144290
IndexType = ISD::SIGNED_SCALED;
4315-
4316-
if (STy || !Index.getValueType().isVector()) {
4317-
unsigned GEPWidth = cast<VectorType>(GEP->getType())->getNumElements();
4318-
EVT VT = EVT::getVectorVT(Context, Index.getValueType(), GEPWidth);
4319-
Index = DAG.getSplatBuildVector(VT, SDLoc(Index), Index);
4320-
}
4291+
Scale = DAG.getTargetConstant(
4292+
DL.getTypeAllocSize(GEP->getResultElementType()),
4293+
SDB->getCurSDLoc(), TLI.getPointerTy(DL));
43214294
return true;
43224295
}
43234296

@@ -4341,7 +4314,8 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
43414314
SDValue Index;
43424315
ISD::MemIndexType IndexType;
43434316
SDValue Scale;
4344-
bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this);
4317+
bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this,
4318+
I.getParent());
43454319

43464320
unsigned AS = Ptr->getType()->getScalarType()->getPointerAddressSpace();
43474321
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -4452,7 +4426,8 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
44524426
SDValue Index;
44534427
ISD::MemIndexType IndexType;
44544428
SDValue Scale;
4455-
bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this);
4429+
bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this,
4430+
I.getParent());
44564431
unsigned AS = Ptr->getType()->getScalarType()->getPointerAddressSpace();
44574432
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
44584433
MachinePointerInfo(AS), MachineMemOperand::MOLoad,

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,6 @@ class SelectionDAGBuilder {
518518
void resolveOrClearDbgInfo();
519519

520520
SDValue getValue(const Value *V);
521-
bool findValue(const Value *V) const;
522521

523522
/// Return the SDNode for the specified IR value if it exists.
524523
SDNode *getNodeForIRValue(const Value *V) {

llvm/test/CodeGen/X86/masked_gather.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,11 +1721,10 @@ define <8 x i32> @gather_v8i32_v8i32(<8 x i32> %trigger) {
17211721
; AVX512-NEXT: vptestnmd %zmm0, %zmm0, %k0
17221722
; AVX512-NEXT: kshiftlw $8, %k0, %k0
17231723
; AVX512-NEXT: kshiftrw $8, %k0, %k1
1724-
; AVX512-NEXT: vpbroadcastd {{.*#+}} zmm0 = [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3]
1724+
; AVX512-NEXT: vpxor %xmm0, %xmm0, %xmm0
17251725
; AVX512-NEXT: kmovw %k1, %k2
1726-
; AVX512-NEXT: vpgatherdd c(,%zmm0,4), %zmm1 {%k2}
1727-
; AVX512-NEXT: vpbroadcastd {{.*#+}} zmm0 = [28,28,28,28,28,28,28,28,28,28,28,28,28,28,28,28]
1728-
; AVX512-NEXT: vpgatherdd c(,%zmm0), %zmm2 {%k1}
1726+
; AVX512-NEXT: vpgatherdd c+12(,%zmm0), %zmm1 {%k2}
1727+
; AVX512-NEXT: vpgatherdd c+28(,%zmm0), %zmm2 {%k1}
17291728
; AVX512-NEXT: vpaddd %ymm2, %ymm2, %ymm0
17301729
; AVX512-NEXT: vpaddd %ymm0, %ymm1, %ymm0
17311730
; AVX512-NEXT: retq

llvm/test/CodeGen/X86/masked_gather_scatter.ll

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -638,30 +638,38 @@ entry:
638638
define <16 x float> @test11(float* %base, i32 %ind) {
639639
; KNL_64-LABEL: test11:
640640
; KNL_64: # %bb.0:
641-
; KNL_64-NEXT: vpbroadcastd %esi, %zmm1
641+
; KNL_64-NEXT: movslq %esi, %rax
642+
; KNL_64-NEXT: leaq (%rdi,%rax,4), %rax
643+
; KNL_64-NEXT: vxorps %xmm1, %xmm1, %xmm1
642644
; KNL_64-NEXT: kxnorw %k0, %k0, %k1
643-
; KNL_64-NEXT: vgatherdps (%rdi,%zmm1,4), %zmm0 {%k1}
645+
; KNL_64-NEXT: vgatherdps (%rax,%zmm1,4), %zmm0 {%k1}
644646
; KNL_64-NEXT: retq
645647
;
646648
; KNL_32-LABEL: test11:
647649
; KNL_32: # %bb.0:
648650
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
649-
; KNL_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1
651+
; KNL_32-NEXT: shll $2, %eax
652+
; KNL_32-NEXT: addl {{[0-9]+}}(%esp), %eax
653+
; KNL_32-NEXT: vxorps %xmm1, %xmm1, %xmm1
650654
; KNL_32-NEXT: kxnorw %k0, %k0, %k1
651655
; KNL_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
652656
; KNL_32-NEXT: retl
653657
;
654658
; SKX-LABEL: test11:
655659
; SKX: # %bb.0:
656-
; SKX-NEXT: vpbroadcastd %esi, %zmm1
660+
; SKX-NEXT: movslq %esi, %rax
661+
; SKX-NEXT: leaq (%rdi,%rax,4), %rax
662+
; SKX-NEXT: vxorps %xmm1, %xmm1, %xmm1
657663
; SKX-NEXT: kxnorw %k0, %k0, %k1
658-
; SKX-NEXT: vgatherdps (%rdi,%zmm1,4), %zmm0 {%k1}
664+
; SKX-NEXT: vgatherdps (%rax,%zmm1,4), %zmm0 {%k1}
659665
; SKX-NEXT: retq
660666
;
661667
; SKX_32-LABEL: test11:
662668
; SKX_32: # %bb.0:
663669
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
664-
; SKX_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1
670+
; SKX_32-NEXT: shll $2, %eax
671+
; SKX_32-NEXT: addl {{[0-9]+}}(%esp), %eax
672+
; SKX_32-NEXT: vxorps %xmm1, %xmm1, %xmm1
665673
; SKX_32-NEXT: kxnorw %k0, %k0, %k1
666674
; SKX_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
667675
; SKX_32-NEXT: retl

llvm/test/CodeGen/X86/pr45067.ll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
define void @foo(<8 x i32>* %x, <8 x i1> %y) {
77
; CHECK-LABEL: foo:
88
; CHECK: ## %bb.0:
9-
; CHECK-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
10-
; CHECK-NEXT: vpbroadcastq _global@{{.*}}(%rip), %ymm2
11-
; CHECK-NEXT: vpgatherqd %xmm1, (,%ymm2), %xmm3
9+
; CHECK-NEXT: vpcmpeqd %ymm1, %ymm1, %ymm1
10+
; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2
11+
; CHECK-NEXT: movq _global@{{.*}}(%rip), %rax
12+
; CHECK-NEXT: vpgatherdd %ymm1, (%rax,%ymm2), %ymm3
1213
; CHECK-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
1314
; CHECK-NEXT: vpslld $31, %ymm0, %ymm0
14-
; CHECK-NEXT: vinserti128 $1, %xmm3, %ymm3, %ymm1
15-
; CHECK-NEXT: vpmaskmovd %ymm1, %ymm0, (%rdi)
15+
; CHECK-NEXT: vpmaskmovd %ymm3, %ymm0, (%rdi)
1616
; CHECK-NEXT: ud2
1717
%tmp = call <8 x i32> @llvm.masked.gather.v8i32.v8p0i32(<8 x i32*> <i32* @global, i32* @global, i32* @global, i32* @global, i32* @global, i32* @global, i32* @global, i32* @global>, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
1818
call void @llvm.masked.store.v8i32.p0v8i32(<8 x i32> %tmp, <8 x i32>* %x, i32 4, <8 x i1> %y)

0 commit comments

Comments
 (0)