Skip to content

Commit e3056ae

Browse files
committed
[NFC][TTI] Explicit use of VectorType
The API for shuffles and reductions uses generic Type parameters, instead of VectorType, and so assertions and casts are used a lot. This patch makes those types explicit, which means that the clients can't be lazy, but results in less ambiguity, and that can only be a good thing. Bugzilla: https://bugs.llvm.org/show_bug.cgi?id=45562 Differential Revision: https://reviews.llvm.org/D78357
1 parent a8e15ee commit e3056ae

17 files changed

+141
-135
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -910,8 +910,8 @@ class TargetTransformInfo {
910910
/// extraction shuffle kinds to show the insert/extract point and the type of
911911
/// the subvector being inserted/extracted.
912912
/// NOTE: For subvector extractions Tp represents the source type.
913-
int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index = 0,
914-
Type *SubTp = nullptr) const;
913+
int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index = 0,
914+
VectorType *SubTp = nullptr) const;
915915

916916
/// \return The expected cost of cast instructions, such as bitcast, trunc,
917917
/// zext, etc. If there is an existing instruction that holds Opcode, it
@@ -989,10 +989,10 @@ class TargetTransformInfo {
989989
/// Split:
990990
/// (v0, v1, v2, v3)
991991
/// ((v0+v2), (v1+v3), undef, undef)
992-
int getArithmeticReductionCost(unsigned Opcode, Type *Ty,
992+
int getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
993993
bool IsPairwiseForm) const;
994-
int getMinMaxReductionCost(Type *Ty, Type *CondTy, bool IsPairwiseForm,
995-
bool IsUnsigned) const;
994+
int getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
995+
bool IsPairwiseForm, bool IsUnsigned) const;
996996

997997
/// \returns The cost of Intrinsic instructions. Analyses the real arguments.
998998
/// Three cases are handled: 1. scalar instruction 2. vector instruction
@@ -1332,8 +1332,8 @@ class TargetTransformInfo::Concept {
13321332
OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo,
13331333
OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
13341334
const Instruction *CxtI = nullptr) = 0;
1335-
virtual int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index,
1336-
Type *SubTp) = 0;
1335+
virtual int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index,
1336+
VectorType *SubTp) = 0;
13371337
virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
13381338
const Instruction *I) = 0;
13391339
virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst,
@@ -1356,9 +1356,9 @@ class TargetTransformInfo::Concept {
13561356
ArrayRef<unsigned> Indices, unsigned Alignment,
13571357
unsigned AddressSpace, bool UseMaskForCond = false,
13581358
bool UseMaskForGaps = false) = 0;
1359-
virtual int getArithmeticReductionCost(unsigned Opcode, Type *Ty,
1359+
virtual int getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
13601360
bool IsPairwiseForm) = 0;
1361-
virtual int getMinMaxReductionCost(Type *Ty, Type *CondTy,
1361+
virtual int getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
13621362
bool IsPairwiseForm, bool IsUnsigned) = 0;
13631363
virtual int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
13641364
ArrayRef<Type *> Tys, FastMathFlags FMF,
@@ -1731,8 +1731,8 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
17311731
return Impl.getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info,
17321732
Opd1PropInfo, Opd2PropInfo, Args, CxtI);
17331733
}
1734-
int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index,
1735-
Type *SubTp) override {
1734+
int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index,
1735+
VectorType *SubTp) override {
17361736
return Impl.getShuffleCost(Kind, Tp, Index, SubTp);
17371737
}
17381738
int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
@@ -1775,12 +1775,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
17751775
Alignment, AddressSpace,
17761776
UseMaskForCond, UseMaskForGaps);
17771777
}
1778-
int getArithmeticReductionCost(unsigned Opcode, Type *Ty,
1778+
int getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
17791779
bool IsPairwiseForm) override {
17801780
return Impl.getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm);
17811781
}
1782-
int getMinMaxReductionCost(Type *Ty, Type *CondTy, bool IsPairwiseForm,
1783-
bool IsUnsigned) override {
1782+
int getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
1783+
bool IsPairwiseForm, bool IsUnsigned) override {
17841784
return Impl.getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned);
17851785
}
17861786
int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, ArrayRef<Type *> Tys,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,8 @@ class TargetTransformInfoImplBase {
438438
return 1;
439439
}
440440

441-
unsigned getShuffleCost(TTI::ShuffleKind Kind, Type *Ty, int Index,
442-
Type *SubTp) {
441+
unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *Ty, int Index,
442+
VectorType *SubTp) {
443443
return 1;
444444
}
445445

@@ -512,9 +512,9 @@ class TargetTransformInfoImplBase {
512512
return 0;
513513
}
514514

515-
unsigned getArithmeticReductionCost(unsigned, Type *, bool) { return 1; }
515+
unsigned getArithmeticReductionCost(unsigned, VectorType *, bool) { return 1; }
516516

517-
unsigned getMinMaxReductionCost(Type *, Type *, bool, bool) { return 1; }
517+
unsigned getMinMaxReductionCost(VectorType *, VectorType *, bool, bool) { return 1; }
518518

519519
unsigned getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { return 0; }
520520

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
8080

8181
/// Estimate a cost of Broadcast as an extract and sequence of insert
8282
/// operations.
83-
unsigned getBroadcastShuffleOverhead(Type *Ty) {
84-
auto *VTy = cast<VectorType>(Ty);
83+
unsigned getBroadcastShuffleOverhead(VectorType *VTy) {
8584
unsigned Cost = 0;
8685
// Broadcast cost is equal to the cost of extracting the zero'th element
8786
// plus the cost of inserting it into every element of the result vector.
@@ -97,8 +96,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
9796

9897
/// Estimate a cost of shuffle as a sequence of extract and insert
9998
/// operations.
100-
unsigned getPermuteShuffleOverhead(Type *Ty) {
101-
auto *VTy = cast<VectorType>(Ty);
99+
unsigned getPermuteShuffleOverhead(VectorType *VTy) {
102100
unsigned Cost = 0;
103101
// Shuffle cost is equal to the cost of extracting element from its argument
104102
// plus the cost of inserting them onto the result vector.
@@ -118,11 +116,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
118116

119117
/// Estimate a cost of subvector extraction as a sequence of extract and
120118
/// insert operations.
121-
unsigned getExtractSubvectorOverhead(Type *Ty, int Index, Type *SubTy) {
122-
assert(Ty && Ty->isVectorTy() && SubTy && SubTy->isVectorTy() &&
119+
unsigned getExtractSubvectorOverhead(VectorType *VTy, int Index,
120+
VectorType *SubVTy) {
121+
assert(VTy && SubVTy &&
123122
"Can only extract subvectors from vectors");
124-
auto *VTy = cast<VectorType>(Ty);
125-
auto *SubVTy = cast<VectorType>(SubTy);
126123
int NumSubElts = SubVTy->getNumElements();
127124
assert((Index + NumSubElts) <= (int)VTy->getNumElements() &&
128125
"SK_ExtractSubvector index out of range");
@@ -142,11 +139,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
142139

143140
/// Estimate a cost of subvector insertion as a sequence of extract and
144141
/// insert operations.
145-
unsigned getInsertSubvectorOverhead(Type *Ty, int Index, Type *SubTy) {
146-
assert(Ty && Ty->isVectorTy() && SubTy && SubTy->isVectorTy() &&
142+
unsigned getInsertSubvectorOverhead(VectorType *VTy, int Index,
143+
VectorType *SubVTy) {
144+
assert(VTy && SubVTy &&
147145
"Can only insert subvectors into vectors");
148-
auto *VTy = cast<VectorType>(Ty);
149-
auto *SubVTy = cast<VectorType>(SubTy);
150146
int NumSubElts = SubVTy->getNumElements();
151147
assert((Index + NumSubElts) <= (int)VTy->getNumElements() &&
152148
"SK_InsertSubvector index out of range");
@@ -683,8 +679,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
683679
return OpCost;
684680
}
685681

686-
unsigned getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index,
687-
Type *SubTp) {
682+
unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index,
683+
VectorType *SubTp) {
688684
switch (Kind) {
689685
case TTI::SK_Broadcast:
690686
return getBroadcastShuffleOverhead(Tp);
@@ -1198,6 +1194,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
11981194
unsigned ScalarizationCostPassed = std::numeric_limits<unsigned>::max(),
11991195
const Instruction *I = nullptr) {
12001196
auto *ConcreteTTI = static_cast<T *>(this);
1197+
auto *VecOpTy = Tys.empty() ? nullptr : dyn_cast<VectorType>(Tys[0]);
12011198

12021199
SmallVector<unsigned, 2> ISDs;
12031200
unsigned SingleCallCost = 10; // Library call cost. Make it expensive.
@@ -1320,41 +1317,43 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
13201317
case Intrinsic::masked_load:
13211318
return ConcreteTTI->getMaskedMemoryOpCost(Instruction::Load, RetTy, 0, 0);
13221319
case Intrinsic::experimental_vector_reduce_add:
1323-
return ConcreteTTI->getArithmeticReductionCost(Instruction::Add, Tys[0],
1320+
return ConcreteTTI->getArithmeticReductionCost(Instruction::Add, VecOpTy,
13241321
/*IsPairwiseForm=*/false);
13251322
case Intrinsic::experimental_vector_reduce_mul:
1326-
return ConcreteTTI->getArithmeticReductionCost(Instruction::Mul, Tys[0],
1323+
return ConcreteTTI->getArithmeticReductionCost(Instruction::Mul, VecOpTy,
13271324
/*IsPairwiseForm=*/false);
13281325
case Intrinsic::experimental_vector_reduce_and:
1329-
return ConcreteTTI->getArithmeticReductionCost(Instruction::And, Tys[0],
1326+
return ConcreteTTI->getArithmeticReductionCost(Instruction::And, VecOpTy,
13301327
/*IsPairwiseForm=*/false);
13311328
case Intrinsic::experimental_vector_reduce_or:
1332-
return ConcreteTTI->getArithmeticReductionCost(Instruction::Or, Tys[0],
1329+
return ConcreteTTI->getArithmeticReductionCost(Instruction::Or, VecOpTy,
13331330
/*IsPairwiseForm=*/false);
13341331
case Intrinsic::experimental_vector_reduce_xor:
1335-
return ConcreteTTI->getArithmeticReductionCost(Instruction::Xor, Tys[0],
1332+
return ConcreteTTI->getArithmeticReductionCost(Instruction::Xor, VecOpTy,
13361333
/*IsPairwiseForm=*/false);
13371334
case Intrinsic::experimental_vector_reduce_v2_fadd:
13381335
return ConcreteTTI->getArithmeticReductionCost(
1339-
Instruction::FAdd, Tys[0],
1336+
Instruction::FAdd, VecOpTy,
13401337
/*IsPairwiseForm=*/false); // FIXME: Add new flag for cost of strict
13411338
// reductions.
13421339
case Intrinsic::experimental_vector_reduce_v2_fmul:
13431340
return ConcreteTTI->getArithmeticReductionCost(
1344-
Instruction::FMul, Tys[0],
1341+
Instruction::FMul, VecOpTy,
13451342
/*IsPairwiseForm=*/false); // FIXME: Add new flag for cost of strict
13461343
// reductions.
13471344
case Intrinsic::experimental_vector_reduce_smax:
13481345
case Intrinsic::experimental_vector_reduce_smin:
13491346
case Intrinsic::experimental_vector_reduce_fmax:
13501347
case Intrinsic::experimental_vector_reduce_fmin:
13511348
return ConcreteTTI->getMinMaxReductionCost(
1352-
Tys[0], CmpInst::makeCmpResultType(Tys[0]), /*IsPairwiseForm=*/false,
1349+
VecOpTy, cast<VectorType>(CmpInst::makeCmpResultType(VecOpTy)),
1350+
/*IsPairwiseForm=*/false,
13531351
/*IsUnsigned=*/false);
13541352
case Intrinsic::experimental_vector_reduce_umax:
13551353
case Intrinsic::experimental_vector_reduce_umin:
13561354
return ConcreteTTI->getMinMaxReductionCost(
1357-
Tys[0], CmpInst::makeCmpResultType(Tys[0]), /*IsPairwiseForm=*/false,
1355+
VecOpTy, cast<VectorType>(CmpInst::makeCmpResultType(VecOpTy)),
1356+
/*IsPairwiseForm=*/false,
13581357
/*IsUnsigned=*/true);
13591358
case Intrinsic::sadd_sat:
13601359
case Intrinsic::ssub_sat: {
@@ -1639,11 +1638,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
16391638
///
16401639
/// The cost model should take into account that the actual length of the
16411640
/// vector is reduced on each iteration.
1642-
unsigned getArithmeticReductionCost(unsigned Opcode, Type *Ty,
1641+
unsigned getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
16431642
bool IsPairwise) {
1644-
assert(Ty->isVectorTy() && "Expect a vector type");
1645-
Type *ScalarTy = cast<VectorType>(Ty)->getElementType();
1646-
unsigned NumVecElts = cast<VectorType>(Ty)->getNumElements();
1643+
Type *ScalarTy = Ty->getElementType();
1644+
unsigned NumVecElts = Ty->getNumElements();
16471645
unsigned NumReduxLevels = Log2_32(NumVecElts);
16481646
unsigned ArithCost = 0;
16491647
unsigned ShuffleCost = 0;
@@ -1655,7 +1653,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
16551653
LT.second.isVector() ? LT.second.getVectorNumElements() : 1;
16561654
while (NumVecElts > MVTLen) {
16571655
NumVecElts /= 2;
1658-
Type *SubTy = VectorType::get(ScalarTy, NumVecElts);
1656+
VectorType *SubTy = VectorType::get(ScalarTy, NumVecElts);
16591657
// Assume the pairwise shuffles add a cost.
16601658
ShuffleCost += (IsPairwise + 1) *
16611659
ConcreteTTI->getShuffleCost(TTI::SK_ExtractSubvector, Ty,
@@ -1689,12 +1687,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
16891687

16901688
/// Try to calculate op costs for min/max reduction operations.
16911689
/// \param CondTy Conditional type for the Select instruction.
1692-
unsigned getMinMaxReductionCost(Type *Ty, Type *CondTy, bool IsPairwise,
1693-
bool) {
1694-
assert(Ty->isVectorTy() && "Expect a vector type");
1695-
Type *ScalarTy = cast<VectorType>(Ty)->getElementType();
1696-
Type *ScalarCondTy = cast<VectorType>(CondTy)->getElementType();
1697-
unsigned NumVecElts = cast<VectorType>(Ty)->getNumElements();
1690+
unsigned getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
1691+
bool IsPairwise, bool) {
1692+
Type *ScalarTy = Ty->getElementType();
1693+
Type *ScalarCondTy = CondTy->getElementType();
1694+
unsigned NumVecElts = Ty->getNumElements();
16981695
unsigned NumReduxLevels = Log2_32(NumVecElts);
16991696
unsigned CmpOpcode;
17001697
if (Ty->isFPOrFPVectorTy()) {
@@ -1714,7 +1711,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
17141711
LT.second.isVector() ? LT.second.getVectorNumElements() : 1;
17151712
while (NumVecElts > MVTLen) {
17161713
NumVecElts /= 2;
1717-
Type *SubTy = VectorType::get(ScalarTy, NumVecElts);
1714+
VectorType *SubTy = VectorType::get(ScalarTy, NumVecElts);
17181715
CondTy = VectorType::get(ScalarCondTy, NumVecElts);
17191716

17201717
// Assume the pairwise shuffles add a cost.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,8 @@ int TargetTransformInfo::getArithmeticInstrCost(
599599
return Cost;
600600
}
601601

602-
int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, Type *Ty, int Index,
603-
Type *SubTp) const {
602+
int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, VectorType *Ty,
603+
int Index, VectorType *SubTp) const {
604604
int Cost = TTIImpl->getShuffleCost(Kind, Ty, Index, SubTp);
605605
assert(Cost >= 0 && "TTI should not produce negative costs!");
606606
return Cost;
@@ -732,14 +732,16 @@ int TargetTransformInfo::getMemcpyCost(const Instruction *I) const {
732732
return Cost;
733733
}
734734

735-
int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty,
735+
int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode,
736+
VectorType *Ty,
736737
bool IsPairwiseForm) const {
737738
int Cost = TTIImpl->getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm);
738739
assert(Cost >= 0 && "TTI should not produce negative costs!");
739740
return Cost;
740741
}
741742

742-
int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy,
743+
int TargetTransformInfo::getMinMaxReductionCost(VectorType *Ty,
744+
VectorType *CondTy,
743745
bool IsPairwiseForm,
744746
bool IsUnsigned) const {
745747
int Cost =
@@ -1011,7 +1013,8 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
10111013
}
10121014

10131015
static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
1014-
unsigned &Opcode, Type *&Ty) {
1016+
unsigned &Opcode,
1017+
VectorType *&Ty) {
10151018
if (!EnableReduxCost)
10161019
return RK_None;
10171020

@@ -1076,7 +1079,7 @@ getShuffleAndOtherOprd(Value *L, Value *R) {
10761079

10771080
static ReductionKind
10781081
matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
1079-
unsigned &Opcode, Type *&Ty) {
1082+
unsigned &Opcode, VectorType *&Ty) {
10801083
if (!EnableReduxCost)
10811084
return RK_None;
10821085

@@ -1249,19 +1252,19 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
12491252
// Try to match a reduction sequence (series of shufflevector and vector
12501253
// adds followed by a extractelement).
12511254
unsigned ReduxOpCode;
1252-
Type *ReduxType;
1255+
VectorType *ReduxType;
12531256

12541257
switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
12551258
case RK_Arithmetic:
12561259
return getArithmeticReductionCost(ReduxOpCode, ReduxType,
12571260
/*IsPairwiseForm=*/false);
12581261
case RK_MinMax:
12591262
return getMinMaxReductionCost(
1260-
ReduxType, CmpInst::makeCmpResultType(ReduxType),
1263+
ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
12611264
/*IsPairwiseForm=*/false, /*IsUnsigned=*/false);
12621265
case RK_UnsignedMinMax:
12631266
return getMinMaxReductionCost(
1264-
ReduxType, CmpInst::makeCmpResultType(ReduxType),
1267+
ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
12651268
/*IsPairwiseForm=*/false, /*IsUnsigned=*/true);
12661269
case RK_None:
12671270
break;
@@ -1273,11 +1276,11 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
12731276
/*IsPairwiseForm=*/true);
12741277
case RK_MinMax:
12751278
return getMinMaxReductionCost(
1276-
ReduxType, CmpInst::makeCmpResultType(ReduxType),
1279+
ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
12771280
/*IsPairwiseForm=*/true, /*IsUnsigned=*/false);
12781281
case RK_UnsignedMinMax:
12791282
return getMinMaxReductionCost(
1280-
ReduxType, CmpInst::makeCmpResultType(ReduxType),
1283+
ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
12811284
/*IsPairwiseForm=*/true, /*IsUnsigned=*/true);
12821285
case RK_None:
12831286
break;
@@ -1298,8 +1301,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
12981301
return 0; // Model all ExtractValue nodes as free.
12991302
case Instruction::ShuffleVector: {
13001303
const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
1301-
Type *Ty = Shuffle->getType();
1302-
Type *SrcTy = Shuffle->getOperand(0)->getType();
1304+
auto *Ty = cast<VectorType>(Shuffle->getType());
1305+
auto *SrcTy = cast<VectorType>(Shuffle->getOperand(0)->getType());
13031306

13041307
// TODO: Identify and add costs for insert subvector, etc.
13051308
int SubIndex;

0 commit comments

Comments
 (0)