Skip to content

refactor visit_conditional_expr to fix ternary behavior (option 2) #19562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,13 @@ def check_func_def(
new_frame: Frame | None = None
for frame in old_binder.frames:
for key, narrowed_type in frame.types.items():
key_var = extract_var_from_literal_hash(key)
# get the variable from the key, considering that it might be a
# nested MemberExpr like Foo.attr1.attr2.attr3
_key = key
while _key[0] == "Member":
_key = _key[1]
key_var = extract_var_from_literal_hash(_key)

if key_var is not None and not self.is_var_redefined_in_outer_context(
key_var, defn.line
):
Expand Down
68 changes: 15 additions & 53 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@
get_type_vars,
is_literal_type_like,
make_simplified_union,
simple_literal_type,
true_only,
try_expanding_sum_type_to_union,
try_getting_str_literals,
Expand Down Expand Up @@ -5892,7 +5891,7 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None:

def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = False) -> Type:
self.accept(e.cond)
ctx = self.type_context[-1]
ctx: Type | None = self.type_context[-1]

# Gain type information from isinstance if it is there
# but only for the current expression
Expand All @@ -5903,63 +5902,26 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
elif else_map is None:
self.msg.redundant_condition_in_if(True, e.cond)

if ctx is None:
# When no context is provided, compute each branch individually, and
# use the union of the results as artificial context. Important for:
# - testUnificationDict
# - testConditionalExpressionWithEmpty
ctx_if_type = self.analyze_cond_branch(
if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return
)
ctx_else_type = self.analyze_cond_branch(
else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return
)
ctx = make_simplified_union([ctx_if_type, ctx_else_type])

if_type = self.analyze_cond_branch(
if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return
)

# we want to keep the narrowest value of if_type for union'ing the branches
# however, it would be silly to pass a literal as a type context. Pass the
# underlying fallback type instead.
if_type_fallback = simple_literal_type(get_proper_type(if_type)) or if_type

# Analyze the right branch using full type context and store the type
full_context_else_type = self.analyze_cond_branch(
else_type = self.analyze_cond_branch(
else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return
)

if not mypy.checker.is_valid_inferred_type(if_type, self.chk.options):
# Analyze the right branch disregarding the left branch.
else_type = full_context_else_type
# we want to keep the narrowest value of else_type for union'ing the branches
# however, it would be silly to pass a literal as a type context. Pass the
# underlying fallback type instead.
else_type_fallback = simple_literal_type(get_proper_type(else_type)) or else_type

# If it would make a difference, re-analyze the left
# branch using the right branch's type as context.
if ctx is None or not is_equivalent(else_type_fallback, ctx):
# TODO: If it's possible that the previous analysis of
# the left branch produced errors that are avoided
# using this context, suppress those errors.
if_type = self.analyze_cond_branch(
if_map,
e.if_expr,
context=else_type_fallback,
allow_none_return=allow_none_return,
)

elif if_type_fallback == ctx:
# There is no point re-running the analysis if if_type is equal to ctx.
# That would be an exact duplicate of the work we just did.
# This optimization is particularly important to avoid exponential blowup with nested
# if/else expressions: https://github.com/python/mypy/issues/9591
# TODO: would checking for is_proper_subtype also work and cover more cases?
else_type = full_context_else_type
else:
# Analyze the right branch in the context of the left
# branch's type.
else_type = self.analyze_cond_branch(
else_map,
e.else_expr,
context=if_type_fallback,
allow_none_return=allow_none_return,
)

# In most cases using if_type as a context for right branch gives better inferred types.
# This is however not the case for literal types, so use the full context instead.
if is_literal_type_like(full_context_else_type) and not is_literal_type_like(else_type):
else_type = full_context_else_type

res: Type = make_simplified_union([if_type, else_type])
if has_uninhabited_component(res) and not isinstance(
get_proper_type(self.type_context[-1]), UnionType
Expand Down
134 changes: 134 additions & 0 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -2946,6 +2946,140 @@ reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word'
reveal_type(C().word) # N: Revealed type is "Literal['word']"
[builtins fixtures/tuple.pyi]

[case testStringLiteralTernary]
def test(b: bool) -> None:
l = "foo" if b else "bar"
reveal_type(l) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]

[case testintLiteralTernary]
def test(b: bool) -> None:
l = 0 if b else 1
reveal_type(l) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testStringIntUnionTernary]
def test(b: bool) -> None:
l = 1 if b else "a"
reveal_type(l) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testListComprehensionTernary]
# gh-19534
def test(b: bool) -> None:
l = [1] if b else ["a"]
reveal_type(l) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]"
[builtins fixtures/list.pyi]

[case testSetComprehensionTernary]
# gh-19534
def test(b: bool) -> None:
s = {1} if b else {"a"}
reveal_type(s) # N: Revealed type is "Union[builtins.set[builtins.int], builtins.set[builtins.str]]"
[builtins fixtures/set.pyi]

[case testDictComprehensionTernary]
# gh-19534
def test(b: bool) -> None:
d = {1:1} if "" else {"a": "a"}
reveal_type(d) # N: Revealed type is "Union[builtins.dict[builtins.int, builtins.int], builtins.dict[builtins.str, builtins.str]]"
[builtins fixtures/dict.pyi]

[case testLambdaTernary]
from typing import TypeVar, Union, Callable, reveal_type

NOOP = lambda: None
class A: pass
class B:
attr: Union[A, None]

def test_static(x: Union[A, None]) -> None:
def foo(t: A) -> None: ...

l1: Callable[[], object] = (lambda: foo(x)) if x is not None else NOOP
r1: Callable[[], object] = NOOP if x is None else (lambda: foo(x))
l2 = (lambda: foo(x)) if x is not None else NOOP
r2 = NOOP if x is None else (lambda: foo(x))
reveal_type(l2) # N: Revealed type is "def ()"
reveal_type(r2) # N: Revealed type is "def ()"

def test_generic(x: Union[A, None]) -> None:
T = TypeVar("T")
def bar(t: T) -> T: return t

l1: Callable[[], None] = (lambda: bar(x)) if x is None else NOOP
r1: Callable[[], None] = NOOP if x is not None else (lambda: bar(x))
l2 = (lambda: bar(x)) if x is None else NOOP
r2 = NOOP if x is not None else (lambda: bar(x))
reveal_type(l2) # N: Revealed type is "def ()"
reveal_type(r2) # N: Revealed type is "def ()"


[case testLambdaTernaryIndirectAttribute]
# fails due to binder issue inside `check_func_def`
# gh-19561
from typing import TypeVar, Union, Callable, reveal_type

NOOP = lambda: None
class A: pass
class B:
attr: Union[A, None]

def test_static_with_attr(x: B) -> None:
def foo(t: A) -> None: ...

l1: Callable[[], None] = (lambda: foo(x.attr)) if x.attr is not None else NOOP
r1: Callable[[], None] = NOOP if x.attr is None else (lambda: foo(x.attr))
l2 = (lambda: foo(x.attr)) if x.attr is not None else NOOP
r2 = NOOP if x.attr is None else (lambda: foo(x.attr))
reveal_type(l2) # N: Revealed type is "def ()"
reveal_type(r2) # N: Revealed type is "def ()"

def test_generic_with_attr(x: B) -> None:
T = TypeVar("T")
def bar(t: T) -> T: return t

l1: Callable[[], None] = (lambda: bar(x.attr)) if x.attr is None else NOOP
r1: Callable[[], None] = NOOP if x.attr is not None else (lambda: bar(x.attr))
l2 = (lambda: bar(x.attr)) if x.attr is None else NOOP
r2 = NOOP if x.attr is not None else (lambda: bar(x.attr))
reveal_type(l2) # N: Revealed type is "def ()"
reveal_type(r2) # N: Revealed type is "def ()"

[case testLambdaTernaryDoubleIndirectAttribute]
# fails due to binder issue inside `check_func_def`
# gh-19561
from typing import TypeVar, Union, Callable, reveal_type

NOOP = lambda: None
class A: pass
class B:
attr: Union[A, None]
class C:
attr: B

def test_static_with_attr(x: C) -> None:
def foo(t: A) -> None: ...

l1: Callable[[], None] = (lambda: foo(x.attr.attr)) if x.attr.attr is not None else NOOP
r1: Callable[[], None] = NOOP if x.attr.attr is None else (lambda: foo(x.attr.attr))
l2 = (lambda: foo(x.attr.attr)) if x.attr.attr is not None else NOOP
r2 = NOOP if x.attr.attr is None else (lambda: foo(x.attr.attr))
reveal_type(l2) # N: Revealed type is "def ()"
reveal_type(r2) # N: Revealed type is "def ()"

def test_generic_with_attr(x: C) -> None:
T = TypeVar("T")
def bar(t: T) -> T: return t

l1: Callable[[], None] = (lambda: bar(x.attr.attr)) if x.attr.attr is None else NOOP
r1: Callable[[], None] = NOOP if x.attr.attr is not None else (lambda: bar(x.attr.attr))
l2 = (lambda: bar(x.attr.attr)) if x.attr.attr is None else NOOP
r2 = NOOP if x.attr.attr is not None else (lambda: bar(x.attr.attr))
reveal_type(l2) # N: Revealed type is "def ()"
reveal_type(r2) # N: Revealed type is "def ()"


[case testLiteralTernaryUnionNarrowing]
from typing import Literal, Optional

Expand Down
7 changes: 4 additions & 3 deletions test-data/unit/check-optional.test
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ reveal_type(l) # N: Revealed type is "builtins.list[typing.Generator[builtins.s
[builtins fixtures/list.pyi]

[case testNoneListTernary]
x = [None] if "" else [1] # E: List item 0 has incompatible type "int"; expected "None"
# gh-19534
x = [None] if "" else [1]
reveal_type(x) # N: Revealed type is "Union[builtins.list[None], builtins.list[builtins.int]]"
[builtins fixtures/list.pyi]

[case testListIncompatibleErrorMessage]
Expand Down Expand Up @@ -1107,10 +1109,9 @@ class C:
a: Optional[str]

def attribute_narrowing(c: C) -> None:
# This case is not supported, since we can't keep track of assignments to attributes.
c.a = "x"
def nested() -> str:
return c.a # E: Incompatible return value type (got "Optional[str]", expected "str")
return c.a
nested()

def assignment_in_for(x: Optional[str]) -> None:
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-python313.test
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,21 @@ reveal_type(A1().x) # N: Revealed type is "builtins.int"
reveal_type(A2().x) # N: Revealed type is "builtins.int"
reveal_type(A3().x) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]


[case testTernaryOperatorWithTypeVarDefault]
# gh-18817

class Ok[T, E = None]:
def __init__(self, value: T) -> None:
self._value = value

class Err[E, T = None]:
def __init__(self, value: E) -> None:
self._value = value

type Result[T, E] = Ok[T, E] | Err[E, T]

class Bar[U]:
def foo(data: U, cond: bool) -> Result[U, str]:
return Ok(data) if cond else Err("Error")