Skip to content

Commit

Permalink
Basic support for decorated overloads (#15898)
Browse files Browse the repository at this point in the history
Fixes #15737
Fixes #12844
Fixes #12716

My goal was to fix the `ParamSpec` issues, but it turns out decorated
overloads were not supported at all. Namely:
* Decorators on overload items were ignored, caller would see original
undecorated item types
* Overload item overlap checks were performed for original types, while
arguably we should use decorated types
* Overload items completeness w.r.t. to implementation was checked with
decorated implementation, and undecorated items

Here I add basic support using same logic as for regular decorated
functions: initially set type to `None` and defer callers until
definition is type-checked. Note this results in few more `Cannot
determine type` in case of other errors, but I think it is fine.

Note I also add special-casing for "inline" applications of generic
functions to overload arguments. This use case was mentioned few times
alongside overloads. The general fix would be tricky, and my
special-casing should cover typical use cases.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilevkivskyi and pre-commit-ci[bot] authored Aug 18, 2023
1 parent b3d0937 commit fa84534
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 45 deletions.
91 changes: 59 additions & 32 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,30 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
self.visit_decorator(defn.items[0])
for fdef in defn.items:
assert isinstance(fdef, Decorator)
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
if defn.is_property:
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
else:
# Perform full check for real overloads to infer type of all decorated
# overload variants.
self.visit_decorator_inner(fdef, allow_empty=True)
if fdef.func.abstract_status in (IS_ABSTRACT, IMPLICITLY_ABSTRACT):
num_abstract += 1
if num_abstract not in (0, len(defn.items)):
self.fail(message_registry.INCONSISTENT_ABSTRACT_OVERLOAD, defn)
if defn.impl:
defn.impl.accept(self)
if not defn.is_property:
self.check_overlapping_overloads(defn)
if defn.type is None:
item_types = []
for item in defn.items:
assert isinstance(item, Decorator)
item_type = self.extract_callable_type(item.var.type, item)
if item_type is not None:
item_types.append(item_type)
if item_types:
defn.type = Overloaded(item_types)
# Check override validity after we analyzed current definition.
if defn.info:
found_method_base_classes = self.check_method_override(defn)
if (
Expand All @@ -653,10 +670,35 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
self.msg.no_overridable_method(defn.name, defn)
self.check_explicit_override_decorator(defn, found_method_base_classes, defn.impl)
self.check_inplace_operator_method(defn)
if not defn.is_property:
self.check_overlapping_overloads(defn)
return None

def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
"""Get type as seen by an overload item caller."""
inner_type = get_proper_type(inner_type)
outer_type: CallableType | None = None
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, CallableType):
outer_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=ctx,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
)
if isinstance(inner_call, CallableType):
outer_type = inner_call
if outer_type is None:
self.msg.not_callable(inner_type, ctx)
return outer_type

def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# At this point we should have set the impl already, and all remaining
# items are decorators
Expand All @@ -680,40 +722,20 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:

# This can happen if we've got an overload with a different
# decorator or if the implementation is untyped -- we gave up on the types.
inner_type = get_proper_type(inner_type)
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, CallableType):
impl_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=defn.impl,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
)
if isinstance(inner_call, CallableType):
impl_type = inner_call
if impl_type is None:
self.msg.not_callable(inner_type, defn.impl)
impl_type = self.extract_callable_type(inner_type, defn.impl)

is_descriptor_get = defn.info and defn.name == "__get__"
for i, item in enumerate(defn.items):
# TODO overloads involving decorators
assert isinstance(item, Decorator)
sig1 = self.function_type(item.func)
assert isinstance(sig1, CallableType)
sig1 = self.extract_callable_type(item.var.type, item)
if sig1 is None:
continue

for j, item2 in enumerate(defn.items[i + 1 :]):
assert isinstance(item2, Decorator)
sig2 = self.function_type(item2.func)
assert isinstance(sig2, CallableType)
sig2 = self.extract_callable_type(item2.var.type, item2)
if sig2 is None:
continue

if not are_argument_counts_overlapping(sig1, sig2):
continue
Expand Down Expand Up @@ -4751,17 +4773,20 @@ def visit_decorator(self, e: Decorator) -> None:
e.var.type = AnyType(TypeOfAny.special_form)
e.var.is_ready = True
return
self.visit_decorator_inner(e)

def visit_decorator_inner(self, e: Decorator, allow_empty: bool = False) -> None:
if self.recurse_into_functions:
with self.tscope.function_scope(e.func):
self.check_func_item(e.func, name=e.func.name)
self.check_func_item(e.func, name=e.func.name, allow_empty=allow_empty)

# Process decorators from the inside out to determine decorated signature, which
# may be different from the declared signature.
sig: Type = self.function_type(e.func)
for d in reversed(e.decorators):
if refers_to_fullname(d, OVERLOAD_NAMES):
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
if not allow_empty:
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
continue
dec = self.expr_checker.accept(d)
temp = self.temp_node(sig, context=e)
Expand All @@ -4788,6 +4813,8 @@ def visit_decorator(self, e: Decorator) -> None:
self.msg.fail("Too many arguments for property", e)
self.check_incompatible_property_override(e)
# For overloaded functions we already checked override for overload as a whole.
if allow_empty:
return
if e.func.info and not e.func.is_dynamic() and not e.is_overload:
found_method_base_classes = self.check_method_override(e)
if (
Expand Down
73 changes: 67 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,13 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
elif isinstance(node, FuncDef):
# Reference to a global function.
result = function_type(node, self.named_type("builtins.function"))
elif isinstance(node, OverloadedFuncDef) and node.type is not None:
# node.type is None when there are multiple definitions of a function
# and it's decorated by something that is not typing.overload
# TODO: use a dummy Overloaded instead of AnyType in this case
# like we do in mypy.types.function_type()?
result = node.type
elif isinstance(node, OverloadedFuncDef):
if node.type is None:
if self.chk.in_checked_function() and node.items:
self.chk.handle_cannot_determine_type(node.name, e)
result = AnyType(TypeOfAny.from_error)
else:
result = node.type
elif isinstance(node, TypeInfo):
# Reference to a type object.
if node.typeddict_type:
Expand Down Expand Up @@ -1337,6 +1338,55 @@ def transform_callee_type(

return callee

def is_generic_decorator_overload_call(
self, callee_type: CallableType, args: list[Expression]
) -> Overloaded | None:
"""Check if this looks like an application of a generic function to overload argument."""
assert callee_type.variables
if len(callee_type.arg_types) != 1 or len(args) != 1:
# TODO: can we handle more general cases?
return None
if not isinstance(get_proper_type(callee_type.arg_types[0]), CallableType):
return None
if not isinstance(get_proper_type(callee_type.ret_type), CallableType):
return None
with self.chk.local_type_map():
with self.msg.filter_errors():
arg_type = get_proper_type(self.accept(args[0], type_context=None))
if isinstance(arg_type, Overloaded):
return arg_type
return None

def handle_decorator_overload_call(
self, callee_type: CallableType, overloaded: Overloaded, ctx: Context
) -> tuple[Type, Type] | None:
"""Type-check application of a generic callable to an overload.
We check call on each individual overload item, and then combine results into a new
overload. This function should be only used if callee_type takes and returns a Callable.
"""
result = []
inferred_args = []
for item in overloaded.items:
arg = TempNode(typ=item)
with self.msg.filter_errors() as err:
item_result, inferred_arg = self.check_call(callee_type, [arg], [ARG_POS], ctx)
if err.has_new_errors():
# This overload doesn't match.
continue
p_item_result = get_proper_type(item_result)
if not isinstance(p_item_result, CallableType):
continue
p_inferred_arg = get_proper_type(inferred_arg)
if not isinstance(p_inferred_arg, CallableType):
continue
inferred_args.append(p_inferred_arg)
result.append(p_item_result)
if not result or not inferred_args:
# None of the overload matched (or overload was initially malformed).
return None
return Overloaded(result), Overloaded(inferred_args)

def check_call_expr_with_callee_type(
self,
callee_type: Type,
Expand Down Expand Up @@ -1451,6 +1501,17 @@ def check_call(
callee = get_proper_type(callee)

if isinstance(callee, CallableType):
if callee.variables:
overloaded = self.is_generic_decorator_overload_call(callee, args)
if overloaded is not None:
# Special casing for inline application of generic callables to overloads.
# Supporting general case would be tricky, but this should cover 95% of cases.
overloaded_result = self.handle_decorator_overload_call(
callee, overloaded, context
)
if overloaded_result is not None:
return overloaded_result

return self.check_callable_call(
callee,
args,
Expand Down
12 changes: 11 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,17 @@ def analyze_instance_member_access(
return analyze_var(name, first_item.var, typ, info, mx)
if mx.is_lvalue:
mx.msg.cant_assign_to_method(mx.context)
signature = function_type(method, mx.named_type("builtins.function"))
if not isinstance(method, OverloadedFuncDef):
signature = function_type(method, mx.named_type("builtins.function"))
else:
if method.type is None:
# Overloads may be not ready if they are decorated. Handle this in same
# manner as we would handle a regular decorated function: defer if possible.
if not mx.no_deferral and method.items:
mx.not_ready_callback(method.name, mx.context)
return AnyType(TypeOfAny.special_form)
assert isinstance(method.type, Overloaded)
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
if name != "__call__":
Expand Down
11 changes: 10 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,16 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
elif not non_overload_indexes:
self.handle_missing_overload_implementation(defn)

if types:
if types and not any(
# If some overload items are decorated with other decorators, then
# the overload type will be determined during type checking.
isinstance(it, Decorator) and len(it.decorators) > 1
for it in defn.items
):
# TODO: should we enforce decorated overloads consistency somehow?
# Some existing code uses both styles:
# * Put decorator only on implementation, use "effective" types in overloads
# * Put decorator everywhere, use "bare" types in overloads.
defn.type = Overloaded(types)
defn.type.line = defn.line

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3062,10 +3062,10 @@ def dec5(f: Callable[[int], T]) -> Callable[[int], List[T]]:
reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`3) -> builtins.list[T`3]"
reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`4) -> builtins.list[S`4]"
reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`6) -> S`6"
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`8) -> S`8"
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`9) -> S`9"
reveal_type(dec1(lambda x: 1)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
reveal_type(dec5(lambda x: x)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`15) -> builtins.list[S`15]"
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`16) -> builtins.list[S`16]"
dec4(lambda x: x) # E: Incompatible return value type (got "S", expected "List[object]")
[builtins fixtures/list.pyi]

Expand Down
3 changes: 1 addition & 2 deletions test-data/unit/check-newsemanal.test
Original file line number Diff line number Diff line change
Expand Up @@ -3207,8 +3207,7 @@ class User:
self.first_name = value

def __init__(self, name: str) -> None:
self.name = name # E: Cannot assign to a method \
# E: Incompatible types in assignment (expression has type "str", variable has type "Callable[..., Any]")
self.name = name # E: Cannot assign to a method

[case testNewAnalyzerMemberNameMatchesTypedDict]
from typing import Union, Any
Expand Down
27 changes: 27 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6613,3 +6613,30 @@ def struct(__cols: Union[List[S], Tuple[S, ...]]) -> int: ...
def struct(*cols: Union[S, Union[List[S], Tuple[S, ...]]]) -> int:
pass
[builtins fixtures/tuple.pyi]

[case testRegularGenericDecoratorOverload]
from typing import Callable, overload, TypeVar, List

S = TypeVar("S")
T = TypeVar("T")
def transform(func: Callable[[S], List[T]]) -> Callable[[S], T]: ...

@overload
def foo(x: int) -> List[float]: ...
@overload
def foo(x: str) -> List[str]: ...
def foo(x): ...

reveal_type(transform(foo)) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"

@transform
@overload
def bar(x: int) -> List[float]: ...
@transform
@overload
def bar(x: str) -> List[str]: ...
@transform
def bar(x): ...

reveal_type(bar) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"
[builtins fixtures/paramspec.pyi]
28 changes: 28 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1646,3 +1646,31 @@ def bar(b: B[P]) -> A[Concatenate[int, P]]:
# N: Got: \
# N: def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

[case testParamSpecDecoratorOverload]
from typing import Callable, overload, TypeVar, List
from typing_extensions import ParamSpec

P = ParamSpec("P")
T = TypeVar("T")
def transform(func: Callable[P, List[T]]) -> Callable[P, T]: ...

@overload
def foo(x: int) -> List[float]: ...
@overload
def foo(x: str) -> List[str]: ...
def foo(x): ...

reveal_type(transform(foo)) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"

@transform
@overload
def bar(x: int) -> List[float]: ...
@transform
@overload
def bar(x: str) -> List[str]: ...
@transform
def bar(x): ...

reveal_type(bar) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"
[builtins fixtures/paramspec.pyi]
2 changes: 1 addition & 1 deletion test-data/unit/lib-stub/functools.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, TypeVar, Callable, Any, Mapping
from typing import Generic, TypeVar, Callable, Any, Mapping, overload

_T = TypeVar("_T")

Expand Down

0 comments on commit fa84534

Please sign in to comment.