Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic support for recursive TypeVar defaults (PEP 696) #16878

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,18 @@ def apply_generic_arguments(
# TODO: move apply_poly() logic from checkexpr.py here when new inference
# becomes universally used (i.e. in all passes + in unification).
# With this new logic we can actually *add* some new free variables.
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
remaining_tvars: list[TypeVarLikeType] = []
for tv in tvars:
if tv.id in id_to_type:
continue
if not tv.has_default():
remaining_tvars.append(tv)
continue
# TypeVarLike isn't in id_to_type mapping.
# Only expand the TypeVar default here.
typ = expand_type(tv, id_to_type)
assert isinstance(typ, TypeVarLikeType)
remaining_tvars.append(typ)

return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type),
Expand Down
9 changes: 9 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
self.variables = variables
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
Expand Down Expand Up @@ -226,6 +227,14 @@ def visit_type_var(self, t: TypeVarType) -> Type:
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
return repl.copy_modified(last_known_value=None)
if isinstance(repl, TypeVarType) and repl.has_default():
if (tvar_id := repl.id) in self.recursive_tvar_guard:
return self.recursive_tvar_guard[tvar_id] or repl
self.recursive_tvar_guard[tvar_id] = None
repl = repl.accept(self)
if isinstance(repl, TypeVarType):
repl.default = repl.default.accept(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we .copy_modified() here?

Copy link
Collaborator Author

@cdce8p cdce8p Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't work, unfortunately. This here is used to link nested TypeVarTypes so if it's changed / evaluated in one location the other stays in sync.

An example from the tests

T1 = TypeVar("T1", default=str)
T2 = TypeVar("T2", default=T1)

class ClassD1(Generic[T1, T2]): ...

k = ClassD1()
reveal_type(k)  # should be `ClassD1[str, str]`

For the first pass the type variables are as follows

T1`1 = str
T2`2 = T1`1 = str

The section here now replaces the default for T2 with the tbd of the default of T1:

T1`1 = str
T2`2 = <result of T1`1>

--
With .copy_modified it would be a copy of the T1'1 thus when it's evaluated it's not used for T2. The test would still return

ClassD1[str, T1`1 = str]

At least that's how I understand it 😅

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked at it more closely. It's basically as said above. In particular, it's called in freshend_function_type_vars. We enter this function with linked TypeVars

callee.variables[0] == callee.variables[1].default  # True

Then create new ones for each and link them again in expand_type, so that

tvs[0] == tvs[1].default  # True

# and by extension
fresh.variables[0] == fresh.variables[1].default  # True

mypy/mypy/expandtype.py

Lines 119 to 130 in 517f5ae

def freshen_function_type_vars(callee: F) -> F:
"""Substitute fresh type variables for generic function type variables."""
if isinstance(callee, CallableType):
if not callee.is_generic():
return cast(F, callee)
tvs = []
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
tv = v.new_unification_variable(v)
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)

self.recursive_tvar_guard[tvar_id] = repl
return repl

def visit_param_spec(self, t: ParamSpecType) -> Type:
Expand Down
9 changes: 9 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,6 +1954,15 @@ class Foo(Bar, Generic[T]): ...
del base_type_exprs[i]
tvar_defs: list[TypeVarLikeType] = []
for name, tvar_expr in declared_tvars:
tvar_expr_default = tvar_expr.default
if isinstance(tvar_expr_default, UnboundType):
# TODO: - detect out of order and self-referencing TypeVars
# - nested default types, e.g. list[T1]
n = self.lookup_qualified(
tvar_expr_default.name, tvar_expr_default, suppress_errors=True
)
if n is not None and (default := self.tvar_scope.get_binding(n)) is not None:
tvar_expr.default = default
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
tvar_defs.append(tvar_def)
return base_type_exprs, tvar_defs, is_protocol
Expand Down
22 changes: 22 additions & 0 deletions mypy/tvar_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,26 @@
TypeVarTupleType,
TypeVarType,
)
from mypy.typetraverser import TypeTraverserVisitor


class TypeVarLikeNamespaceSetter(TypeTraverserVisitor):
"""Set namespace for all TypeVarLikeTypes types."""

def __init__(self, namespace: str) -> None:
self.namespace = namespace

def visit_type_var(self, t: TypeVarType) -> None:
t.id.namespace = self.namespace
super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> None:
t.id.namespace = self.namespace
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.id.namespace = self.namespace
super().visit_type_var_tuple(t)


class TypeVarLikeScope:
Expand Down Expand Up @@ -88,6 +108,8 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
i = self.func_id
# TODO: Consider also using namespaces for functions
namespace = ""
tvar_expr.default.accept(TypeVarLikeNamespaceSetter(namespace))

if isinstance(tvar_expr, TypeVarExpr):
tvar_def: TypeVarLikeType = TypeVarType(
name=name,
Expand Down
6 changes: 3 additions & 3 deletions mypy/typetraverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def visit_type_var(self, t: TypeVarType) -> None:
# Note that type variable values and upper bound aren't treated as
# components, since they are components of the type variable
# definition. We want to traverse everything just once.
pass
t.default.accept(self)

def visit_param_spec(self, t: ParamSpecType) -> None:
pass
t.default.accept(self)

def visit_parameters(self, t: Parameters) -> None:
self.traverse_types(t.arg_types)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
pass
t.default.accept(self)

def visit_literal_type(self, t: LiteralType) -> None:
t.fallback.accept(self)
Expand Down
78 changes: 78 additions & 0 deletions test-data/unit/check-typevar-defaults.test
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,84 @@ def func_c4(
reveal_type(m) # N: Revealed type is "__main__.ClassC4[builtins.int, builtins.float]"
[builtins fixtures/tuple.pyi]

[case testTypeVarDefaultsClassRecursive1]
# flags: --disallow-any-generics
from typing import Generic, TypeVar

T1 = TypeVar("T1", default=str)
T2 = TypeVar("T2", default=T1)
T3 = TypeVar("T3", default=T2)

class ClassD1(Generic[T1, T2]): ...

def func_d1(
a: ClassD1,
b: ClassD1[int],
c: ClassD1[int, float]
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
reveal_type(b) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
reveal_type(c) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"

k = ClassD1()
reveal_type(k) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
l = ClassD1[int]()
reveal_type(l) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
m = ClassD1[int, float]()
reveal_type(m) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"

class ClassD2(Generic[T1, T2, T3]): ...

def func_d2(
a: ClassD2,
b: ClassD2[int],
c: ClassD2[int, float],
d: ClassD2[int, float, str],
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
reveal_type(b) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
reveal_type(c) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
reveal_type(d) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"

k = ClassD2()
reveal_type(k) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
l = ClassD2[int]()
reveal_type(l) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
m = ClassD2[int, float]()
reveal_type(m) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
n = ClassD2[int, float, str]()
reveal_type(n) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"

[case testTypeVarDefaultsClassRecursiveMultipleFiles]
# flags: --disallow-any-generics
from typing import Generic, TypeVar
from file2 import T as T2

T = TypeVar('T', default=T2)

class ClassG1(Generic[T2, T]):
pass

def func(
a: ClassG1,
b: ClassG1[str],
c: ClassG1[str, float],
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]"
reveal_type(b) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]"
reveal_type(c) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]"

k = ClassG1()
reveal_type(k) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]"
l = ClassG1[str]()
reveal_type(l) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]"
m = ClassG1[str, float]()
reveal_type(m) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]"

[file file2.py]
from typing import TypeVar
T = TypeVar('T', default=int)

[case testTypeVarDefaultsTypeAlias1]
# flags: --disallow-any-generics
from typing import Any, Dict, List, Tuple, TypeVar, Union
Expand Down
Loading