Skip to content

Commit

Permalink
Fix crash on lambda in generic context with generic method in body (#…
Browse files Browse the repository at this point in the history
…15155)

Fixes #15060 

The example in the issue contains another case where an erased type may
legitimately appear in a generic function. Namely, when we have a lambda
in a generic context, and its body contains a call to a generic
_method_. First, since we infer the type of lambda in erased context,
some of lambda parameters may get assigned a type containing erased
components. Then, when accessing a generic method on such type we may
get a callable that is both generic and has erased components, thus
causing the crash (actually there are two very similar crashes depending
on the details of the generic method).

Provided that we now have two legitimate cases for erased type appearing
in `expand_type()`, and special-casing (enabling) them would be tricky
(a lot of functions will need to have `allow_erased_callables`), I
propose to simply remove the check.

Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
  • Loading branch information
ilevkivskyi and hauntsaninja authored May 2, 2023
1 parent 6f28cc3 commit 9f69bea
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 36 deletions.
19 changes: 5 additions & 14 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def apply_generic_arguments(
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool = False,
allow_erased_callables: bool = False,
) -> CallableType:
"""Apply generic type arguments to a callable type.
Expand Down Expand Up @@ -119,15 +118,9 @@ def apply_generic_arguments(
star_index = callable.arg_kinds.index(ARG_STAR)
callable = callable.copy_modified(
arg_types=(
[
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[:star_index]
]
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
+ [callable.arg_types[star_index]]
+ [
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[star_index + 1 :]
]
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
)
)

Expand Down Expand Up @@ -163,22 +156,20 @@ def apply_generic_arguments(
assert False, "mypy bug: unhandled case applying unpack"
else:
callable = callable.copy_modified(
arg_types=[
expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types
]
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
)

# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type, allow_erased_callables)
type_guard = expand_type(callable.type_guard, id_to_type)
else:
type_guard = None

# The callable may retain some type vars if only some were applied.
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]

return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type, allow_erased_callables),
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
)
30 changes: 9 additions & 21 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,25 @@


@overload
def expand_type(
typ: CallableType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
) -> CallableType:
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType:
...


@overload
def expand_type(
typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
) -> ProperType:
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType:
...


@overload
def expand_type(
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
) -> Type:
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
...


def expand_type(
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
) -> Type:
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
"""Substitute any type variable references in a type given by a type
environment.
"""
return typ.accept(ExpandTypeVisitor(env, allow_erased_callables))
return typ.accept(ExpandTypeVisitor(env))


@overload
Expand Down Expand Up @@ -195,11 +187,8 @@ class ExpandTypeVisitor(TypeVisitor[Type]):

variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value

def __init__(
self, variables: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
) -> None:
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
self.variables = variables
self.allow_erased_callables = allow_erased_callables

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
Expand All @@ -217,13 +206,12 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
return t

def visit_erased_type(self, t: ErasedType) -> Type:
if not self.allow_erased_callables:
raise RuntimeError()
# This may happen during type inference if some function argument
# type is a generic callable, and its erased form will appear in inferred
# constraints, then solver may check subtyping between them, which will trigger
# unify_generic_callables(), this is why we can get here. In all other cases it
# is a sign of a bug, since <Erased> should never appear in any stored types.
# unify_generic_callables(), this is why we can get here. Another example is
# when inferring type of lambda in generic context, the lambda body contains
# a generic method in generic class.
return t

def visit_instance(self, t: Instance) -> Type:
Expand Down
2 changes: 1 addition & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,7 @@ def report(*args: Any) -> None:
# (probably also because solver needs subtyping). See also comment in
# ExpandTypeVisitor.visit_erased_type().
applied = mypy.applytype.apply_generic_arguments(
type, non_none_inferred_vars, report, context=target, allow_erased_callables=True
type, non_none_inferred_vars, report, context=target
)
if had_errors:
return None
Expand Down
13 changes: 13 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -2698,3 +2698,16 @@ def func(var: T) -> T:
reveal_type(func(1)) # N: Revealed type is "builtins.int"

[builtins fixtures/tuple.pyi]

[case testGenericLambdaGenericMethodNoCrash]
from typing import TypeVar, Union, Callable, Generic

S = TypeVar("S")
T = TypeVar("T")

def f(x: Callable[[G[T]], int]) -> T: ...

class G(Generic[T]):
def g(self, x: S) -> Union[S, T]: ...

f(lambda x: x.g(0)) # E: Cannot infer type argument 1 of "f"

0 comments on commit 9f69bea

Please sign in to comment.