Skip to content

Commit

Permalink
Allow None vs TypeVar overlap for overloads (python#15846)
Browse files Browse the repository at this point in the history
Fixes python#8881 

This is technically unsafe, and I remember we explicitly discussed this
a while ago, but related use cases turn out to be more common than I
expected (judging by how popular the issue is). Also the fix is really
simple.

---------

Co-authored-by: Ivan Levkivskyi <ilevkivskyi@hopper.com>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
  • Loading branch information
3 people authored Aug 14, 2023
1 parent a1fcad5 commit 854a9f8
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 29 deletions.
24 changes: 20 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7216,22 +7216,32 @@ def is_unsafe_overlapping_overload_signatures(
#
# This discrepancy is unfortunately difficult to get rid of, so we repeat the
# checks twice in both directions for now.
#
# Note that we ignore possible overlap between type variables and None. This
# is technically unsafe, but unsafety is tiny and this prevents some common
# use cases like:
# @overload
# def foo(x: None) -> None: ..
# @overload
# def foo(x: T) -> Foo[T]: ...
return is_callable_compatible(
signature,
other,
is_compat=is_overlapping_types_no_promote_no_uninhabited,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_compat_return=lambda l, r: not is_subtype_no_promote(l, r),
ignore_return=False,
check_args_covariantly=True,
allow_partial_overlap=True,
no_unify_none=True,
) or is_callable_compatible(
other,
signature,
is_compat=is_overlapping_types_no_promote_no_uninhabited,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_compat_return=lambda l, r: not is_subtype_no_promote(r, l),
ignore_return=False,
check_args_covariantly=False,
allow_partial_overlap=True,
no_unify_none=True,
)


Expand Down Expand Up @@ -7717,12 +7727,18 @@ def is_subtype_no_promote(left: Type, right: Type) -> bool:
return is_subtype(left, right, ignore_promotions=True)


def is_overlapping_types_no_promote_no_uninhabited(left: Type, right: Type) -> bool:
def is_overlapping_types_no_promote_no_uninhabited_no_none(left: Type, right: Type) -> bool:
# For the purpose of unsafe overload checks we consider list[<nothing>] and list[int]
# non-overlapping. This is consistent with how we treat list[int] and list[str] as
# non-overlapping, despite [] belongs to both. Also this will prevent false positives
# for failed type inference during unification.
return is_overlapping_types(left, right, ignore_promotions=True, ignore_uninhabited=True)
return is_overlapping_types(
left,
right,
ignore_promotions=True,
ignore_uninhabited=True,
prohibit_none_typevar_overlap=True,
)


def is_private(node_name: str) -> bool:
Expand Down
86 changes: 69 additions & 17 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,11 @@ def check_overload_call(
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
erased_targets: list[CallableType] | None = None
unioned_result: tuple[Type, Type] | None = None

# Determine whether we need to encourage union math. This should be generally safe,
# as union math infers better results in the vast majority of cases, but it is very
# computationally intensive.
none_type_var_overlap = self.possible_none_type_var_overlap(arg_types, plausible_targets)
union_interrupted = False # did we try all union combinations?
if any(self.real_union(arg) for arg in arg_types):
try:
Expand All @@ -2421,6 +2426,7 @@ def check_overload_call(
arg_names,
callable_name,
object_type,
none_type_var_overlap,
context,
)
except TooManyUnions:
Expand Down Expand Up @@ -2453,8 +2459,10 @@ def check_overload_call(
# If any of checks succeed, stop early.
if inferred_result is not None and unioned_result is not None:
# Both unioned and direct checks succeeded, choose the more precise type.
if is_subtype(inferred_result[0], unioned_result[0]) and not isinstance(
get_proper_type(inferred_result[0]), AnyType
if (
is_subtype(inferred_result[0], unioned_result[0])
and not isinstance(get_proper_type(inferred_result[0]), AnyType)
and not none_type_var_overlap
):
return inferred_result
return unioned_result
Expand Down Expand Up @@ -2504,7 +2512,8 @@ def check_overload_call(
callable_name=callable_name,
object_type=object_type,
)
if union_interrupted:
# Do not show the extra error if the union math was forced.
if union_interrupted and not none_type_var_overlap:
self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context)
return result

Expand Down Expand Up @@ -2659,6 +2668,44 @@ def overload_erased_call_targets(
matches.append(typ)
return matches

def possible_none_type_var_overlap(
self, arg_types: list[Type], plausible_targets: list[CallableType]
) -> bool:
"""Heuristic to determine whether we need to try forcing union math.
This is needed to avoid greedy type variable match in situations like this:
@overload
def foo(x: None) -> None: ...
@overload
def foo(x: T) -> list[T]: ...
x: int | None
foo(x)
we want this call to infer list[int] | None, not list[int | None].
"""
if not plausible_targets or not arg_types:
return False
has_optional_arg = False
for arg_type in get_proper_types(arg_types):
if not isinstance(arg_type, UnionType):
continue
for item in get_proper_types(arg_type.items):
if isinstance(item, NoneType):
has_optional_arg = True
break
if not has_optional_arg:
return False

min_prefix = min(len(c.arg_types) for c in plausible_targets)
for i in range(min_prefix):
if any(
isinstance(get_proper_type(c.arg_types[i]), NoneType) for c in plausible_targets
) and any(
isinstance(get_proper_type(c.arg_types[i]), TypeVarType) for c in plausible_targets
):
return True
return False

def union_overload_result(
self,
plausible_targets: list[CallableType],
Expand All @@ -2668,6 +2715,7 @@ def union_overload_result(
arg_names: Sequence[str | None] | None,
callable_name: str | None,
object_type: Type | None,
none_type_var_overlap: bool,
context: Context,
level: int = 0,
) -> list[tuple[Type, Type]] | None:
Expand Down Expand Up @@ -2707,20 +2755,23 @@ def union_overload_result(

# Step 3: Try a direct match before splitting to avoid unnecessary union splits
# and save performance.
with self.type_overrides_set(args, arg_types):
direct = self.infer_overload_return_type(
plausible_targets,
args,
arg_types,
arg_kinds,
arg_names,
callable_name,
object_type,
context,
)
if direct is not None and not isinstance(get_proper_type(direct[0]), (UnionType, AnyType)):
# We only return non-unions soon, to avoid greedy match.
return [direct]
if not none_type_var_overlap:
with self.type_overrides_set(args, arg_types):
direct = self.infer_overload_return_type(
plausible_targets,
args,
arg_types,
arg_kinds,
arg_names,
callable_name,
object_type,
context,
)
if direct is not None and not isinstance(
get_proper_type(direct[0]), (UnionType, AnyType)
):
# We only return non-unions soon, to avoid greedy match.
return [direct]

# Step 4: Split the first remaining union type in arguments into items and
# try to match each item individually (recursive).
Expand All @@ -2738,6 +2789,7 @@ def union_overload_result(
arg_names,
callable_name,
object_type,
none_type_var_overlap,
context,
level + 1,
)
Expand Down
15 changes: 13 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,7 @@ def is_callable_compatible(
check_args_covariantly: bool = False,
allow_partial_overlap: bool = False,
strict_concatenate: bool = False,
no_unify_none: bool = False,
) -> bool:
"""Is the left compatible with the right, using the provided compatibility check?
Expand Down Expand Up @@ -1415,7 +1416,9 @@ def g(x: int) -> int: ...
# (below) treats type variables on the two sides as independent.
if left.variables:
# Apply generic type variables away in left via type inference.
unified = unify_generic_callable(left, right, ignore_return=ignore_return)
unified = unify_generic_callable(
left, right, ignore_return=ignore_return, no_unify_none=no_unify_none
)
if unified is None:
return False
left = unified
Expand All @@ -1427,7 +1430,9 @@ def g(x: int) -> int: ...
# So, we repeat the above checks in the opposite direction. This also
# lets us preserve the 'symmetry' property of allow_partial_overlap.
if allow_partial_overlap and right.variables:
unified = unify_generic_callable(right, left, ignore_return=ignore_return)
unified = unify_generic_callable(
right, left, ignore_return=ignore_return, no_unify_none=no_unify_none
)
if unified is not None:
right = unified

Expand Down Expand Up @@ -1687,6 +1692,8 @@ def unify_generic_callable(
target: NormalizedCallableType,
ignore_return: bool,
return_constraint_direction: int | None = None,
*,
no_unify_none: bool = False,
) -> NormalizedCallableType | None:
"""Try to unify a generic callable type with another callable type.
Expand All @@ -1708,6 +1715,10 @@ def unify_generic_callable(
type.ret_type, target.ret_type, return_constraint_direction
)
constraints.extend(c)
if no_unify_none:
constraints = [
c for c in constraints if not isinstance(get_proper_type(c.target), NoneType)
]
inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints)
if None in inferred_vars:
return None
Expand Down
39 changes: 33 additions & 6 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -2185,36 +2185,63 @@ def bar2(*x: int) -> int: ...
[builtins fixtures/tuple.pyi]

[case testOverloadDetectsPossibleMatchesWithGenerics]
from typing import overload, TypeVar, Generic
# flags: --strict-optional
from typing import overload, TypeVar, Generic, Optional, List

T = TypeVar('T')
# The examples below are unsafe, but it is a quite common pattern
# so we ignore the possibility of type variables taking value `None`
# for the purpose of overload overlap checks.

@overload
def foo(x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def foo(x: None, y: None) -> str: ...
@overload
def foo(x: T, y: T) -> int: ...
def foo(x): ...

oi: Optional[int]
reveal_type(foo(None, None)) # N: Revealed type is "builtins.str"
reveal_type(foo(None, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(42, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(oi, None)) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(foo(oi, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(oi, oi)) # N: Revealed type is "Union[builtins.int, builtins.str]"

@overload
def foo_list(x: None) -> None: ...
@overload
def foo_list(x: T) -> List[T]: ...
def foo_list(x): ...

reveal_type(foo_list(oi)) # N: Revealed type is "Union[builtins.list[builtins.int], None]"

# What if 'T' is 'object'?
@overload
def bar(x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def bar(x: None, y: int) -> str: ...
@overload
def bar(x: T, y: T) -> int: ...
def bar(x, y): ...

class Wrapper(Generic[T]):
@overload
def foo(self, x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def foo(self, x: None, y: None) -> str: ...
@overload
def foo(self, x: T, y: None) -> int: ...
def foo(self, x): ...

@overload
def bar(self, x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def bar(self, x: None, y: int) -> str: ...
@overload
def bar(self, x: T, y: T) -> int: ...
def bar(self, x, y): ...

@overload
def baz(x: str, y: str) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
@overload
def baz(x: T, y: T) -> int: ...
def baz(x): ...
[builtins fixtures/tuple.pyi]

[case testOverloadFlagsPossibleMatches]
from wrapper import *
[file wrapper.pyi]
Expand Down Expand Up @@ -3996,7 +4023,7 @@ T = TypeVar('T')

class FakeAttribute(Generic[T]):
@overload
def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ...
@overload
def dummy(self, instance: T, owner: Type[T]) -> int: ...
def dummy(self, instance: Optional[T], owner: Type[T]) -> Union['FakeAttribute[T]', int]: ...
Expand Down

0 comments on commit 854a9f8

Please sign in to comment.