Skip to content

Commit 9f87d95

Browse files
Clean up usages of asserting vector getters in Type
Summary: Remove usages of asserting vector getters in Type in preparation for the VectorType refactor. The existence of these functions complicates the refactor while adding little value. Reviewers: mcrosier, efriedma, sdesmalen Reviewed By: efriedma Subscribers: hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D77269
1 parent 0d5f15f commit 9f87d95

File tree

3 files changed

+34
-38
lines changed

3 files changed

+34
-38
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9376,10 +9376,9 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
93769376

93779377
// A pointer vector can not be the return type of the ldN intrinsics. Need to
93789378
// load integer vectors first and then convert to pointer vectors.
9379-
Type *EltTy = VecTy->getVectorElementType();
9379+
Type *EltTy = VecTy->getElementType();
93809380
if (EltTy->isPointerTy())
9381-
VecTy =
9382-
VectorType::get(DL.getIntPtrType(EltTy), VecTy->getVectorNumElements());
9381+
VecTy = VectorType::get(DL.getIntPtrType(EltTy), VecTy->getNumElements());
93839382

93849383
IRBuilder<> Builder(LI);
93859384

@@ -9389,15 +9388,15 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
93899388
if (NumLoads > 1) {
93909389
// If we're going to generate more than one load, reset the sub-vector type
93919390
// to something legal.
9392-
VecTy = VectorType::get(VecTy->getVectorElementType(),
9393-
VecTy->getVectorNumElements() / NumLoads);
9391+
VecTy = VectorType::get(VecTy->getElementType(),
9392+
VecTy->getNumElements() / NumLoads);
93949393

93959394
// We will compute the pointer operand of each load from the original base
93969395
// address using GEPs. Cast the base address to a pointer to the scalar
93979396
// element type.
93989397
BaseAddr = Builder.CreateBitCast(
9399-
BaseAddr, VecTy->getVectorElementType()->getPointerTo(
9400-
LI->getPointerAddressSpace()));
9398+
BaseAddr,
9399+
VecTy->getElementType()->getPointerTo(LI->getPointerAddressSpace()));
94019400
}
94029401

94039402
Type *PtrTy = VecTy->getPointerTo(LI->getPointerAddressSpace());
@@ -9418,9 +9417,8 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
94189417
// If we're generating more than one load, compute the base address of
94199418
// subsequent loads as an offset from the previous.
94209419
if (LoadCount > 0)
9421-
BaseAddr =
9422-
Builder.CreateConstGEP1_32(VecTy->getVectorElementType(), BaseAddr,
9423-
VecTy->getVectorNumElements() * Factor);
9420+
BaseAddr = Builder.CreateConstGEP1_32(VecTy->getElementType(), BaseAddr,
9421+
VecTy->getNumElements() * Factor);
94249422

94259423
CallInst *LdN = Builder.CreateCall(
94269424
LdNFunc, Builder.CreateBitCast(BaseAddr, PtrTy), "ldN");
@@ -9435,8 +9433,8 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
94359433
// Convert the integer vector to pointer vector if the element is pointer.
94369434
if (EltTy->isPointerTy())
94379435
SubVec = Builder.CreateIntToPtr(
9438-
SubVec, VectorType::get(SVI->getType()->getVectorElementType(),
9439-
VecTy->getVectorNumElements()));
9436+
SubVec, VectorType::get(SVI->getType()->getElementType(),
9437+
VecTy->getNumElements()));
94409438
SubVecs[SVI].push_back(SubVec);
94419439
}
94429440
}
@@ -9488,11 +9486,10 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
94889486
"Invalid interleave factor");
94899487

94909488
VectorType *VecTy = SVI->getType();
9491-
assert(VecTy->getVectorNumElements() % Factor == 0 &&
9492-
"Invalid interleaved store");
9489+
assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store");
94939490

9494-
unsigned LaneLen = VecTy->getVectorNumElements() / Factor;
9495-
Type *EltTy = VecTy->getVectorElementType();
9491+
unsigned LaneLen = VecTy->getNumElements() / Factor;
9492+
Type *EltTy = VecTy->getElementType();
94969493
VectorType *SubVecTy = VectorType::get(EltTy, LaneLen);
94979494

94989495
const DataLayout &DL = SI->getModule()->getDataLayout();
@@ -9513,7 +9510,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
95139510
// vectors to integer vectors.
95149511
if (EltTy->isPointerTy()) {
95159512
Type *IntTy = DL.getIntPtrType(EltTy);
9516-
unsigned NumOpElts = Op0->getType()->getVectorNumElements();
9513+
unsigned NumOpElts = cast<VectorType>(Op0->getType())->getNumElements();
95179514

95189515
// Convert to the corresponding integer vector.
95199516
Type *IntVecTy = VectorType::get(IntTy, NumOpElts);
@@ -9530,14 +9527,14 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
95309527
// If we're going to generate more than one store, reset the lane length
95319528
// and sub-vector type to something legal.
95329529
LaneLen /= NumStores;
9533-
SubVecTy = VectorType::get(SubVecTy->getVectorElementType(), LaneLen);
9530+
SubVecTy = VectorType::get(SubVecTy->getElementType(), LaneLen);
95349531

95359532
// We will compute the pointer operand of each store from the original base
95369533
// address using GEPs. Cast the base address to a pointer to the scalar
95379534
// element type.
95389535
BaseAddr = Builder.CreateBitCast(
9539-
BaseAddr, SubVecTy->getVectorElementType()->getPointerTo(
9540-
SI->getPointerAddressSpace()));
9536+
BaseAddr,
9537+
SubVecTy->getElementType()->getPointerTo(SI->getPointerAddressSpace()));
95419538
}
95429539

95439540
auto Mask = SVI->getShuffleMask();
@@ -9582,7 +9579,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
95829579
// If we generating more than one store, we compute the base address of
95839580
// subsequent stores as an offset from the previous.
95849581
if (StoreCount > 0)
9585-
BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getVectorElementType(),
9582+
BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getElementType(),
95869583
BaseAddr, LaneLen * Factor);
95879584

95889585
Ops.push_back(Builder.CreateBitCast(BaseAddr, PtrTy));
@@ -9697,7 +9694,7 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
96979694
return false;
96989695

96999696
// FIXME: Update this method to support scalable addressing modes.
9700-
if (Ty->isVectorTy() && Ty->getVectorIsScalable())
9697+
if (Ty->isVectorTy() && cast<VectorType>(Ty)->isScalable())
97019698
return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
97029699

97039700
// check reg + imm case:

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
209209
// elements in type Ty determine the vector width.
210210
auto toVectorTy = [&](Type *ArgTy) {
211211
return VectorType::get(ArgTy->getScalarType(),
212-
DstTy->getVectorNumElements());
212+
cast<VectorType>(DstTy)->getNumElements());
213213
};
214214

215215
// Exit early if DstTy is not a vector type whose elements are at least
@@ -661,7 +661,8 @@ int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
661661
return LT.first * 2 * AmortizationCost;
662662
}
663663

664-
if (Ty->isVectorTy() && Ty->getVectorElementType()->isIntegerTy(8)) {
664+
if (Ty->isVectorTy() &&
665+
cast<VectorType>(Ty)->getElementType()->isIntegerTy(8)) {
665666
unsigned ProfitableNumElements;
666667
if (Opcode == Instruction::Store)
667668
// We use a custom trunc store lowering so v.4b should be profitable.
@@ -671,8 +672,8 @@ int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
671672
// have to promote the elements to v.2.
672673
ProfitableNumElements = 8;
673674

674-
if (Ty->getVectorNumElements() < ProfitableNumElements) {
675-
unsigned NumVecElts = Ty->getVectorNumElements();
675+
if (cast<VectorType>(Ty)->getNumElements() < ProfitableNumElements) {
676+
unsigned NumVecElts = cast<VectorType>(Ty)->getNumElements();
676677
unsigned NumVectorizableInstsToAmortize = NumVecElts * 2;
677678
// We generate 2 instructions per vector element.
678679
return NumVectorizableInstsToAmortize * NumVecElts * 2;
@@ -690,11 +691,11 @@ int AArch64TTIImpl::getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy,
690691
bool UseMaskForCond,
691692
bool UseMaskForGaps) {
692693
assert(Factor >= 2 && "Invalid interleave factor");
693-
assert(isa<VectorType>(VecTy) && "Expect a vector type");
694+
auto *VecVTy = cast<VectorType>(VecTy);
694695

695696
if (!UseMaskForCond && !UseMaskForGaps &&
696697
Factor <= TLI->getMaxSupportedInterleaveFactor()) {
697-
unsigned NumElts = VecTy->getVectorNumElements();
698+
unsigned NumElts = VecVTy->getNumElements();
698699
auto *SubVecTy = VectorType::get(VecTy->getScalarType(), NumElts / Factor);
699700

700701
// ldN/stN only support legal vector types of size 64 or 128 in bits.
@@ -715,7 +716,7 @@ int AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
715716
for (auto *I : Tys) {
716717
if (!I->isVectorTy())
717718
continue;
718-
if (I->getScalarSizeInBits() * I->getVectorNumElements() == 128)
719+
if (I->getScalarSizeInBits() * cast<VectorType>(I)->getNumElements() == 128)
719720
Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0) +
720721
getMemoryOpCost(Instruction::Load, I, Align(128), 0);
721722
}
@@ -907,7 +908,7 @@ bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
907908

908909
bool AArch64TTIImpl::useReductionIntrinsic(unsigned Opcode, Type *Ty,
909910
TTI::ReductionFlags Flags) const {
910-
assert(isa<VectorType>(Ty) && "Expected Ty to be a vector type");
911+
auto *VTy = cast<VectorType>(Ty);
911912
unsigned ScalarBits = Ty->getScalarSizeInBits();
912913
switch (Opcode) {
913914
case Instruction::FAdd:
@@ -918,10 +919,9 @@ bool AArch64TTIImpl::useReductionIntrinsic(unsigned Opcode, Type *Ty,
918919
case Instruction::Mul:
919920
return false;
920921
case Instruction::Add:
921-
return ScalarBits * Ty->getVectorNumElements() >= 128;
922+
return ScalarBits * VTy->getNumElements() >= 128;
922923
case Instruction::ICmp:
923-
return (ScalarBits < 64) &&
924-
(ScalarBits * Ty->getVectorNumElements() >= 128);
924+
return (ScalarBits < 64) && (ScalarBits * VTy->getNumElements() >= 128);
925925
case Instruction::FCmp:
926926
return Flags.NoNaN;
927927
default:

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
153153
if (!isa<VectorType>(DataType) || !ST->hasSVE())
154154
return false;
155155

156-
Type *Ty = DataType->getVectorElementType();
156+
Type *Ty = cast<VectorType>(DataType)->getElementType();
157157
if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())
158158
return true;
159159

@@ -180,10 +180,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
180180
// can be halved so that each half fits into a register. That's the case if
181181
// the element type fits into a register and the number of elements is a
182182
// power of 2 > 1.
183-
if (isa<VectorType>(DataType)) {
184-
unsigned NumElements = DataType->getVectorNumElements();
185-
unsigned EltSize =
186-
DataType->getVectorElementType()->getScalarSizeInBits();
183+
if (auto *DataTypeVTy = dyn_cast<VectorType>(DataType)) {
184+
unsigned NumElements = DataTypeVTy->getNumElements();
185+
unsigned EltSize = DataTypeVTy->getElementType()->getScalarSizeInBits();
187186
return NumElements > 1 && isPowerOf2_64(NumElements) && EltSize >= 8 &&
188187
EltSize <= 128 && isPowerOf2_64(EltSize);
189188
}

0 commit comments

Comments
 (0)