From b3ec408026ac9581e4058e93d2e1c3a32b84ccfe Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 11 Jan 2025 22:15:29 +0100 Subject: [PATCH 1/6] Create fake named expressions for `match` subject in more cases --- mypy/checker.py | 59 +++++++++++++++----- test-data/unit/check-python310.test | 85 ++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 14 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 80de4254766b..0560edcca00c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -65,10 +65,12 @@ CallExpr, ClassDef, ComparisonExpr, + ComplexExpr, Context, ContinueStmt, Decorator, DelStmt, + DictExpr, EllipsisExpr, Expression, ExpressionStmt, @@ -100,6 +102,7 @@ RaiseStmt, RefExpr, ReturnStmt, + SetExpr, StarExpr, Statement, StrExpr, @@ -350,6 +353,9 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): # functions such as open(), etc. plugin: Plugin + # A helper state to produce unique temporary names on demand. + _unique_id: int + def __init__( self, errors: Errors, @@ -413,6 +419,7 @@ def __init__( self, self.msg, self.plugin, per_line_checking_time_ns ) self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) + self._unique_id = 0 @property def type_context(self) -> list[Type | None]: @@ -5273,19 +5280,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return def visit_match_stmt(self, s: MatchStmt) -> None: - named_subject: Expression - if isinstance(s.subject, CallExpr): - # Create a dummy subject expression to handle cases where a match statement's subject - # is not a literal value. This lets us correctly narrow types and check exhaustivity - # This is hack! - id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" - name = "dummy-match-" + id - v = Var(name) - named_subject = NameExpr(name) - named_subject.node = v - else: - named_subject = s.subject - + named_subject = self._make_named_statement_for_match(s.subject) with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) @@ -5362,6 +5357,38 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass + def _make_named_statement_for_match(self, subject: Expression) -> Expression: + """Construct a fake NameExpr for inference if a match clause is complex.""" + expressions_to_preserve = ( + # Already named - we should infer type of it as given + NameExpr, + AssignmentExpr, + # Collection literals defined inline - we want to infer types of variables + # included there, not exprs as a whole + ListExpr, + DictExpr, + TupleExpr, + SetExpr, + # Primitive literals - their type is known, no need to name them + IntExpr, + StrExpr, + BytesExpr, + FloatExpr, + ComplexExpr, + EllipsisExpr, + ) + if isinstance(subject, expressions_to_preserve): + return subject + else: + # Create a dummy subject expression to handle cases where a match statement's subject + # is not a literal value. This lets us correctly narrow types and check exhaustivity + # This is hack! + name = self.new_unique_dummy_name("match") + v = Var(name) + named_subject = NameExpr(name) + named_subject.node = v + return named_subject + def _get_recursive_sub_patterns_map( self, expr: Expression, typ: Type ) -> dict[Expression, Type]: @@ -7715,6 +7742,12 @@ def warn_deprecated_overload_item( if candidate == target: self.warn_deprecated(item.func, context) + def new_unique_dummy_name(self, namespace: str) -> str: + """Generate a name that is guaranteed to be unique for this TypeChecker instance.""" + name = f"dummy-{namespace}-{self._unique_id}" + self._unique_id += 1 + return name + class CollectArgTypeVarTypes(TypeTraverserVisitor): """Collects the non-nested argument types in a set.""" diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 616846789c98..20d4fb057de2 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1239,7 +1239,7 @@ def main() -> None: case a: reveal_type(a) # N: Revealed type is "builtins.int" -[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail] +[case testMatchCapturePatternFromAsyncFunctionReturningUnion] async def func1(arg: bool) -> str | int: ... async def func2(arg: bool) -> bytes | int: ... @@ -2439,3 +2439,86 @@ def foo(x: T) -> T: return out [builtins fixtures/isinstance.pyi] + +[case testMatchFunctionCall] +# flags: --warn-unreachable + +def fn() -> int | str: ... + +match fn(): + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchAttribute] +# flags: --warn-unreachable + +class A: + foo: int | str + +match A().foo: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchOperations] +# flags: --warn-unreachable + +x: int +match -x: + case -1 as s: + reveal_type(s) # N: Revealed type is "Literal[-1]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 + 2: + case 3 as s: + reveal_type(s) # N: Revealed type is "Literal[3]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 > 2: + case True as s: + reveal_type(s) # N: Revealed type is "Literal[True]" + case False as s: + reveal_type(s) # N: Revealed type is "Literal[False]" + case other: + other # E: Statement is unreachable +[builtins fixtures/ops.pyi] + +[case testMatchDictItem] +# flags: --warn-unreachable + +m: dict[str, int | str] +k: str + +match m[k]: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[builtins fixtures/dict.pyi] + +[case testMatchLiteralValuePathological] +# flags: --warn-unreachable + +match 0: + case 0 as i: + reveal_type(i) # N: Revealed type is "Literal[0]?" + case int(i): + i # E: Statement is unreachable + case other: + other # E: Statement is unreachable From 5e0b2594c07ed54661db7f034e3410d6d0c4a006 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 5 Apr 2025 00:40:51 +0200 Subject: [PATCH 2/6] Add the original subject to typemap for inference --- mypy/checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 13985b9cfdaa..0f48ef3be4c6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5565,6 +5565,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_map, else_map = conditional_types_to_typemaps( named_subject, pattern_type.type, pattern_type.rest_type ) + if pattern_map and named_subject in pattern_map: + pattern_map[s.subject] = pattern_map[named_subject] + if else_map and named_subject in else_map: + else_map[s.subject] = else_map[named_subject] pattern_map = self.propagate_up_typemap_info(pattern_map) else_map = self.propagate_up_typemap_info(else_map) self.remove_capture_conflicts(pattern_type.captures, inferred_types) From 85d16bd8d811f8f03b3943daef36f4c302cc1b49 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 5 Apr 2025 00:48:07 +0200 Subject: [PATCH 3/6] Add comment --- mypy/checker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 7223c806739c..3debd929aab0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5452,6 +5452,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_map, else_map = conditional_types_to_typemaps( named_subject, pattern_type.type, pattern_type.rest_type ) + # Maybe the subject type can be inferred from constraints on + # its attribute/item? if pattern_map and named_subject in pattern_map: pattern_map[s.subject] = pattern_map[named_subject] if else_map and named_subject in else_map: From ded98cb649bae1140e3298ea8f142bbfe2d07bc9 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 5 Apr 2025 04:05:07 +0200 Subject: [PATCH 4/6] And we can keep inline collection literals too now since we infer both --- mypy/checker.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 3debd929aab0..ccbed78d49ff 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -72,7 +72,6 @@ ContinueStmt, Decorator, DelStmt, - DictExpr, EllipsisExpr, Expression, ExpressionStmt, @@ -106,7 +105,6 @@ RaiseStmt, RefExpr, ReturnStmt, - SetExpr, StarExpr, Statement, StrExpr, @@ -5512,12 +5510,6 @@ def _make_named_statement_for_match(self, s: MatchStmt) -> Expression: # Already named - we should infer type of it as given NameExpr, AssignmentExpr, - # Collection literals defined inline - we want to infer types of variables - # included there, not exprs as a whole - ListExpr, - DictExpr, - TupleExpr, - SetExpr, # Primitive literals - their type is known, no need to name them IntExpr, StrExpr, From 914f6c80f739c50be26d4bc0646b21677498557f Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sun, 3 Aug 2025 15:32:41 +0200 Subject: [PATCH 5/6] Simplify the check: we can name primitive literals just as well, it should not be common --- mypy/checker.py | 14 +------------- test-data/unit/check-python310.test | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 6e8ba47af7a3..217549cbbaf4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -75,7 +75,6 @@ CallExpr, ClassDef, ComparisonExpr, - ComplexExpr, Context, ContinueStmt, Decorator, @@ -5572,19 +5571,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None: def _make_named_statement_for_match(self, s: MatchStmt) -> Expression: """Construct a fake NameExpr for inference if a match clause is complex.""" subject = s.subject - expressions_to_preserve = ( + if isinstance(subject, (NameExpr, AssignmentExpr)): # Already named - we should infer type of it as given - NameExpr, - AssignmentExpr, - # Primitive literals - their type is known, no need to name them - IntExpr, - StrExpr, - BytesExpr, - FloatExpr, - ComplexExpr, - EllipsisExpr, - ) - if isinstance(subject, expressions_to_preserve): return subject elif s.subject_dummy is not None: return s.subject_dummy diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 371cb20d8487..d035215b4512 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -2613,6 +2613,27 @@ match A().foo: case other: other # E: Statement is unreachable +[case testMatchLiteral] +# flags: --warn-unreachable + +def int_literal() -> None: + match 12: + case 1 as s: + reveal_type(s) # N: Revealed type is "Literal[1]" + case int(i): + reveal_type(i) # N: Revealed type is "Literal[12]?" + case other: + other # E: Statement is unreachable + +def str_literal() -> None: + match 'foo': + case 'a' as s: + reveal_type(s) # N: Revealed type is "Literal['a']" + case str(i): + reveal_type(i) # N: Revealed type is "Literal['foo']?" + case other: + other # E: Statement is unreachable + [case testMatchOperations] # flags: --warn-unreachable From 4773a38cce373efdb805d8b574d3e30858e672ac Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sun, 3 Aug 2025 20:27:50 +0200 Subject: [PATCH 6/6] Move this logic closer to binder, stop putting AssignmentExpr directly --- mypy/binder.py | 20 +++++++++++++++++++- mypy/checker.py | 9 ++++++--- test-data/unit/check-python310.test | 2 ++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index 2ae58dad1fe0..c95481329a57 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -8,7 +8,16 @@ from mypy.erasetype import remove_instance_last_known_values from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys -from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var +from mypy.nodes import ( + LITERAL_NO, + Expression, + IndexExpr, + MemberExpr, + NameExpr, + RefExpr, + TypeInfo, + Var, +) from mypy.options import Options from mypy.subtypes import is_same_type, is_subtype from mypy.typeops import make_simplified_union @@ -173,6 +182,15 @@ def _get(self, key: Key, index: int = -1) -> CurrentType | None: return self.frames[i].types[key] return None + @classmethod + def can_put_directly(cls, expr: Expression) -> bool: + """Will `.put()` on this expression be successful? + + This is inlined in `.put()` because the logic is rather hot and must be kept + in sync. + """ + return isinstance(expr, (IndexExpr, MemberExpr, NameExpr)) and literal(expr) > LITERAL_NO + def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None: """Directly set the narrowed type of expression (if it supports it). diff --git a/mypy/checker.py b/mypy/checker.py index 217549cbbaf4..821fd4e44e45 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5478,6 +5478,9 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: def visit_match_stmt(self, s: MatchStmt) -> None: named_subject = self._make_named_statement_for_match(s) + # In sync with similar actions elsewhere, narrow the target if + # we are matching an AssignmentExpr + unwrapped_subject = collapse_walrus(s.subject) with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) @@ -5513,9 +5516,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None: # Maybe the subject type can be inferred from constraints on # its attribute/item? if pattern_map and named_subject in pattern_map: - pattern_map[s.subject] = pattern_map[named_subject] + pattern_map[unwrapped_subject] = pattern_map[named_subject] if else_map and named_subject in else_map: - else_map[s.subject] = else_map[named_subject] + else_map[unwrapped_subject] = else_map[named_subject] pattern_map = self.propagate_up_typemap_info(pattern_map) else_map = self.propagate_up_typemap_info(else_map) self.remove_capture_conflicts(pattern_type.captures, inferred_types) @@ -5571,7 +5574,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: def _make_named_statement_for_match(self, s: MatchStmt) -> Expression: """Construct a fake NameExpr for inference if a match clause is complex.""" subject = s.subject - if isinstance(subject, (NameExpr, AssignmentExpr)): + if self.binder.can_put_directly(subject): # Already named - we should infer type of it as given return subject elif s.subject_dummy is not None: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index d035215b4512..f264167cb067 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -2179,9 +2179,11 @@ def f() -> None: match x := returns_a_or_none(): case A(): reveal_type(x.a) # N: Revealed type is "builtins.int" + reveal_type(x) # N: Revealed type is "Union[__main__.A, None]" match x := returns_a(): case A(): reveal_type(x.a) # N: Revealed type is "builtins.int" + reveal_type(x) # N: Revealed type is "__main__.A" y = returns_a_or_none() match y: case A():