Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Fixes to union simplification #14364

Merged
merged 3 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,30 @@ def __init__(self, items: list[RType]) -> None:
self.items_set = frozenset(items)
self._ctype = "PyObject *"

@staticmethod
def make_simplified_union(items: list[RType]) -> RType:
"""Return a normalized union that covers the given items.

Flatten nested unions and remove duplicate items.

Overlapping items are *not* simplified. For example,
[object, str] will not be simplified.
"""
items = flatten_nested_unions(items)
assert items

# Remove duplicate items using set + list to preserve item order
seen = set()
new_items = []
for item in items:
if item not in seen:
new_items.append(item)
seen.add(item)
if len(new_items) > 1:
return RUnion(new_items)
else:
return new_items[0]

def accept(self, visitor: RTypeVisitor[T]) -> T:
return visitor.visit_runion(self)

Expand All @@ -823,6 +847,19 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RUnion:
return RUnion(types)


def flatten_nested_unions(types: list[RType]) -> list[RType]:
if not any(isinstance(t, RUnion) for t in types):
return types # Fast path

flat_items: list[RType] = []
for t in types:
if isinstance(t, RUnion):
flat_items.extend(flatten_nested_unions(t.items))
else:
flat_items.append(t)
return flat_items


def optional_value_type(rtype: RType) -> RType | None:
"""If rtype is the union of none_rprimitive and another type X, return X.

Expand Down
13 changes: 11 additions & 2 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Type,
TypeOfAny,
UninhabitedType,
UnionType,
get_proper_type,
)
from mypy.util import split_target
Expand Down Expand Up @@ -85,6 +86,7 @@
RInstance,
RTuple,
RType,
RUnion,
bitmap_rprimitive,
c_int_rprimitive,
c_pyssize_t_rprimitive,
Expand Down Expand Up @@ -864,8 +866,15 @@ def extract_int(self, e: Expression) -> int | None:
return None

def get_sequence_type(self, expr: Expression) -> RType:
target_type = get_proper_type(self.types[expr])
assert isinstance(target_type, Instance)
return self.get_sequence_type_from_type(self.types[expr])

def get_sequence_type_from_type(self, target_type: Type) -> RType:
target_type = get_proper_type(target_type)
if isinstance(target_type, UnionType):
return RUnion.make_simplified_union(
[self.get_sequence_type_from_type(item) for item in target_type.items]
)
assert isinstance(target_type, Instance), target_type
if target_type.type.fullname == "builtins.str":
return str_rprimitive
else:
Expand Down
13 changes: 1 addition & 12 deletions mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,7 @@ def type_to_rtype(self, typ: Type | None) -> RType:
elif isinstance(typ, NoneTyp):
return none_rprimitive
elif isinstance(typ, UnionType):
# Remove redundant items using set + list to preserve item order
seen = set()
items = []
for item in typ.items:
rtype = self.type_to_rtype(item)
if rtype not in seen:
items.append(rtype)
seen.add(rtype)
if len(items) > 1:
return RUnion(items)
else:
return items[0]
return RUnion.make_simplified_union([self.type_to_rtype(item) for item in typ.items])
elif isinstance(typ, AnyType):
return object_rprimitive
elif isinstance(typ, TypeType):
Expand Down
70 changes: 67 additions & 3 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,20 @@ L5:
return 1

[case testSimplifyListUnion]
from typing import List, Union
from typing import List, Union, Optional

def f(a: Union[List[str], List[bytes], int]) -> int:
def narrow(a: Union[List[str], List[bytes], int]) -> int:
if isinstance(a, list):
return len(a)
return a
def loop(a: Union[List[str], List[bytes]]) -> None:
for x in a:
pass
def nested_union(a: Union[List[str], List[Optional[str]]]) -> None:
for x in a:
pass
[out]
def f(a):
def narrow(a):
a :: union[list, int]
r0 :: object
r1 :: int32
Expand Down Expand Up @@ -465,3 +471,61 @@ L1:
L2:
r8 = unbox(int, a)
return r8
def loop(a):
a :: list
r0 :: short_int
r1 :: ptr
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: union[str, bytes]
r7 :: short_int
L0:
r0 = 0
L1:
r1 = get_element_ptr a ob_size :: PyVarObject
r2 = load_mem r1 :: native_int*
keep_alive a
r3 = r2 << 1
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = CPyList_GetItemUnsafe(a, r0)
r6 = cast(union[str, bytes], r5)
x = r6
L3:
r7 = r0 + 2
r0 = r7
goto L1
L4:
return 1
def nested_union(a):
a :: list
r0 :: short_int
r1 :: ptr
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: union[str, None]
r7 :: short_int
L0:
r0 = 0
L1:
r1 = get_element_ptr a ob_size :: PyVarObject
r2 = load_mem r1 :: native_int*
keep_alive a
r3 = r2 << 1
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = CPyList_GetItemUnsafe(a, r0)
r6 = cast(union[str, None], r5)
x = r6
L3:
r7 = r0 + 2
r0 = r7
goto L1
L4:
return 1
26 changes: 25 additions & 1 deletion mypyc/test/test_subtype.py → mypyc/test/test_typeops.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Test cases for is_subtype and is_runtime_subtype."""
"""Test cases for various RType operations."""

from __future__ import annotations

import unittest

from mypyc.ir.rtypes import (
RUnion,
bit_rprimitive,
bool_rprimitive,
int32_rprimitive,
int64_rprimitive,
int_rprimitive,
object_rprimitive,
short_int_rprimitive,
str_rprimitive,
)
from mypyc.rt_subtype import is_runtime_subtype
from mypyc.subtype import is_subtype
Expand Down Expand Up @@ -50,3 +53,24 @@ def test_bit(self) -> None:
def test_bool(self) -> None:
assert not is_runtime_subtype(bool_rprimitive, bit_rprimitive)
assert not is_runtime_subtype(bool_rprimitive, int_rprimitive)


class TestUnionSimplification(unittest.TestCase):
def test_simple_type_result(self) -> None:
assert RUnion.make_simplified_union([int_rprimitive]) == int_rprimitive

def test_remove_duplicate(self) -> None:
assert RUnion.make_simplified_union([int_rprimitive, int_rprimitive]) == int_rprimitive

def test_cannot_simplify(self) -> None:
assert RUnion.make_simplified_union(
[int_rprimitive, str_rprimitive, object_rprimitive]
) == RUnion([int_rprimitive, str_rprimitive, object_rprimitive])

def test_nested(self) -> None:
assert RUnion.make_simplified_union(
[int_rprimitive, RUnion([str_rprimitive, int_rprimitive])]
) == RUnion([int_rprimitive, str_rprimitive])
assert RUnion.make_simplified_union(
[int_rprimitive, RUnion([str_rprimitive, RUnion([int_rprimitive])])]
) == RUnion([int_rprimitive, str_rprimitive])