Skip to content

Commit

Permalink
[mypyc] Fixes to union simplification (#14364)
Browse files Browse the repository at this point in the history
Flatten nested unions before simplifying unions.

Simplify item type unions in loops. This fixes a crash introduced in
#14363.
  • Loading branch information
JukkaL authored Dec 29, 2022
1 parent e51fb56 commit 0070071
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 18 deletions.
37 changes: 37 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,30 @@ def __init__(self, items: list[RType]) -> None:
self.items_set = frozenset(items)
self._ctype = "PyObject *"

@staticmethod
def make_simplified_union(items: list[RType]) -> RType:
"""Return a normalized union that covers the given items.
Flatten nested unions and remove duplicate items.
Overlapping items are *not* simplified. For example,
[object, str] will not be simplified.
"""
items = flatten_nested_unions(items)
assert items

# Remove duplicate items using set + list to preserve item order
seen = set()
new_items = []
for item in items:
if item not in seen:
new_items.append(item)
seen.add(item)
if len(new_items) > 1:
return RUnion(new_items)
else:
return new_items[0]

def accept(self, visitor: RTypeVisitor[T]) -> T:
return visitor.visit_runion(self)

Expand All @@ -823,6 +847,19 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RUnion:
return RUnion(types)


def flatten_nested_unions(types: list[RType]) -> list[RType]:
if not any(isinstance(t, RUnion) for t in types):
return types # Fast path

flat_items: list[RType] = []
for t in types:
if isinstance(t, RUnion):
flat_items.extend(flatten_nested_unions(t.items))
else:
flat_items.append(t)
return flat_items


def optional_value_type(rtype: RType) -> RType | None:
"""If rtype is the union of none_rprimitive and another type X, return X.
Expand Down
13 changes: 11 additions & 2 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Type,
TypeOfAny,
UninhabitedType,
UnionType,
get_proper_type,
)
from mypy.util import split_target
Expand Down Expand Up @@ -85,6 +86,7 @@
RInstance,
RTuple,
RType,
RUnion,
bitmap_rprimitive,
c_int_rprimitive,
c_pyssize_t_rprimitive,
Expand Down Expand Up @@ -864,8 +866,15 @@ def extract_int(self, e: Expression) -> int | None:
return None

def get_sequence_type(self, expr: Expression) -> RType:
target_type = get_proper_type(self.types[expr])
assert isinstance(target_type, Instance)
return self.get_sequence_type_from_type(self.types[expr])

def get_sequence_type_from_type(self, target_type: Type) -> RType:
target_type = get_proper_type(target_type)
if isinstance(target_type, UnionType):
return RUnion.make_simplified_union(
[self.get_sequence_type_from_type(item) for item in target_type.items]
)
assert isinstance(target_type, Instance), target_type
if target_type.type.fullname == "builtins.str":
return str_rprimitive
else:
Expand Down
13 changes: 1 addition & 12 deletions mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,7 @@ def type_to_rtype(self, typ: Type | None) -> RType:
elif isinstance(typ, NoneTyp):
return none_rprimitive
elif isinstance(typ, UnionType):
# Remove redundant items using set + list to preserve item order
seen = set()
items = []
for item in typ.items:
rtype = self.type_to_rtype(item)
if rtype not in seen:
items.append(rtype)
seen.add(rtype)
if len(items) > 1:
return RUnion(items)
else:
return items[0]
return RUnion.make_simplified_union([self.type_to_rtype(item) for item in typ.items])
elif isinstance(typ, AnyType):
return object_rprimitive
elif isinstance(typ, TypeType):
Expand Down
70 changes: 67 additions & 3 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,20 @@ L5:
return 1

[case testSimplifyListUnion]
from typing import List, Union
from typing import List, Union, Optional

def f(a: Union[List[str], List[bytes], int]) -> int:
def narrow(a: Union[List[str], List[bytes], int]) -> int:
if isinstance(a, list):
return len(a)
return a
def loop(a: Union[List[str], List[bytes]]) -> None:
for x in a:
pass
def nested_union(a: Union[List[str], List[Optional[str]]]) -> None:
for x in a:
pass
[out]
def f(a):
def narrow(a):
a :: union[list, int]
r0 :: object
r1 :: int32
Expand Down Expand Up @@ -465,3 +471,61 @@ L1:
L2:
r8 = unbox(int, a)
return r8
def loop(a):
a :: list
r0 :: short_int
r1 :: ptr
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: union[str, bytes]
r7 :: short_int
L0:
r0 = 0
L1:
r1 = get_element_ptr a ob_size :: PyVarObject
r2 = load_mem r1 :: native_int*
keep_alive a
r3 = r2 << 1
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = CPyList_GetItemUnsafe(a, r0)
r6 = cast(union[str, bytes], r5)
x = r6
L3:
r7 = r0 + 2
r0 = r7
goto L1
L4:
return 1
def nested_union(a):
a :: list
r0 :: short_int
r1 :: ptr
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: union[str, None]
r7 :: short_int
L0:
r0 = 0
L1:
r1 = get_element_ptr a ob_size :: PyVarObject
r2 = load_mem r1 :: native_int*
keep_alive a
r3 = r2 << 1
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = CPyList_GetItemUnsafe(a, r0)
r6 = cast(union[str, None], r5)
x = r6
L3:
r7 = r0 + 2
r0 = r7
goto L1
L4:
return 1
26 changes: 25 additions & 1 deletion mypyc/test/test_subtype.py → mypyc/test/test_typeops.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Test cases for is_subtype and is_runtime_subtype."""
"""Test cases for various RType operations."""

from __future__ import annotations

import unittest

from mypyc.ir.rtypes import (
RUnion,
bit_rprimitive,
bool_rprimitive,
int32_rprimitive,
int64_rprimitive,
int_rprimitive,
object_rprimitive,
short_int_rprimitive,
str_rprimitive,
)
from mypyc.rt_subtype import is_runtime_subtype
from mypyc.subtype import is_subtype
Expand Down Expand Up @@ -50,3 +53,24 @@ def test_bit(self) -> None:
def test_bool(self) -> None:
assert not is_runtime_subtype(bool_rprimitive, bit_rprimitive)
assert not is_runtime_subtype(bool_rprimitive, int_rprimitive)


class TestUnionSimplification(unittest.TestCase):
def test_simple_type_result(self) -> None:
assert RUnion.make_simplified_union([int_rprimitive]) == int_rprimitive

def test_remove_duplicate(self) -> None:
assert RUnion.make_simplified_union([int_rprimitive, int_rprimitive]) == int_rprimitive

def test_cannot_simplify(self) -> None:
assert RUnion.make_simplified_union(
[int_rprimitive, str_rprimitive, object_rprimitive]
) == RUnion([int_rprimitive, str_rprimitive, object_rprimitive])

def test_nested(self) -> None:
assert RUnion.make_simplified_union(
[int_rprimitive, RUnion([str_rprimitive, int_rprimitive])]
) == RUnion([int_rprimitive, str_rprimitive])
assert RUnion.make_simplified_union(
[int_rprimitive, RUnion([str_rprimitive, RUnion([int_rprimitive])])]
) == RUnion([int_rprimitive, str_rprimitive])

0 comments on commit 0070071

Please sign in to comment.