Skip to content

Commit 43c8307

Browse files
committed
[Coroutines] CoroElide enhancement
Fix regression of CoreElide pass when current function is coroutine. Differential Revision: https://reviews.llvm.org/D71663
1 parent 2b5a897 commit 43c8307

File tree

2 files changed

+176
-18
lines changed

2 files changed

+176
-18
lines changed

llvm/lib/Transforms/Coroutines/CoroElide.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "llvm/Transforms/Coroutines/CoroElide.h"
1010
#include "CoroInternal.h"
11+
#include "llvm/ADT/DenseMap.h"
1112
#include "llvm/Analysis/AliasAnalysis.h"
1213
#include "llvm/Analysis/InstructionSimplify.h"
1314
#include "llvm/IR/Dominators.h"
@@ -27,8 +28,9 @@ struct Lowerer : coro::LowererBase {
2728
SmallVector<CoroBeginInst *, 1> CoroBegins;
2829
SmallVector<CoroAllocInst *, 1> CoroAllocs;
2930
SmallVector<CoroSubFnInst *, 4> ResumeAddr;
30-
SmallVector<CoroSubFnInst *, 4> DestroyAddr;
31+
DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr;
3132
SmallVector<CoroFreeInst *, 1> CoroFrees;
33+
CoroSuspendInst *CoroFinalSuspend;
3234

3335
Lowerer(Module &M) : LowererBase(M) {}
3436

@@ -146,33 +148,62 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
146148
if (CoroAllocs.empty())
147149
return false;
148150

149-
// Check that for every coro.begin there is a coro.destroy directly
150-
// referencing the SSA value of that coro.begin along a non-exceptional path.
151+
// Check that for every coro.begin there is at least one coro.destroy directly
152+
// referencing the SSA value of that coro.begin along each
153+
// non-exceptional path.
151154
// If the value escaped, then coro.destroy would have been referencing a
152155
// memory ___location storing that value and not the virtual register.
153156

154-
// First gather all of the non-exceptional terminators for the function.
155157
SmallPtrSet<Instruction *, 8> Terminators;
156-
for (BasicBlock &B : *F) {
157-
auto *TI = B.getTerminator();
158-
if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
159-
!isa<UnreachableInst>(TI))
160-
Terminators.insert(TI);
158+
bool HasMultiPred = false;
159+
// First gather all of the non-exceptional terminators for the function.
160+
// Consider the final coro.suspend as the real terminator when the current
161+
// function is a coroutine.
162+
if (CoroFinalSuspend) {
163+
// If block of final coro.suspend has more than one predecessor,
164+
// then there is one resume path and the others are exceptional paths,
165+
// consider these predecessors as terminators.
166+
BasicBlock *FinalBB = CoroFinalSuspend->getParent();
167+
if (FinalBB->hasNPredecessorsOrMore(2)) {
168+
HasMultiPred = true;
169+
for (auto *B : predecessors(FinalBB))
170+
Terminators.insert(B->getTerminator());
171+
} else
172+
Terminators.insert(CoroFinalSuspend);
173+
} else {
174+
for (BasicBlock &B : *F) {
175+
auto *TI = B.getTerminator();
176+
if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
177+
!isa<UnreachableInst>(TI))
178+
Terminators.insert(TI);
179+
}
161180
}
162181

163182
// Filter out the coro.destroy that lie along exceptional paths.
164183
SmallPtrSet<CoroSubFnInst *, 4> DAs;
165-
for (CoroSubFnInst *DA : DestroyAddr) {
166-
for (Instruction *TI : Terminators) {
167-
if (DT.dominates(DA, TI)) {
168-
DAs.insert(DA);
169-
break;
184+
SmallPtrSet<Instruction *, 2> TIs;
185+
SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
186+
for (auto &It : DestroyAddr) {
187+
for (CoroSubFnInst *DA : It.second) {
188+
for (Instruction *TI : Terminators) {
189+
if (DT.dominates(DA, TI)) {
190+
if (HasMultiPred)
191+
TIs.insert(TI);
192+
else
193+
DAs.insert(DA);
194+
break;
195+
}
170196
}
171197
}
198+
// If all the predecessors dominate coro.destroys that reference same
199+
// coro.begin, record the coro.begin
200+
if (TIs.size() == Terminators.size()) {
201+
ReferencedCoroBegins.insert(It.first);
202+
TIs.clear();
203+
}
172204
}
173205

174206
// Find all the coro.begin referenced by coro.destroy along happy paths.
175-
SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
176207
for (CoroSubFnInst *DA : DAs) {
177208
if (auto *CB = dyn_cast<CoroBeginInst>(DA->getFrame()))
178209
ReferencedCoroBegins.insert(CB);
@@ -188,12 +219,22 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
188219

189220
void Lowerer::collectPostSplitCoroIds(Function *F) {
190221
CoroIds.clear();
191-
for (auto &I : instructions(F))
222+
CoroFinalSuspend = nullptr;
223+
for (auto &I : instructions(F)) {
192224
if (auto *CII = dyn_cast<CoroIdInst>(&I))
193225
if (CII->getInfo().isPostSplit())
194226
// If it is the coroutine itself, don't touch it.
195227
if (CII->getCoroutine() != CII->getFunction())
196228
CoroIds.push_back(CII);
229+
230+
if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
231+
if (CSI->isFinal()) {
232+
if (!CoroFinalSuspend)
233+
CoroFinalSuspend = CSI;
234+
else
235+
report_fatal_error("Only one suspend point can be marked as final");
236+
}
237+
}
197238
}
198239

199240
bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
@@ -226,7 +267,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
226267
ResumeAddr.push_back(II);
227268
break;
228269
case CoroSubFnInst::DestroyIndex:
229-
DestroyAddr.push_back(II);
270+
DestroyAddr[CB].push_back(II);
230271
break;
231272
default:
232273
llvm_unreachable("unexpected coro.subfn.addr constant");
@@ -249,7 +290,8 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
249290
Resumers,
250291
ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);
251292

252-
replaceWithConstant(DestroyAddrConstant, DestroyAddr);
293+
for (auto &It : DestroyAddr)
294+
replaceWithConstant(DestroyAddrConstant, It.second);
253295

254296
if (ShouldElide) {
255297
auto *FrameTy = getFrameType(cast<Function>(ResumeAddrConstant));

llvm/test/Transforms/Coroutines/coro-heap-elide.ll

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,120 @@ entry:
8484
ret void
8585
}
8686

87+
; CHECK-LABEL: @callResume_with_coro_suspend_1(
88+
define void @callResume_with_coro_suspend_1() {
89+
entry:
90+
; CHECK: alloca %f.frame
91+
; CHECK-NOT: coro.begin
92+
; CHECK-NOT: CustomAlloc
93+
; CHECK: call void @may_throw()
94+
%hdl = call i8* @f()
95+
96+
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame)
97+
%0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
98+
%1 = bitcast i8* %0 to void (i8*)*
99+
call fastcc void %1(i8* %hdl)
100+
%2 = call token @llvm.coro.save(i8* %hdl)
101+
%3 = call i8 @llvm.coro.suspend(token %2, i1 false)
102+
switch i8 %3, label %coro.ret [
103+
i8 0, label %final.suspend
104+
i8 1, label %cleanups
105+
]
106+
107+
; CHECK-LABEL: final.suspend:
108+
final.suspend:
109+
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
110+
%4 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
111+
%5 = bitcast i8* %4 to void (i8*)*
112+
call fastcc void %5(i8* %hdl)
113+
%6 = call token @llvm.coro.save(i8* %hdl)
114+
%7 = call i8 @llvm.coro.suspend(token %6, i1 true)
115+
switch i8 %7, label %coro.ret [
116+
i8 0, label %coro.ret
117+
i8 1, label %cleanups
118+
]
119+
120+
; CHECK-LABEL: cleanups:
121+
cleanups:
122+
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
123+
%8 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
124+
%9 = bitcast i8* %8 to void (i8*)*
125+
call fastcc void %9(i8* %hdl)
126+
br label %coro.ret
127+
128+
; CHECK-LABEL: coro.ret:
129+
coro.ret:
130+
; CHECK-NEXT: ret void
131+
ret void
132+
}
133+
134+
; CHECK-LABEL: @callResume_with_coro_suspend_2(
135+
define void @callResume_with_coro_suspend_2() personality i8* null {
136+
entry:
137+
; CHECK: alloca %f.frame
138+
; CHECK-NOT: coro.begin
139+
; CHECK-NOT: CustomAlloc
140+
; CHECK: call void @may_throw()
141+
%hdl = call i8* @f()
142+
143+
%0 = call token @llvm.coro.save(i8* %hdl)
144+
; CHECK: invoke fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame)
145+
%1 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
146+
%2 = bitcast i8* %1 to void (i8*)*
147+
invoke fastcc void %2(i8* %hdl)
148+
to label %invoke.cont1 unwind label %lpad
149+
150+
; CHECK-LABEL: invoke.cont1:
151+
invoke.cont1:
152+
%3 = call i8 @llvm.coro.suspend(token %0, i1 false)
153+
switch i8 %3, label %coro.ret [
154+
i8 0, label %final.ready
155+
i8 1, label %cleanups
156+
]
157+
158+
; CHECK-LABEL: lpad:
159+
lpad:
160+
%4 = landingpad { i8*, i32 }
161+
catch i8* null
162+
; CHECK: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
163+
%5 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
164+
%6 = bitcast i8* %5 to void (i8*)*
165+
call fastcc void %6(i8* %hdl)
166+
br label %final.suspend
167+
168+
; CHECK-LABEL: final.ready:
169+
final.ready:
170+
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
171+
%7 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
172+
%8 = bitcast i8* %7 to void (i8*)*
173+
call fastcc void %8(i8* %hdl)
174+
br label %final.suspend
175+
176+
; CHECK-LABEL: final.suspend:
177+
final.suspend:
178+
%9 = call token @llvm.coro.save(i8* %hdl)
179+
%10 = call i8 @llvm.coro.suspend(token %9, i1 true)
180+
switch i8 %10, label %coro.ret [
181+
i8 0, label %coro.ret
182+
i8 1, label %cleanups
183+
]
184+
185+
; CHECK-LABEL: cleanups:
186+
cleanups:
187+
; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame)
188+
%11 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
189+
%12 = bitcast i8* %11 to void (i8*)*
190+
call fastcc void %12(i8* %hdl)
191+
br label %coro.ret
192+
193+
; CHECK-LABEL: coro.ret:
194+
coro.ret:
195+
; CHECK-NEXT: ret void
196+
ret void
197+
}
198+
199+
200+
87201
; CHECK-LABEL: @callResume_PR34897_no_elision(
88202
define void @callResume_PR34897_no_elision(i1 %cond) {
89203
; CHECK-LABEL: entry:
@@ -161,3 +275,5 @@ declare i8* @llvm.coro.free(token, i8*)
161275
declare i8* @llvm.coro.begin(token, i8*)
162276
declare i8* @llvm.coro.frame(token)
163277
declare i8* @llvm.coro.subfn.addr(i8*, i8)
278+
declare i8 @llvm.coro.suspend(token, i1)
279+
declare token @llvm.coro.save(i8*)

0 commit comments

Comments
 (0)