Skip to content

Commit 555edf3

Browse files
Introduce temporary named expressions for match subjects (#18446)
Fixes #18440. Fixes #17230. Fixes #16650. Improves behavior in #14731 (but still thinks that match is non-exhaustive, "missing return" false positive remains). #16503 did this specifically for `CallExpr`, but that isn't the only kind of such statements. I propose to expand this for more general expressions and believe that a blacklist is more reasonable here: we do **not** want to introduce a temporary name only for certain expressions that either are already named or can be used to infer contained variables (inline tuple/list/dict/set literals). Writing logic to generate a name for every other kind of expression would be quite cumbersome - I circumvent this by using a simple counter to generate unique names on demand. --------- Co-authored-by: Ivan Levkivskyi <[email protected]>
1 parent 268c837 commit 555edf3

File tree

3 files changed

+165
-17
lines changed

3 files changed

+165
-17
lines changed

mypy/binder.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,16 @@
88

99
from mypy.erasetype import remove_instance_last_known_values
1010
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys
11-
from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var
11+
from mypy.nodes import (
12+
LITERAL_NO,
13+
Expression,
14+
IndexExpr,
15+
MemberExpr,
16+
NameExpr,
17+
RefExpr,
18+
TypeInfo,
19+
Var,
20+
)
1221
from mypy.options import Options
1322
from mypy.subtypes import is_same_type, is_subtype
1423
from mypy.typeops import make_simplified_union
@@ -173,6 +182,15 @@ def _get(self, key: Key, index: int = -1) -> CurrentType | None:
173182
return self.frames[i].types[key]
174183
return None
175184

185+
@classmethod
186+
def can_put_directly(cls, expr: Expression) -> bool:
187+
"""Will `.put()` on this expression be successful?
188+
189+
This is inlined in `.put()` because the logic is rather hot and must be kept
190+
in sync.
191+
"""
192+
return isinstance(expr, (IndexExpr, MemberExpr, NameExpr)) and literal(expr) > LITERAL_NO
193+
176194
def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None:
177195
"""Directly set the narrowed type of expression (if it supports it).
178196

mypy/checker.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):
358358
# functions such as open(), etc.
359359
plugin: Plugin
360360

361+
# A helper state to produce unique temporary names on demand.
362+
_unique_id: int
363+
361364
def __init__(
362365
self,
363366
errors: Errors,
@@ -428,6 +431,7 @@ def __init__(
428431
self, self.msg, self.plugin, per_line_checking_time_ns
429432
)
430433
self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
434+
self._unique_id = 0
431435

432436
@property
433437
def expr_checker(self) -> mypy.checkexpr.ExpressionChecker:
@@ -5476,21 +5480,10 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
54765480
return
54775481

54785482
def visit_match_stmt(self, s: MatchStmt) -> None:
5479-
named_subject: Expression
5480-
if isinstance(s.subject, CallExpr):
5481-
# Create a dummy subject expression to handle cases where a match statement's subject
5482-
# is not a literal value. This lets us correctly narrow types and check exhaustivity
5483-
# This is hack!
5484-
if s.subject_dummy is None:
5485-
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
5486-
name = "dummy-match-" + id
5487-
v = Var(name)
5488-
s.subject_dummy = NameExpr(name)
5489-
s.subject_dummy.node = v
5490-
named_subject = s.subject_dummy
5491-
else:
5492-
named_subject = s.subject
5493-
5483+
named_subject = self._make_named_statement_for_match(s)
5484+
# In sync with similar actions elsewhere, narrow the target if
5485+
# we are matching an AssignmentExpr
5486+
unwrapped_subject = collapse_walrus(s.subject)
54945487
with self.binder.frame_context(can_skip=False, fall_through=0):
54955488
subject_type = get_proper_type(self.expr_checker.accept(s.subject))
54965489

@@ -5523,6 +5516,12 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
55235516
pattern_map, else_map = conditional_types_to_typemaps(
55245517
named_subject, pattern_type.type, pattern_type.rest_type
55255518
)
5519+
# Maybe the subject type can be inferred from constraints on
5520+
# its attribute/item?
5521+
if pattern_map and named_subject in pattern_map:
5522+
pattern_map[unwrapped_subject] = pattern_map[named_subject]
5523+
if else_map and named_subject in else_map:
5524+
else_map[unwrapped_subject] = else_map[named_subject]
55265525
pattern_map = self.propagate_up_typemap_info(pattern_map)
55275526
else_map = self.propagate_up_typemap_info(else_map)
55285527
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
@@ -5575,6 +5574,25 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
55755574
with self.binder.frame_context(can_skip=False, fall_through=2):
55765575
pass
55775576

5577+
def _make_named_statement_for_match(self, s: MatchStmt) -> Expression:
5578+
"""Construct a fake NameExpr for inference if a match clause is complex."""
5579+
subject = s.subject
5580+
if self.binder.can_put_directly(subject):
5581+
# Already named - we should infer type of it as given
5582+
return subject
5583+
elif s.subject_dummy is not None:
5584+
return s.subject_dummy
5585+
else:
5586+
# Create a dummy subject expression to handle cases where a match statement's subject
5587+
# is not a literal value. This lets us correctly narrow types and check exhaustivity
5588+
# This is hack!
5589+
name = self.new_unique_dummy_name("match")
5590+
v = Var(name)
5591+
named_subject = NameExpr(name)
5592+
named_subject.node = v
5593+
s.subject_dummy = named_subject
5594+
return named_subject
5595+
55785596
def _get_recursive_sub_patterns_map(
55795597
self, expr: Expression, typ: Type
55805598
) -> dict[Expression, Type]:
@@ -7966,6 +7984,12 @@ def warn_deprecated(self, node: Node | None, context: Context) -> None:
79667984
warn = self.msg.note if self.options.report_deprecated_as_note else self.msg.fail
79677985
warn(deprecated, context, code=codes.DEPRECATED)
79687986

7987+
def new_unique_dummy_name(self, namespace: str) -> str:
7988+
"""Generate a name that is guaranteed to be unique for this TypeChecker instance."""
7989+
name = f"dummy-{namespace}-{self._unique_id}"
7990+
self._unique_id += 1
7991+
return name
7992+
79697993
# leafs
79707994

79717995
def visit_pass_stmt(self, o: PassStmt, /) -> None:

test-data/unit/check-python310.test

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ def main() -> None:
13021302
case a:
13031303
reveal_type(a) # N: Revealed type is "builtins.int"
13041304

1305-
[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail]
1305+
[case testMatchCapturePatternFromAsyncFunctionReturningUnion]
13061306
async def func1(arg: bool) -> str | int: ...
13071307
async def func2(arg: bool) -> bytes | int: ...
13081308

@@ -2179,9 +2179,11 @@ def f() -> None:
21792179
match x := returns_a_or_none():
21802180
case A():
21812181
reveal_type(x.a) # N: Revealed type is "builtins.int"
2182+
reveal_type(x) # N: Revealed type is "Union[__main__.A, None]"
21822183
match x := returns_a():
21832184
case A():
21842185
reveal_type(x.a) # N: Revealed type is "builtins.int"
2186+
reveal_type(x) # N: Revealed type is "__main__.A"
21852187
y = returns_a_or_none()
21862188
match y:
21872189
case A():
@@ -2586,6 +2588,110 @@ def fn2(x: Some | int | str) -> None:
25862588
pass
25872589
[builtins fixtures/dict.pyi]
25882590

2591+
[case testMatchFunctionCall]
2592+
# flags: --warn-unreachable
2593+
2594+
def fn() -> int | str: ...
2595+
2596+
match fn():
2597+
case str(s):
2598+
reveal_type(s) # N: Revealed type is "builtins.str"
2599+
case int(i):
2600+
reveal_type(i) # N: Revealed type is "builtins.int"
2601+
case other:
2602+
other # E: Statement is unreachable
2603+
2604+
[case testMatchAttribute]
2605+
# flags: --warn-unreachable
2606+
2607+
class A:
2608+
foo: int | str
2609+
2610+
match A().foo:
2611+
case str(s):
2612+
reveal_type(s) # N: Revealed type is "builtins.str"
2613+
case int(i):
2614+
reveal_type(i) # N: Revealed type is "builtins.int"
2615+
case other:
2616+
other # E: Statement is unreachable
2617+
2618+
[case testMatchLiteral]
2619+
# flags: --warn-unreachable
2620+
2621+
def int_literal() -> None:
2622+
match 12:
2623+
case 1 as s:
2624+
reveal_type(s) # N: Revealed type is "Literal[1]"
2625+
case int(i):
2626+
reveal_type(i) # N: Revealed type is "Literal[12]?"
2627+
case other:
2628+
other # E: Statement is unreachable
2629+
2630+
def str_literal() -> None:
2631+
match 'foo':
2632+
case 'a' as s:
2633+
reveal_type(s) # N: Revealed type is "Literal['a']"
2634+
case str(i):
2635+
reveal_type(i) # N: Revealed type is "Literal['foo']?"
2636+
case other:
2637+
other # E: Statement is unreachable
2638+
2639+
[case testMatchOperations]
2640+
# flags: --warn-unreachable
2641+
2642+
x: int
2643+
match -x:
2644+
case -1 as s:
2645+
reveal_type(s) # N: Revealed type is "Literal[-1]"
2646+
case int(s):
2647+
reveal_type(s) # N: Revealed type is "builtins.int"
2648+
case other:
2649+
other # E: Statement is unreachable
2650+
2651+
match 1 + 2:
2652+
case 3 as s:
2653+
reveal_type(s) # N: Revealed type is "Literal[3]"
2654+
case int(s):
2655+
reveal_type(s) # N: Revealed type is "builtins.int"
2656+
case other:
2657+
other # E: Statement is unreachable
2658+
2659+
match 1 > 2:
2660+
case True as s:
2661+
reveal_type(s) # N: Revealed type is "Literal[True]"
2662+
case False as s:
2663+
reveal_type(s) # N: Revealed type is "Literal[False]"
2664+
case other:
2665+
other # E: Statement is unreachable
2666+
[builtins fixtures/ops.pyi]
2667+
2668+
[case testMatchDictItem]
2669+
# flags: --warn-unreachable
2670+
2671+
m: dict[str, int | str]
2672+
k: str
2673+
2674+
match m[k]:
2675+
case str(s):
2676+
reveal_type(s) # N: Revealed type is "builtins.str"
2677+
case int(i):
2678+
reveal_type(i) # N: Revealed type is "builtins.int"
2679+
case other:
2680+
other # E: Statement is unreachable
2681+
2682+
[builtins fixtures/dict.pyi]
2683+
2684+
[case testMatchLiteralValuePathological]
2685+
# flags: --warn-unreachable
2686+
2687+
match 0:
2688+
case 0 as i:
2689+
reveal_type(i) # N: Revealed type is "Literal[0]?"
2690+
case int(i):
2691+
i # E: Statement is unreachable
2692+
case other:
2693+
other # E: Statement is unreachable
2694+
25892695
[case testMatchNamedTupleSequence]
25902696
from typing import Any, NamedTuple
25912697

0 commit comments

Comments
 (0)