diff --git a/mypy/checker.py b/mypy/checker.py index 6da537fad5cba..b8a1e98130712 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -26,7 +26,7 @@ from typing_extensions import TypeAlias as _TypeAlias import mypy.checkexpr -from mypy import errorcodes as codes, message_registry, nodes, operators +from mypy import errorcodes as codes, join, message_registry, nodes, operators from mypy.binder import ConditionalTypeBinder, Frame, get_declaration from mypy.checkmember import ( MemberContext, @@ -699,6 +699,21 @@ def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> Callab ) if isinstance(inner_call, CallableType): outer_type = inner_call + elif isinstance(inner_type, UnionType): + union_type = make_simplified_union(inner_type.items) + if isinstance(union_type, UnionType): + items = [] + for item in union_type.items: + callable_item = self.extract_callable_type(item, ctx) + if callable_item is None: + break + items.append(callable_item) + else: + joined_type = get_proper_type(join.join_type_list(items)) + if isinstance(joined_type, CallableType): + outer_type = joined_type + else: + return self.extract_callable_type(union_type, ctx) if outer_type is None: self.msg.not_callable(inner_type, ctx) return outer_type diff --git a/mypy/join.py b/mypy/join.py index 7e0ff301ebf8a..782b4fbebd7b8 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import overload +from typing import Sequence, overload import mypy.typeops from mypy.maptype import map_instance_to_supertype @@ -853,7 +853,7 @@ def object_or_any_from_type(typ: ProperType) -> ProperType: return AnyType(TypeOfAny.implementation_artifact) -def join_type_list(types: list[Type]) -> Type: +def join_type_list(types: Sequence[Type]) -> Type: if not types: # This is a little arbitrary but reasonable. Any empty tuple should be compatible # with all variable length tuples, and this makes it possible. diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index 5af5dfc8e4698..30ab36abef01d 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -324,3 +324,25 @@ p(bar, 1, "a", 3.0) # OK p(bar, 1, "a", 3.0, kwarg="asdf") # OK p(bar, 1, "a", "b") # E: Argument 1 to "foo" has incompatible type "Callable[[int, str, float], None]"; expected "Callable[[int, str, str], None]" [builtins fixtures/dict.pyi] + +[case testFunctoolsPartialUnion] +import functools +from typing import Any, Callable, Union + +cls1: Any +cls2: Union[Any, Any] +reveal_type(functools.partial(cls1, 2)()) # N: Revealed type is "Any" +reveal_type(functools.partial(cls2, 2)()) # N: Revealed type is "Any" + +fn1: Union[Callable[[int], int], Callable[[int], int]] +reveal_type(functools.partial(fn1, 2)()) # N: Revealed type is "builtins.int" + +fn2: Union[Callable[[int], int], Callable[[int], str]] +reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "builtins.object" + +fn3: Union[Callable[[int], int], str] +reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \ + # E: "Union[Callable[[int], int], str]" not callable \ + # N: Revealed type is "builtins.int" \ + # E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]" +[builtins fixtures/tuple.pyi]