diff --git a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h index c243e236901d5..6832f31e558c4 100644 --- a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h +++ b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h @@ -67,7 +67,7 @@ class Value; class IntrinsicInst; class raw_ostream; -enum PredicateType { PT_Branch, PT_Assume, PT_Switch }; +enum PredicateType { PT_Branch, PT_Assume, PT_Switch, PT_PHI }; /// Constraint for a predicate of the form "cmp Pred Op, OtherOp", where Op /// is the value the constraint applies to (the ssa.copy result). @@ -171,6 +171,19 @@ class PredicateSwitch : public PredicateWithEdge { } }; +class PredicatePHI : public PredicateBase { +public: + BasicBlock *PHIBlock; + SmallVector, 4> IncomingPredicates; + + PredicatePHI(Value *Op, BasicBlock *PHIBB) + : PredicateBase(PT_PHI, Op, nullptr), PHIBlock(PHIBB) {} + PredicatePHI() = delete; + static bool classof(const PredicateBase *PB) { return PB->Type == PT_PHI; } + + LLVM_ABI std::optional getConstraint() const; +}; + /// Encapsulates PredicateInfo, including all data associated with memory /// accesses. class PredicateInfo { diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp index de9deab28750f..91473eb69d168 100644 --- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -213,6 +214,8 @@ class PredicateInfoBuilder { // whether it returned a valid result. DenseMap ValueInfoNums; + DenseMap> PHICandidates; + BumpPtrAllocator &Allocator; ValueInfo &getOrCreateValueInfo(Value *); @@ -224,6 +227,10 @@ class PredicateInfoBuilder { SmallVectorImpl &OpsToRename); void processSwitch(SwitchInst *, BasicBlock *, SmallVectorImpl &OpsToRename); + void identifyPHICandidates(SmallVectorImpl &OpsToRename); + void insertPredicatePHIs(Value *Op, BasicBlock *PHIBlock, + SmallVectorImpl &OpsToRename); + void processPredicatePHIs(SmallVectorImpl &OpsToRename); void renameUses(SmallVectorImpl &OpsToRename); void addInfoFor(SmallVectorImpl &OpsToRename, Value *Op, PredicateBase *PB); @@ -452,6 +459,71 @@ void PredicateInfoBuilder::processSwitch( } } +void PredicateInfoBuilder::identifyPHICandidates( + SmallVectorImpl &OpsToRename) { + for (Value *Op : OpsToRename) { + const auto &ValueInfo = getValueInfo(Op); + SmallPtrSet DefiningBlocks; + for (const auto *PInfo : ValueInfo.Infos) { + if (auto *PBranch = dyn_cast(PInfo)) { + DefiningBlocks.insert(PBranch->To); + } else if (auto *PSwitch = dyn_cast(PInfo)) { + DefiningBlocks.insert(PSwitch->To); + } + } + + if (DefiningBlocks.size() > 1) { + SmallPtrSet PHIBlocks(DefiningBlocks.begin(), + DefiningBlocks.end()); + PHICandidates[*PHIBlocks.begin()].push_back(Op); + } + } +} + +void PredicateInfoBuilder::insertPredicatePHIs( + Value *Op, BasicBlock *PHIBlock, SmallVectorImpl &OpsToRename) { + IRBuilder<> Builder(&PHIBlock->front()); + + PHINode *PHI = Builder.CreatePHI(Op->getType(), pred_size(PHIBlock), + Op->getName() + ".predicate.phi"); + PredicatePHI *PPHI = new (Allocator) PredicatePHI(Op, PHIBlock); + + PPHI->RenamedOp = PHI; + + const auto &ValueInfo = getValueInfo(Op); + for (BasicBlock *Pred : predecessors(PHIBlock)) { + Value *IncomingValue = nullptr; + for (const auto *PInfo : ValueInfo.Infos) { + if (auto *PBranch = dyn_cast(PInfo)) { + if (PBranch->From == Pred && PBranch->To == PHIBlock) { + PPHI->IncomingPredicates.push_back( + {Pred, const_cast(PInfo)}); + IncomingValue = PBranch->OriginalOp; + } + } else if (auto *PSwitch = dyn_cast(PInfo)) { + if (PSwitch->From == Pred && PSwitch->To == PHIBlock) { + PPHI->IncomingPredicates.push_back( + {Pred, const_cast(PInfo)}); + IncomingValue = PSwitch->OriginalOp; + } + } + } + PHI->addIncoming(IncomingValue, Pred); + } + + addInfoFor(OpsToRename, Op, PPHI); +} + +void PredicateInfoBuilder::processPredicatePHIs( + SmallVectorImpl &OpsToRename) { + for (const auto &PHICandidate : PHICandidates) { + BasicBlock *PHIBlock = PHICandidate.first; + for (Value *Op : PHICandidate.second) { + insertPredicatePHIs(Op, PHIBlock, OpsToRename); + } + } +} + // Build predicate info for our function void PredicateInfoBuilder::buildPredicateInfo() { DT.updateDFSNumbers(); @@ -478,6 +550,10 @@ void PredicateInfoBuilder::buildPredicateInfo() { if (DT.isReachableFromEntry(II->getParent())) processAssume(II, II->getParent(), OpsToRename); } + + identifyPHICandidates(OpsToRename); + processPredicatePHIs(OpsToRename); + // Now rename all our operations. renameUses(OpsToRename); } @@ -535,6 +611,12 @@ Value *PredicateInfoBuilder::materializeStack(unsigned int &Counter, CreateSSACopy(B, Op, Op->getName() + "." + Twine(Counter++)); PI.PredicateMap.insert({PIC, ValInfo}); Result.Def = PIC; + } else if (isa(ValInfo)) { + auto *PPHI = dyn_cast(ValInfo); + IRBuilder<> B(&*PPHI->PHIBlock->getFirstInsertionPt()); + CallInst *PIC = CreateSSACopy(B, Op); + PI.PredicateMap.insert({PIC, ValInfo}); + Result.Def = PIC; } else { auto *PAssume = dyn_cast(ValInfo); assert(PAssume && @@ -623,6 +705,15 @@ void PredicateInfoBuilder::renameUses(SmallVectorImpl &OpsToRename) { OrderedUses.push_back(VD); } } + } else if (const auto *PPHI = dyn_cast(PossibleCopy)) { + VD.LocalNum = LN_First; + auto *DomNode = DT.getNode(PPHI->PHIBlock); + if (DomNode) { + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.PInfo = PossibleCopy; + OrderedUses.push_back(VD); + } } } @@ -773,10 +864,31 @@ std::optional PredicateBase::getConstraint() const { } return {{CmpInst::ICMP_EQ, cast(this)->CaseValue}}; + case PT_PHI: + return cast(this)->getConstraint(); } llvm_unreachable("Unknown predicate type"); } +std::optional PredicatePHI::getConstraint() const { + if (IncomingPredicates.empty()) + return std::nullopt; + + auto FirstConstraint = IncomingPredicates[0].second->getConstraint(); + if (!FirstConstraint) + return std::nullopt; + + for (size_t I = 1; I < IncomingPredicates.size(); ++I) { + auto Constraint = IncomingPredicates[I].second->getConstraint(); + if (!Constraint || Constraint->Predicate != FirstConstraint->Predicate || + Constraint->OtherOp != FirstConstraint->OtherOp) { + return std::nullopt; + } + } + + return FirstConstraint; +} + void PredicateInfo::verifyPredicateInfo() const {} // Replace ssa_copy calls created by PredicateInfo with their operand. @@ -838,6 +950,10 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter { } else if (const auto *PA = dyn_cast(PI)) { OS << "; assume predicate info {" << " Comparison:" << *PA->Condition; + } else if (const auto *PP = dyn_cast(PI)) { + OS << "; phi predicate info { PHIBlock: "; + PP->PHIBlock->printAsOperand(OS); + OS << " IncomingEdges: " << PP->IncomingPredicates.size(); } OS << ", RenamedOp: "; PI->RenamedOp->printAsOperand(OS, false);