Skip to content

Commit e9c9329

Browse files
committed
[TTI] Add TargetCostKind argument to getUserCost
There are several different types of cost that TTI tries to provide explicit information for: throughput, latency, code size along with a vague 'intersection of code-size cost and execution cost'. The vectorizer is a keen user of RecipThroughput and there's at least 'getInstructionThroughput' and 'getArithmeticInstrCost' designed to help with this cost. The latency cost has a single use and a single implementation. The intersection cost appears to cover most of the rest of the API. getUserCost is explicitly called from within TTI when the user has been explicit in wanting the code size (also only one use) as well as a few passes which are concerned with a mixture of size and/or a relative cost. In many cases these costs are closely related, such as when multiple instructions are required, but one evident diverging cost in this function is for div/rem. This patch adds an argument so that the cost required is explicit, so that we can make the important distinction when necessary. Differential Revision: https://reviews.llvm.org/D78635
1 parent e849e7a commit e9c9329

21 files changed

+76
-50
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ class TargetTransformInfo {
153153
enum TargetCostKind {
154154
TCK_RecipThroughput, ///< Reciprocal throughput.
155155
TCK_Latency, ///< The latency of instruction.
156-
TCK_CodeSize ///< Instruction code size.
156+
TCK_CodeSize, ///< Instruction code size.
157+
TCK_SizeAndLatency ///< The weighted sum of size and latency.
157158
};
158159

159160
/// Query the cost of a specified instruction.
@@ -172,7 +173,8 @@ class TargetTransformInfo {
172173
return getInstructionLatency(I);
173174

174175
case TCK_CodeSize:
175-
return getUserCost(I);
176+
case TCK_SizeAndLatency:
177+
return getUserCost(I, kind);
176178
}
177179
llvm_unreachable("Unknown instruction cost kind");
178180
}
@@ -263,14 +265,15 @@ class TargetTransformInfo {
263265
///
264266
/// The returned cost is defined in terms of \c TargetCostConstants, see its
265267
/// comments for a detailed explanation of the cost values.
266-
int getUserCost(const User *U, ArrayRef<const Value *> Operands) const;
268+
int getUserCost(const User *U, ArrayRef<const Value *> Operands,
269+
TargetCostKind CostKind) const;
267270

268271
/// This is a helper function which calls the two-argument getUserCost
269272
/// with \p Operands which are the current operands U has.
270-
int getUserCost(const User *U) const {
273+
int getUserCost(const User *U, TargetCostKind CostKind) const {
271274
SmallVector<const Value *, 4> Operands(U->value_op_begin(),
272275
U->value_op_end());
273-
return getUserCost(U, Operands);
276+
return getUserCost(U, Operands, CostKind);
274277
}
275278

276279
/// Return true if branch divergence exists.
@@ -1170,7 +1173,8 @@ class TargetTransformInfo::Concept {
11701173
getEstimatedNumberOfCaseClusters(const SwitchInst &SI, unsigned &JTSize,
11711174
ProfileSummaryInfo *PSI,
11721175
BlockFrequencyInfo *BFI) = 0;
1173-
virtual int getUserCost(const User *U, ArrayRef<const Value *> Operands) = 0;
1176+
virtual int getUserCost(const User *U, ArrayRef<const Value *> Operands,
1177+
TargetCostKind CostKind) = 0;
11741178
virtual bool hasBranchDivergence() = 0;
11751179
virtual bool useGPUDivergenceAnalysis() = 0;
11761180
virtual bool isSourceOfDivergence(const Value *V) = 0;
@@ -1422,8 +1426,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
14221426
int getMemcpyCost(const Instruction *I) override {
14231427
return Impl.getMemcpyCost(I);
14241428
}
1425-
int getUserCost(const User *U, ArrayRef<const Value *> Operands) override {
1426-
return Impl.getUserCost(U, Operands);
1429+
int getUserCost(const User *U, ArrayRef<const Value *> Operands,
1430+
TargetCostKind CostKind) override {
1431+
return Impl.getUserCost(U, Operands, CostKind);
14271432
}
14281433
bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); }
14291434
bool useGPUDivergenceAnalysis() override {

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,9 +792,11 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
792792
return static_cast<T *>(this)->getIntrinsicCost(IID, RetTy, ParamTys, U);
793793
}
794794

795-
unsigned getUserCost(const User *U, ArrayRef<const Value *> Operands) {
795+
unsigned getUserCost(const User *U, ArrayRef<const Value *> Operands,
796+
enum TTI::TargetCostKind CostKind) {
796797
auto *TargetTTI = static_cast<T *>(this);
797798

799+
// FIXME: Unlikely to be true for anything but CodeSize.
798800
if (const auto *CB = dyn_cast<CallBase>(U)) {
799801
const Function *F = CB->getCalledFunction();
800802
if (F) {
@@ -841,6 +843,7 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
841843
case Instruction::SRem:
842844
case Instruction::UDiv:
843845
case Instruction::URem:
846+
// FIXME: Unlikely to be true for CodeSize.
844847
return TTI::TCC_Expensive;
845848
case Instruction::IntToPtr:
846849
case Instruction::PtrToInt:
@@ -867,7 +870,7 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
867870
int getInstructionLatency(const Instruction *I) {
868871
SmallVector<const Value *, 4> Operands(I->value_op_begin(),
869872
I->value_op_end());
870-
if (getUserCost(I, Operands) == TTI::TCC_Free)
873+
if (getUserCost(I, Operands, TTI::TCK_Latency) == TTI::TCC_Free)
871874
return 0;
872875

873876
if (isa<LoadInst>(I))

llvm/lib/Analysis/CodeMetrics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB,
172172
if (InvI->cannotDuplicate())
173173
notDuplicatable = true;
174174

175-
NumInsts += TTI.getUserCost(&I);
175+
NumInsts += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize);
176176
}
177177

178178
if (isa<ReturnInst>(BB->getTerminator()))

llvm/lib/Analysis/InlineCost.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,8 @@ bool CallAnalyzer::isGEPFree(GetElementPtrInst &GEP) {
803803
Operands.push_back(SimpleOp);
804804
else
805805
Operands.push_back(*I);
806-
return TargetTransformInfo::TCC_Free == TTI.getUserCost(&GEP, Operands);
806+
return TargetTransformInfo::TCC_Free ==
807+
TTI.getUserCost(&GEP, Operands, TargetTransformInfo::TCK_SizeAndLatency);
807808
}
808809

809810
bool CallAnalyzer::visitAlloca(AllocaInst &I) {
@@ -1051,7 +1052,8 @@ bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) {
10511052
if (auto *SROAArg = getSROAArgForValueOrNull(I.getOperand(0)))
10521053
SROAArgValues[&I] = SROAArg;
10531054

1054-
return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I);
1055+
return TargetTransformInfo::TCC_Free ==
1056+
TTI.getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
10551057
}
10561058

10571059
bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) {
@@ -1075,7 +1077,8 @@ bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) {
10751077
if (auto *SROAArg = getSROAArgForValueOrNull(Op))
10761078
SROAArgValues[&I] = SROAArg;
10771079

1078-
return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I);
1080+
return TargetTransformInfo::TCC_Free ==
1081+
TTI.getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
10791082
}
10801083

10811084
bool CallAnalyzer::visitCastInst(CastInst &I) {
@@ -1105,7 +1108,8 @@ bool CallAnalyzer::visitCastInst(CastInst &I) {
11051108
break;
11061109
}
11071110

1108-
return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I);
1111+
return TargetTransformInfo::TCC_Free ==
1112+
TTI.getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
11091113
}
11101114

11111115
bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) {
@@ -1807,7 +1811,8 @@ bool CallAnalyzer::visitUnreachableInst(UnreachableInst &I) {
18071811
bool CallAnalyzer::visitInstruction(Instruction &I) {
18081812
// Some instructions are free. All of the free intrinsics can also be
18091813
// handled by SROA, etc.
1810-
if (TargetTransformInfo::TCC_Free == TTI.getUserCost(&I))
1814+
if (TargetTransformInfo::TCC_Free ==
1815+
TTI.getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency))
18111816
return true;
18121817

18131818
// We found something we don't understand or can't handle. Mark any SROA-able

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,9 @@ unsigned TargetTransformInfo::getEstimatedNumberOfCaseClusters(
178178
}
179179

180180
int TargetTransformInfo::getUserCost(const User *U,
181-
ArrayRef<const Value *> Operands) const {
182-
int Cost = TTIImpl->getUserCost(U, Operands);
181+
ArrayRef<const Value *> Operands,
182+
enum TargetCostKind CostKind) const {
183+
int Cost = TTIImpl->getUserCost(U, Operands, CostKind);
183184
assert(Cost >= 0 && "TTI should not produce negative costs!");
184185
return Cost;
185186
}
@@ -1152,7 +1153,7 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
11521153
int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
11531154
switch (I->getOpcode()) {
11541155
case Instruction::GetElementPtr:
1155-
return getUserCost(I);
1156+
return getUserCost(I, TCK_RecipThroughput);
11561157

11571158
case Instruction::Ret:
11581159
case Instruction::PHI:

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6103,7 +6103,8 @@ static bool sinkSelectOperand(const TargetTransformInfo *TTI, Value *V) {
61036103
// If it's safe to speculatively execute, then it should not have side
61046104
// effects; therefore, it's safe to sink and possibly *not* execute.
61056105
return I && I->hasOneUse() && isSafeToSpeculativelyExecute(I) &&
6106-
TTI->getUserCost(I) >= TargetTransformInfo::TCC_Expensive;
6106+
TTI->getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency) >=
6107+
TargetTransformInfo::TCC_Expensive;
61076108
}
61086109

61096110
/// Returns true if a SelectInst should be turned into an explicit branch.

llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -949,11 +949,12 @@ void GCNTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
949949
CommonTTI.getUnrollingPreferences(L, SE, UP);
950950
}
951951

952-
unsigned GCNTTIImpl::getUserCost(const User *U,
953-
ArrayRef<const Value *> Operands) {
952+
unsigned
953+
GCNTTIImpl::getUserCost(const User *U, ArrayRef<const Value *> Operands,
954+
TTI::TargetCostKind CostKind) {
954955
const Instruction *I = dyn_cast<Instruction>(U);
955956
if (!I)
956-
return BaseT::getUserCost(U, Operands);
957+
return BaseT::getUserCost(U, Operands, CostKind);
957958

958959
// Estimate different operations to be optimized out
959960
switch (I->getOpcode()) {
@@ -980,7 +981,7 @@ unsigned GCNTTIImpl::getUserCost(const User *U,
980981
return getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), Args,
981982
FMF, 1, II);
982983
} else {
983-
return BaseT::getUserCost(U, Operands);
984+
return BaseT::getUserCost(U, Operands, CostKind);
984985
}
985986
}
986987
case Instruction::ShuffleVector: {
@@ -994,7 +995,7 @@ unsigned GCNTTIImpl::getUserCost(const User *U,
994995
return getShuffleCost(TTI::SK_ExtractSubvector, SrcTy, SubIndex, Ty);
995996

996997
if (Shuffle->changesLength())
997-
return BaseT::getUserCost(U, Operands);
998+
return BaseT::getUserCost(U, Operands, CostKind);
998999

9991000
if (Shuffle->isIdentity())
10001001
return 0;
@@ -1059,7 +1060,7 @@ unsigned GCNTTIImpl::getUserCost(const User *U,
10591060
break;
10601061
}
10611062

1062-
return BaseT::getUserCost(U, Operands);
1063+
return BaseT::getUserCost(U, Operands, CostKind);
10631064
}
10641065

10651066
unsigned R600TTIImpl::getHardwareNumberOfRegisters(bool Vec) const {

llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
242242
int getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
243243
bool IsPairwiseForm,
244244
bool IsUnsigned);
245-
unsigned getUserCost(const User *U, ArrayRef<const Value *> Operands);
245+
unsigned getUserCost(const User *U, ArrayRef<const Value *> Operands,
246+
TTI::TargetCostKind CostKind);
246247
};
247248

248249
class R600TTIImpl final : public BasicTTIImplBase<R600TTIImpl> {

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1362,7 +1362,7 @@ void ARMTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
13621362

13631363
SmallVector<const Value*, 4> Operands(I.value_op_begin(),
13641364
I.value_op_end());
1365-
Cost += getUserCost(&I, Operands);
1365+
Cost += getUserCost(&I, Operands, TargetTransformInfo::TCK_CodeSize);
13661366
}
13671367
}
13681368

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,10 @@ unsigned HexagonTTIImpl::getCacheLineSize() const {
298298
return ST.getL1CacheLineSize();
299299
}
300300

301-
int HexagonTTIImpl::getUserCost(const User *U,
302-
ArrayRef<const Value *> Operands) {
301+
int
302+
HexagonTTIImpl::getUserCost(const User *U,
303+
ArrayRef<const Value *> Operands,
304+
TTI::TargetCostKind CostKind) {
303305
auto isCastFoldedIntoLoad = [this](const CastInst *CI) -> bool {
304306
if (!CI->isIntegerCast())
305307
return false;
@@ -321,7 +323,7 @@ int HexagonTTIImpl::getUserCost(const User *U,
321323
if (const CastInst *CI = dyn_cast<const CastInst>(U))
322324
if (isCastFoldedIntoLoad(CI))
323325
return TargetTransformInfo::TCC_Free;
324-
return BaseT::getUserCost(U, Operands);
326+
return BaseT::getUserCost(U, Operands, CostKind);
325327
}
326328

327329
bool HexagonTTIImpl::shouldBuildLookupTables() const {

0 commit comments

Comments
 (0)