Skip to content

Commit 8e9bdd9

Browse files
Use ExitStack in Application.run()
1 parent d3dcbe8 commit 8e9bdd9

File tree

1 file changed

+100
-59
lines changed

1 file changed

+100
-59
lines changed

src/prompt_toolkit/application/application.py

Lines changed: 100 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
set_event_loop,
1616
sleep,
1717
)
18-
from contextlib import contextmanager
18+
from contextlib import ExitStack, contextmanager
1919
from subprocess import Popen
2020
from traceback import format_tb
2121
from typing import (
@@ -29,6 +29,7 @@
2929
Generic,
3030
Hashable,
3131
Iterable,
32+
Iterator,
3233
List,
3334
Optional,
3435
Tuple,
@@ -670,12 +671,7 @@ async def run_async(
670671
# See: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/1553
671672
handle_sigint = False
672673

673-
async def _run_async() -> _AppResult:
674-
"Coroutine."
675-
loop = get_event_loop()
676-
f = loop.create_future()
677-
self.future = f # XXX: make sure to set this before calling '_redraw'.
678-
self.loop = loop
674+
async def _run_async(f: "asyncio.Future[_AppResult]") -> _AppResult:
679675
self.context = contextvars.copy_context()
680676

681677
# Counter for cancelling 'flush' timeouts. Every time when a key is
@@ -790,70 +786,115 @@ def flush_input() -> None:
790786
# Store unprocessed input as typeahead for next time.
791787
store_typeahead(self.input, self.key_processor.empty_queue())
792788

793-
return cast(_AppResult, result)
789+
return result
790+
791+
@contextmanager
792+
def get_loop() -> Iterator[AbstractEventLoop]:
793+
loop = get_event_loop()
794+
self.loop = loop
794795

795-
async def _run_async2() -> _AppResult:
796-
self._is_running = True
797796
try:
798-
# Make sure to set `_invalidated` to `False` to begin with,
799-
# otherwise we're not going to paint anything. This can happen if
800-
# this application had run before on a different event loop, and a
801-
# paint was scheduled using `call_soon_threadsafe` with
802-
# `max_postpone_time`.
803-
self._invalidated = False
804-
805-
loop = get_event_loop()
806-
807-
if handle_sigint:
808-
loop.add_signal_handler(
809-
signal.SIGINT,
810-
lambda *_: loop.call_soon_threadsafe(
811-
self.key_processor.send_sigint
812-
),
813-
)
797+
yield loop
798+
finally:
799+
self.loop = None
814800

815-
if set_exception_handler:
816-
previous_exc_handler = loop.get_exception_handler()
817-
loop.set_exception_handler(self._handle_exception)
801+
@contextmanager
802+
def set_is_running() -> Iterator[None]:
803+
self._is_running = True
804+
try:
805+
yield
806+
finally:
807+
self._is_running = False
818808

819-
# Set slow_callback_duration.
820-
original_slow_callback_duration = loop.slow_callback_duration
821-
loop.slow_callback_duration = slow_callback_duration
809+
@contextmanager
810+
def set_handle_sigint(loop: AbstractEventLoop) -> Iterator[None]:
811+
if handle_sigint:
812+
loop.add_signal_handler(
813+
signal.SIGINT,
814+
lambda *_: loop.call_soon_threadsafe(
815+
self.key_processor.send_sigint
816+
),
817+
)
818+
try:
819+
yield
820+
finally:
821+
loop.remove_signal_handler(signal.SIGINT)
822+
else:
823+
yield
822824

825+
@contextmanager
826+
def set_exception_handler_ctx(loop: AbstractEventLoop) -> Iterator[None]:
827+
if set_exception_handler:
828+
previous_exc_handler = loop.get_exception_handler()
829+
loop.set_exception_handler(self._handle_exception)
823830
try:
824-
with set_app(self), self._enable_breakpointhook():
825-
try:
826-
result = await _run_async()
827-
finally:
828-
# Wait for the background tasks to be done. This needs to
829-
# go in the finally! If `_run_async` raises
830-
# `KeyboardInterrupt`, we still want to wait for the
831-
# background tasks.
832-
await self.cancel_and_wait_for_background_tasks()
833-
834-
# Also remove the Future again. (This brings the
835-
# application back to its initial state, where it also
836-
# doesn't have a Future.)
837-
self.future = None
838-
839-
return result
831+
yield
840832
finally:
841-
if set_exception_handler:
842-
loop.set_exception_handler(previous_exc_handler)
833+
loop.set_exception_handler(previous_exc_handler)
843834

844-
if handle_sigint:
845-
loop.remove_signal_handler(signal.SIGINT)
835+
else:
836+
yield
846837

847-
# Reset slow_callback_duration.
848-
loop.slow_callback_duration = original_slow_callback_duration
838+
@contextmanager
839+
def set_callback_duration(loop: AbstractEventLoop) -> Iterator[None]:
840+
# Set slow_callback_duration.
841+
original_slow_callback_duration = loop.slow_callback_duration
842+
loop.slow_callback_duration = slow_callback_duration
843+
try:
844+
yield
845+
finally:
846+
# Reset slow_callback_duration.
847+
loop.slow_callback_duration = original_slow_callback_duration
848+
849+
@contextmanager
850+
def create_future(
851+
loop: AbstractEventLoop,
852+
) -> "Iterator[asyncio.Future[_AppResult]]":
853+
f = loop.create_future()
854+
self.future = f # XXX: make sure to set this before calling '_redraw'.
849855

856+
try:
857+
yield f
850858
finally:
851-
# Set the `_is_running` flag to `False`. Normally this happened
852-
# already in the finally block in `run_async` above, but in
853-
# case of exceptions, that's not always the case.
854-
self._is_running = False
859+
# Also remove the Future again. (This brings the
860+
# application back to its initial state, where it also
861+
# doesn't have a Future.)
862+
self.future = None
863+
864+
with ExitStack() as stack:
865+
stack.enter_context(set_is_running())
866+
867+
# Make sure to set `_invalidated` to `False` to begin with,
868+
# otherwise we're not going to paint anything. This can happen if
869+
# this application had run before on a different event loop, and a
870+
# paint was scheduled using `call_soon_threadsafe` with
871+
# `max_postpone_time`.
872+
self._invalidated = False
873+
874+
loop = stack.enter_context(get_loop())
855875

856-
return await _run_async2()
876+
stack.enter_context(set_handle_sigint(loop))
877+
stack.enter_context(set_exception_handler_ctx(loop))
878+
stack.enter_context(set_callback_duration(loop))
879+
stack.enter_context(set_app(self))
880+
stack.enter_context(self._enable_breakpointhook())
881+
882+
f = stack.enter_context(create_future(loop))
883+
884+
try:
885+
return await _run_async(f)
886+
finally:
887+
# Wait for the background tasks to be done. This needs to
888+
# go in the finally! If `_run_async` raises
889+
# `KeyboardInterrupt`, we still want to wait for the
890+
# background tasks.
891+
await self.cancel_and_wait_for_background_tasks()
892+
893+
# The `ExitStack` above is defined in typeshed in a way that it can
894+
# swallow exceptions. Without next line, mypy would think that there's
895+
# a possibility we don't return here. See:
896+
# https://github.com/python/mypy/issues/7726
897+
assert False, "unreachable"
857898

858899
def run(
859900
self,

0 commit comments

Comments
 (0)