Skip to content

Commit ff9ac33

Browse files
committed
[ARM][MVE] Validate tail predication values
Iterate through the loop and check that the observable values produced are the same whether tail predication happens or not. We want to find out if the tail-predicated version of this loop will produce the same values as the loop in its original form. For this to be true, the newly inserted implicit predication must not change the the (observable) results. We're doing this because many instructions in the loop will not be predicated and so the conversion from VPT predication to tail predication can result in different values being produced, because of falsely predicated lanes not being updated in the converted form. A masked load, whether through VPT or tail predication, will write zeros to any of the falsely predicated bytes. So, from the loads, we know that the false lanes are zeroed and here we're trying to track that those false lanes remain zero, or where they change, the differences are masked away by their user(s). All MVE loads and stores have to be predicated, so we know that any load operands, or stored results are equivalent already. Other explicitly predicated instructions will perform the same operation in the original loop and the tail-predicated form too. Because of this, we can insert loads, stores and other predicated instructions into our KnownFalseZeros set and build from there. Differential Revision: https://reviews.llvm.org/D75452
1 parent 0c28a09 commit ff9ac33

File tree

4 files changed

+881
-5
lines changed

4 files changed

+881
-5
lines changed

llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,20 +508,31 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
508508
return true;
509509
}
510510

511+
static bool isVectorPredicated(MachineInstr *MI) {
512+
int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
513+
return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
514+
}
515+
516+
static bool isRegInClass(const MachineOperand &MO,
517+
const TargetRegisterClass *Class) {
518+
return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
519+
}
520+
511521
bool LowOverheadLoop::ValidateLiveOuts() const {
512522
// Collect Q-regs that are live in the exit blocks. We don't collect scalars
513523
// because they won't be affected by lane predication.
514524
const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
515525
SmallSet<Register, 2> LiveOuts;
516-
SmallVector<MachineBasicBlock*, 2> ExitBlocks;
526+
SmallVector<MachineBasicBlock *, 2> ExitBlocks;
517527
ML.getExitBlocks(ExitBlocks);
518528
for (auto *MBB : ExitBlocks)
519529
for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins())
520530
if (QPRs->contains(RegMask.PhysReg))
521531
LiveOuts.insert(RegMask.PhysReg);
522532

523533
// Collect the instructions in the loop body that define the live-out values.
524-
SmallPtrSet<MachineInstr*, 2> LiveMIs;
534+
SmallPtrSet<MachineInstr *, 2> LiveMIs;
535+
assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
525536
MachineBasicBlock *MBB = ML.getHeader();
526537
for (auto Reg : LiveOuts)
527538
if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg))
@@ -534,12 +545,98 @@ bool LowOverheadLoop::ValidateLiveOuts() const {
534545
// equivalent when we perform the predication transformation; so we know that
535546
// any VPT predicated instruction is predicated upon VCTP. Any live-out
536547
// instruction needs to be predicated, so check this here.
537-
for (auto *MI : LiveMIs) {
538-
int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
539-
if (PIdx == -1 || MI->getOperand(PIdx+1).getReg() != ARM::VPR)
548+
for (auto *MI : LiveMIs)
549+
if (!isVectorPredicated(MI))
540550
return false;
551+
552+
// We want to find out if the tail-predicated version of this loop will
553+
// produce the same values as the loop in its original form. For this to
554+
// be true, the newly inserted implicit predication must not change the
555+
// the (observable) results.
556+
// We're doing this because many instructions in the loop will not be
557+
// predicated and so the conversion from VPT predication to tail-predication
558+
// can result in different values being produced; due to the tail-predication
559+
// preventing many instructions from updating their falsely predicated
560+
// lanes. This analysis assumes that all the instructions perform lane-wise
561+
// operations and don't perform any exchanges.
562+
// A masked load, whether through VPT or tail predication, will write zeros
563+
// to any of the falsely predicated bytes. So, from the loads, we know that
564+
// the false lanes are zeroed and here we're trying to track that those false
565+
// lanes remain zero, or where they change, the differences are masked away
566+
// by their user(s).
567+
// All MVE loads and stores have to be predicated, so we know that any load
568+
// operands, or stored results are equivalent already. Other explicitly
569+
// predicated instructions will perform the same operation in the original
570+
// loop and the tail-predicated form too. Because of this, we can insert
571+
// loads, stores and other predicated instructions into our KnownFalseZeros
572+
// set and build from there.
573+
SetVector<MachineInstr *> UnknownFalseLanes;
574+
SmallPtrSet<MachineInstr *, 4> KnownFalseZeros;
575+
for (auto &MI : *MBB) {
576+
const MCInstrDesc &MCID = MI.getDesc();
577+
uint64_t Flags = MCID.TSFlags;
578+
if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
579+
continue;
580+
581+
if (isVectorPredicated(&MI)) {
582+
KnownFalseZeros.insert(&MI);
583+
continue;
584+
}
585+
586+
if (MI.getNumDefs() == 0)
587+
continue;
588+
589+
// Only evaluate instructions which produce a single value.
590+
assert((MI.getNumDefs() == 1 && MI.defs().begin()->isReg()) &&
591+
"Expected no more than one register def");
592+
593+
Register DefReg = MI.defs().begin()->getReg();
594+
for (auto &MO : MI.operands()) {
595+
if (!isRegInClass(MO, QPRs) || !MO.isUse() || MO.getReg() != DefReg)
596+
continue;
597+
598+
// If this instruction overwrites one of its operands, and that register
599+
// has known lanes, then this instruction also has known predicated false
600+
// lanes.
601+
if (auto *OpDef = RDA.getMIOperand(&MI, MO)) {
602+
if (KnownFalseZeros.count(OpDef)) {
603+
KnownFalseZeros.insert(&MI);
604+
break;
605+
}
606+
}
607+
}
608+
if (!KnownFalseZeros.count(&MI))
609+
UnknownFalseLanes.insert(&MI);
541610
}
542611

612+
auto HasKnownUsers = [this](MachineInstr *MI, const MachineOperand &MO,
613+
SmallPtrSetImpl<MachineInstr *> &Knowns) {
614+
SmallPtrSet<MachineInstr *, 2> Uses;
615+
RDA.getGlobalUses(MI, MO.getReg(), Uses);
616+
for (auto *Use : Uses) {
617+
if (Use != MI && !Knowns.count(Use))
618+
return false;
619+
}
620+
return true;
621+
};
622+
623+
// Now for all the unknown values, see if they're only consumed by known
624+
// instructions. Visit in reverse so that we can start at the values being
625+
// stored and then we can work towards the leaves, hopefully adding more
626+
// instructions to KnownFalseZeros.
627+
for (auto *MI : reverse(UnknownFalseLanes)) {
628+
for (auto &MO : MI->operands()) {
629+
if (!isRegInClass(MO, QPRs) || !MO.isDef())
630+
continue;
631+
if (!HasKnownUsers(MI, MO, KnownFalseZeros)) {
632+
LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
633+
<< TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
634+
return false;
635+
}
636+
}
637+
// Any unknown false lanes have been masked away by the user(s).
638+
KnownFalseZeros.insert(MI);
639+
}
543640
return true;
544641
}
545642

0 commit comments

Comments
 (0)