From 661adb756800ecc40fabbe62e9339efd253aff4e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 15 Nov 2023 10:20:25 +0000 Subject: [PATCH] Fix crash on strict-equality with recursive types (#16483) Fixes https://github.com/python/mypy/issues/16473 Potentially we can turn this helper function into a proper visitor, but I don't think it is worth it as of right now. --------- Co-authored-by: Alex Waygood --- mypy/checkexpr.py | 21 +++++++++++++----- mypy/meet.py | 12 +++++++++- test-data/unit/check-expressions.test | 32 +++++++++++++++++++++++++++ test-data/unit/fixtures/list.pyi | 1 + 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index c87d1f6cd31c..da61833bbe5b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3617,8 +3617,9 @@ def dangerous_comparison( self, left: Type, right: Type, - original_container: Type | None = None, *, + original_container: Type | None = None, + seen_types: set[tuple[Type, Type]] | None = None, prefer_literal: bool = True, ) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. @@ -3639,6 +3640,12 @@ def dangerous_comparison( if not self.chk.options.strict_equality: return False + if seen_types is None: + seen_types = set() + if (left, right) in seen_types: + return False + seen_types.add((left, right)) + left, right = get_proper_types((left, right)) # We suppress the error if there is a custom __eq__() method on either @@ -3694,17 +3701,21 @@ def dangerous_comparison( abstract_set = self.chk.lookup_typeinfo("typing.AbstractSet") left = map_instance_to_supertype(left, abstract_set) right = map_instance_to_supertype(right, abstract_set) - return self.dangerous_comparison(left.args[0], right.args[0]) + return self.dangerous_comparison( + left.args[0], right.args[0], seen_types=seen_types + ) elif left.type.has_base("typing.Mapping") and right.type.has_base("typing.Mapping"): # Similar to above: Mapping ignores the classes, it just compares items. abstract_map = self.chk.lookup_typeinfo("typing.Mapping") left = map_instance_to_supertype(left, abstract_map) right = map_instance_to_supertype(right, abstract_map) return self.dangerous_comparison( - left.args[0], right.args[0] - ) or self.dangerous_comparison(left.args[1], right.args[1]) + left.args[0], right.args[0], seen_types=seen_types + ) or self.dangerous_comparison(left.args[1], right.args[1], seen_types=seen_types) elif left_name in ("builtins.list", "builtins.tuple") and right_name == left_name: - return self.dangerous_comparison(left.args[0], right.args[0]) + return self.dangerous_comparison( + left.args[0], right.args[0], seen_types=seen_types + ) elif left_name in OVERLAPPING_BYTES_ALLOWLIST and right_name in ( OVERLAPPING_BYTES_ALLOWLIST ): diff --git a/mypy/meet.py b/mypy/meet.py index 610185d6bbbf..df8b960cdf3f 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -262,6 +262,7 @@ def is_overlapping_types( ignore_promotions: bool = False, prohibit_none_typevar_overlap: bool = False, ignore_uninhabited: bool = False, + seen_types: set[tuple[Type, Type]] | None = None, ) -> bool: """Can a value of type 'left' also be of type 'right' or vice-versa? @@ -275,18 +276,27 @@ def is_overlapping_types( # A type guard forces the new type even if it doesn't overlap the old. return True + if seen_types is None: + seen_types = set() + if (left, right) in seen_types: + return True + if isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType): + seen_types.add((left, right)) + left, right = get_proper_types((left, right)) def _is_overlapping_types(left: Type, right: Type) -> bool: """Encode the kind of overlapping check to perform. - This function mostly exists so we don't have to repeat keyword arguments everywhere.""" + This function mostly exists, so we don't have to repeat keyword arguments everywhere. + """ return is_overlapping_types( left, right, ignore_promotions=ignore_promotions, prohibit_none_typevar_overlap=prohibit_none_typevar_overlap, ignore_uninhabited=ignore_uninhabited, + seen_types=seen_types.copy(), ) # We should never encounter this type. diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 4ac5512580d2..8fe68365e5ac 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2378,6 +2378,38 @@ assert a == b [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] +[case testStrictEqualityWithRecursiveMapTypes] +# flags: --strict-equality +from typing import Dict + +R = Dict[str, R] + +a: R +b: R +assert a == b + +R2 = Dict[int, R2] +c: R2 +assert a == c # E: Non-overlapping equality check (left operand type: "Dict[str, R]", right operand type: "Dict[int, R2]") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityWithRecursiveListTypes] +# flags: --strict-equality +from typing import List, Union + +R = List[Union[str, R]] + +a: R +b: R +assert a == b + +R2 = List[Union[int, R2]] +c: R2 +assert a == c +[builtins fixtures/list.pyi] +[typing fixtures/typing-full.pyi] + [case testUnimportedHintAny] def f(x: Any) -> None: # E: Name "Any" is not defined \ # N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any") diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index 90fbabe8bc92..3dcdf18b2faa 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -6,6 +6,7 @@ T = TypeVar('T') class object: def __init__(self) -> None: pass + def __eq__(self, other: object) -> bool: pass class type: pass class ellipsis: pass