Skip to content

[PredicateInfo] Support existing PredicateType by adding PredicatePHI when needing introduction of phi nodes #151132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion llvm/include/llvm/Transforms/Utils/PredicateInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Value;
class IntrinsicInst;
class raw_ostream;

enum PredicateType { PT_Branch, PT_Assume, PT_Switch };
enum PredicateType { PT_Branch, PT_Assume, PT_Switch, PT_PHI };

/// Constraint for a predicate of the form "cmp Pred Op, OtherOp", where Op
/// is the value the constraint applies to (the ssa.copy result).
Expand Down Expand Up @@ -171,6 +171,19 @@ class PredicateSwitch : public PredicateWithEdge {
}
};

class PredicatePHI : public PredicateBase {
public:
BasicBlock *PHIBlock;
SmallVector<std::pair<BasicBlock *, PredicateBase *>, 4> IncomingPredicates;

PredicatePHI(Value *Op, BasicBlock *PHIBB)
: PredicateBase(PT_PHI, Op, nullptr), PHIBlock(PHIBB) {}
PredicatePHI() = delete;
static bool classof(const PredicateBase *PB) { return PB->Type == PT_PHI; }

LLVM_ABI std::optional<PredicateConstraint> getConstraint() const;
};

/// Encapsulates PredicateInfo, including all data associated with memory
/// accesses.
class PredicateInfo {
Expand Down
116 changes: 116 additions & 0 deletions llvm/lib/Transforms/Utils/PredicateInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/IR/AssemblyAnnotationWriter.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
Expand Down Expand Up @@ -213,6 +214,8 @@ class PredicateInfoBuilder {
// whether it returned a valid result.
DenseMap<Value *, unsigned int> ValueInfoNums;

DenseMap<BasicBlock *, SmallVector<Value *, 4>> PHICandidates;

BumpPtrAllocator &Allocator;

ValueInfo &getOrCreateValueInfo(Value *);
Expand All @@ -224,6 +227,10 @@ class PredicateInfoBuilder {
SmallVectorImpl<Value *> &OpsToRename);
void processSwitch(SwitchInst *, BasicBlock *,
SmallVectorImpl<Value *> &OpsToRename);
void identifyPHICandidates(SmallVectorImpl<Value *> &OpsToRename);
void insertPredicatePHIs(Value *Op, BasicBlock *PHIBlock,
SmallVectorImpl<Value *> &OpsToRename);
void processPredicatePHIs(SmallVectorImpl<Value *> &OpsToRename);
void renameUses(SmallVectorImpl<Value *> &OpsToRename);
void addInfoFor(SmallVectorImpl<Value *> &OpsToRename, Value *Op,
PredicateBase *PB);
Expand Down Expand Up @@ -452,6 +459,71 @@ void PredicateInfoBuilder::processSwitch(
}
}

void PredicateInfoBuilder::identifyPHICandidates(
SmallVectorImpl<Value *> &OpsToRename) {
for (Value *Op : OpsToRename) {
const auto &ValueInfo = getValueInfo(Op);
SmallPtrSet<BasicBlock *, 4> DefiningBlocks;
for (const auto *PInfo : ValueInfo.Infos) {
if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
DefiningBlocks.insert(PBranch->To);
} else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
DefiningBlocks.insert(PSwitch->To);
}
}

if (DefiningBlocks.size() > 1) {
SmallPtrSet<BasicBlock *, 8> PHIBlocks(DefiningBlocks.begin(),
DefiningBlocks.end());
PHICandidates[*PHIBlocks.begin()].push_back(Op);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is what needs to be fixed, for now I just pushed to make sure the insertion functionality and other logics work well, which it does:

PredicateInfo for function: h5diff
define noundef i32 @h5diff(i32 %0, i1 %1) local_unnamed_addr {
  %cond = icmp eq i32 %0, 0
  br i1 %1, label %3, label %4

3:                                                ; preds = %2
  store i32 1, ptr @p, align 4
; Has predicate info
; branch predicate info { TrueEdge: 1 Comparison:  %cond = icmp eq i32 %0, 0 Edge: [label %3,label %5], RenamedOp: %cond }
  %cond.0 = call i1 @llvm.ssa.copy.i1(i1 %cond)
; Has predicate info
; branch predicate info { TrueEdge: 1 Comparison:  %cond = icmp eq i32 %0, 0 Edge: [label %3,label %5], RenamedOp: %0 }
  %.0 = call i32 @llvm.ssa.copy.i32(i32 %0)
  br i1 %cond, label %5, label %common.ret

common.ret:                                       ; preds = %5, %4, %3
  ret i32 0

4:                                                ; preds = %2
  store i32 2, ptr @p, align 4
; Has predicate info
; branch predicate info { TrueEdge: 1 Comparison:  %cond = icmp eq i32 %0, 0 Edge: [label %4,label %5], RenamedOp: %cond }
  %cond.1 = call i1 @llvm.ssa.copy.i1(i1 %cond)
; Has predicate info
; branch predicate info { TrueEdge: 1 Comparison:  %cond = icmp eq i32 %0, 0 Edge: [label %4,label %5], RenamedOp: %0 }
  %.1 = call i32 @llvm.ssa.copy.i32(i32 %0)
  br i1 %cond, label %5, label %common.ret

5:                                                ; preds = %4, %3
  %.predicate.phi = phi i32 [ %.1, %4 ], [ %.0, %3 ]
  %cond.predicate.phi = phi i1 [ %cond.1, %4 ], [ %cond.0, %3 ]
; Has predicate info
; phi predicate info { PHIBlock: label %5 IncomingEdges: 2, RenamedOp: %0 }
  %6 = call i32 @llvm.ssa.copy.i32(i32 %0)
  store i32 %6, ptr @p, align 4
  br label %common.ret
}

}
}
}

void PredicateInfoBuilder::insertPredicatePHIs(
Value *Op, BasicBlock *PHIBlock, SmallVectorImpl<Value *> &OpsToRename) {
IRBuilder<> Builder(&PHIBlock->front());

PHINode *PHI = Builder.CreatePHI(Op->getType(), pred_size(PHIBlock),
Op->getName() + ".predicate.phi");
PredicatePHI *PPHI = new (Allocator) PredicatePHI(Op, PHIBlock);

PPHI->RenamedOp = PHI;

const auto &ValueInfo = getValueInfo(Op);
for (BasicBlock *Pred : predecessors(PHIBlock)) {
Value *IncomingValue = nullptr;
for (const auto *PInfo : ValueInfo.Infos) {
if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
if (PBranch->From == Pred && PBranch->To == PHIBlock) {
PPHI->IncomingPredicates.push_back(
{Pred, const_cast<PredicateBase *>(PInfo)});
IncomingValue = PBranch->OriginalOp;
}
} else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
if (PSwitch->From == Pred && PSwitch->To == PHIBlock) {
PPHI->IncomingPredicates.push_back(
{Pred, const_cast<PredicateBase *>(PInfo)});
IncomingValue = PSwitch->OriginalOp;
}
}
}
PHI->addIncoming(IncomingValue, Pred);
}

addInfoFor(OpsToRename, Op, PPHI);
}

void PredicateInfoBuilder::processPredicatePHIs(
SmallVectorImpl<Value *> &OpsToRename) {
for (const auto &PHICandidate : PHICandidates) {
BasicBlock *PHIBlock = PHICandidate.first;
for (Value *Op : PHICandidate.second) {
insertPredicatePHIs(Op, PHIBlock, OpsToRename);
}
}
}

// Build predicate info for our function
void PredicateInfoBuilder::buildPredicateInfo() {
DT.updateDFSNumbers();
Expand All @@ -478,6 +550,10 @@ void PredicateInfoBuilder::buildPredicateInfo() {
if (DT.isReachableFromEntry(II->getParent()))
processAssume(II, II->getParent(), OpsToRename);
}

identifyPHICandidates(OpsToRename);
processPredicatePHIs(OpsToRename);

// Now rename all our operations.
renameUses(OpsToRename);
}
Expand Down Expand Up @@ -535,6 +611,12 @@ Value *PredicateInfoBuilder::materializeStack(unsigned int &Counter,
CreateSSACopy(B, Op, Op->getName() + "." + Twine(Counter++));
PI.PredicateMap.insert({PIC, ValInfo});
Result.Def = PIC;
} else if (isa<PredicatePHI>(ValInfo)) {
auto *PPHI = dyn_cast<PredicatePHI>(ValInfo);
IRBuilder<> B(&*PPHI->PHIBlock->getFirstInsertionPt());
CallInst *PIC = CreateSSACopy(B, Op);
PI.PredicateMap.insert({PIC, ValInfo});
Result.Def = PIC;
} else {
auto *PAssume = dyn_cast<PredicateAssume>(ValInfo);
assert(PAssume &&
Expand Down Expand Up @@ -623,6 +705,15 @@ void PredicateInfoBuilder::renameUses(SmallVectorImpl<Value *> &OpsToRename) {
OrderedUses.push_back(VD);
}
}
} else if (const auto *PPHI = dyn_cast<PredicatePHI>(PossibleCopy)) {
VD.LocalNum = LN_First;
auto *DomNode = DT.getNode(PPHI->PHIBlock);
if (DomNode) {
VD.DFSIn = DomNode->getDFSNumIn();
VD.DFSOut = DomNode->getDFSNumOut();
VD.PInfo = PossibleCopy;
OrderedUses.push_back(VD);
}
}
}

Expand Down Expand Up @@ -773,10 +864,31 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
}

return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}};
case PT_PHI:
return cast<PredicatePHI>(this)->getConstraint();
}
llvm_unreachable("Unknown predicate type");
}

std::optional<PredicateConstraint> PredicatePHI::getConstraint() const {
if (IncomingPredicates.empty())
return std::nullopt;

auto FirstConstraint = IncomingPredicates[0].second->getConstraint();
if (!FirstConstraint)
return std::nullopt;

for (size_t I = 1; I < IncomingPredicates.size(); ++I) {
auto Constraint = IncomingPredicates[I].second->getConstraint();
if (!Constraint || Constraint->Predicate != FirstConstraint->Predicate ||
Constraint->OtherOp != FirstConstraint->OtherOp) {
return std::nullopt;
}
}

return FirstConstraint;
}

void PredicateInfo::verifyPredicateInfo() const {}

// Replace ssa_copy calls created by PredicateInfo with their operand.
Expand Down Expand Up @@ -838,6 +950,10 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
} else if (const auto *PA = dyn_cast<PredicateAssume>(PI)) {
OS << "; assume predicate info {"
<< " Comparison:" << *PA->Condition;
} else if (const auto *PP = dyn_cast<PredicatePHI>(PI)) {
OS << "; phi predicate info { PHIBlock: ";
PP->PHIBlock->printAsOperand(OS);
OS << " IncomingEdges: " << PP->IncomingPredicates.size();
}
OS << ", RenamedOp: ";
PI->RenamedOp->printAsOperand(OS, false);
Expand Down
Loading