Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up make_simplified_union, fix recursive tuple crash #15128

Merged
merged 8 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def visit_instance(self, left: Instance) -> bool:
# dynamic base classes correctly, see #5456.
return not isinstance(self.right, NoneType)
right = self.right
if isinstance(right, TupleType) and mypy.typeops.tuple_fallback(right).type.is_enum:
if isinstance(right, TupleType) and right.partial_fallback.type.is_enum:
return self._is_subtype(left, mypy.typeops.tuple_fallback(right))
if isinstance(right, Instance):
if type_state.is_cached_subtype_check(self._subtype_kind, left, right):
Expand Down Expand Up @@ -753,7 +753,9 @@ def visit_tuple_type(self, left: TupleType) -> bool:
# for isinstance(x, tuple), though it's unclear why.
return True
return all(self._is_subtype(li, iter_type) for li in left.items)
elif self._is_subtype(mypy.typeops.tuple_fallback(left), right):
elif self._is_subtype(left.partial_fallback, right) and self._is_subtype(
mypy.typeops.tuple_fallback(left), right
):
return True
return False
elif isinstance(right, TupleType):
Expand Down
5 changes: 1 addition & 4 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,7 @@ def test_simplified_union_with_mixed_str_literals(self) -> None:
[fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]),
)
self.assert_simplified_union(
[fx.lit_str1, fx.lit_str1, fx.lit_str1_inst],
UnionType([fx.lit_str1, fx.lit_str1_inst]),
)
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1)

def assert_simplified_union(self, original: list[Type], union: Type) -> None:
assert_equal(make_simplified_union(original), union)
Expand Down
159 changes: 70 additions & 89 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,25 +385,6 @@ def callable_corresponding_argument(
return by_name if by_name is not None else by_pos


def simple_literal_value_key(t: ProperType) -> tuple[str, ...] | None:
"""Return a hashable description of simple literal type.

Return None if not a simple literal type.

The return value can be used to simplify away duplicate types in
unions by comparing keys for equality. For now enum, string or
Instance with string last_known_value are supported.
"""
if isinstance(t, LiteralType):
if t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str":
assert isinstance(t.value, str)
return "literal", t.value, t.fallback.type.fullname
if isinstance(t, Instance):
if t.last_known_value is not None and isinstance(t.last_known_value.value, str):
return "instance", t.last_known_value.value, t.type.fullname
return None


def simple_literal_type(t: ProperType | None) -> Instance | None:
"""Extract the underlying fallback Instance type for a simple Literal"""
if isinstance(t, Instance) and t.last_known_value is not None:
Expand All @@ -414,7 +395,6 @@ def simple_literal_type(t: ProperType | None) -> Instance | None:


def is_simple_literal(t: ProperType) -> bool:
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
if isinstance(t, LiteralType):
return t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str"
if isinstance(t, Instance):
Expand Down Expand Up @@ -500,68 +480,80 @@ def make_simplified_union(
def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
from mypy.subtypes import is_proper_subtype

removed: set[int] = set()
seen: set[tuple[str, ...]] = set()

# NB: having a separate fast path for Union of Literal and slow path for other things
# would arguably be cleaner, however it breaks down when simplifying the Union of two
# different enum types as try_expanding_sum_type_to_union works recursively and will
# trigger intermediate simplifications that would render the fast path useless
for i, item in enumerate(items):
proper_item = get_proper_type(item)
if i in removed:
continue
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
k = simple_literal_value_key(proper_item)
if k is not None:
if k in seen:
removed.add(i)
# The first pass through this loop, we check if later items are subtypes of earlier items.
# The second pass through this loop, we check if earlier items are subtypes of later items
# (by reversing the remaining items)
for _direction in range(2):
new_items: list[Type] = []
# seen is a map from a type to its index in new_items
seen: dict[ProperType, int] = {}
unduplicated_literal_fallbacks: set[Instance] | None = None
for ti in items:
proper_ti = get_proper_type(ti)

# UninhabitedType is always redundant
if isinstance(proper_ti, UninhabitedType):
continue

# NB: one would naively expect that it would be safe to skip the slow path
# always for literals. One would be sorely mistaken. Indeed, some simplifications
# such as that of None/Optional when strict optional is false, do require that we
# proceed with the slow path. Thankfully, all literals will have the same subtype
# relationship to non-literal types, so we only need to do that walk for the first
# literal, which keeps the fast path fast even in the presence of a mixture of
# literals and other types.
safe_skip = len(seen) > 0
seen.add(k)
if safe_skip:
continue

# Keep track of the truthiness info for deleted subtypes which can be relevant
cbt = cbf = False
for j, tj in enumerate(items):
proper_tj = get_proper_type(tj)
if (
i == j
# avoid further checks if this item was already marked redundant.
or j in removed
# if the current item is a simple literal then this simplification loop can
# safely skip all other simple literals as two literals will only ever be
# subtypes of each other if they are equal, which is already handled above.
# However, if the current item is not a literal, it might plausibly be a
# supertype of other literals in the union, so we must check them again.
# This is an important optimization as is_proper_subtype is pretty expensive.
or (k is not None and is_simple_literal(proper_tj))
):
continue
# actual redundancy checks (XXX?)
if is_redundant_literal_instance(proper_item, proper_tj) and is_proper_subtype(
tj, item, keep_erased_types=keep_erased, ignore_promotions=True
duplicate_index = -1
# Quickly check if we've seen this type
if proper_ti in seen:
duplicate_index = seen[proper_ti]
elif (
isinstance(proper_ti, LiteralType)
and unduplicated_literal_fallbacks is not None
and proper_ti.fallback in unduplicated_literal_fallbacks
):
# We found a redundant item in the union.
removed.add(j)
cbt = cbt or tj.can_be_true
cbf = cbf or tj.can_be_false
# if deleted subtypes had more general truthiness, use that
if not item.can_be_true and cbt:
items[i] = true_or_false(item)
elif not item.can_be_false and cbf:
items[i] = true_or_false(item)
# This is an optimisation for unions with many LiteralType
# We've already checked for exact duplicates. This means that any super type of
# the LiteralType must be a super type of its fallback. If we've gone through
# the expensive loop below and found no super type for a previous LiteralType
# with the same fallback, we can skip doing that work again and just add the type
# to new_items
pass
else:
# If not, check if we've seen a supertype of this type
for j, tj in enumerate(new_items):
tj = get_proper_type(tj)
# If tj is an Instance with a last_known_value, do not remove proper_ti
# (unless it's an instance with the same last_known_value)
if (
isinstance(tj, Instance)
and tj.last_known_value is not None
and not (
isinstance(proper_ti, Instance)
and tj.last_known_value == proper_ti.last_known_value
)
):
continue

if is_proper_subtype(
proper_ti, tj, keep_erased_types=keep_erased, ignore_promotions=True
):
duplicate_index = j
break
if duplicate_index != -1:
# If deleted subtypes had more general truthiness, use that
orig_item = new_items[duplicate_index]
if not orig_item.can_be_true and ti.can_be_true:
new_items[duplicate_index] = true_or_false(orig_item)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we only change can_be_true, or is it ok to only adjust can_be_false?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question. I'm going to preserve the existing behaviour for now and explore this in another PR. These code paths don't have many unit tests (see #15094 and #15098), so I want to be careful here

elif not orig_item.can_be_false and ti.can_be_false:
new_items[duplicate_index] = true_or_false(orig_item)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.

else:
# We have a non-duplicate item, add it to new_items
seen[proper_ti] = len(new_items)
new_items.append(ti)
if isinstance(proper_ti, LiteralType):
if unduplicated_literal_fallbacks is None:
unduplicated_literal_fallbacks = set()
unduplicated_literal_fallbacks.add(proper_ti.fallback)

return [items[i] for i in range(len(items)) if i not in removed]
items = new_items
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved
if len(items) <= 1:
break
items.reverse()
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved

return items


def _get_type_special_method_bool_ret_type(t: Type) -> Type | None:
Expand Down Expand Up @@ -992,17 +984,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
return False


def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool:
if not isinstance(general, Instance) or general.last_known_value is None:
return True
if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value:
return True
if isinstance(specific, UninhabitedType):
return True

return False


def separate_union_literals(t: UnionType) -> tuple[Sequence[LiteralType], Sequence[Type]]:
"""Separate literals from other members in a union type."""
literal_items = []
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-type-aliases.test
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,19 @@ class C(Generic[T]):
def test(cls) -> None:
cls.attr
[builtins fixtures/classmethod.pyi]

[case testRecursiveAliasTuple]
from typing_extensions import Literal, TypeAlias
from typing import Tuple, Union

Expr: TypeAlias = Union[
Tuple[Literal[123], int],
Tuple[Literal[456], "Expr"],
]

def eval(e: Expr) -> int:
if e[0] == 123:
return e[1]
elif e[0] == 456:
return -eval(e[1])
[builtins fixtures/dict.pyi]