Skip to content

Commit 5034df8

Browse files
committed
[SampleProfile] Use CallBase in function arguments and data structures to reduce the number of explicit casts. NFCI
Removing CallSite left us with a bunch of explicit casts from Instruction to CallBase. This moves the casts earlier so that function arguments and data structure types are CallBase so we don't have to cast when we use them. Differential Revision: https://reviews.llvm.org/D78246
1 parent a6f1976 commit 5034df8

File tree

1 file changed

+45
-52
lines changed

1 file changed

+45
-52
lines changed

llvm/lib/Transforms/IPO/SampleProfile.cpp

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -329,18 +329,19 @@ class SampleProfileLoader {
329329
bool emitAnnotations(Function &F);
330330
ErrorOr<uint64_t> getInstWeight(const Instruction &I);
331331
ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB);
332-
const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const;
332+
const FunctionSamples *findCalleeFunctionSamples(const CallBase &I) const;
333333
std::vector<const FunctionSamples *>
334334
findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const;
335335
mutable DenseMap<const DILocation *, const FunctionSamples *> DILocation2SampleMap;
336336
const FunctionSamples *findFunctionSamples(const Instruction &I) const;
337-
bool inlineCallInstruction(Instruction *I);
337+
bool inlineCallInstruction(CallBase &CB);
338338
bool inlineHotFunctions(Function &F,
339339
DenseSet<GlobalValue::GUID> &InlinedGUIDs);
340340
// Inline cold/small functions in addition to hot ones
341-
bool shouldInlineColdCallee(Instruction &CallInst);
341+
bool shouldInlineColdCallee(CallBase &CallInst);
342342
void emitOptimizationRemarksForInlineCandidates(
343-
const SmallVector<Instruction *, 10> &Candidates, const Function &F, bool Hot);
343+
const SmallVectorImpl<CallBase *> &Candidates, const Function &F,
344+
bool Hot);
344345
void printEdgeWeight(raw_ostream &OS, Edge E);
345346
void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const;
346347
void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB);
@@ -718,9 +719,9 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
718719
// (findCalleeFunctionSamples returns non-empty result), but not inlined here,
719720
// it means that the inlined callsite has no sample, thus the call
720721
// instruction should have 0 count.
721-
if ((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) &&
722-
!cast<CallBase>(Inst).isIndirectCall() && findCalleeFunctionSamples(Inst))
723-
return 0;
722+
if (auto *CB = dyn_cast<CallBase>(&Inst))
723+
if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB))
724+
return 0;
724725

725726
const DILocation *DIL = DLoc;
726727
uint32_t LineOffset = FunctionSamples::getOffset(DIL);
@@ -808,7 +809,7 @@ bool SampleProfileLoader::computeBlockWeights(Function &F) {
808809
///
809810
/// \returns The FunctionSamples pointer to the inlined instance.
810811
const FunctionSamples *
811-
SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const {
812+
SampleProfileLoader::findCalleeFunctionSamples(const CallBase &Inst) const {
812813
const DILocation *DIL = Inst.getDebugLoc();
813814
if (!DIL) {
814815
return nullptr;
@@ -892,15 +893,11 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const {
892893
return it.first->second;
893894
}
894895

895-
// FIXME(CallSite): Parameter should be CallBase&, as it's assumed to be that,
896-
// and non-null.
897-
bool SampleProfileLoader::inlineCallInstruction(Instruction *I) {
898-
assert(isa<CallInst>(I) || isa<InvokeInst>(I));
899-
CallBase &CS = *cast<CallBase>(I);
900-
Function *CalledFunction = CS.getCalledFunction();
896+
bool SampleProfileLoader::inlineCallInstruction(CallBase &CB) {
897+
Function *CalledFunction = CB.getCalledFunction();
901898
assert(CalledFunction);
902-
DebugLoc DLoc = I->getDebugLoc();
903-
BasicBlock *BB = I->getParent();
899+
DebugLoc DLoc = CB.getDebugLoc();
900+
BasicBlock *BB = CB.getParent();
904901
InlineParams Params = getInlineParams();
905902
Params.ComputeFullInlineCost = true;
906903
// Checks if there is anything in the reachable portion of the callee at
@@ -909,16 +906,15 @@ bool SampleProfileLoader::inlineCallInstruction(Instruction *I) {
909906
// when cost exceeds threshold without checking all IRs in the callee.
910907
// The acutal cost does not matter because we only checks isNever() to
911908
// see if it is legal to inline the callsite.
912-
InlineCost Cost =
913-
getInlineCost(cast<CallBase>(*I), Params, GetTTI(*CalledFunction), GetAC,
914-
None, GetTLI, nullptr, nullptr);
909+
InlineCost Cost = getInlineCost(CB, Params, GetTTI(*CalledFunction), GetAC,
910+
None, GetTLI, nullptr, nullptr);
915911
if (Cost.isNever()) {
916912
ORE->emit(OptimizationRemarkAnalysis(CSINLINE_DEBUG, "InlineFail", DLoc, BB)
917913
<< "incompatible inlining");
918914
return false;
919915
}
920916
InlineFunctionInfo IFI(nullptr, &GetAC);
921-
if (InlineFunction(CS, IFI).isSuccess()) {
917+
if (InlineFunction(CB, IFI).isSuccess()) {
922918
// The call to InlineFunction erases I, so we can't pass it here.
923919
ORE->emit(OptimizationRemark(CSINLINE_DEBUG, "InlineSuccess", DLoc, BB)
924920
<< "inlined callee '" << ore::NV("Callee", CalledFunction)
@@ -928,26 +924,25 @@ bool SampleProfileLoader::inlineCallInstruction(Instruction *I) {
928924
return false;
929925
}
930926

931-
bool SampleProfileLoader::shouldInlineColdCallee(Instruction &CallInst) {
927+
bool SampleProfileLoader::shouldInlineColdCallee(CallBase &CallInst) {
932928
if (!ProfileSizeInline)
933929
return false;
934930

935-
Function *Callee = cast<CallBase>(CallInst).getCalledFunction();
931+
Function *Callee = CallInst.getCalledFunction();
936932
if (Callee == nullptr)
937933
return false;
938934

939-
InlineCost Cost =
940-
getInlineCost(cast<CallBase>(CallInst), getInlineParams(),
941-
GetTTI(*Callee), GetAC, None, GetTLI, nullptr, nullptr);
935+
InlineCost Cost = getInlineCost(CallInst, getInlineParams(), GetTTI(*Callee),
936+
GetAC, None, GetTLI, nullptr, nullptr);
942937

943938
return Cost.getCost() <= SampleColdCallSiteThreshold;
944939
}
945940

946941
void SampleProfileLoader::emitOptimizationRemarksForInlineCandidates(
947-
const SmallVector<Instruction *, 10> &Candidates, const Function &F,
942+
const SmallVectorImpl<CallBase *> &Candidates, const Function &F,
948943
bool Hot) {
949944
for (auto I : Candidates) {
950-
Function *CalledFunction = cast<CallBase>(I)->getCalledFunction();
945+
Function *CalledFunction = I->getCalledFunction();
951946
if (CalledFunction) {
952947
ORE->emit(OptimizationRemarkAnalysis(CSINLINE_DEBUG, "InlineAttempt",
953948
I->getDebugLoc(), I->getParent())
@@ -984,45 +979,43 @@ bool SampleProfileLoader::inlineHotFunctions(
984979
"ProfAccForSymsInList should be false when profile-sample-accurate "
985980
"is enabled");
986981

987-
// FIXME(CallSite): refactor the vectors here, as they operate with CallBase
988-
// values
989-
DenseMap<Instruction *, const FunctionSamples *> localNotInlinedCallSites;
982+
DenseMap<CallBase *, const FunctionSamples *> localNotInlinedCallSites;
990983
bool Changed = false;
991984
while (true) {
992985
bool LocalChanged = false;
993-
SmallVector<Instruction *, 10> CIS;
986+
SmallVector<CallBase *, 10> CIS;
994987
for (auto &BB : F) {
995988
bool Hot = false;
996-
SmallVector<Instruction *, 10> AllCandidates;
997-
SmallVector<Instruction *, 10> ColdCandidates;
989+
SmallVector<CallBase *, 10> AllCandidates;
990+
SmallVector<CallBase *, 10> ColdCandidates;
998991
for (auto &I : BB.getInstList()) {
999992
const FunctionSamples *FS = nullptr;
1000-
if ((isa<CallInst>(I) || isa<InvokeInst>(I)) &&
1001-
!isa<IntrinsicInst>(I) && (FS = findCalleeFunctionSamples(I))) {
1002-
AllCandidates.push_back(&I);
1003-
if (FS->getEntrySamples() > 0)
1004-
localNotInlinedCallSites.try_emplace(&I, FS);
1005-
if (callsiteIsHot(FS, PSI))
1006-
Hot = true;
1007-
else if (shouldInlineColdCallee(I))
1008-
ColdCandidates.push_back(&I);
993+
if (auto *CB = dyn_cast<CallBase>(&I)) {
994+
if (!isa<IntrinsicInst>(I) && (FS = findCalleeFunctionSamples(*CB))) {
995+
AllCandidates.push_back(CB);
996+
if (FS->getEntrySamples() > 0)
997+
localNotInlinedCallSites.try_emplace(CB, FS);
998+
if (callsiteIsHot(FS, PSI))
999+
Hot = true;
1000+
else if (shouldInlineColdCallee(*CB))
1001+
ColdCandidates.push_back(CB);
1002+
}
10091003
}
10101004
}
10111005
if (Hot) {
10121006
CIS.insert(CIS.begin(), AllCandidates.begin(), AllCandidates.end());
10131007
emitOptimizationRemarksForInlineCandidates(AllCandidates, F, true);
1014-
}
1015-
else {
1008+
} else {
10161009
CIS.insert(CIS.begin(), ColdCandidates.begin(), ColdCandidates.end());
10171010
emitOptimizationRemarksForInlineCandidates(ColdCandidates, F, false);
10181011
}
10191012
}
1020-
for (auto I : CIS) {
1021-
Function *CalledFunction = cast<CallBase>(I)->getCalledFunction();
1013+
for (CallBase *I : CIS) {
1014+
Function *CalledFunction = I->getCalledFunction();
10221015
// Do not inline recursive calls.
10231016
if (CalledFunction == &F)
10241017
continue;
1025-
if (cast<CallBase>(I)->isIndirectCall()) {
1018+
if (I->isIndirectCall()) {
10261019
if (PromotedInsns.count(I))
10271020
continue;
10281021
uint64_t Sum;
@@ -1049,15 +1042,15 @@ bool SampleProfileLoader::inlineHotFunctions(
10491042
if (R != SymbolMap.end() && R->getValue() &&
10501043
!R->getValue()->isDeclaration() &&
10511044
R->getValue()->getSubprogram() &&
1052-
isLegalToPromote(*cast<CallBase>(I), R->getValue(), &Reason)) {
1045+
isLegalToPromote(*I, R->getValue(), &Reason)) {
10531046
uint64_t C = FS->getEntrySamples();
10541047
Instruction *DI =
10551048
pgo::promoteIndirectCall(I, R->getValue(), C, Sum, false, ORE);
10561049
Sum -= C;
10571050
PromotedInsns.insert(I);
10581051
// If profile mismatches, we should not attempt to inline DI.
10591052
if ((isa<CallInst>(DI) || isa<InvokeInst>(DI)) &&
1060-
inlineCallInstruction(DI)) {
1053+
inlineCallInstruction(*cast<CallBase>(DI))) {
10611054
localNotInlinedCallSites.erase(I);
10621055
LocalChanged = true;
10631056
++NumCSInlined;
@@ -1070,7 +1063,7 @@ bool SampleProfileLoader::inlineHotFunctions(
10701063
}
10711064
} else if (CalledFunction && CalledFunction->getSubprogram() &&
10721065
!CalledFunction->isDeclaration()) {
1073-
if (inlineCallInstruction(I)) {
1066+
if (inlineCallInstruction(*I)) {
10741067
localNotInlinedCallSites.erase(I);
10751068
LocalChanged = true;
10761069
++NumCSInlined;
@@ -1089,8 +1082,8 @@ bool SampleProfileLoader::inlineHotFunctions(
10891082

10901083
// Accumulate not inlined callsite information into notInlinedSamples
10911084
for (const auto &Pair : localNotInlinedCallSites) {
1092-
Instruction *I = Pair.getFirst();
1093-
Function *Callee = cast<CallBase>(I)->getCalledFunction();
1085+
CallBase *I = Pair.getFirst();
1086+
Function *Callee = I->getCalledFunction();
10941087
if (!Callee || Callee->isDeclaration())
10951088
continue;
10961089

0 commit comments

Comments
 (0)