Skip to content

Introduce temporary named expressions for match subjects #18446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b3ec408
Create fake named expressions for `match` subject in more cases
sterliakov Jan 11, 2025
2da85ea
Merge remote-tracking branch 'upstream/master' into bugfix/st-synthet…
sterliakov Jan 13, 2025
0160d1d
Merge remote-tracking branch 'upstream/master' into bugfix/st-synthet…
sterliakov Jan 13, 2025
01b92c9
Merge branch 'master' into bugfix/st-synthetic-named-expr-in-match
sterliakov Jan 20, 2025
9967f27
Merge remote-tracking branch 'upstream/master' into bugfix/st-synthet…
sterliakov Feb 1, 2025
6e09fd0
Merge remote-tracking branch 'upstream/master' into bugfix/st-synthet…
sterliakov Mar 29, 2025
5e0b259
Add the original subject to typemap for inference
sterliakov Apr 4, 2025
968a8bd
Merge remote-tracking branch 'upstream/master' into bugfix/st-synthet…
sterliakov Apr 4, 2025
85d16bd
Add comment
sterliakov Apr 4, 2025
ded98cb
And we can keep inline collection literals too now since we infer both
sterliakov Apr 5, 2025
a772dc9
Merge branch 'master' into bugfix/st-synthetic-named-expr-in-match
sterliakov Jun 21, 2025
a24af26
Merge branch 'master' into bugfix/st-synthetic-named-expr-in-match
sterliakov Jun 28, 2025
a9f206b
Merge branch 'master' into bugfix/st-synthetic-named-expr-in-match
sterliakov Jul 14, 2025
8f8cf42
Merge branch 'master' into bugfix/st-synthetic-named-expr-in-match
sterliakov Jul 21, 2025
0bf5dc9
Merge branch 'master' into bugfix/st-synthetic-named-expr-in-match
ilevkivskyi Aug 3, 2025
914f6c8
Simplify the check: we can name primitive literals just as well, it s…
sterliakov Aug 3, 2025
4773a38
Move this logic closer to binder, stop putting AssignmentExpr directly
sterliakov Aug 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@

from mypy.erasetype import remove_instance_last_known_values
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys
from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var
from mypy.nodes import (
LITERAL_NO,
Expression,
IndexExpr,
MemberExpr,
NameExpr,
RefExpr,
TypeInfo,
Var,
)
from mypy.options import Options
from mypy.subtypes import is_same_type, is_subtype
from mypy.typeops import make_simplified_union
Expand Down Expand Up @@ -173,6 +182,15 @@ def _get(self, key: Key, index: int = -1) -> CurrentType | None:
return self.frames[i].types[key]
return None

@classmethod
def can_put_directly(cls, expr: Expression) -> bool:
"""Will `.put()` on this expression be successful?

This is inlined in `.put()` because the logic is rather hot and must be kept
in sync.
"""
return isinstance(expr, (IndexExpr, MemberExpr, NameExpr)) and literal(expr) > LITERAL_NO

def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None:
"""Directly set the narrowed type of expression (if it supports it).

Expand Down
54 changes: 39 additions & 15 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):
# functions such as open(), etc.
plugin: Plugin

# A helper state to produce unique temporary names on demand.
_unique_id: int

def __init__(
self,
errors: Errors,
Expand Down Expand Up @@ -428,6 +431,7 @@ def __init__(
self, self.msg, self.plugin, per_line_checking_time_ns
)
self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
self._unique_id = 0

@property
def expr_checker(self) -> mypy.checkexpr.ExpressionChecker:
Expand Down Expand Up @@ -5473,21 +5477,10 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
return

def visit_match_stmt(self, s: MatchStmt) -> None:
named_subject: Expression
if isinstance(s.subject, CallExpr):
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
if s.subject_dummy is None:
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
name = "dummy-match-" + id
v = Var(name)
s.subject_dummy = NameExpr(name)
s.subject_dummy.node = v
named_subject = s.subject_dummy
else:
named_subject = s.subject

named_subject = self._make_named_statement_for_match(s)
# In sync with similar actions elsewhere, narrow the target if
# we are matching an AssignmentExpr
unwrapped_subject = collapse_walrus(s.subject)
with self.binder.frame_context(can_skip=False, fall_through=0):
subject_type = get_proper_type(self.expr_checker.accept(s.subject))

Expand Down Expand Up @@ -5520,6 +5513,12 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
pattern_map, else_map = conditional_types_to_typemaps(
named_subject, pattern_type.type, pattern_type.rest_type
)
# Maybe the subject type can be inferred from constraints on
# its attribute/item?
if pattern_map and named_subject in pattern_map:
pattern_map[unwrapped_subject] = pattern_map[named_subject]
if else_map and named_subject in else_map:
else_map[unwrapped_subject] = else_map[named_subject]
pattern_map = self.propagate_up_typemap_info(pattern_map)
else_map = self.propagate_up_typemap_info(else_map)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
Expand Down Expand Up @@ -5572,6 +5571,25 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
with self.binder.frame_context(can_skip=False, fall_through=2):
pass

def _make_named_statement_for_match(self, s: MatchStmt) -> Expression:
"""Construct a fake NameExpr for inference if a match clause is complex."""
subject = s.subject
if self.binder.can_put_directly(subject):
# Already named - we should infer type of it as given
return subject
elif s.subject_dummy is not None:
return s.subject_dummy
else:
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
name = self.new_unique_dummy_name("match")
v = Var(name)
named_subject = NameExpr(name)
named_subject.node = v
s.subject_dummy = named_subject
return named_subject

def _get_recursive_sub_patterns_map(
self, expr: Expression, typ: Type
) -> dict[Expression, Type]:
Expand Down Expand Up @@ -7963,6 +7981,12 @@ def warn_deprecated(self, node: Node | None, context: Context) -> None:
warn = self.msg.note if self.options.report_deprecated_as_note else self.msg.fail
warn(deprecated, context, code=codes.DEPRECATED)

def new_unique_dummy_name(self, namespace: str) -> str:
"""Generate a name that is guaranteed to be unique for this TypeChecker instance."""
name = f"dummy-{namespace}-{self._unique_id}"
self._unique_id += 1
return name

# leafs

def visit_pass_stmt(self, o: PassStmt, /) -> None:
Expand Down
108 changes: 107 additions & 1 deletion test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def main() -> None:
case a:
reveal_type(a) # N: Revealed type is "builtins.int"

[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail]
[case testMatchCapturePatternFromAsyncFunctionReturningUnion]
async def func1(arg: bool) -> str | int: ...
async def func2(arg: bool) -> bytes | int: ...

Expand Down Expand Up @@ -2179,9 +2179,11 @@ def f() -> None:
match x := returns_a_or_none():
case A():
reveal_type(x.a) # N: Revealed type is "builtins.int"
reveal_type(x) # N: Revealed type is "Union[__main__.A, None]"
match x := returns_a():
case A():
reveal_type(x.a) # N: Revealed type is "builtins.int"
reveal_type(x) # N: Revealed type is "__main__.A"
y = returns_a_or_none()
match y:
case A():
Expand Down Expand Up @@ -2586,6 +2588,110 @@ def fn2(x: Some | int | str) -> None:
pass
[builtins fixtures/dict.pyi]

[case testMatchFunctionCall]
# flags: --warn-unreachable

def fn() -> int | str: ...

match fn():
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[case testMatchAttribute]
# flags: --warn-unreachable

class A:
foo: int | str

match A().foo:
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[case testMatchLiteral]
# flags: --warn-unreachable

def int_literal() -> None:
match 12:
case 1 as s:
reveal_type(s) # N: Revealed type is "Literal[1]"
case int(i):
reveal_type(i) # N: Revealed type is "Literal[12]?"
case other:
other # E: Statement is unreachable

def str_literal() -> None:
match 'foo':
case 'a' as s:
reveal_type(s) # N: Revealed type is "Literal['a']"
case str(i):
reveal_type(i) # N: Revealed type is "Literal['foo']?"
case other:
other # E: Statement is unreachable

[case testMatchOperations]
# flags: --warn-unreachable

x: int
match -x:
case -1 as s:
reveal_type(s) # N: Revealed type is "Literal[-1]"
case int(s):
reveal_type(s) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

match 1 + 2:
case 3 as s:
reveal_type(s) # N: Revealed type is "Literal[3]"
case int(s):
reveal_type(s) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

match 1 > 2:
case True as s:
reveal_type(s) # N: Revealed type is "Literal[True]"
case False as s:
reveal_type(s) # N: Revealed type is "Literal[False]"
case other:
other # E: Statement is unreachable
[builtins fixtures/ops.pyi]

[case testMatchDictItem]
# flags: --warn-unreachable

m: dict[str, int | str]
k: str

match m[k]:
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[builtins fixtures/dict.pyi]

[case testMatchLiteralValuePathological]
# flags: --warn-unreachable

match 0:
case 0 as i:
reveal_type(i) # N: Revealed type is "Literal[0]?"
case int(i):
i # E: Statement is unreachable
case other:
other # E: Statement is unreachable

[case testMatchNamedTupleSequence]
from typing import Any, NamedTuple

Expand Down
Loading