Skip to content

Commit

Permalink
Properly support union of TypedDicts as dict literal context (#14505)
Browse files Browse the repository at this point in the history
Fixes #14481 (regression)
Fixes #13274
Fixes #8533

Most notably, if literal matches multiple items in union, it is not an
error, it is only an error if it matches none of them, so I adjust the
error message accordingly.

An import caveat is that an unrelated error like `{"key": 42 + "no"}`
can cause no item to match (an hence an extra error), but I think it is
fine, since we still show the actual error, and avoiding this would
require some dirty hacks.

Also note there was an (obvious) bug in one of the fixtures, that caused
one of repros not repro in tests, fixing it required tweaking an
unrelated test.
  • Loading branch information
ilevkivskyi authored Jan 23, 2023
1 parent cb14d6f commit d841859
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 33 deletions.
56 changes: 33 additions & 23 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4188,6 +4188,17 @@ def fast_dict_type(self, e: DictExpr) -> Type | None:
self.resolved_type[e] = dt
return dt

def check_typeddict_literal_in_context(
self, e: DictExpr, typeddict_context: TypedDictType
) -> Type:
orig_ret_type = self.check_typeddict_call_with_dict(
callee=typeddict_context, kwargs=e, context=e, orig_callee=None
)
ret_type = get_proper_type(orig_ret_type)
if isinstance(ret_type, TypedDictType):
return ret_type.copy_modified()
return typeddict_context.copy_modified()

def visit_dict_expr(self, e: DictExpr) -> Type:
"""Type check a dict expression.
Expand All @@ -4197,15 +4208,20 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
# an error, but returns the TypedDict type that matches the literal it found
# that would cause a second error when that TypedDict type is returned upstream
# to avoid the second error, we always return TypedDict type that was requested
typeddict_context = self.find_typeddict_context(self.type_context[-1], e)
if typeddict_context:
orig_ret_type = self.check_typeddict_call_with_dict(
callee=typeddict_context, kwargs=e, context=e, orig_callee=None
)
ret_type = get_proper_type(orig_ret_type)
if isinstance(ret_type, TypedDictType):
return ret_type.copy_modified()
return typeddict_context.copy_modified()
typeddict_contexts = self.find_typeddict_context(self.type_context[-1], e)
if typeddict_contexts:
if len(typeddict_contexts) == 1:
return self.check_typeddict_literal_in_context(e, typeddict_contexts[0])
# Multiple items union, check if at least one of them matches cleanly.
for typeddict_context in typeddict_contexts:
with self.msg.filter_errors() as err, self.chk.local_type_map() as tmap:
ret_type = self.check_typeddict_literal_in_context(e, typeddict_context)
if err.has_new_errors():
continue
self.chk.store_types(tmap)
return ret_type
# No item matched without an error, so we can't unambiguously choose the item.
self.msg.typeddict_context_ambiguous(typeddict_contexts, e)

# fast path attempt
dt = self.fast_dict_type(e)
Expand Down Expand Up @@ -4271,26 +4287,20 @@ def visit_dict_expr(self, e: DictExpr) -> Type:

def find_typeddict_context(
self, context: Type | None, dict_expr: DictExpr
) -> TypedDictType | None:
) -> list[TypedDictType]:
context = get_proper_type(context)
if isinstance(context, TypedDictType):
return context
return [context]
elif isinstance(context, UnionType):
items = []
for item in context.items:
item_context = self.find_typeddict_context(item, dict_expr)
if item_context is not None and self.match_typeddict_call_with_dict(
item_context, dict_expr, dict_expr
):
items.append(item_context)
if len(items) == 1:
# Only one union item is valid TypedDict for the given dict_expr, so use the
# context as it's unambiguous.
return items[0]
if len(items) > 1:
self.msg.typeddict_context_ambiguous(items, dict_expr)
item_contexts = self.find_typeddict_context(item, dict_expr)
for item_context in item_contexts:
if self.match_typeddict_call_with_dict(item_context, dict_expr, dict_expr):
items.append(item_context)
return items
# No TypedDict type in context.
return None
return []

def visit_lambda_expr(self, e: LambdaExpr) -> Type:
"""Type check lambda expression."""
Expand Down
4 changes: 3 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,9 @@ def typeddict_key_not_found(

def typeddict_context_ambiguous(self, types: list[TypedDictType], context: Context) -> None:
formatted_types = ", ".join(list(format_type_distinctly(*types)))
self.fail(f"Type of TypedDict is ambiguous, could be any of ({formatted_types})", context)
self.fail(
f"Type of TypedDict is ambiguous, none of ({formatted_types}) matches cleanly", context
)

def typeddict_key_cannot_be_deleted(
self, typ: TypedDictType, item_name: str, context: Context
Expand Down
96 changes: 91 additions & 5 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -895,15 +895,25 @@ c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'}
reveal_type(c) # N: Revealed type is "Union[TypedDict('__main__.A', {'@type': Literal['a-type'], 'a': builtins.str}), TypedDict('__main__.B', {'@type': Literal['b-type'], 'b': builtins.int})]"
[builtins fixtures/dict.pyi]

[case testTypedDictUnionAmbiguousCase]
[case testTypedDictUnionAmbiguousCaseBothMatch]
from typing import Union, Mapping, Any, cast
from typing_extensions import TypedDict, Literal

A = TypedDict('A', {'@type': Literal['a-type'], 'a': str})
B = TypedDict('B', {'@type': Literal['a-type'], 'a': str})
A = TypedDict('A', {'@type': Literal['a-type'], 'value': str})
B = TypedDict('B', {'@type': Literal['b-type'], 'value': str})

c: Union[A, B] = {'@type': 'a-type', 'value': 'Test'}
[builtins fixtures/dict.pyi]

[case testTypedDictUnionAmbiguousCaseNoMatch]
from typing import Union, Mapping, Any, cast
from typing_extensions import TypedDict, Literal

c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'} # E: Type of TypedDict is ambiguous, could be any of ("A", "B") \
# E: Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "Union[A, B]")
A = TypedDict('A', {'@type': Literal['a-type'], 'value': int})
B = TypedDict('B', {'@type': Literal['b-type'], 'value': int})

c: Union[A, B] = {'@type': 'a-type', 'value': 'Test'} # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \
# E: Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "Union[A, B]")
[builtins fixtures/dict.pyi]

-- Use dict literals
Expand Down Expand Up @@ -2786,3 +2796,79 @@ TDC = TypedDict("TDC", {"val": int, "next": Optional[Self]}) # E: Self type can

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testUnionOfEquivalentTypedDictsInferred]
from typing import TypedDict, Dict

D = TypedDict("D", {"foo": int}, total=False)

def f(d: Dict[str, D]) -> None:
args = d["a"]
args.update(d.get("b", {})) # OK
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testUnionOfEquivalentTypedDictsDeclared]
from typing import TypedDict, Union

class A(TypedDict, total=False):
name: str
class B(TypedDict, total=False):
name: str

def foo(data: Union[A, B]) -> None: ...
foo({"name": "Robert"}) # OK
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testUnionOfEquivalentTypedDictsEmpty]
from typing import TypedDict, Union

class Foo(TypedDict, total=False):
foo: str
class Bar(TypedDict, total=False):
bar: str

def foo(body: Union[Foo, Bar] = {}) -> None: # OK
...
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testUnionOfEquivalentTypedDictsDistinct]
from typing import TypedDict, Union, Literal

class A(TypedDict):
type: Literal['a']
value: bool
class B(TypedDict):
type: Literal['b']
value: str

Response = Union[A, B]
def method(message: Response) -> None: ...

method({'type': 'a', 'value': True}) # OK
method({'type': 'b', 'value': 'abc'}) # OK
method({'type': 'a', 'value': 'abc'}) # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \
# E: Argument 1 to "method" has incompatible type "Dict[str, str]"; expected "Union[A, B]"
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testUnionOfEquivalentTypedDictsNested]
from typing import TypedDict, Union

class A(TypedDict, total=False):
foo: C
class B(TypedDict, total=False):
foo: D
class C(TypedDict, total=False):
c: str
class D(TypedDict, total=False):
d: str

def foo(data: Union[A, B]) -> None: ...
foo({"foo": {"c": "foo"}}) # OK
foo({"foo": {"e": "foo"}}) # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \
# E: Argument 1 to "foo" has incompatible type "Dict[str, Dict[str, str]]"; expected "Union[A, B]"
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]
6 changes: 3 additions & 3 deletions test-data/unit/check-unions.test
Original file line number Diff line number Diff line change
Expand Up @@ -971,14 +971,14 @@ if x:
[builtins fixtures/dict.pyi]
[out]

[case testUnpackUnionNoCrashOnPartialNoneList]
[case testUnpackUnionNoCrashOnPartialList]
# flags: --strict-optional
from typing import Dict, Tuple, List, Any

a: Any
d: Dict[str, Tuple[List[Tuple[str, str]], str]]
x, _ = d.get(a, ([], []))
reveal_type(x) # N: Revealed type is "Union[builtins.list[Tuple[builtins.str, builtins.str]], builtins.list[<nothing>]]"
x, _ = d.get(a, ([], ""))
reveal_type(x) # N: Revealed type is "builtins.list[Tuple[builtins.str, builtins.str]]"

for y in x: pass
[builtins fixtures/dict.pyi]
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class dict(Mapping[KT, VT]):
@overload
def get(self, k: KT) -> Optional[VT]: pass
@overload
def get(self, k: KT, default: Union[KT, T]) -> Union[VT, T]: pass
def get(self, k: KT, default: Union[VT, T]) -> Union[VT, T]: pass
def __len__(self) -> int: ...

class int: # for convenience
Expand Down

0 comments on commit d841859

Please sign in to comment.