Skip to content

Commit bd94bcb

Browse files
authored
Try simple-minded call expression cache (#19505)
This gives a modest 1% improvement on self-check (compiled), but it gives almost 40% on `mypy -c "import colour"`. Some comments: * I only cache `CallExpr`, `ListExpr`, and `TupleExpr`, this is not very principled, I found this as a best balance between rare cases like `colour`, and more common cases like self-check. * Caching is fragile within lambdas, so I simply disable it, it rarely matters anyway. * I cache both messages and the type map, surprisingly the latter only affects couple test cases, but I still do this generally for peace of mind. * It looks like there are only three things that require cache invalidation: binder, partial types, and deferrals. In general, this is a bit scary (as this a major change), but also perf improvements for slow libraries are very tempting.
1 parent e40c36c commit bd94bcb

File tree

6 files changed

+84
-5
lines changed

6 files changed

+84
-5
lines changed

mypy/binder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def __init__(self, options: Options) -> None:
138138
# flexible inference of variable types (--allow-redefinition-new).
139139
self.bind_all = options.allow_redefinition_new
140140

141+
# This tracks any externally visible changes in binder to invalidate
142+
# expression caches when needed.
143+
self.version = 0
144+
141145
def _get_id(self) -> int:
142146
self.next_id += 1
143147
return self.next_id
@@ -158,6 +162,7 @@ def push_frame(self, conditional_frame: bool = False) -> Frame:
158162
return f
159163

160164
def _put(self, key: Key, type: Type, from_assignment: bool, index: int = -1) -> None:
165+
self.version += 1
161166
self.frames[index].types[key] = CurrentType(type, from_assignment)
162167

163168
def _get(self, key: Key, index: int = -1) -> CurrentType | None:
@@ -185,6 +190,7 @@ def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> N
185190
self._put(key, typ, from_assignment)
186191

187192
def unreachable(self) -> None:
193+
self.version += 1
188194
self.frames[-1].unreachable = True
189195

190196
def suppress_unreachable_warnings(self) -> None:

mypy/checker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,6 @@ def reset(self) -> None:
449449
self.binder = ConditionalTypeBinder(self.options)
450450
self._type_maps[1:] = []
451451
self._type_maps[0].clear()
452-
self.temp_type_map = None
453452
self.expr_checker.reset()
454453
self.deferred_nodes = []
455454
self.partial_types = []
@@ -3024,6 +3023,8 @@ def visit_block(self, b: Block) -> None:
30243023
break
30253024
else:
30263025
self.accept(s)
3026+
# Clear expression cache after each statement to avoid unlimited growth.
3027+
self.expr_checker.expr_cache.clear()
30273028

30283029
def should_report_unreachable_issues(self) -> bool:
30293030
return (
@@ -4005,7 +4006,7 @@ def check_multi_assignment_from_union(
40054006
for t, lv in zip(transposed, self.flatten_lvalues(lvalues)):
40064007
# We can access _type_maps directly since temporary type maps are
40074008
# only created within expressions.
4008-
t.append(self._type_maps[0].pop(lv, AnyType(TypeOfAny.special_form)))
4009+
t.append(self._type_maps[-1].pop(lv, AnyType(TypeOfAny.special_form)))
40094010
union_types = tuple(make_simplified_union(col) for col in transposed)
40104011
for expr, items in assignments.items():
40114012
# Bind a union of types collected in 'assignments' to every expression.
@@ -4664,6 +4665,8 @@ def replace_partial_type(
46644665
) -> None:
46654666
"""Replace the partial type of var with a non-partial type."""
46664667
var.type = new_type
4668+
# Updating a partial type should invalidate expression caches.
4669+
self.binder.version += 1
46674670
del partial_types[var]
46684671
if self.options.allow_redefinition_new:
46694672
# When using --allow-redefinition-new, binder tracks all types of

mypy/checkexpr.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mypy.checkmember import analyze_member_access, has_operator
2020
from mypy.checkstrformat import StringFormatterChecker
2121
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
22-
from mypy.errors import ErrorWatcher, report_internal_error
22+
from mypy.errors import ErrorInfo, ErrorWatcher, report_internal_error
2323
from mypy.expandtype import (
2424
expand_type,
2525
expand_type_by_instance,
@@ -355,9 +355,15 @@ def __init__(
355355
type_state.infer_polymorphic = not self.chk.options.old_type_inference
356356

357357
self._arg_infer_context_cache = None
358+
self.expr_cache: dict[
359+
tuple[Expression, Type | None],
360+
tuple[int, Type, list[ErrorInfo], dict[Expression, Type]],
361+
] = {}
362+
self.in_lambda_expr = False
358363

359364
def reset(self) -> None:
360365
self.resolved_type = {}
366+
self.expr_cache.clear()
361367

362368
def visit_name_expr(self, e: NameExpr) -> Type:
363369
"""Type check a name expression.
@@ -5402,6 +5408,8 @@ def find_typeddict_context(
54025408

54035409
def visit_lambda_expr(self, e: LambdaExpr) -> Type:
54045410
"""Type check lambda expression."""
5411+
old_in_lambda = self.in_lambda_expr
5412+
self.in_lambda_expr = True
54055413
self.chk.check_default_args(e, body_is_trivial=False)
54065414
inferred_type, type_override = self.infer_lambda_type_using_context(e)
54075415
if not inferred_type:
@@ -5422,6 +5430,7 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type:
54225430
ret_type = self.accept(e.expr(), allow_none_return=True)
54235431
fallback = self.named_type("builtins.function")
54245432
self.chk.return_types.pop()
5433+
self.in_lambda_expr = old_in_lambda
54255434
return callable_type(e, fallback, ret_type)
54265435
else:
54275436
# Type context available.
@@ -5434,6 +5443,7 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type:
54345443
self.accept(e.expr(), allow_none_return=True)
54355444
ret_type = self.chk.lookup_type(e.expr())
54365445
self.chk.return_types.pop()
5446+
self.in_lambda_expr = old_in_lambda
54375447
return replace_callable_return_type(inferred_type, ret_type)
54385448

54395449
def infer_lambda_type_using_context(
@@ -5978,6 +5988,24 @@ def accept(
59785988
typ = self.visit_conditional_expr(node, allow_none_return=True)
59795989
elif allow_none_return and isinstance(node, AwaitExpr):
59805990
typ = self.visit_await_expr(node, allow_none_return=True)
5991+
# Deeply nested generic calls can deteriorate performance dramatically.
5992+
# Although in most cases caching makes little difference, in worst case
5993+
# it avoids exponential complexity.
5994+
# We cannot use cache inside lambdas, because they skip immediate type
5995+
# context, and use enclosing one, see infer_lambda_type_using_context().
5996+
# TODO: consider using cache for more expression kinds.
5997+
elif isinstance(node, (CallExpr, ListExpr, TupleExpr)) and not (
5998+
self.in_lambda_expr or self.chk.current_node_deferred
5999+
):
6000+
if (node, type_context) in self.expr_cache:
6001+
binder_version, typ, messages, type_map = self.expr_cache[(node, type_context)]
6002+
if binder_version == self.chk.binder.version:
6003+
self.chk.store_types(type_map)
6004+
self.msg.add_errors(messages)
6005+
else:
6006+
typ = self.accept_maybe_cache(node, type_context=type_context)
6007+
else:
6008+
typ = self.accept_maybe_cache(node, type_context=type_context)
59816009
else:
59826010
typ = node.accept(self)
59836011
except Exception as err:
@@ -6008,6 +6036,21 @@ def accept(
60086036
self.in_expression = False
60096037
return result
60106038

6039+
def accept_maybe_cache(self, node: Expression, type_context: Type | None = None) -> Type:
6040+
binder_version = self.chk.binder.version
6041+
# Micro-optimization: inline local_type_map() as it is somewhat slow in mypyc.
6042+
type_map: dict[Expression, Type] = {}
6043+
self.chk._type_maps.append(type_map)
6044+
with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as msg:
6045+
typ = node.accept(self)
6046+
messages = msg.filtered_errors()
6047+
if binder_version == self.chk.binder.version and not self.chk.current_node_deferred:
6048+
self.expr_cache[(node, type_context)] = (binder_version, typ, messages, type_map)
6049+
self.chk._type_maps.pop()
6050+
self.chk.store_types(type_map)
6051+
self.msg.add_errors(messages)
6052+
return typ
6053+
60116054
def named_type(self, name: str) -> Instance:
60126055
"""Return an instance type with type given by the name and no type
60136056
arguments. Alias for TypeChecker.named_type.

mypy/errors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ class Errors:
390390
# in some cases to avoid reporting huge numbers of errors.
391391
seen_import_error = False
392392

393-
_watchers: list[ErrorWatcher] = []
393+
_watchers: list[ErrorWatcher]
394394

395395
def __init__(
396396
self,
@@ -421,6 +421,7 @@ def initialize(self) -> None:
421421
self.scope = None
422422
self.target_module = None
423423
self.seen_import_error = False
424+
self._watchers = []
424425

425426
def reset(self) -> None:
426427
self.initialize()
@@ -931,7 +932,8 @@ def prefer_simple_messages(self) -> bool:
931932
if self.file in self.ignored_files:
932933
# Errors ignored, so no point generating fancy messages
933934
return True
934-
for _watcher in self._watchers:
935+
if self._watchers:
936+
_watcher = self._watchers[-1]
935937
if _watcher._filter is True and _watcher._filtered is None:
936938
# Errors are filtered
937939
return True

test-data/unit/check-overloading.test

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6801,3 +6801,26 @@ class D(Generic[T]):
68016801
a: D[str] # E: Type argument "str" of "D" must be a subtype of "C"
68026802
reveal_type(a.f(1)) # N: Revealed type is "builtins.int"
68036803
reveal_type(a.f("x")) # N: Revealed type is "builtins.str"
6804+
6805+
[case testMultiAssignFromUnionInOverloadCached]
6806+
from typing import Iterable, overload, Union, Optional
6807+
6808+
@overload
6809+
def always_bytes(str_or_bytes: None) -> None: ...
6810+
@overload
6811+
def always_bytes(str_or_bytes: Union[str, bytes]) -> bytes: ...
6812+
def always_bytes(str_or_bytes: Union[None, str, bytes]) -> Optional[bytes]:
6813+
pass
6814+
6815+
class Headers:
6816+
def __init__(self, iter: Iterable[tuple[bytes, bytes]]) -> None: ...
6817+
6818+
headers: Union[Headers, dict[Union[str, bytes], Union[str, bytes]], Iterable[tuple[bytes, bytes]]]
6819+
6820+
if isinstance(headers, dict):
6821+
headers = Headers(
6822+
(always_bytes(k), always_bytes(v)) for k, v in headers.items()
6823+
)
6824+
6825+
reveal_type(headers) # N: Revealed type is "Union[__main__.Headers, typing.Iterable[tuple[builtins.bytes, builtins.bytes]]]"
6826+
[builtins fixtures/isinstancelist.pyi]

test-data/unit/fixtures/isinstancelist.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class bool(int): pass
2626
class str:
2727
def __add__(self, x: str) -> str: pass
2828
def __getitem__(self, x: int) -> str: pass
29+
class bytes: pass
2930

3031
T = TypeVar('T')
3132
KT = TypeVar('KT')
@@ -52,6 +53,7 @@ class dict(Mapping[KT, VT]):
5253
def __setitem__(self, k: KT, v: VT) -> None: pass
5354
def __iter__(self) -> Iterator[KT]: pass
5455
def update(self, a: Mapping[KT, VT]) -> None: pass
56+
def items(self) -> Iterable[Tuple[KT, VT]]: pass
5557

5658
class set(Generic[T]):
5759
def __iter__(self) -> Iterator[T]: pass

0 commit comments

Comments
 (0)