diff --git a/mypy/binder.py b/mypy/binder.py index 2ae58dad1fe0..c95481329a57 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -8,7 +8,16 @@ from mypy.erasetype import remove_instance_last_known_values from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys -from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var +from mypy.nodes import ( + LITERAL_NO, + Expression, + IndexExpr, + MemberExpr, + NameExpr, + RefExpr, + TypeInfo, + Var, +) from mypy.options import Options from mypy.subtypes import is_same_type, is_subtype from mypy.typeops import make_simplified_union @@ -173,6 +182,15 @@ def _get(self, key: Key, index: int = -1) -> CurrentType | None: return self.frames[i].types[key] return None + @classmethod + def can_put_directly(cls, expr: Expression) -> bool: + """Will `.put()` on this expression be successful? + + This is inlined in `.put()` because the logic is rather hot and must be kept + in sync. + """ + return isinstance(expr, (IndexExpr, MemberExpr, NameExpr)) and literal(expr) > LITERAL_NO + def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None: """Directly set the narrowed type of expression (if it supports it). diff --git a/mypy/checker.py b/mypy/checker.py index 35a67d188311..821fd4e44e45 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -358,6 +358,9 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi): # functions such as open(), etc. plugin: Plugin + # A helper state to produce unique temporary names on demand. + _unique_id: int + def __init__( self, errors: Errors, @@ -428,6 +431,7 @@ def __init__( self, self.msg, self.plugin, per_line_checking_time_ns ) self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) + self._unique_id = 0 @property def expr_checker(self) -> mypy.checkexpr.ExpressionChecker: @@ -5473,21 +5477,10 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return def visit_match_stmt(self, s: MatchStmt) -> None: - named_subject: Expression - if isinstance(s.subject, CallExpr): - # Create a dummy subject expression to handle cases where a match statement's subject - # is not a literal value. This lets us correctly narrow types and check exhaustivity - # This is hack! - if s.subject_dummy is None: - id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" - name = "dummy-match-" + id - v = Var(name) - s.subject_dummy = NameExpr(name) - s.subject_dummy.node = v - named_subject = s.subject_dummy - else: - named_subject = s.subject - + named_subject = self._make_named_statement_for_match(s) + # In sync with similar actions elsewhere, narrow the target if + # we are matching an AssignmentExpr + unwrapped_subject = collapse_walrus(s.subject) with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) @@ -5520,6 +5513,12 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_map, else_map = conditional_types_to_typemaps( named_subject, pattern_type.type, pattern_type.rest_type ) + # Maybe the subject type can be inferred from constraints on + # its attribute/item? + if pattern_map and named_subject in pattern_map: + pattern_map[unwrapped_subject] = pattern_map[named_subject] + if else_map and named_subject in else_map: + else_map[unwrapped_subject] = else_map[named_subject] pattern_map = self.propagate_up_typemap_info(pattern_map) else_map = self.propagate_up_typemap_info(else_map) self.remove_capture_conflicts(pattern_type.captures, inferred_types) @@ -5572,6 +5571,25 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass + def _make_named_statement_for_match(self, s: MatchStmt) -> Expression: + """Construct a fake NameExpr for inference if a match clause is complex.""" + subject = s.subject + if self.binder.can_put_directly(subject): + # Already named - we should infer type of it as given + return subject + elif s.subject_dummy is not None: + return s.subject_dummy + else: + # Create a dummy subject expression to handle cases where a match statement's subject + # is not a literal value. This lets us correctly narrow types and check exhaustivity + # This is hack! + name = self.new_unique_dummy_name("match") + v = Var(name) + named_subject = NameExpr(name) + named_subject.node = v + s.subject_dummy = named_subject + return named_subject + def _get_recursive_sub_patterns_map( self, expr: Expression, typ: Type ) -> dict[Expression, Type]: @@ -7963,6 +7981,12 @@ def warn_deprecated(self, node: Node | None, context: Context) -> None: warn = self.msg.note if self.options.report_deprecated_as_note else self.msg.fail warn(deprecated, context, code=codes.DEPRECATED) + def new_unique_dummy_name(self, namespace: str) -> str: + """Generate a name that is guaranteed to be unique for this TypeChecker instance.""" + name = f"dummy-{namespace}-{self._unique_id}" + self._unique_id += 1 + return name + # leafs def visit_pass_stmt(self, o: PassStmt, /) -> None: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 80fd64fa3569..f264167cb067 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1302,7 +1302,7 @@ def main() -> None: case a: reveal_type(a) # N: Revealed type is "builtins.int" -[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail] +[case testMatchCapturePatternFromAsyncFunctionReturningUnion] async def func1(arg: bool) -> str | int: ... async def func2(arg: bool) -> bytes | int: ... @@ -2179,9 +2179,11 @@ def f() -> None: match x := returns_a_or_none(): case A(): reveal_type(x.a) # N: Revealed type is "builtins.int" + reveal_type(x) # N: Revealed type is "Union[__main__.A, None]" match x := returns_a(): case A(): reveal_type(x.a) # N: Revealed type is "builtins.int" + reveal_type(x) # N: Revealed type is "__main__.A" y = returns_a_or_none() match y: case A(): @@ -2586,6 +2588,110 @@ def fn2(x: Some | int | str) -> None: pass [builtins fixtures/dict.pyi] +[case testMatchFunctionCall] +# flags: --warn-unreachable + +def fn() -> int | str: ... + +match fn(): + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchAttribute] +# flags: --warn-unreachable + +class A: + foo: int | str + +match A().foo: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchLiteral] +# flags: --warn-unreachable + +def int_literal() -> None: + match 12: + case 1 as s: + reveal_type(s) # N: Revealed type is "Literal[1]" + case int(i): + reveal_type(i) # N: Revealed type is "Literal[12]?" + case other: + other # E: Statement is unreachable + +def str_literal() -> None: + match 'foo': + case 'a' as s: + reveal_type(s) # N: Revealed type is "Literal['a']" + case str(i): + reveal_type(i) # N: Revealed type is "Literal['foo']?" + case other: + other # E: Statement is unreachable + +[case testMatchOperations] +# flags: --warn-unreachable + +x: int +match -x: + case -1 as s: + reveal_type(s) # N: Revealed type is "Literal[-1]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 + 2: + case 3 as s: + reveal_type(s) # N: Revealed type is "Literal[3]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 > 2: + case True as s: + reveal_type(s) # N: Revealed type is "Literal[True]" + case False as s: + reveal_type(s) # N: Revealed type is "Literal[False]" + case other: + other # E: Statement is unreachable +[builtins fixtures/ops.pyi] + +[case testMatchDictItem] +# flags: --warn-unreachable + +m: dict[str, int | str] +k: str + +match m[k]: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[builtins fixtures/dict.pyi] + +[case testMatchLiteralValuePathological] +# flags: --warn-unreachable + +match 0: + case 0 as i: + reveal_type(i) # N: Revealed type is "Literal[0]?" + case int(i): + i # E: Statement is unreachable + case other: + other # E: Statement is unreachable + [case testMatchNamedTupleSequence] from typing import Any, NamedTuple