diff --git a/mypy/constraints.py b/mypy/constraints.py index 0524e38f9643..b61d882da3c4 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -328,6 +328,18 @@ def _infer_constraints( if isinstance(template, TypeVarType): return [Constraint(template, direction, actual)] + if ( + isinstance(actual, TypeVarType) + and not actual.id.is_meta_var() + and direction == SUPERTYPE_OF + ): + # Unless template is also a type variable (or a union that contains one), using the upper + # bound for inference will usually give better result for actual that is a type variable. + if not isinstance(template, UnionType) or not any( + isinstance(t, TypeVarType) for t in template.items + ): + actual = get_proper_type(actual.upper_bound) + # Now handle the case of either template or actual being a Union. # For a Union to be a subtype of another type, every item of the Union # must be a subtype of that type, so concatenate the constraints. diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index caa44cb40ad4..870417ca87c0 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3687,6 +3687,36 @@ reveal_type(f(g)) # N: Revealed type is "Tuple[Never, Never]" \ # E: Argument 1 to "f" has incompatible type "Callable[[VarArg(str)], None]"; expected "Call[Never]" [builtins fixtures/list.pyi] +[case testInferenceAgainstTypeVarActualBound] +from typing import Callable, TypeVar + +T = TypeVar("T") +S = TypeVar("S") +def test(f: Callable[[T], S]) -> Callable[[T], S]: ... + +F = TypeVar("F", bound=Callable[..., object]) +def dec(f: F) -> F: + reveal_type(test(f)) # N: Revealed type is "def (Any) -> builtins.object" + return f + +[case testInferenceAgainstTypeVarActualUnionBound] +from typing import Protocol, TypeVar, Union + +T_co = TypeVar("T_co", covariant=True) +class SupportsFoo(Protocol[T_co]): + def foo(self) -> T_co: ... + +class A: + def foo(self) -> A: ... +class B: + def foo(self) -> B: ... + +def foo(f: SupportsFoo[T_co]) -> T_co: ... + +ABT = TypeVar("ABT", bound=Union[A, B]) +def simpler(k: ABT): + foo(k) + [case testInferenceWorksWithEmptyCollectionsNested] from typing import List, TypeVar, NoReturn T = TypeVar('T')