Skip to content

Commit c72558e

Browse files
Use nursery-like concept for background tasks.
1 parent 2de878d commit c72558e

File tree

8 files changed

+95
-82
lines changed

8 files changed

+95
-82
lines changed

prompt_toolkit/application/application.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import time
66
from asyncio import (
77
AbstractEventLoop,
8+
CancelledError,
89
Future,
10+
Task,
911
ensure_future,
1012
get_event_loop,
1113
sleep,
@@ -14,6 +16,7 @@
1416
from traceback import format_tb
1517
from typing import (
1618
Any,
19+
Awaitable,
1720
Callable,
1821
Dict,
1922
FrozenSet,
@@ -134,6 +137,10 @@ class Application(Generic[_AppResult]):
134137
scheduled calls), postpone the rendering max x seconds. '0' means:
135138
don't postpone. '.5' means: try to draw at least twice a second.
136139
140+
:param refresh_interval: Automatically invalidate the UI every so many
141+
seconds. When `None` (the default), only invalidate when `invalidate`
142+
has been called.
143+
137144
Filters:
138145
139146
:param mouse_support: (:class:`~prompt_toolkit.filters.Filter` or
@@ -190,6 +197,7 @@ def __init__(self,
190197
reverse_vi_search_direction: FilterOrBool = False,
191198
min_redraw_interval: Union[float, int, None] = None,
192199
max_render_postpone_time: Union[float, int, None] = .01,
200+
refresh_interval: Optional[float] = None,
193201

194202
on_reset: Optional[ApplicationEventHandler] = None,
195203
on_invalidate: Optional[ApplicationEventHandler] = None,
@@ -238,6 +246,7 @@ def __init__(self,
238246
self.enable_page_navigation_bindings = enable_page_navigation_bindings
239247
self.min_redraw_interval = min_redraw_interval
240248
self.max_render_postpone_time = max_render_postpone_time
249+
self.refresh_interval = refresh_interval
241250

242251
# Events.
243252
self.on_invalidate = Event(self, on_invalidate)
@@ -386,6 +395,8 @@ def reset(self) -> None:
386395

387396
self.exit_style = ''
388397

398+
self.background_tasks: List[Task] = []
399+
389400
self.renderer.reset()
390401
self.key_processor.reset()
391402
self.layout.reset()
@@ -437,7 +448,7 @@ def schedule_redraw() -> None:
437448
async def redraw_in_future() -> None:
438449
await sleep(cast(float, self.min_redraw_interval) - diff)
439450
schedule_redraw()
440-
ensure_future(redraw_in_future())
451+
self.create_background_task(redraw_in_future())
441452
else:
442453
schedule_redraw()
443454
else:
@@ -488,6 +499,19 @@ def run_in_context() -> None:
488499
if self.context is not None:
489500
self.context.run(run_in_context)
490501

502+
def _start_auto_refresh_task(self) -> None:
503+
"""
504+
Start a while/true loop in the background for automatic invalidation of
505+
the UI.
506+
"""
507+
async def auto_refresh():
508+
while True:
509+
await sleep(self.refresh_interval)
510+
self.invalidate()
511+
512+
if self.refresh_interval:
513+
self.create_background_task(auto_refresh())
514+
491515
def _update_invalidate_events(self) -> None:
492516
"""
493517
Make sure to attach 'invalidate' handlers to all invalidate events in
@@ -612,7 +636,7 @@ def read_from_input() -> None:
612636
counter = flush_counter
613637

614638
# Automatically flush keys.
615-
ensure_future(auto_flush_input(counter))
639+
self.create_background_task(auto_flush_input(counter))
616640

617641
async def auto_flush_input(counter: int) -> None:
618642
# Flush input after timeout.
@@ -638,6 +662,7 @@ def flush_input() -> None:
638662
# Draw UI.
639663
self._request_absolute_cursor_position()
640664
self._redraw()
665+
self._start_auto_refresh_task()
641666

642667
has_sigwinch = hasattr(signal, 'SIGWINCH') and in_main_thread()
643668
if has_sigwinch:
@@ -696,6 +721,12 @@ async def _run_async2() -> _AppResult:
696721
try:
697722
result = await _run_async()
698723
finally:
724+
# Wait for the background tasks to be done. This needs to
725+
# go in the finally! If `_run_async` raises
726+
# `KeyboardInterrupt`, we still want to wait for the
727+
# background tasks.
728+
await self.cancel_and_wait_for_background_tasks()
729+
699730
# Set the `_is_running` flag to `False`. Normally this
700731
# happened already in the finally block in `run_async`
701732
# above, but in case of exceptions, that's not always the
@@ -717,9 +748,8 @@ def run(self, pre_run: Optional[Callable[[], None]] = None,
717748
loop = get_event_loop()
718749

719750
def run() -> _AppResult:
720-
f = ensure_future(self.run_async(pre_run=pre_run))
721-
get_event_loop().run_until_complete(f)
722-
return f.result()
751+
coro = self.run_async(pre_run=pre_run)
752+
return get_event_loop().run_until_complete(coro)
723753

724754
def handle_exception(loop, context: Dict[str, Any]) -> None:
725755
" Print the exception, using run_in_terminal. "
@@ -752,6 +782,32 @@ async def in_term() -> None:
752782
else:
753783
return run()
754784

785+
def create_background_task(self, coroutine: Awaitable[None]) -> None:
786+
"""
787+
Start a background task (coroutine) for the running application.
788+
If asyncio had nurseries like Trio, we would create a nursery in
789+
`Application.run_async`, and run the given coroutine in that nursery.
790+
"""
791+
self.background_tasks.append(get_event_loop().create_task(coroutine))
792+
793+
async def cancel_and_wait_for_background_tasks(self) -> None:
794+
"""
795+
Cancel all background tasks, and wait for the cancellation to be done.
796+
If any of the background tasks raised an exception, this will also
797+
propagate the exception.
798+
799+
(If we had nurseries like Trio, this would be the `__aexit__` of a
800+
nursery.)
801+
"""
802+
for task in self.background_tasks:
803+
task.cancel()
804+
805+
for task in self.background_tasks:
806+
try:
807+
await task
808+
except CancelledError:
809+
pass
810+
755811
def cpr_not_supported_callback(self) -> None:
756812
"""
757813
Called when we don't receive the cursor position response in time.

prompt_toolkit/buffer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def _text_changed(self) -> None:
464464
# (This happens on all change events, unlike auto completion, also when
465465
# deleting text.)
466466
if self.validator and self.validate_while_typing():
467-
ensure_future(self._async_validator())
467+
get_app().create_background_task(self._async_validator())
468468

469469
def _cursor_position_changed(self) -> None:
470470
# Remove any complete state.
@@ -1144,11 +1144,11 @@ def insert_text(self, data: str, overwrite: bool = False,
11441144

11451145
# Only complete when "complete_while_typing" is enabled.
11461146
if self.completer and self.complete_while_typing():
1147-
ensure_future(self._async_completer())
1147+
get_app().create_background_task(self._async_completer())
11481148

11491149
# Call auto_suggest.
11501150
if self.auto_suggest:
1151-
ensure_future(self._async_suggester())
1151+
get_app().create_background_task(self._async_suggester())
11521152

11531153
def undo(self) -> None:
11541154
# Pop from the undo-stack until we find a text that if different from
@@ -1484,7 +1484,7 @@ def start_completion(
14841484
# Only one of these options can be selected.
14851485
assert select_first + select_last + insert_common_part <= 1
14861486

1487-
ensure_future(self._async_completer(
1487+
get_app().create_background_task(self._async_completer(
14881488
select_first=select_first,
14891489
select_last=select_last,
14901490
insert_common_part=insert_common_part,

prompt_toolkit/contrib/ssh/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def shell_requested(self) -> bool:
5858
return True
5959

6060
def session_started(self) -> None:
61-
asyncio.ensure_future(self._interact())
61+
asyncio.get_event_loop().create_task(self._interact())
6262

6363
async def _interact(self) -> None:
6464
if self._chan is None:

prompt_toolkit/history.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import datetime
1010
import os
1111
from abc import ABCMeta, abstractmethod
12-
from asyncio import ensure_future
1312
from typing import AsyncGenerator, Iterable, List
1413

14+
from prompt_toolkit.application.current import get_app
15+
1516
from .eventloop import generator_to_async_generator
1617
from .utils import Event
1718

@@ -59,7 +60,7 @@ def start_loading(self) -> None:
5960
" Start loading the history. "
6061
if not self._loading:
6162
self._loading = True
62-
ensure_future(self._start_loading())
63+
get_app().create_background_task(self._start_loading())
6364

6465
def get_item_loaded_event(self) -> Event['History']:
6566
" Event which is triggered when a new item is loaded. "

prompt_toolkit/key_binding/key_processor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
correct callbacks when new key presses are feed through `feed`.
88
"""
99
import weakref
10-
from asyncio import ensure_future, sleep
10+
from asyncio import sleep
1111
from collections import deque
1212
from typing import TYPE_CHECKING, Any, Deque, Generator, List, Optional, Union
1313

@@ -392,7 +392,8 @@ def _start_timeout(self) -> None:
392392
no key was pressed in the meantime, we flush all data in the queue and
393393
call the appropriate key binding handlers.
394394
"""
395-
timeout = get_app().timeoutlen
395+
app = get_app()
396+
timeout = app.timeoutlen
396397

397398
if timeout is None:
398399
return
@@ -415,7 +416,7 @@ def flush_keys() -> None:
415416
# Automatically flush keys.
416417
# (_daemon needs to be set, otherwise, this will hang the
417418
# application for .5 seconds before exiting.)
418-
ensure_future(wait())
419+
app.create_background_task(wait())
419420

420421

421422
class KeyPressEvent:

prompt_toolkit/renderer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Renders the command line on the console.
33
(Redraws parts of the input line that were changed.)
44
"""
5-
from asyncio import FIRST_COMPLETED, Future, ensure_future, sleep, wait
5+
from asyncio import FIRST_COMPLETED, Future, sleep, wait
66
from collections import deque
77
from enum import Enum
88
from typing import (
@@ -16,6 +16,7 @@
1616
Tuple,
1717
)
1818

19+
from prompt_toolkit.application.current import get_app
1920
from prompt_toolkit.data_structures import Point, Size
2021
from prompt_toolkit.filters import FilterOrBool, to_filter
2122
from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
@@ -452,7 +453,7 @@ async def timer() -> None:
452453
# Make sure to call this callback in the main thread.
453454
self.cpr_not_supported_callback()
454455

455-
ensure_future(timer())
456+
get_app().create_background_task(timer())
456457

457458
def report_absolute_cursor_row(self, row: int) -> None:
458459
"""
@@ -505,11 +506,11 @@ async def wait_for_timeout() -> None:
505506
# Got timeout, erase queue.
506507
self._waiting_for_cpr_futures = deque()
507508

508-
futures = [
509-
ensure_future(wait_for_responses()),
510-
ensure_future(wait_for_timeout()),
509+
coroutines = [
510+
wait_for_responses(),
511+
wait_for_timeout(),
511512
]
512-
await wait(futures, return_when=FIRST_COMPLETED)
513+
await wait(coroutines, return_when=FIRST_COMPLETED)
513514

514515
def render(self, app: 'Application[Any]', layout: 'Layout',
515516
is_done: bool = False) -> None:

prompt_toolkit/shortcuts/progress_bar/base.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
for item in pb(data):
88
...
99
"""
10-
import contextlib
1110
import datetime
1211
import functools
1312
import os
@@ -16,7 +15,7 @@
1615
import threading
1716
import traceback
1817
from asyncio import (
19-
ensure_future,
18+
CancelledError,
2019
get_event_loop,
2120
new_event_loop,
2221
set_event_loop,
@@ -190,19 +189,19 @@ def width_for_formatter(formatter: Formatter) -> AnyDimension:
190189
])),
191190
style=self.style,
192191
key_bindings=self.key_bindings,
192+
refresh_interval=.3,
193193
color_depth=self.color_depth,
194194
output=self.output,
195195
input=self.input)
196196

197197
# Run application in different thread.
198198
def run() -> None:
199199
set_event_loop(self._app_loop)
200-
with _auto_refresh_context(self.app, .3):
201-
try:
202-
self.app.run()
203-
except BaseException as e:
204-
traceback.print_exc()
205-
print(e)
200+
try:
201+
self.app.run()
202+
except BaseException as e:
203+
traceback.print_exc()
204+
print(e)
206205

207206
ctx: contextvars.Context = contextvars.copy_context()
208207

@@ -230,6 +229,7 @@ def __exit__(self, *a: object) -> None:
230229

231230
if self._thread is not None:
232231
self._thread.join()
232+
self._app_loop.close()
233233

234234
def __call__(self,
235235
data: Optional[Iterable[_T]] = None,
@@ -372,28 +372,3 @@ def time_left(self) -> Optional[datetime.timedelta]:
372372
return None
373373
else:
374374
return self.time_elapsed * (100 - self.percentage) / self.percentage
375-
376-
377-
@contextlib.contextmanager
378-
def _auto_refresh_context(
379-
app: 'Application', refresh_interval: Optional[float] = None
380-
) -> Generator[None, None, None]:
381-
"""
382-
Return a context manager for the auto-refresh loop.
383-
"""
384-
done = False
385-
386-
async def run() -> None:
387-
if refresh_interval:
388-
while not done:
389-
await sleep(refresh_interval)
390-
app.invalidate()
391-
392-
if refresh_interval:
393-
ensure_future(run())
394-
395-
try:
396-
yield
397-
finally:
398-
# Exit.
399-
done = True

0 commit comments

Comments
 (0)