Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for decorated overloads #15898

Merged
merged 6 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,30 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
self.visit_decorator(defn.items[0])
for fdef in defn.items:
assert isinstance(fdef, Decorator)
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
if defn.is_property:
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
else:
# Perform full check for real overloads to infer type of all decorated
# overload variants.
self.visit_decorator_inner(fdef, allow_empty=True)
if fdef.func.abstract_status in (IS_ABSTRACT, IMPLICITLY_ABSTRACT):
num_abstract += 1
if num_abstract not in (0, len(defn.items)):
self.fail(message_registry.INCONSISTENT_ABSTRACT_OVERLOAD, defn)
if defn.impl:
defn.impl.accept(self)
if not defn.is_property:
self.check_overlapping_overloads(defn)
if defn.type is None:
item_types = []
for item in defn.items:
assert isinstance(item, Decorator)
item_type = self.extract_callable_type(item.var.type, item)
if item_type is not None:
item_types.append(item_type)
if item_types:
defn.type = Overloaded(item_types)
# Check override validity after we analyzed current definition.
if defn.info:
found_method_base_classes = self.check_method_override(defn)
if (
Expand All @@ -653,10 +670,35 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
self.msg.no_overridable_method(defn.name, defn)
self.check_explicit_override_decorator(defn, found_method_base_classes, defn.impl)
self.check_inplace_operator_method(defn)
if not defn.is_property:
self.check_overlapping_overloads(defn)
return None

def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
"""Get type as seen by an overload item caller."""
inner_type = get_proper_type(inner_type)
outer_type: CallableType | None = None
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, CallableType):
outer_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=ctx,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
)
if isinstance(inner_call, CallableType):
outer_type = inner_call
if outer_type is None:
self.msg.not_callable(inner_type, ctx)
return outer_type

def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# At this point we should have set the impl already, and all remaining
# items are decorators
Expand All @@ -680,40 +722,20 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:

# This can happen if we've got an overload with a different
# decorator or if the implementation is untyped -- we gave up on the types.
inner_type = get_proper_type(inner_type)
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, CallableType):
impl_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=defn.impl,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
)
if isinstance(inner_call, CallableType):
impl_type = inner_call
if impl_type is None:
self.msg.not_callable(inner_type, defn.impl)
impl_type = self.extract_callable_type(inner_type, defn.impl)

is_descriptor_get = defn.info and defn.name == "__get__"
for i, item in enumerate(defn.items):
# TODO overloads involving decorators
assert isinstance(item, Decorator)
sig1 = self.function_type(item.func)
assert isinstance(sig1, CallableType)
sig1 = self.extract_callable_type(item.var.type, item)
if sig1 is None:
continue

for j, item2 in enumerate(defn.items[i + 1 :]):
assert isinstance(item2, Decorator)
sig2 = self.function_type(item2.func)
assert isinstance(sig2, CallableType)
sig2 = self.extract_callable_type(item2.var.type, item2)
if sig2 is None:
continue

if not are_argument_counts_overlapping(sig1, sig2):
continue
Expand Down Expand Up @@ -4751,17 +4773,20 @@ def visit_decorator(self, e: Decorator) -> None:
e.var.type = AnyType(TypeOfAny.special_form)
e.var.is_ready = True
return
self.visit_decorator_inner(e)

def visit_decorator_inner(self, e: Decorator, allow_empty: bool = False) -> None:
if self.recurse_into_functions:
with self.tscope.function_scope(e.func):
self.check_func_item(e.func, name=e.func.name)
self.check_func_item(e.func, name=e.func.name, allow_empty=allow_empty)

# Process decorators from the inside out to determine decorated signature, which
# may be different from the declared signature.
sig: Type = self.function_type(e.func)
for d in reversed(e.decorators):
if refers_to_fullname(d, OVERLOAD_NAMES):
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
if not allow_empty:
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
continue
dec = self.expr_checker.accept(d)
temp = self.temp_node(sig, context=e)
Expand All @@ -4788,6 +4813,8 @@ def visit_decorator(self, e: Decorator) -> None:
self.msg.fail("Too many arguments for property", e)
self.check_incompatible_property_override(e)
# For overloaded functions we already checked override for overload as a whole.
if allow_empty:
return
if e.func.info and not e.func.is_dynamic() and not e.is_overload:
found_method_base_classes = self.check_method_override(e)
if (
Expand Down
73 changes: 67 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,13 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
elif isinstance(node, FuncDef):
# Reference to a global function.
result = function_type(node, self.named_type("builtins.function"))
elif isinstance(node, OverloadedFuncDef) and node.type is not None:
# node.type is None when there are multiple definitions of a function
# and it's decorated by something that is not typing.overload
# TODO: use a dummy Overloaded instead of AnyType in this case
# like we do in mypy.types.function_type()?
result = node.type
elif isinstance(node, OverloadedFuncDef):
if node.type is None:
if self.chk.in_checked_function() and node.items:
self.chk.handle_cannot_determine_type(node.name, e)
result = AnyType(TypeOfAny.from_error)
else:
result = node.type
elif isinstance(node, TypeInfo):
# Reference to a type object.
if node.typeddict_type:
Expand Down Expand Up @@ -1337,6 +1338,55 @@ def transform_callee_type(

return callee

def is_generic_decorator_overload_call(
self, callee_type: CallableType, args: list[Expression]
) -> Overloaded | None:
"""Check if this looks like an application of a generic function to overload argument."""
assert callee_type.variables
if len(callee_type.arg_types) != 1 or len(args) != 1:
# TODO: can we handle more general cases?
return None
if not isinstance(get_proper_type(callee_type.arg_types[0]), CallableType):
return None
if not isinstance(get_proper_type(callee_type.ret_type), CallableType):
return None
with self.chk.local_type_map():
with self.msg.filter_errors():
arg_type = get_proper_type(self.accept(args[0], type_context=None))
if isinstance(arg_type, Overloaded):
return arg_type
return None

def handle_decorator_overload_call(
self, callee_type: CallableType, overloaded: Overloaded, ctx: Context
) -> tuple[Type, Type] | None:
"""Type-check application of a generic callable to an overload.

We check call on each individual overload item, and then combine results into a new
overload. This function should be only used if callee_type takes and returns a Callable.
"""
result = []
inferred_args = []
for item in overloaded.items:
arg = TempNode(typ=item)
with self.msg.filter_errors() as err:
item_result, inferred_arg = self.check_call(callee_type, [arg], [ARG_POS], ctx)
if err.has_new_errors():
# This overload doesn't match.
continue
p_item_result = get_proper_type(item_result)
if not isinstance(p_item_result, CallableType):
continue
p_inferred_arg = get_proper_type(inferred_arg)
if not isinstance(p_inferred_arg, CallableType):
continue
inferred_args.append(p_inferred_arg)
result.append(p_item_result)
if not result or not inferred_args:
# None of the overload matched (or overload was initially malformed).
return None
return Overloaded(result), Overloaded(inferred_args)

def check_call_expr_with_callee_type(
self,
callee_type: Type,
Expand Down Expand Up @@ -1451,6 +1501,17 @@ def check_call(
callee = get_proper_type(callee)

if isinstance(callee, CallableType):
if callee.variables:
overloaded = self.is_generic_decorator_overload_call(callee, args)
if overloaded is not None:
# Special casing for inline application of generic callables to overloads.
# Supporting general case would be tricky, but this should cover 95% of cases.
overloaded_result = self.handle_decorator_overload_call(
callee, overloaded, context
)
if overloaded_result is not None:
return overloaded_result

return self.check_callable_call(
callee,
args,
Expand Down
12 changes: 11 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,17 @@ def analyze_instance_member_access(
return analyze_var(name, first_item.var, typ, info, mx)
if mx.is_lvalue:
mx.msg.cant_assign_to_method(mx.context)
signature = function_type(method, mx.named_type("builtins.function"))
if not isinstance(method, OverloadedFuncDef):
signature = function_type(method, mx.named_type("builtins.function"))
else:
if method.type is None:
# Overloads may be not ready if they are decorated. Handle this in same
# manner as we would handle a regular decorated function: defer if possible.
if not mx.no_deferral and method.items:
mx.not_ready_callback(method.name, mx.context)
return AnyType(TypeOfAny.special_form)
assert isinstance(method.type, Overloaded)
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
if name != "__call__":
Expand Down
11 changes: 10 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,16 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
elif not non_overload_indexes:
self.handle_missing_overload_implementation(defn)

if types:
if types and not any(
# If some overload items are decorated with other decorators, then
# the overload type will be determined during type checking.
isinstance(it, Decorator) and len(it.decorators) > 1
for it in defn.items
):
# TODO: should we enforce decorated overloads consistency somehow?
# Some existing code uses both styles:
# * Put decorator only on implementation, use "effective" types in overloads
# * Put decorator everywhere, use "bare" types in overloads.
defn.type = Overloaded(types)
defn.type.line = defn.line

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3062,10 +3062,10 @@ def dec5(f: Callable[[int], T]) -> Callable[[int], List[T]]:
reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`3) -> builtins.list[T`3]"
reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`4) -> builtins.list[S`4]"
reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`6) -> S`6"
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`8) -> S`8"
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`9) -> S`9"
reveal_type(dec1(lambda x: 1)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
reveal_type(dec5(lambda x: x)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`15) -> builtins.list[S`15]"
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`16) -> builtins.list[S`16]"
dec4(lambda x: x) # E: Incompatible return value type (got "S", expected "List[object]")
[builtins fixtures/list.pyi]

Expand Down
3 changes: 1 addition & 2 deletions test-data/unit/check-newsemanal.test
Original file line number Diff line number Diff line change
Expand Up @@ -3207,8 +3207,7 @@ class User:
self.first_name = value

def __init__(self, name: str) -> None:
self.name = name # E: Cannot assign to a method \
# E: Incompatible types in assignment (expression has type "str", variable has type "Callable[..., Any]")
self.name = name # E: Cannot assign to a method

[case testNewAnalyzerMemberNameMatchesTypedDict]
from typing import Union, Any
Expand Down
27 changes: 27 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6613,3 +6613,30 @@ def struct(__cols: Union[List[S], Tuple[S, ...]]) -> int: ...
def struct(*cols: Union[S, Union[List[S], Tuple[S, ...]]]) -> int:
pass
[builtins fixtures/tuple.pyi]

[case testRegularGenericDecoratorOverload]
from typing import Callable, overload, TypeVar, List

S = TypeVar("S")
T = TypeVar("T")
def transform(func: Callable[[S], List[T]]) -> Callable[[S], T]: ...

@overload
def foo(x: int) -> List[float]: ...
@overload
def foo(x: str) -> List[str]: ...
def foo(x): ...

reveal_type(transform(foo)) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"

@transform
@overload
def bar(x: int) -> List[float]: ...
@transform
@overload
def bar(x: str) -> List[str]: ...
@transform
def bar(x): ...

reveal_type(bar) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"
[builtins fixtures/paramspec.pyi]
28 changes: 28 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1646,3 +1646,31 @@ def bar(b: B[P]) -> A[Concatenate[int, P]]:
# N: Got: \
# N: def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

[case testParamSpecDecoratorOverload]
from typing import Callable, overload, TypeVar, List
from typing_extensions import ParamSpec

P = ParamSpec("P")
T = TypeVar("T")
def transform(func: Callable[P, List[T]]) -> Callable[P, T]: ...

@overload
def foo(x: int) -> List[float]: ...
@overload
def foo(x: str) -> List[str]: ...
def foo(x): ...

reveal_type(transform(foo)) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"

@transform
@overload
def bar(x: int) -> List[float]: ...
@transform
@overload
def bar(x: str) -> List[str]: ...
@transform
def bar(x): ...

reveal_type(bar) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"
[builtins fixtures/paramspec.pyi]
2 changes: 1 addition & 1 deletion test-data/unit/lib-stub/functools.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, TypeVar, Callable, Any, Mapping
from typing import Generic, TypeVar, Callable, Any, Mapping, overload

_T = TypeVar("_T")

Expand Down