Skip to content

Commit

Permalink
Lenient handling of trivial Callable suffixes (#15913)
Browse files Browse the repository at this point in the history
Fixes #15734
Fixes #15188
Fixes #14321
Fixes #13107 (plain Callable was
already working, this fixes the protocol example)
Fixes #16058

It looks like treating trivial suffixes (especially for erased
callables) as "whatever works" is a right thing, because it reflects the
whole idea of why we normally check subtyping with respect to an e.g.
erased type. As you can see this fixes a bunch of issues. Note it was
necessary to make couple more tweaks to make everything work smoothly:
* Adjust self-type erasure level in `checker.py` to match other places.
* Explicitly allow `Callable` as a `self`/`cls` annotation (actually I
am not sure we need to keep this check at all, since we now have good
inference for self-types, and we check they are safe either at
definition site or at call site).
  • Loading branch information
ilevkivskyi authored Sep 14, 2023
1 parent b327557 commit f41e24c
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 11 deletions.
4 changes: 3 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,9 @@ def check_func_def(
):
if defn.is_class or defn.name == "__new__":
ref_type = mypy.types.TypeType.make_normalized(ref_type)
erased = get_proper_type(erase_to_bound(arg_type))
# This level of erasure matches the one in checkmember.check_self_arg(),
# better keep these two checks consistent.
erased = get_proper_type(erase_typevars(erase_to_bound(arg_type)))
if not is_subtype(ref_type, erased, ignore_type_params=True):
if (
isinstance(erased, Instance)
Expand Down
2 changes: 2 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ def f(self: S) -> T: ...
return functype
else:
selfarg = get_proper_type(item.arg_types[0])
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
if subtypes.is_subtype(dispatched_arg_type, erase_typevars(erase_to_bound(selfarg))):
new_items.append(item)
elif isinstance(selfarg, ParamSpecType):
Expand Down
3 changes: 3 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,9 @@ def report_protocol_problems(
not is_subtype(subtype, erase_type(supertype), options=self.options)
or not subtype.type.defn.type_vars
or not supertype.type.defn.type_vars
# Always show detailed message for ParamSpec
or subtype.type.has_param_spec_type
or supertype.type.has_param_spec_type
):
type_name = format_type(subtype, self.options, module_names=True)
self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code)
Expand Down
19 changes: 16 additions & 3 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,18 @@ def are_trivial_parameters(param: Parameters | NormalizedCallableType) -> bool:
)


def is_trivial_suffix(param: Parameters | NormalizedCallableType) -> bool:
param_star = param.var_arg()
param_star2 = param.kw_arg()
return (
param.arg_kinds[-2:] == [ARG_STAR, ARG_STAR2]
and param_star is not None
and isinstance(get_proper_type(param_star.typ), AnyType)
and param_star2 is not None
and isinstance(get_proper_type(param_star2.typ), AnyType)
)


def are_parameters_compatible(
left: Parameters | NormalizedCallableType,
right: Parameters | NormalizedCallableType,
Expand All @@ -1540,6 +1552,7 @@ def are_parameters_compatible(
# Treat "def _(*a: Any, **kw: Any) -> X" similarly to "Callable[..., X]"
if are_trivial_parameters(right):
return True
trivial_suffix = is_trivial_suffix(right)

# Match up corresponding arguments and check them for compatibility. In
# every pair (argL, argR) of corresponding arguments from L and R, argL must
Expand Down Expand Up @@ -1570,7 +1583,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
if right_arg is None:
return False
if left_arg is None:
return not allow_partial_overlap
return not allow_partial_overlap and not trivial_suffix
return not is_compat(right_arg.typ, left_arg.typ)

if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
Expand All @@ -1594,7 +1607,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None:
if right_star is not None and not trivial_suffix:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand All @@ -1621,7 +1634,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# Phase 1d: Check kw args. Right has an infinite series of optional named
# arguments. Get all further named args of left, and make sure
# they're more general than the corresponding member in right.
if right_star2 is not None:
if right_star2 is not None and not trivial_suffix:
right_names = {name for name in right.arg_names if name is not None}
left_only_names = set()
for name, kind in zip(left.arg_names, left.arg_kinds):
Expand Down
4 changes: 4 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def supported_self_type(typ: ProperType) -> bool:
"""
if isinstance(typ, TypeType):
return supported_self_type(typ.item)
if isinstance(typ, CallableType):
# Special case: allow class callable instead of Type[...] as cls annotation,
# as well as callable self for callback protocols.
return True
return isinstance(typ, TypeVarType) or (
isinstance(typ, Instance) and typ != fill_typevars(typ.type)
)
Expand Down
31 changes: 31 additions & 0 deletions test-data/unit/check-callable.test
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,34 @@ a: A
a() # E: Missing positional argument "other" in call to "__call__" of "A"
a(a)
a(lambda: None)

[case testCallableSubtypingTrivialSuffix]
from typing import Any, Protocol

class Call(Protocol):
def __call__(self, x: int, *args: Any, **kwargs: Any) -> None: ...

def f1() -> None: ...
a1: Call = f1 # E: Incompatible types in assignment (expression has type "Callable[[], None]", variable has type "Call") \
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
def f2(x: str) -> None: ...
a2: Call = f2 # E: Incompatible types in assignment (expression has type "Callable[[str], None]", variable has type "Call") \
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
def f3(y: int) -> None: ...
a3: Call = f3 # E: Incompatible types in assignment (expression has type "Callable[[int], None]", variable has type "Call") \
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
def f4(x: int) -> None: ...
a4: Call = f4

def f5(x: int, y: int) -> None: ...
a5: Call = f5

def f6(x: int, y: int = 0) -> None: ...
a6: Call = f6

def f7(x: int, *, y: int) -> None: ...
a7: Call = f7

def f8(x: int, *args: int, **kwargs: str) -> None: ...
a8: Call = f8
[builtins fixtures/tuple.pyi]
12 changes: 6 additions & 6 deletions test-data/unit/check-modules.test
Original file line number Diff line number Diff line change
Expand Up @@ -3193,29 +3193,29 @@ from test1 import aaaa # E: Module "test1" has no attribute "aaaa"
import b
[file a.py]
class Foo:
def frobnicate(self, x, *args, **kwargs): pass
def frobnicate(self, x: str, *args, **kwargs): pass
[file b.py]
from a import Foo
class Bar(Foo):
def frobnicate(self) -> None: pass
[file b.py.2]
from a import Foo
class Bar(Foo):
def frobnicate(self, *args) -> None: pass
def frobnicate(self, *args: int) -> None: pass
[file b.py.3]
from a import Foo
class Bar(Foo):
def frobnicate(self, *args) -> None: pass # type: ignore[override] # I know
def frobnicate(self, *args: int) -> None: pass # type: ignore[override] # I know
[builtins fixtures/dict.pyi]
[out1]
tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo"
tmp/b.py:3: note: Superclass:
tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: Subclass:
tmp/b.py:3: note: def frobnicate(self) -> None
[out2]
tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo"
tmp/b.py:3: note: Superclass:
tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: Subclass:
tmp/b.py:3: note: def frobnicate(self, *args: Any) -> None
tmp/b.py:3: note: def frobnicate(self, *args: int) -> None
139 changes: 138 additions & 1 deletion test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,12 @@ class A(Protocol[P]):
...

def bar(b: A[P]) -> A[Concatenate[int, P]]:
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") \
# N: Following member(s) of "A[P]" have conflicts: \
# N: Expected: \
# N: def foo(self, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \
# N: Got: \
# N: def foo(self, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingValidNonStrict]
Expand Down Expand Up @@ -1825,6 +1830,138 @@ c: C[int, [int, str], str] # E: Nested parameter specifications are not allowed
reveal_type(c) # N: Revealed type is "__main__.C[Any]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateSelfType]
from typing import Callable
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
class A:
def __init__(self, a_param_1: str) -> None: ...

@classmethod
def add_params(cls: Callable[P, A]) -> Callable[Concatenate[float, P], A]:
def new_constructor(i: float, *args: P.args, **kwargs: P.kwargs) -> A:
return cls(*args, **kwargs)
return new_constructor

@classmethod
def remove_params(cls: Callable[Concatenate[str, P], A]) -> Callable[P, A]:
def new_constructor(*args: P.args, **kwargs: P.kwargs) -> A:
return cls("my_special_str", *args, **kwargs)
return new_constructor

reveal_type(A.add_params()) # N: Revealed type is "def (builtins.float, a_param_1: builtins.str) -> __main__.A"
reveal_type(A.remove_params()) # N: Revealed type is "def () -> __main__.A"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackProtocol]
from typing import Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R", covariant=True)

class Path: ...

class Function(Protocol[P, R]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...

def file_cache(fn: Function[Concatenate[Path, P], R]) -> Function[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "__main__.Function[[*, some_arg: builtins.int], builtins.int]"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateKeywordOnly]
from typing import Callable, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R")

class Path: ...

def file_cache(fn: Callable[Concatenate[Path, P], R]) -> Callable[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "def (*, some_arg: builtins.int) -> builtins.int"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackApply]
from typing import Callable, Protocol
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class FuncType(Protocol[P]):
def __call__(self, x: int, s: str, *args: P.args, **kw_args: P.kwargs) -> str: ...

def forwarder1(fp: FuncType[P], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def forwarder2(fp: Callable[Concatenate[int, str, P], str], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def my_f(x: int, s: str, d: bool) -> str: ...
forwarder1(my_f, True) # OK
forwarder2(my_f, True) # OK
forwarder1(my_f, 1.0) # E: Argument 2 to "forwarder1" has incompatible type "float"; expected "bool"
forwarder2(my_f, 1.0) # E: Argument 2 to "forwarder2" has incompatible type "float"; expected "bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecCallbackProtocolSelf]
from typing import Callable, Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

Params = ParamSpec("Params")
Result = TypeVar("Result", covariant=True)

class FancyMethod(Protocol):
def __call__(self, arg1: int, arg2: str) -> bool: ...
def return_me(self: Callable[Params, Result]) -> Callable[Params, Result]: ...
def return_part(self: Callable[Concatenate[int, Params], Result]) -> Callable[Params, Result]: ...

m: FancyMethod
reveal_type(m.return_me()) # N: Revealed type is "def (arg1: builtins.int, arg2: builtins.str) -> builtins.bool"
reveal_type(m.return_part()) # N: Revealed type is "def (arg2: builtins.str) -> builtins.bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecInferenceCallableAgainstAny]
from typing import Callable, TypeVar, Any
from typing_extensions import ParamSpec, Concatenate

_P = ParamSpec("_P")
_R = TypeVar("_R")

class A: ...
a = A()

def a_func(
func: Callable[Concatenate[A, _P], _R],
) -> Callable[Concatenate[Any, _P], _R]:
def wrapper(__a: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
return func(a, *args, **kwargs)
return wrapper

def test(a, *args): ...
x: Any
y: object

a_func(test)
x = a_func(test)
y = a_func(test)
[builtins fixtures/paramspec.pyi]

[case testParamSpecInferenceWithCallbackProtocol]
from typing import Protocol, Callable, ParamSpec

Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/paramspec.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class object:

class function: ...
class ellipsis: ...
class classmethod: ...

class type:
def __init__(self, *a: object) -> None: ...
Expand Down

0 comments on commit f41e24c

Please sign in to comment.