@@ -80,8 +80,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
80
80
81
81
// / Estimate a cost of Broadcast as an extract and sequence of insert
82
82
// / operations.
83
- unsigned getBroadcastShuffleOverhead (Type *Ty) {
84
- auto *VTy = cast<VectorType>(Ty);
83
+ unsigned getBroadcastShuffleOverhead (VectorType *VTy) {
85
84
unsigned Cost = 0 ;
86
85
// Broadcast cost is equal to the cost of extracting the zero'th element
87
86
// plus the cost of inserting it into every element of the result vector.
@@ -97,8 +96,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
97
96
98
97
// / Estimate a cost of shuffle as a sequence of extract and insert
99
98
// / operations.
100
- unsigned getPermuteShuffleOverhead (Type *Ty) {
101
- auto *VTy = cast<VectorType>(Ty);
99
+ unsigned getPermuteShuffleOverhead (VectorType *VTy) {
102
100
unsigned Cost = 0 ;
103
101
// Shuffle cost is equal to the cost of extracting element from its argument
104
102
// plus the cost of inserting them onto the result vector.
@@ -118,11 +116,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
118
116
119
117
// / Estimate a cost of subvector extraction as a sequence of extract and
120
118
// / insert operations.
121
- unsigned getExtractSubvectorOverhead (Type *Ty, int Index, Type *SubTy) {
122
- assert (Ty && Ty->isVectorTy () && SubTy && SubTy->isVectorTy () &&
119
+ unsigned getExtractSubvectorOverhead (VectorType *VTy, int Index,
120
+ VectorType *SubVTy) {
121
+ assert (VTy && SubVTy &&
123
122
" Can only extract subvectors from vectors" );
124
- auto *VTy = cast<VectorType>(Ty);
125
- auto *SubVTy = cast<VectorType>(SubTy);
126
123
int NumSubElts = SubVTy->getNumElements ();
127
124
assert ((Index + NumSubElts) <= (int )VTy->getNumElements () &&
128
125
" SK_ExtractSubvector index out of range" );
@@ -142,11 +139,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
142
139
143
140
// / Estimate a cost of subvector insertion as a sequence of extract and
144
141
// / insert operations.
145
- unsigned getInsertSubvectorOverhead (Type *Ty, int Index, Type *SubTy) {
146
- assert (Ty && Ty->isVectorTy () && SubTy && SubTy->isVectorTy () &&
142
+ unsigned getInsertSubvectorOverhead (VectorType *VTy, int Index,
143
+ VectorType *SubVTy) {
144
+ assert (VTy && SubVTy &&
147
145
" Can only insert subvectors into vectors" );
148
- auto *VTy = cast<VectorType>(Ty);
149
- auto *SubVTy = cast<VectorType>(SubTy);
150
146
int NumSubElts = SubVTy->getNumElements ();
151
147
assert ((Index + NumSubElts) <= (int )VTy->getNumElements () &&
152
148
" SK_InsertSubvector index out of range" );
@@ -683,8 +679,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
683
679
return OpCost;
684
680
}
685
681
686
- unsigned getShuffleCost (TTI::ShuffleKind Kind, Type *Tp, int Index,
687
- Type *SubTp) {
682
+ unsigned getShuffleCost (TTI::ShuffleKind Kind, VectorType *Tp, int Index,
683
+ VectorType *SubTp) {
688
684
switch (Kind) {
689
685
case TTI::SK_Broadcast:
690
686
return getBroadcastShuffleOverhead (Tp);
@@ -1198,6 +1194,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
1198
1194
unsigned ScalarizationCostPassed = std::numeric_limits<unsigned >::max(),
1199
1195
const Instruction *I = nullptr) {
1200
1196
auto *ConcreteTTI = static_cast <T *>(this );
1197
+ auto *VecOpTy = Tys.empty () ? nullptr : dyn_cast<VectorType>(Tys[0 ]);
1201
1198
1202
1199
SmallVector<unsigned , 2 > ISDs;
1203
1200
unsigned SingleCallCost = 10 ; // Library call cost. Make it expensive.
@@ -1320,41 +1317,43 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
1320
1317
case Intrinsic::masked_load:
1321
1318
return ConcreteTTI->getMaskedMemoryOpCost (Instruction::Load, RetTy, 0 , 0 );
1322
1319
case Intrinsic::experimental_vector_reduce_add:
1323
- return ConcreteTTI->getArithmeticReductionCost (Instruction::Add, Tys[ 0 ] ,
1320
+ return ConcreteTTI->getArithmeticReductionCost (Instruction::Add, VecOpTy ,
1324
1321
/* IsPairwiseForm=*/ false );
1325
1322
case Intrinsic::experimental_vector_reduce_mul:
1326
- return ConcreteTTI->getArithmeticReductionCost (Instruction::Mul, Tys[ 0 ] ,
1323
+ return ConcreteTTI->getArithmeticReductionCost (Instruction::Mul, VecOpTy ,
1327
1324
/* IsPairwiseForm=*/ false );
1328
1325
case Intrinsic::experimental_vector_reduce_and:
1329
- return ConcreteTTI->getArithmeticReductionCost (Instruction::And, Tys[ 0 ] ,
1326
+ return ConcreteTTI->getArithmeticReductionCost (Instruction::And, VecOpTy ,
1330
1327
/* IsPairwiseForm=*/ false );
1331
1328
case Intrinsic::experimental_vector_reduce_or:
1332
- return ConcreteTTI->getArithmeticReductionCost (Instruction::Or, Tys[ 0 ] ,
1329
+ return ConcreteTTI->getArithmeticReductionCost (Instruction::Or, VecOpTy ,
1333
1330
/* IsPairwiseForm=*/ false );
1334
1331
case Intrinsic::experimental_vector_reduce_xor:
1335
- return ConcreteTTI->getArithmeticReductionCost (Instruction::Xor, Tys[ 0 ] ,
1332
+ return ConcreteTTI->getArithmeticReductionCost (Instruction::Xor, VecOpTy ,
1336
1333
/* IsPairwiseForm=*/ false );
1337
1334
case Intrinsic::experimental_vector_reduce_v2_fadd:
1338
1335
return ConcreteTTI->getArithmeticReductionCost (
1339
- Instruction::FAdd, Tys[ 0 ] ,
1336
+ Instruction::FAdd, VecOpTy ,
1340
1337
/* IsPairwiseForm=*/ false ); // FIXME: Add new flag for cost of strict
1341
1338
// reductions.
1342
1339
case Intrinsic::experimental_vector_reduce_v2_fmul:
1343
1340
return ConcreteTTI->getArithmeticReductionCost (
1344
- Instruction::FMul, Tys[ 0 ] ,
1341
+ Instruction::FMul, VecOpTy ,
1345
1342
/* IsPairwiseForm=*/ false ); // FIXME: Add new flag for cost of strict
1346
1343
// reductions.
1347
1344
case Intrinsic::experimental_vector_reduce_smax:
1348
1345
case Intrinsic::experimental_vector_reduce_smin:
1349
1346
case Intrinsic::experimental_vector_reduce_fmax:
1350
1347
case Intrinsic::experimental_vector_reduce_fmin:
1351
1348
return ConcreteTTI->getMinMaxReductionCost (
1352
- Tys[0 ], CmpInst::makeCmpResultType (Tys[0 ]), /* IsPairwiseForm=*/ false ,
1349
+ VecOpTy, cast<VectorType>(CmpInst::makeCmpResultType (VecOpTy)),
1350
+ /* IsPairwiseForm=*/ false ,
1353
1351
/* IsUnsigned=*/ false );
1354
1352
case Intrinsic::experimental_vector_reduce_umax:
1355
1353
case Intrinsic::experimental_vector_reduce_umin:
1356
1354
return ConcreteTTI->getMinMaxReductionCost (
1357
- Tys[0 ], CmpInst::makeCmpResultType (Tys[0 ]), /* IsPairwiseForm=*/ false ,
1355
+ VecOpTy, cast<VectorType>(CmpInst::makeCmpResultType (VecOpTy)),
1356
+ /* IsPairwiseForm=*/ false ,
1358
1357
/* IsUnsigned=*/ true );
1359
1358
case Intrinsic::sadd_sat:
1360
1359
case Intrinsic::ssub_sat: {
@@ -1639,11 +1638,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
1639
1638
// /
1640
1639
// / The cost model should take into account that the actual length of the
1641
1640
// / vector is reduced on each iteration.
1642
- unsigned getArithmeticReductionCost (unsigned Opcode, Type *Ty,
1641
+ unsigned getArithmeticReductionCost (unsigned Opcode, VectorType *Ty,
1643
1642
bool IsPairwise) {
1644
- assert (Ty->isVectorTy () && " Expect a vector type" );
1645
- Type *ScalarTy = cast<VectorType>(Ty)->getElementType ();
1646
- unsigned NumVecElts = cast<VectorType>(Ty)->getNumElements ();
1643
+ Type *ScalarTy = Ty->getElementType ();
1644
+ unsigned NumVecElts = Ty->getNumElements ();
1647
1645
unsigned NumReduxLevels = Log2_32 (NumVecElts);
1648
1646
unsigned ArithCost = 0 ;
1649
1647
unsigned ShuffleCost = 0 ;
@@ -1655,7 +1653,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
1655
1653
LT.second .isVector () ? LT.second .getVectorNumElements () : 1 ;
1656
1654
while (NumVecElts > MVTLen) {
1657
1655
NumVecElts /= 2 ;
1658
- Type *SubTy = VectorType::get (ScalarTy, NumVecElts);
1656
+ VectorType *SubTy = VectorType::get (ScalarTy, NumVecElts);
1659
1657
// Assume the pairwise shuffles add a cost.
1660
1658
ShuffleCost += (IsPairwise + 1 ) *
1661
1659
ConcreteTTI->getShuffleCost (TTI::SK_ExtractSubvector, Ty,
@@ -1689,12 +1687,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
1689
1687
1690
1688
// / Try to calculate op costs for min/max reduction operations.
1691
1689
// / \param CondTy Conditional type for the Select instruction.
1692
- unsigned getMinMaxReductionCost (Type *Ty, Type *CondTy, bool IsPairwise,
1693
- bool ) {
1694
- assert (Ty->isVectorTy () && " Expect a vector type" );
1695
- Type *ScalarTy = cast<VectorType>(Ty)->getElementType ();
1696
- Type *ScalarCondTy = cast<VectorType>(CondTy)->getElementType ();
1697
- unsigned NumVecElts = cast<VectorType>(Ty)->getNumElements ();
1690
+ unsigned getMinMaxReductionCost (VectorType *Ty, VectorType *CondTy,
1691
+ bool IsPairwise, bool ) {
1692
+ Type *ScalarTy = Ty->getElementType ();
1693
+ Type *ScalarCondTy = CondTy->getElementType ();
1694
+ unsigned NumVecElts = Ty->getNumElements ();
1698
1695
unsigned NumReduxLevels = Log2_32 (NumVecElts);
1699
1696
unsigned CmpOpcode;
1700
1697
if (Ty->isFPOrFPVectorTy ()) {
@@ -1714,7 +1711,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
1714
1711
LT.second .isVector () ? LT.second .getVectorNumElements () : 1 ;
1715
1712
while (NumVecElts > MVTLen) {
1716
1713
NumVecElts /= 2 ;
1717
- Type *SubTy = VectorType::get (ScalarTy, NumVecElts);
1714
+ VectorType *SubTy = VectorType::get (ScalarTy, NumVecElts);
1718
1715
CondTy = VectorType::get (ScalarCondTy, NumVecElts);
1719
1716
1720
1717
// Assume the pairwise shuffles add a cost.
0 commit comments