Skip to content

Commit c7eb02c

Browse files
Ken Odegardjonathanslenders
authored andcommitted
Unpack tempfile's callables
Need to unpack tempfile and tempfile_suffix callables before deciding whether creating a simple or advanced temp file for the editor.
1 parent fc86ff0 commit c7eb02c

File tree

2 files changed

+57
-28
lines changed

2 files changed

+57
-28
lines changed

prompt_toolkit/buffer.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from .history import History, InMemoryHistory
4444
from .search import SearchDirection, SearchState
4545
from .selection import PasteMode, SelectionState, SelectionType
46-
from .utils import Event, to_str
46+
from .utils import Event, call_if_callable, to_str
4747
from .validation import ValidationError, Validator
4848

4949
__all__ = [
@@ -1395,6 +1395,45 @@ def apply_search(self, search_state: SearchState,
13951395
def exit_selection(self) -> None:
13961396
self.selection_state = None
13971397

1398+
def _editor_simple_tempfile(self) -> Tuple[bool, str, str]:
1399+
# Simple (file) tempfile implementation.
1400+
suffix = call_if_callable(self.tempfile_suffix)
1401+
suffix = str(suffix) if suffix else None
1402+
descriptor, filename = tempfile.mkstemp(suffix)
1403+
1404+
os.write(descriptor, self.text.encode('utf-8'))
1405+
os.close(descriptor)
1406+
1407+
# Returning (simple, filename to open, path to remove).
1408+
return True, filename, filename
1409+
1410+
def _editor_complex_tempfile(self) -> Tuple[bool, str, str]:
1411+
# Complex (directory) tempfile implementation.
1412+
headtail = call_if_callable(self.tempfile)
1413+
if not headtail:
1414+
# Revert to simple case.
1415+
return self._editor_simple_tempfile()
1416+
headtail = str(headtail)
1417+
1418+
# Try to make according to tempfile logic.
1419+
head, tail = os.path.split(headtail)
1420+
if os.path.isabs(head):
1421+
head = head[1:]
1422+
1423+
dirpath = tempfile.mkdtemp()
1424+
if head:
1425+
dirpath = os.path.join(dirpath, head)
1426+
# Assume there is no issue creating dirs in this temp dir.
1427+
os.makedirs(dirpath)
1428+
1429+
# Open the filename and write current text.
1430+
filename = os.path.join(dirpath, tail)
1431+
with open(filename, "w", encoding="utf-8") as fh:
1432+
fh.write(self.text)
1433+
1434+
# Returning (complex, filename to open, path to remove).
1435+
return False, filename, dirpath
1436+
13981437
def open_in_editor(self, validate_and_handle: bool = False) -> 'asyncio.Task[None]':
13991438
"""
14001439
Open code in editor.
@@ -1404,33 +1443,11 @@ def open_in_editor(self, validate_and_handle: bool = False) -> 'asyncio.Task[Non
14041443
if self.read_only():
14051444
raise EditReadOnlyBuffer()
14061445

1407-
# Write to temporary file
1408-
if self.tempfile:
1409-
# Try to make according to tempfile logic.
1410-
head, tail = os.path.split(to_str(self.tempfile))
1411-
if os.path.isabs(head):
1412-
head = head[1:]
1413-
1414-
dirpath = tempfile.mkdtemp()
1415-
remove = dirpath
1416-
if head:
1417-
dirpath = os.path.join(dirpath, head)
1418-
# Assume there is no issue creating dirs in this temp dir.
1419-
os.makedirs(dirpath)
1420-
1421-
# Open the filename of interest.
1422-
filename = os.path.join(dirpath, tail)
1423-
descriptor = os.open(filename, os.O_WRONLY|os.O_CREAT)
1446+
# Write current text to temporary file
1447+
if not self.tempfile:
1448+
simple, filename, remove = self._editor_simple_tempfile()
14241449
else:
1425-
# Fallback to tempfile_suffix logic.
1426-
suffix = None
1427-
if self.tempfile_suffix:
1428-
suffix = to_str(self.tempfile_suffix)
1429-
descriptor, filename = tempfile.mkstemp(suffix)
1430-
remove = filename
1431-
1432-
os.write(descriptor, self.text.encode('utf-8'))
1433-
os.close(descriptor)
1450+
simple, filename, remove = self._editor_complex_tempfile()
14341451

14351452
async def run() -> None:
14361453
try:
@@ -1461,7 +1478,10 @@ async def run() -> None:
14611478

14621479
finally:
14631480
# Clean up temp dir/file.
1464-
shutil.rmtree(remove)
1481+
if simple:
1482+
os.unlink(remove)
1483+
else:
1484+
shutil.rmtree(remove)
14651485

14661486
return get_app().create_background_task(run())
14671487

prompt_toolkit/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
'is_windows',
2727
'in_main_thread',
2828
'take_using_weights',
29+
'call_if_callable',
2930
'to_str',
3031
'to_int',
3132
'to_float',
@@ -261,6 +262,14 @@ def take_using_weights(
261262
i += 1
262263

263264

265+
def call_if_callable(value: Union[Callable[[], _T], _T]) -> _T:
266+
" Call if callable, otherwise return as is.. "
267+
if callable(value):
268+
return call_if_callable(value())
269+
else:
270+
return value
271+
272+
264273
def to_str(value: Union[Callable[[], str], str]) -> str:
265274
" Turn callable or string into string. "
266275
if callable(value):

0 commit comments

Comments
 (0)