Skip to content

Commit b58b62c

Browse files
committed
[PredicateInfo] Support existing PredicateType by adding PredicatePHI
when needing introduction of phi nodes Resolves #150606 Currently `ssa.copy` is used mostly for straight line code, i.e, without joins or uses of phi nodes. With this, passes would be able to pick up the relevant info and further optimize the IR.
1 parent ece7a72 commit b58b62c

File tree

2 files changed

+202
-1
lines changed

2 files changed

+202
-1
lines changed

llvm/include/llvm/Transforms/Utils/PredicateInfo.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Value;
6767
class IntrinsicInst;
6868
class raw_ostream;
6969

70-
enum PredicateType { PT_Branch, PT_Assume, PT_Switch };
70+
enum PredicateType { PT_Branch, PT_Assume, PT_Switch, PT_PHI };
7171

7272
/// Constraint for a predicate of the form "cmp Pred Op, OtherOp", where Op
7373
/// is the value the constraint applies to (the ssa.copy result).
@@ -171,6 +171,18 @@ class PredicateSwitch : public PredicateWithEdge {
171171
}
172172
};
173173

174+
class PredicatePHI : public PredicateBase {
175+
public:
176+
BasicBlock *PHIBlock;
177+
SmallVector<std::pair<BasicBlock *, PredicateBase *>, 4> IncomingPredicates;
178+
179+
PredicatePHI(Value *Op, BasicBlock *PHIBB)
180+
: PredicateBase(PT_PHI, Op, nullptr), PHIBlock(PHIBB) {}
181+
static bool classof(const PredicateBase *PB) { return PB->Type == PT_PHI; }
182+
183+
LLVM_ABI std::optional<PredicateConstraint> getConstraint() const;
184+
};
185+
174186
/// Encapsulates PredicateInfo, including all data associated with memory
175187
/// accesses.
176188
class PredicateInfo {

llvm/lib/Transforms/Utils/PredicateInfo.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ class PredicateInfoBuilder {
213213
// whether it returned a valid result.
214214
DenseMap<Value *, unsigned int> ValueInfoNums;
215215

216+
DenseMap<BasicBlock *, SmallVector<Value *, 4>> PHICandidates;
217+
216218
BumpPtrAllocator &Allocator;
217219

218220
ValueInfo &getOrCreateValueInfo(Value *);
@@ -224,6 +226,13 @@ class PredicateInfoBuilder {
224226
SmallVectorImpl<Value *> &OpsToRename);
225227
void processSwitch(SwitchInst *, BasicBlock *,
226228
SmallVectorImpl<Value *> &OpsToRename);
229+
void identifyPHIInsertionPoints(SmallVectorImpl<Value *> &OpsToRename);
230+
bool needsPHIForPredicateInfo(Value *Op, BasicBlock *BB);
231+
void insertPredicatePHIs(Value *Op, BasicBlock *BB);
232+
void computePredicateIDF(Value *Op,
233+
const SmallPtrSet<BasicBlock *, 8> &DefiningBlocks,
234+
SmallPtrSet<BasicBlock *, 8> &PHIBlocks);
235+
void processPredicatePHIs(SmallVectorImpl<Value *> &OpsToRename);
227236
void renameUses(SmallVectorImpl<Value *> &OpsToRename);
228237
void addInfoFor(SmallVectorImpl<Value *> &OpsToRename, Value *Op,
229238
PredicateBase *PB);
@@ -452,6 +461,155 @@ void PredicateInfoBuilder::processSwitch(
452461
}
453462
}
454463

464+
void PredicateInfoBuilder::identifyPHIInsertionPoints(
465+
SmallVectorImpl<Value *> &OpsToRename) {
466+
for (Value *Op : OpsToRename) {
467+
const auto &ValueInfo = getValueInfo(Op);
468+
SmallPtrSet<BasicBlock *, 8> DefiningBlocks;
469+
470+
for (const auto *PInfo : ValueInfo.Infos) {
471+
if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
472+
DefiningBlocks.insert(PBranch->To);
473+
} else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
474+
DefiningBlocks.insert(PSwitch->To);
475+
} else if (auto *PAssume = dyn_cast<PredicateAssume>(PInfo)) {
476+
DefiningBlocks.insert(PAssume->AssumeInst->getParent());
477+
}
478+
}
479+
480+
if (DefiningBlocks.size() > 1) {
481+
SmallPtrSet<BasicBlock *, 8> PHIBlocks;
482+
computePredicateIDF(Op, DefiningBlocks, PHIBlocks);
483+
484+
for (BasicBlock *PHIBlock : PHIBlocks) {
485+
if (needsPHIForPredicateInfo(Op, PHIBlock)) {
486+
PHICandidates[PHIBlock].push_back(Op);
487+
}
488+
}
489+
}
490+
}
491+
}
492+
493+
void PredicateInfoBuilder::computePredicateIDF(
494+
Value *Op, const SmallPtrSet<BasicBlock *, 8> &DefiningBlocks,
495+
SmallPtrSet<BasicBlock *, 8> &PHIBlocks) {
496+
497+
SmallVector<BasicBlock *, 8> Worklist(DefiningBlocks.begin(),
498+
DefiningBlocks.end());
499+
500+
while (!Worklist.empty()) {
501+
BasicBlock *BB = Worklist.pop_back_val();
502+
503+
DomTreeNode *Node = DT.getNode(BB);
504+
if (!Node)
505+
continue;
506+
507+
for (BasicBlock *Succ : successors(BB)) {
508+
if (!DT.dominates(BB, Succ)) {
509+
BasicBlock *IDom = DT.getNode(BB)->getIDom()->getBlock();
510+
if (DT.dominates(IDom, Succ)) {
511+
if (PHIBlocks.insert(Succ).second) {
512+
bool HasOpUse = false;
513+
for (auto &I : *Succ) {
514+
for (Use &U : I.uses()) {
515+
if (U.get() == Op) {
516+
HasOpUse = true;
517+
break;
518+
}
519+
}
520+
}
521+
if (HasOpUse) {
522+
Worklist.push_back(Succ);
523+
}
524+
}
525+
}
526+
}
527+
}
528+
}
529+
}
530+
531+
bool PredicateInfoBuilder::needsPHIForPredicateInfo(Value *Op, BasicBlock *BB) {
532+
if (BB->getSinglePredecessor())
533+
return false;
534+
535+
const auto &ValueInfo = getValueInfo(Op);
536+
SmallDenseSet<PredicateBase *, 4> PredPredicates;
537+
538+
for (BasicBlock *Pred : predecessors(BB)) {
539+
PredicateBase *PredInfo = nullptr;
540+
541+
for (const auto *PInfo : ValueInfo.Infos) {
542+
if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
543+
if (PBranch->From == Pred && PBranch->To == BB) {
544+
PredInfo = const_cast<PredicateBase *>(PInfo);
545+
break;
546+
}
547+
} else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
548+
if (PSwitch->From == Pred && PSwitch->To == BB) {
549+
PredInfo = const_cast<PredicateBase *>(PInfo);
550+
break;
551+
}
552+
}
553+
}
554+
555+
if (PredInfo) {
556+
PredPredicates.insert(PredInfo);
557+
}
558+
}
559+
560+
return PredPredicates.size() > 1 ||
561+
(PredPredicates.size() == 1 && pred_size(BB) > 1);
562+
}
563+
564+
void PredicateInfoBuilder::insertPredicatePHIs(Value *Op, BasicBlock *BB) {
565+
IRBuilder<> Builder(&BB->front());
566+
PHINode *PHI = Builder.CreatePHI(Op->getType(), pred_size(BB),
567+
Op->getName() + ".predicate.phi");
568+
569+
auto *PPhi = new (Allocator) PredicatePHI(Op, BB);
570+
PPhi->RenamedOp = PHI;
571+
572+
const auto &ValueInfo = getValueInfo(Op);
573+
for (BasicBlock *Pred : predecessors(BB)) {
574+
Value *IncomingValue = Op;
575+
576+
for (const auto *PInfo : ValueInfo.Infos) {
577+
if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
578+
if (PBranch->From == Pred && PBranch->To == BB) {
579+
PPhi->IncomingPredicates.push_back(
580+
{Pred, const_cast<PredicateBase *>(PInfo)});
581+
break;
582+
}
583+
} else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
584+
if (PSwitch->From == Pred && PSwitch->To == BB) {
585+
PPhi->IncomingPredicates.push_back(
586+
{Pred, const_cast<PredicateBase *>(PInfo)});
587+
break;
588+
}
589+
}
590+
}
591+
592+
PHI->addIncoming(IncomingValue, Pred);
593+
}
594+
595+
PI.PredicateMap.insert({PHI, PPhi});
596+
597+
auto &OperandInfo = getOrCreateValueInfo(Op);
598+
OperandInfo.Infos.push_back(PPhi);
599+
}
600+
601+
void PredicateInfoBuilder::processPredicatePHIs(
602+
SmallVectorImpl<Value *> &OpsToRename) {
603+
604+
for (const auto &Entry : PHICandidates) {
605+
BasicBlock *PHIBlock = Entry.first;
606+
607+
for (Value *Op : Entry.second) {
608+
insertPredicatePHIs(Op, PHIBlock);
609+
}
610+
}
611+
}
612+
455613
// Build predicate info for our function
456614
void PredicateInfoBuilder::buildPredicateInfo() {
457615
DT.updateDFSNumbers();
@@ -478,6 +636,10 @@ void PredicateInfoBuilder::buildPredicateInfo() {
478636
if (DT.isReachableFromEntry(II->getParent()))
479637
processAssume(II, II->getParent(), OpsToRename);
480638
}
639+
640+
identifyPHIInsertionPoints(OpsToRename);
641+
processPredicatePHIs(OpsToRename);
642+
481643
// Now rename all our operations.
482644
renameUses(OpsToRename);
483645
}
@@ -773,10 +935,33 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
773935
}
774936

775937
return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}};
938+
case PT_PHI:
939+
return cast<PredicatePHI>(this)->getConstraint();
776940
}
777941
llvm_unreachable("Unknown predicate type");
778942
}
779943

944+
std::optional<PredicateConstraint> PredicatePHI::getConstraint() const {
945+
// For PHI predicates, find the common constraint across all incoming edges
946+
if (IncomingPredicates.empty())
947+
return std::nullopt;
948+
949+
auto FirstConstraint = IncomingPredicates[0].second->getConstraint();
950+
if (!FirstConstraint)
951+
return std::nullopt;
952+
953+
// Verify all incoming predicates have the same constraint
954+
for (size_t I = 1; I < IncomingPredicates.size(); ++I) {
955+
auto Constraint = IncomingPredicates[I].second->getConstraint();
956+
if (!Constraint || Constraint->Predicate != FirstConstraint->Predicate ||
957+
Constraint->OtherOp != FirstConstraint->OtherOp) {
958+
return std::nullopt;
959+
}
960+
}
961+
962+
return FirstConstraint;
963+
}
964+
780965
void PredicateInfo::verifyPredicateInfo() const {}
781966

782967
// Replace ssa_copy calls created by PredicateInfo with their operand.
@@ -838,6 +1023,10 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
8381023
} else if (const auto *PA = dyn_cast<PredicateAssume>(PI)) {
8391024
OS << "; assume predicate info {"
8401025
<< " Comparison:" << *PA->Condition;
1026+
} else if (const auto *PP = dyn_cast<PredicatePHI>(PI)) {
1027+
OS << "; phi predicate info { PHIBlock: ";
1028+
PP->PHIBlock->printAsOperand(OS);
1029+
OS << " IncomingEdges: " << PP->IncomingPredicates.size();
8411030
}
8421031
OS << ", RenamedOp: ";
8431032
PI->RenamedOp->printAsOperand(OS, false);

0 commit comments

Comments
 (0)