Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support recursive TypedDicts #13373

Merged
merged 15 commits into from
Aug 11, 2022
15 changes: 11 additions & 4 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,17 @@ def lookup_fully_qualified_alias(
if isinstance(node, TypeAlias):
return node
elif isinstance(node, TypeInfo):
if node.tuple_alias:
return node.tuple_alias
alias = TypeAlias.from_tuple_type(node)
node.tuple_alias = alias
if node.special_alias:
# Already fixed up.
return node.special_alias
if node.tuple_type:
alias = TypeAlias.from_tuple_type(node)
elif node.typeddict_type:
alias = TypeAlias.from_typeddict_type(node)
else:
assert allow_missing
return missing_alias()
node.special_alias = alias
return alias
else:
# Looks like a missing TypeAlias during an initial daemon load, put something there
Expand Down
44 changes: 36 additions & 8 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,7 @@ class is generic then it will be a type constructor of higher kind.
"bases",
"_promote",
"tuple_type",
"tuple_alias",
"special_alias",
"is_named_tuple",
"typeddict_type",
"is_newtype",
Expand Down Expand Up @@ -2795,8 +2795,16 @@ class is generic then it will be a type constructor of higher kind.
# It is useful for plugins to add their data to save in the cache.
metadata: Dict[str, JsonDict]

# Store type alias representing this type (for named tuples).
tuple_alias: Optional["TypeAlias"]
# Store type alias representing this type (for named tuples and TypedDicts).
# Although definitions of these types are stored in symbol tables as TypeInfo,
# when a type analyzer will find them, it should construct a TupleType, or
# a TypedDict type. However, we can't use the plain types, since if the definition
# is recursive, this will create an actual recursive structure of types (i.e. as
# internal Python objects) causing infinite recursions everywhere during type checking.
# To overcome this, we create a TypeAlias node, that will point to these types.
# We store this node in the `special_alias` attribute, because it must be the same node
# in case we are doing multiple semantic analysis passes.
special_alias: Optional["TypeAlias"]

FLAGS: Final = [
"is_abstract",
Expand Down Expand Up @@ -2844,7 +2852,7 @@ def __init__(self, names: "SymbolTable", defn: ClassDef, module_name: str) -> No
self._promote = []
self.alt_promote = None
self.tuple_type = None
self.tuple_alias = None
self.special_alias = None
self.is_named_tuple = False
self.typeddict_type = None
self.is_newtype = False
Expand Down Expand Up @@ -2976,13 +2984,22 @@ def direct_base_classes(self) -> "List[TypeInfo]":
return [base.type for base in self.bases]

def update_tuple_type(self, typ: "mypy.types.TupleType") -> None:
"""Update tuple_type and tuple_alias as needed."""
"""Update tuple_type and special_alias as needed."""
self.tuple_type = typ
alias = TypeAlias.from_tuple_type(self)
if not self.tuple_alias:
self.tuple_alias = alias
if not self.special_alias:
self.special_alias = alias
else:
self.tuple_alias.target = alias.target
self.special_alias.target = alias.target

def update_typeddict_type(self, typ: "mypy.types.TypedDictType") -> None:
"""Update typeddict_type and special_alias as needed."""
self.typeddict_type = typ
alias = TypeAlias.from_typeddict_type(self)
if not self.special_alias:
self.special_alias = alias
else:
self.special_alias.target = alias.target

def __str__(self) -> str:
"""Return a string representation of the type.
Expand Down Expand Up @@ -3283,6 +3300,17 @@ def from_tuple_type(cls, info: TypeInfo) -> "TypeAlias":
info.column,
)

@classmethod
def from_typeddict_type(cls, info: TypeInfo) -> "TypeAlias":
"""Generate an alias to the TypedDict type described by a given TypeInfo."""
assert info.typeddict_type
return TypeAlias(
info.typeddict_type.copy_modified(fallback=mypy.types.Instance(info, [])),
info.fullname,
info.line,
info.column,
)

@property
def name(self) -> str:
return self._fullname.split(".")[-1]
Expand Down
47 changes: 34 additions & 13 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,17 +1378,7 @@ def analyze_class(self, defn: ClassDef) -> None:
self.mark_incomplete(defn.name, defn)
return

is_typeddict, info = self.typed_dict_analyzer.analyze_typeddict_classdef(defn)
if is_typeddict:
for decorator in defn.decorators:
decorator.accept(self)
if isinstance(decorator, RefExpr):
if decorator.fullname in FINAL_DECORATOR_NAMES:
self.fail("@final cannot be used with TypedDict", decorator)
if info is None:
self.mark_incomplete(defn.name, defn)
else:
self.prepare_class_def(defn, info)
if self.analyze_typeddict_classdef(defn):
return

if self.analyze_namedtuple_classdef(defn):
Expand Down Expand Up @@ -1423,6 +1413,28 @@ def analyze_class_body_common(self, defn: ClassDef) -> None:
self.apply_class_plugin_hooks(defn)
self.leave_class()

def analyze_typeddict_classdef(self, defn: ClassDef) -> bool:
if (
defn.info
and defn.info.typeddict_type
and not has_placeholder(defn.info.typeddict_type)
):
# This is a valid TypedDict, and it is fully analyzed.
return True
is_typeddict, info = self.typed_dict_analyzer.analyze_typeddict_classdef(defn)
if is_typeddict:
for decorator in defn.decorators:
decorator.accept(self)
if isinstance(decorator, RefExpr):
if decorator.fullname in FINAL_DECORATOR_NAMES:
self.fail("@final cannot be used with TypedDict", decorator)
if info is None:
self.mark_incomplete(defn.name, defn)
else:
self.prepare_class_def(defn, info)
return True
return False

def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
"""Check if this class can define a named tuple."""
if (
Expand Down Expand Up @@ -1840,7 +1852,7 @@ def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instanc
if info.tuple_type and info.tuple_type != base and not has_placeholder(info.tuple_type):
self.fail("Class has two incompatible bases derived from tuple", defn)
defn.has_incompatible_baseclass = True
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
if info.special_alias and has_placeholder(info.special_alias.target):
self.defer(force_progress=True)
info.update_tuple_type(base)

Expand Down Expand Up @@ -2660,7 +2672,11 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool:
"""Check if s defines a typed dict."""
if isinstance(s.rvalue, CallExpr) and isinstance(s.rvalue.analyzed, TypedDictExpr):
return True # This is a valid and analyzed typed dict definition, nothing to do here.
if s.rvalue.analyzed.info.typeddict_type and not has_placeholder(
s.rvalue.analyzed.info.typeddict_type
):
# This is a valid and analyzed typed dict definition, nothing to do here.
return True
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
return False
lvalue = s.lvalues[0]
Expand Down Expand Up @@ -5504,6 +5520,11 @@ def defer(self, debug_context: Optional[Context] = None, force_progress: bool =
"""
assert not self.final_iteration, "Must not defer during final iteration"
if force_progress:
# Usually, we report progress if we have replaced a placeholder node
# with an actual valid node. However, sometimes we need to update an
# existing node *in-place*. For example, this is used by type aliases
# in context of forward references and/or recursive aliases, and in
# similar situations (recursive named tuples etc).
self.progress = True
self.deferred = True
# Store debug info for this deferral.
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def build_namedtuple_typeinfo(
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
info.is_named_tuple = True
tuple_base = TupleType(types, fallback)
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
if info.special_alias and has_placeholder(info.special_alias.target):
self.api.defer(force_progress=True)
info.update_tuple_type(tuple_base)
info.line = line
Expand Down
7 changes: 5 additions & 2 deletions mypy/semanal_newtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool:

old_type, should_defer = self.check_newtype_args(var_name, call, s)
old_type = get_proper_type(old_type)
if not call.analyzed:
if not isinstance(call.analyzed, NewTypeExpr):
call.analyzed = NewTypeExpr(var_name, old_type, line=call.line, column=call.column)
else:
call.analyzed.old_type = old_type
if old_type is None:
if should_defer:
# Base type is not ready.
Expand Down Expand Up @@ -230,6 +232,7 @@ def build_newtype_typeinfo(
existing_info: Optional[TypeInfo],
) -> TypeInfo:
info = existing_info or self.api.basic_new_typeinfo(name, base_type, line)
info.bases = [base_type] # Update in case there were nested placeholders.
info.is_newtype = True

# Add __init__ method
Expand All @@ -250,7 +253,7 @@ def build_newtype_typeinfo(
init_func._fullname = info.fullname + ".__init__"
info.names["__init__"] = SymbolTableNode(MDEF, init_func)

if info.tuple_type and has_placeholder(info.tuple_type):
if has_placeholder(old_type) or info.tuple_type and has_placeholder(info.tuple_type):
self.api.defer(force_progress=True)
return info

Expand Down
49 changes: 39 additions & 10 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TypeInfo,
)
from mypy.options import Options
from mypy.semanal_shared import SemanticAnalyzerInterface
from mypy.semanal_shared import SemanticAnalyzerInterface, has_placeholder
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type
from mypy.types import TPDICT_NAMES, AnyType, RequiredType, Type, TypedDictType, TypeOfAny

Expand Down Expand Up @@ -66,6 +66,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
if base_expr.fullname in TPDICT_NAMES or self.is_typeddict(base_expr):
possible = True
if possible:
existing_info = None
if isinstance(defn.analyzed, TypedDictExpr):
existing_info = defn.analyzed.info
if (
len(defn.base_type_exprs) == 1
and isinstance(defn.base_type_exprs[0], RefExpr)
Expand All @@ -76,7 +79,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
if fields is None:
return True, None # Defer
info = self.build_typeddict_typeinfo(
defn.name, fields, types, required_keys, defn.line
defn.name, fields, types, required_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
Expand Down Expand Up @@ -128,7 +131,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
keys.extend(new_keys)
types.extend(new_types)
required_keys.update(new_required_keys)
info = self.build_typeddict_typeinfo(defn.name, keys, types, required_keys, defn.line)
info = self.build_typeddict_typeinfo(
defn.name, keys, types, required_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
defn.analyzed.column = defn.column
Expand Down Expand Up @@ -173,7 +178,12 @@ def analyze_typeddict_classdef_fields(
if stmt.type is None:
types.append(AnyType(TypeOfAny.unannotated))
else:
analyzed = self.api.anal_type(stmt.type, allow_required=True)
analyzed = self.api.anal_type(
stmt.type,
allow_required=True,
allow_placeholder=self.options.enable_recursive_aliases
and not self.api.is_func_scope(),
)
if analyzed is None:
return None, [], set() # Need to defer
types.append(analyzed)
Expand Down Expand Up @@ -232,7 +242,7 @@ def check_typeddict(
name, items, types, total, ok = res
if not ok:
# Error. Construct dummy return value.
info = self.build_typeddict_typeinfo("TypedDict", [], [], set(), call.line)
info = self.build_typeddict_typeinfo("TypedDict", [], [], set(), call.line, None)
else:
if var_name is not None and name != var_name:
self.fail(
Expand All @@ -254,7 +264,12 @@ def check_typeddict(
types = [ # unwrap Required[T] to just T
t.item if isinstance(t, RequiredType) else t for t in types # type: ignore[misc]
]
info = self.build_typeddict_typeinfo(name, items, types, required_keys, call.line)
existing_info = None
if isinstance(node.analyzed, TypedDictExpr):
existing_info = node.analyzed.info
info = self.build_typeddict_typeinfo(
name, items, types, required_keys, call.line, existing_info
)
info.line = node.line
# Store generated TypeInfo under both names, see semanal_namedtuple for more details.
if name != var_name or is_func_scope:
Expand Down Expand Up @@ -357,7 +372,12 @@ def parse_typeddict_fields_with_types(
else:
self.fail_typeddict_arg("Invalid field type", field_type_expr)
return [], [], False
analyzed = self.api.anal_type(type, allow_required=True)
analyzed = self.api.anal_type(
type,
allow_required=True,
allow_placeholder=self.options.enable_recursive_aliases
and not self.api.is_func_scope(),
)
if analyzed is None:
return None
types.append(analyzed)
Expand All @@ -370,7 +390,13 @@ def fail_typeddict_arg(
return "", [], [], True, False

def build_typeddict_typeinfo(
self, name: str, items: List[str], types: List[Type], required_keys: Set[str], line: int
self,
name: str,
items: List[str],
types: List[Type],
required_keys: Set[str],
line: int,
existing_info: Optional[TypeInfo],
) -> TypeInfo:
# Prefer typing then typing_extensions if available.
fallback = (
Expand All @@ -379,8 +405,11 @@ def build_typeddict_typeinfo(
or self.api.named_type_or_none("mypy_extensions._TypedDict", [])
)
assert fallback is not None
info = self.api.basic_new_typeinfo(name, fallback, line)
info.typeddict_type = TypedDictType(dict(zip(items, types)), required_keys, fallback)
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
typeddict_type = TypedDictType(dict(zip(items, types)), required_keys, fallback)
if info.special_alias and has_placeholder(info.special_alias.target):
self.api.defer(force_progress=True)
info.update_typeddict_type(typeddict_type)
return info

# Helpers
Expand Down
16 changes: 8 additions & 8 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def replacement_map_from_symbol_table(
node.node.names, new_node.node.names, prefix
)
replacements.update(type_repl)
if node.node.tuple_alias and new_node.node.tuple_alias:
replacements[new_node.node.tuple_alias] = node.node.tuple_alias
if node.node.special_alias and new_node.node.special_alias:
replacements[new_node.node.special_alias] = node.node.special_alias
return replacements


Expand Down Expand Up @@ -338,10 +338,10 @@ def fixup(self, node: SN) -> SN:
new = self.replacements[node]
skip_slots: Tuple[str, ...] = ()
if isinstance(node, TypeInfo) and isinstance(new, TypeInfo):
# Special case: tuple_alias is not exposed in symbol tables, but may appear
# Special case: special_alias is not exposed in symbol tables, but may appear
# in external types (e.g. named tuples), so we need to update it manually.
skip_slots = ("tuple_alias",)
replace_object_state(new.tuple_alias, node.tuple_alias)
skip_slots = ("special_alias",)
replace_object_state(new.special_alias, node.special_alias)
replace_object_state(new, node, skip_slots=skip_slots)
return cast(SN, new)
return node
Expand Down Expand Up @@ -372,8 +372,8 @@ def process_type_info(self, info: Optional[TypeInfo]) -> None:
self.fixup_type(target)
self.fixup_type(info.tuple_type)
self.fixup_type(info.typeddict_type)
if info.tuple_alias:
self.fixup_type(info.tuple_alias.target)
if info.special_alias:
self.fixup_type(info.special_alias.target)
info.defn.info = self.fixup(info)
replace_nodes_in_symbol_table(info.names, self.replacements)
for i, item in enumerate(info.mro):
Expand Down Expand Up @@ -547,7 +547,7 @@ def replace_nodes_in_symbol_table(
new = replacements[node.node]
old = node.node
# Needed for TypeInfo, see comment in fixup() above.
replace_object_state(new, old, skip_slots=("tuple_alias",))
replace_object_state(new, old, skip_slots=("special_alias",))
node.node = new
if isinstance(node.node, (Var, TypeAlias)):
# Handle them here just in case these aren't exposed through the AST.
Expand Down
Loading