diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 7fe8a940e4c2..babfe0770f35 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -797,6 +797,30 @@ def __init__(self, items: list[RType]) -> None: self.items_set = frozenset(items) self._ctype = "PyObject *" + @staticmethod + def make_simplified_union(items: list[RType]) -> RType: + """Return a normalized union that covers the given items. + + Flatten nested unions and remove duplicate items. + + Overlapping items are *not* simplified. For example, + [object, str] will not be simplified. + """ + items = flatten_nested_unions(items) + assert items + + # Remove duplicate items using set + list to preserve item order + seen = set() + new_items = [] + for item in items: + if item not in seen: + new_items.append(item) + seen.add(item) + if len(new_items) > 1: + return RUnion(new_items) + else: + return new_items[0] + def accept(self, visitor: RTypeVisitor[T]) -> T: return visitor.visit_runion(self) @@ -823,6 +847,19 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RUnion: return RUnion(types) +def flatten_nested_unions(types: list[RType]) -> list[RType]: + if not any(isinstance(t, RUnion) for t in types): + return types # Fast path + + flat_items: list[RType] = [] + for t in types: + if isinstance(t, RUnion): + flat_items.extend(flatten_nested_unions(t.items)) + else: + flat_items.append(t) + return flat_items + + def optional_value_type(rtype: RType) -> RType | None: """If rtype is the union of none_rprimitive and another type X, return X. diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 6310c25c64fb..792697970785 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -53,6 +53,7 @@ Type, TypeOfAny, UninhabitedType, + UnionType, get_proper_type, ) from mypy.util import split_target @@ -85,6 +86,7 @@ RInstance, RTuple, RType, + RUnion, bitmap_rprimitive, c_int_rprimitive, c_pyssize_t_rprimitive, @@ -864,8 +866,15 @@ def extract_int(self, e: Expression) -> int | None: return None def get_sequence_type(self, expr: Expression) -> RType: - target_type = get_proper_type(self.types[expr]) - assert isinstance(target_type, Instance) + return self.get_sequence_type_from_type(self.types[expr]) + + def get_sequence_type_from_type(self, target_type: Type) -> RType: + target_type = get_proper_type(target_type) + if isinstance(target_type, UnionType): + return RUnion.make_simplified_union( + [self.get_sequence_type_from_type(item) for item in target_type.items] + ) + assert isinstance(target_type, Instance), target_type if target_type.type.fullname == "builtins.str": return str_rprimitive else: diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index a108766644ce..dddb35230fd5 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -116,18 +116,7 @@ def type_to_rtype(self, typ: Type | None) -> RType: elif isinstance(typ, NoneTyp): return none_rprimitive elif isinstance(typ, UnionType): - # Remove redundant items using set + list to preserve item order - seen = set() - items = [] - for item in typ.items: - rtype = self.type_to_rtype(item) - if rtype not in seen: - items.append(rtype) - seen.add(rtype) - if len(items) > 1: - return RUnion(items) - else: - return items[0] + return RUnion.make_simplified_union([self.type_to_rtype(item) for item in typ.items]) elif isinstance(typ, AnyType): return object_rprimitive elif isinstance(typ, TypeType): diff --git a/mypyc/test-data/irbuild-lists.test b/mypyc/test-data/irbuild-lists.test index b82217465fef..cb9687a2f942 100644 --- a/mypyc/test-data/irbuild-lists.test +++ b/mypyc/test-data/irbuild-lists.test @@ -430,14 +430,20 @@ L5: return 1 [case testSimplifyListUnion] -from typing import List, Union +from typing import List, Union, Optional -def f(a: Union[List[str], List[bytes], int]) -> int: +def narrow(a: Union[List[str], List[bytes], int]) -> int: if isinstance(a, list): return len(a) return a +def loop(a: Union[List[str], List[bytes]]) -> None: + for x in a: + pass +def nested_union(a: Union[List[str], List[Optional[str]]]) -> None: + for x in a: + pass [out] -def f(a): +def narrow(a): a :: union[list, int] r0 :: object r1 :: int32 @@ -465,3 +471,61 @@ L1: L2: r8 = unbox(int, a) return r8 +def loop(a): + a :: list + r0 :: short_int + r1 :: ptr + r2 :: native_int + r3 :: short_int + r4 :: bit + r5 :: object + r6, x :: union[str, bytes] + r7 :: short_int +L0: + r0 = 0 +L1: + r1 = get_element_ptr a ob_size :: PyVarObject + r2 = load_mem r1 :: native_int* + keep_alive a + r3 = r2 << 1 + r4 = r0 < r3 :: signed + if r4 goto L2 else goto L4 :: bool +L2: + r5 = CPyList_GetItemUnsafe(a, r0) + r6 = cast(union[str, bytes], r5) + x = r6 +L3: + r7 = r0 + 2 + r0 = r7 + goto L1 +L4: + return 1 +def nested_union(a): + a :: list + r0 :: short_int + r1 :: ptr + r2 :: native_int + r3 :: short_int + r4 :: bit + r5 :: object + r6, x :: union[str, None] + r7 :: short_int +L0: + r0 = 0 +L1: + r1 = get_element_ptr a ob_size :: PyVarObject + r2 = load_mem r1 :: native_int* + keep_alive a + r3 = r2 << 1 + r4 = r0 < r3 :: signed + if r4 goto L2 else goto L4 :: bool +L2: + r5 = CPyList_GetItemUnsafe(a, r0) + r6 = cast(union[str, None], r5) + x = r6 +L3: + r7 = r0 + 2 + r0 = r7 + goto L1 +L4: + return 1 diff --git a/mypyc/test/test_subtype.py b/mypyc/test/test_typeops.py similarity index 64% rename from mypyc/test/test_subtype.py rename to mypyc/test/test_typeops.py index 4a0d8737c852..f414edd1a2bb 100644 --- a/mypyc/test/test_subtype.py +++ b/mypyc/test/test_typeops.py @@ -1,16 +1,19 @@ -"""Test cases for is_subtype and is_runtime_subtype.""" +"""Test cases for various RType operations.""" from __future__ import annotations import unittest from mypyc.ir.rtypes import ( + RUnion, bit_rprimitive, bool_rprimitive, int32_rprimitive, int64_rprimitive, int_rprimitive, + object_rprimitive, short_int_rprimitive, + str_rprimitive, ) from mypyc.rt_subtype import is_runtime_subtype from mypyc.subtype import is_subtype @@ -50,3 +53,24 @@ def test_bit(self) -> None: def test_bool(self) -> None: assert not is_runtime_subtype(bool_rprimitive, bit_rprimitive) assert not is_runtime_subtype(bool_rprimitive, int_rprimitive) + + +class TestUnionSimplification(unittest.TestCase): + def test_simple_type_result(self) -> None: + assert RUnion.make_simplified_union([int_rprimitive]) == int_rprimitive + + def test_remove_duplicate(self) -> None: + assert RUnion.make_simplified_union([int_rprimitive, int_rprimitive]) == int_rprimitive + + def test_cannot_simplify(self) -> None: + assert RUnion.make_simplified_union( + [int_rprimitive, str_rprimitive, object_rprimitive] + ) == RUnion([int_rprimitive, str_rprimitive, object_rprimitive]) + + def test_nested(self) -> None: + assert RUnion.make_simplified_union( + [int_rprimitive, RUnion([str_rprimitive, int_rprimitive])] + ) == RUnion([int_rprimitive, str_rprimitive]) + assert RUnion.make_simplified_union( + [int_rprimitive, RUnion([str_rprimitive, RUnion([int_rprimitive])])] + ) == RUnion([int_rprimitive, str_rprimitive])