Skip to content

Commit

Permalink
fix(patterns): PatternList should keep the original pattern's type
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Nov 6, 2023
1 parent b8e463d commit 6552639
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
9 changes: 6 additions & 3 deletions ibis/common/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ibis.common.bases import Final, FrozenSlotted, Hashable, Immutable, Slotted
from ibis.common.collections import FrozenDict
from ibis.common.typing import Coercible, CoercionError
from ibis.util import PseudoHashable, is_iterable
from ibis.util import PseudoHashable


class Resolver(Coercible, Hashable):
Expand Down Expand Up @@ -519,9 +519,12 @@ def resolver(obj):
elif isinstance(obj, collections.abc.Mapping):
# allow nesting deferred patterns in dicts
return Mapping(obj)
elif is_iterable(obj):
elif isinstance(obj, collections.abc.Sequence):
# allow nesting deferred patterns in tuples/lists
return Sequence(obj)
if isinstance(obj, (str, bytes)):
return Just(obj)
else:
return Sequence(obj)
elif isinstance(obj, type):
return Just(obj)
elif callable(obj):
Expand Down
12 changes: 9 additions & 3 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,15 +1426,18 @@ class PatternList(Slotted, Pattern):

@classmethod
def __create__(cls, patterns, type=list):
if patterns == ():
return EqualTo(patterns)

patterns = tuple(map(pattern, patterns))
for pat in patterns:
pat = _maybe_unwrap_capture(pat)
if isinstance(pat, (SomeOf, SomeChunksOf)):
return VariadicPatternList(patterns, type)

return super().__create__(patterns, type)

def __init__(self, patterns, type):
patterns = tuple(map(pattern, patterns))
super().__init__(patterns=patterns, type=type)

def describe(self, plural=False):
Expand Down Expand Up @@ -1584,12 +1587,15 @@ def pattern(obj: AnyType) -> Pattern:
return Capture(obj)
elif isinstance(obj, Mapping):
raise TypeError("Cannot create a pattern from a mapping")
elif isinstance(obj, Sequence):
if isinstance(obj, (str, bytes)):
return EqualTo(obj)
else:
return PatternList(obj, type=type(obj))
elif isinstance(obj, type):
return InstanceOf(obj)
elif get_origin(obj):
return Pattern.from_typehint(obj, allow_coercion=False)
elif is_iterable(obj):
return PatternList(obj)
elif callable(obj):
return Custom(obj)
else:
Expand Down
5 changes: 5 additions & 0 deletions ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,11 @@ def test_matching_sequence_pattern():
assert match([Some(...), 2, 3, 4, Some(...)], list(range(8))) == list(range(8))


def test_matching_sequence_pattern_keeps_original_type():
assert match([1, 2, 3, 4, Some(...)], tuple(range(1, 9))) == list(range(1, 9))
assert match((1, 2, 3, Some(...)), [1, 2, 3, 4, 5]) == (1, 2, 3, 4, 5)


def test_matching_sequence_with_captures():
v = list(range(1, 9))
assert match([1, 2, 3, 4, Some(...)], v) == v
Expand Down

0 comments on commit 6552639

Please sign in to comment.