From b8e463de82feb32b32140ba91b50678c822f7c8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 17 Oct 2023 15:56:54 +0200 Subject: [PATCH] refactor(patterns): support more flexible sequence matching - improve the implementation of `PatternSequence` and related supplemental patterns - `TupleOf` pattern now aligns with `*Of` naming to match sequence of a single pattern rather than a tuple of patterns - remove reduntand object creation logic for alternative but more performant implementations of certain patterns - support the unpacking syntax to turn patterns into a `SomeOf` pattern - short circuit `SequenceOf(Any())` to not traverse over the input sequence - remove unused `PatternMapping` pattern --- ibis/backends/clickhouse/compiler/core.py | 4 +- ibis/common/patterns.py | 352 +++++++++++----------- ibis/common/tests/test_patterns.py | 283 +++++++++-------- 3 files changed, 322 insertions(+), 317 deletions(-) diff --git a/ibis/backends/clickhouse/compiler/core.py b/ibis/backends/clickhouse/compiler/core.py index 7b00566458c3..ad7fba431718 100644 --- a/ibis/backends/clickhouse/compiler/core.py +++ b/ibis/backends/clickhouse/compiler/core.py @@ -90,7 +90,7 @@ def fn(node, _, **kwargs): # replace the right side of InColumn into a scalar subquery for sql # backends - replace_in_column_with_table_array_view = p.InColumn(..., y) >> _.copy( + replace_in_column_with_table_array_view = p.InColumn(options=y) >> _.copy( options=c.TableArrayView( c.Selection(table=lambda _, y: find_first_base_table(y), selections=(y,)) ), @@ -98,7 +98,7 @@ def fn(node, _, **kwargs): # replace any checks against an empty right side of the IN operation with # `False` - replace_empty_in_values_with_false = p.InValues(..., ()) >> c.Literal( + replace_empty_in_values_with_false = p.InValues(options=()) >> c.Literal( False, dtype="bool" ) diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index f5db3b47e1e5..180e8104e8ee 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -6,7 +6,6 @@ from collections.abc import Callable, Mapping, Sequence from enum import Enum from inspect import Parameter -from itertools import chain from typing import ( Annotated, ForwardRef, @@ -161,20 +160,18 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern: # variadic tuples differently, e.g. tuple[int, ...] is a variadic # tuple of integers, while tuple[int] is a tuple with a single int first, *rest = args - # TODO(kszucs): consider to support the same SequenceOf path if args - # has a single element, e.g. tuple[int] since annotation a single - # element tuple is not common OR use typing.Sequence for annotating - # instead of tuple[T, ...] OR have a VarTupleOf pattern if rest == [Ellipsis]: - inners = cls.from_typehint(first) + return TupleOf(cls.from_typehint(first)) else: - inners = tuple(map(cls.from_typehint, args)) - return TupleOf(inners) + return PatternList(map(cls.from_typehint, args), type=origin) elif issubclass(origin, Sequence): # construct a validator for the sequence elements where all elements # must be of the same type, e.g. Sequence[int] is a sequence of ints (value_inner,) = map(cls.from_typehint, args) - return SequenceOf(value_inner, type=origin) + if allow_coercion and issubclass(origin, Coercible): + return GenericSequenceOf(value_inner, type=origin) + else: + return SequenceOf(value_inner, type=origin) elif issubclass(origin, Mapping): # construct a validator for the mapping keys and values, e.g. # Mapping[str, int] is a mapping with string keys and int values @@ -277,6 +274,9 @@ def __rmatmul__(self, name: str) -> Capture: """ return Capture(name, self) + def __iter__(self) -> SomeOf: + yield SomeOf(self) + class Is(Slotted, Pattern): """Pattern that matches a value against a reference value. @@ -419,9 +419,6 @@ def match(self, value, context): return NoMatch -If = Check - - class DeferredCheck(Slotted, Pattern): __slots__ = ("resolver",) resolver: Resolver @@ -495,9 +492,6 @@ def describe(self, plural=False): return repr(self.value) -Eq = EqualTo - - class DeferredEqualTo(Slotted, Pattern): """Pattern that checks a value equals to the given value. @@ -777,9 +771,6 @@ def __call__(self, *args, **kwargs): return Object(self.type, *args, **kwargs) -As = CoercedTo - - class GenericCoercedTo(Slotted, Pattern): """Force a value to have a particular generic Python type. @@ -1069,9 +1060,6 @@ def match(self, value, context): return NoMatch -In = IsIn - - class SequenceOf(Slotted, Pattern): """Pattern that matches if all of the items in a sequence match a given pattern. @@ -1091,27 +1079,7 @@ class SequenceOf(Slotted, Pattern): item: Pattern type: type - @classmethod - def __create__( - cls, - item, - type: type = tuple, - exactly: Optional[int] = None, - at_least: Optional[int] = None, - at_most: Optional[int] = None, - ): - if ( - exactly is not None - or at_least is not None - or at_most is not None - or issubclass(type, Coercible) - ): - return GenericSequenceOf( - item, type=type, exactly=exactly, at_least=at_least, at_most=at_most - ) - return super().__create__(item, type=type) - - def __init__(self, item, type=tuple): + def __init__(self, item, type=list): super().__init__(item=pattern(item), type=type) def describe(self, plural=False): @@ -1123,12 +1091,16 @@ def match(self, values, context): if not is_iterable(values): return NoMatch - result = [] - for item in values: - item = self.item.match(item, context) - if item is NoMatch: - return NoMatch - result.append(item) + if self.item == _any: + # optimization to avoid unnecessary iteration + result = values + else: + result = [] + for item in values: + item = self.item.match(item, context) + if item is NoMatch: + return NoMatch + result.append(item) return self.type(result) @@ -1141,7 +1113,7 @@ class GenericSequenceOf(Slotted, Pattern): item The pattern to match against each item in the sequence. type - The type to coerce the sequence to. Defaults to tuple. + The type to coerce the sequence to. Defaults to list. exactly The exact length of the sequence. at_least @@ -1155,48 +1127,33 @@ class GenericSequenceOf(Slotted, Pattern): type: Pattern length: Length - @classmethod - def __create__( - cls, - item: Pattern, - type: type = tuple, - exactly: Optional[int] = None, - at_least: Optional[int] = None, - at_most: Optional[int] = None, - ): - if ( - exactly is None - and at_least is None - and at_most is None - and not issubclass(type, Coercible) - ): - return SequenceOf(item, type=type) - else: - return super().__create__(item, type, exactly, at_least, at_most) - def __init__( self, item: Pattern, - type: type = tuple, + type: type = list, exactly: Optional[int] = None, at_least: Optional[int] = None, at_most: Optional[int] = None, ): item = pattern(item) type = CoercedTo(type) - length = Length(at_least=at_least, at_most=at_most) + length = Length(exactly=exactly, at_least=at_least, at_most=at_most) super().__init__(item=item, type=type, length=length) def match(self, values, context): if not is_iterable(values): return NoMatch - result = [] - for value in values: - value = self.item.match(value, context) - if value is NoMatch: - return NoMatch - result.append(value) + if self.item == _any: + # optimization to avoid unnecessary iteration + result = values + else: + result = [] + for value in values: + value = self.item.match(value, context) + if value is NoMatch: + return NoMatch + result.append(value) result = self.type.match(result, context) if result is NoMatch: @@ -1205,52 +1162,6 @@ def match(self, values, context): return self.length.match(result, context) -class TupleOf(Slotted, Pattern): - """Pattern that matches if the respective items in a tuple match the given patterns. - - Parameters - ---------- - fields - The patterns to match the respective items in the tuple. - """ - - __slots__ = ("fields",) - fields: tuple[Pattern, ...] - - @classmethod - def __create__(cls, fields): - if not isinstance(fields, tuple): - return SequenceOf(fields, tuple) - return super().__create__(fields) - - def __init__(self, fields): - fields = tuple(map(pattern, fields)) - super().__init__(fields=fields) - - def describe(self, plural=False): - fields = ", ".join(f.describe(plural=False) for f in self.fields) - if plural: - return f"tuples of ({fields})" - else: - return f"a tuple of ({fields})" - - def match(self, values, context): - if not is_iterable(values): - return NoMatch - - if len(values) != len(self.fields): - return NoMatch - - result = [] - for pattern, value in zip(self.fields, values): - value = pattern.match(value, context) - if value is NoMatch: - return NoMatch - result.append(value) - - return tuple(result) - - class GenericMappingOf(Slotted, Pattern): """Pattern that matches if all of the keys and values match the given patterns. @@ -1443,47 +1354,138 @@ def match(self, value, context): return fn -class PatternSequence(Slotted, Pattern): - # TODO(kszucs): add a length optimization to not even try to match if the - # length of the sequence is lower than the length of the pattern sequence +class SomeOf(Slotted, Pattern): + __slots__ = ("pattern", "delimiter") - __slots__ = ("pattern_window",) - pattern_window: tuple[tuple[Pattern, Pattern], ...] + @classmethod + def __create__(cls, *args, **kwargs): + if len(args) == 1: + return super().__create__(*args, **kwargs) + else: + return SomeChunksOf(*args, **kwargs) - def __init__(self, patterns): - current_patterns = [ - SequenceOf(_any) if p is Ellipsis else pattern(p) for p in patterns - ] - following_patterns = chain(current_patterns[1:], [Not(_any)]) - pattern_window = tuple(zip(current_patterns, following_patterns)) - super().__init__(pattern_window=pattern_window) + def __init__(self, item, **kwargs): + pattern = GenericSequenceOf(item, **kwargs) + delimiter = pattern.item + super().__init__(pattern=pattern, delimiter=delimiter) - def match(self, value, context): - it = RewindableIterator(value) - result = [] + def match(self, values, context): + return self.pattern.match(values, context) - if not self.pattern_window: - try: - next(it) - except StopIteration: - return result + +class SomeChunksOf(Slotted, Pattern): + """Pattern that unpacks a value into its elements. + + Designed to be used inside a `PatternList` pattern with the `*` syntax. + """ + + __slots__ = ("pattern", "delimiter") + + def __init__(self, *args, **kwargs): + pattern = GenericSequenceOf(PatternList(args), **kwargs) + delimiter = pattern.item.patterns[0] + super().__init__(pattern=pattern, delimiter=delimiter) + + def chunk(self, values, context): + chunk = [] + for item in values: + if self.delimiter.match(item, context) is NoMatch: + chunk.append(item) else: + if chunk: # only yield if there are items in the chunk + yield chunk + chunk = [item] # start a new chunk with the delimiter + if chunk: + yield chunk + + def match(self, values, context): + chunks = self.chunk(values, context) + result = self.pattern.match(chunks, context) + if result is NoMatch: + return NoMatch + else: + return sum(result, []) + + +def _maybe_unwrap_capture(obj): + return obj.pattern if isinstance(obj, Capture) else obj + + +class PatternList(Slotted, Pattern): + """Pattern that matches if the respective items in a tuple match the given patterns. + + Parameters + ---------- + fields + The patterns to match the respective items in the tuple. + """ + + __slots__ = ("patterns", "type") + patterns: tuple[Pattern, ...] + type: type + + @classmethod + def __create__(cls, patterns, type=list): + patterns = tuple(map(pattern, patterns)) + for pat in patterns: + pat = _maybe_unwrap_capture(pat) + if isinstance(pat, (SomeOf, SomeChunksOf)): + return VariadicPatternList(patterns, type) + return super().__create__(patterns, type) + + def __init__(self, patterns, type): + patterns = tuple(map(pattern, patterns)) + super().__init__(patterns=patterns, type=type) + + def describe(self, plural=False): + patterns = ", ".join(f.describe(plural=False) for f in self.patterns) + if plural: + return f"tuples of ({patterns})" + else: + return f"a tuple of ({patterns})" + + def match(self, values, context): + if not is_iterable(values): + return NoMatch + + if len(values) != len(self.patterns): + return NoMatch + + result = [] + for pattern, value in zip(self.patterns, values): + value = pattern.match(value, context) + if value is NoMatch: return NoMatch + result.append(value) - for current, following in self.pattern_window: - original = current + return self.type(result) + + +class VariadicPatternList(Slotted, Pattern): + __slots__ = ("patterns", "type") + patterns: tuple[Pattern, ...] + type: type + + def __init__(self, patterns, type=list): + patterns = tuple(map(pattern, patterns)) + super().__init__(patterns=patterns, type=type) + + def match(self, value, context): + if not self.patterns: + return NoMatch if value else [] + + it = RewindableIterator(value) + result = [] - if isinstance(current, Capture): - current = current.pattern - if isinstance(following, Capture): - following = following.pattern + following_patterns = self.patterns[1:] + (Nothing(),) + for current, following in zip(self.patterns, following_patterns): + original = current + current = _maybe_unwrap_capture(current) + following = _maybe_unwrap_capture(following) - if isinstance(current, (SequenceOf, GenericSequenceOf, PatternSequence)): - if isinstance(following, (SequenceOf, GenericSequenceOf)): - following = following.item - elif isinstance(following, PatternSequence): - # first pattern to match from the pattern window - following = following.pattern_window[0][0] + if isinstance(current, (SomeOf, SomeChunksOf)): + if isinstance(following, (SomeOf, SomeChunksOf)): + following = following.delimiter matches = [] while True: @@ -1517,32 +1519,7 @@ def match(self, value, context): else: result.append(res) - return result - - -class PatternMapping(Slotted, Pattern): - __slots__ = ("keys", "values") - keys: PatternSequence - values: PatternSequence - - def __init__(self, patterns): - keys = PatternSequence(list(map(pattern, patterns.keys()))) - values = PatternSequence(list(map(pattern, patterns.values()))) - super().__init__(keys=keys, values=values) - - def match(self, value, context): - if not isinstance(value, Mapping): - return NoMatch - - keys = value.keys() - if (keys := self.keys.match(keys, context)) is NoMatch: - return NoMatch - - values = value.values() - if (values := self.values.match(values, context)) is NoMatch: - return NoMatch - - return dict(zip(keys, values)) + return self.type(result) def NoneOf(*args) -> Pattern: @@ -1555,6 +1532,11 @@ def ListOf(pattern): return SequenceOf(pattern, type=list) +def TupleOf(pattern): + """Match a variable-length tuple of items matching the given pattern.""" + return SequenceOf(pattern, type=tuple) + + def DictOf(key_pattern, value_pattern): """Match a dictionary with keys and values matching the given patterns.""" return MappingOf(key_pattern, value_pattern, type=dict) @@ -1601,13 +1583,13 @@ def pattern(obj: AnyType) -> Pattern: elif isinstance(obj, (Deferred, Resolver)): return Capture(obj) elif isinstance(obj, Mapping): - return PatternMapping(obj) + raise TypeError("Cannot create a pattern from a mapping") elif isinstance(obj, type): return InstanceOf(obj) elif get_origin(obj): return Pattern.from_typehint(obj, allow_coercion=False) elif is_iterable(obj): - return PatternSequence(obj) + return PatternList(obj) elif callable(obj): return Custom(obj) else: @@ -1657,3 +1639,9 @@ def match( IsTruish = Check(lambda x: bool(x)) IsNumber = InstanceOf(numbers.Number) & ~InstanceOf(bool) IsString = InstanceOf(str) + +As = CoercedTo +Eq = EqualTo +In = IsIn +If = Check +Some = SomeOf diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 091230bc8e29..ad9d78a4077d 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -57,10 +57,10 @@ Object, Option, Pattern, - PatternMapping, - PatternSequence, + PatternList, Replace, SequenceOf, + Some, SubclassOf, TupleOf, TypeOf, @@ -148,6 +148,12 @@ def test_pattern_factory_wraps_variable_with_capture(): assert ctx == {"other": 10} +def test_match_on_ellipsis(): + assert match(..., 1) == 1 + assert match(..., [1, 2, 3]) == [1, 2, 3] + assert match(..., (1, 2, 3)) == (1, 2, 3) + + def test_capture(): ctx = {} @@ -484,23 +490,15 @@ class MyList(list, Coercible): def __coerce__(cls, value, T=...): return cls(value) - p = SequenceOf(InstanceOf(str), MyList) - assert isinstance(p, GenericSequenceOf) - assert p == GenericSequenceOf(InstanceOf(str), MyList) + p = GenericSequenceOf(InstanceOf(str), MyList) assert p.match(["foo", "bar"], context={}) == MyList(["foo", "bar"]) assert p.match("string", context={}) is NoMatch - p = SequenceOf(InstanceOf(str), tuple, at_least=1) - assert isinstance(p, GenericSequenceOf) + p = GenericSequenceOf(InstanceOf(str), tuple, at_least=1) assert p == GenericSequenceOf(InstanceOf(str), tuple, at_least=1) assert p.match(("foo", "bar"), context={}) == ("foo", "bar") assert p.match([], context={}) is NoMatch - p = GenericSequenceOf(InstanceOf(str), list) - assert isinstance(p, SequenceOf) - assert p == SequenceOf(InstanceOf(str), list) - assert p.match(("foo", "bar"), context={}) == ["foo", "bar"] - def test_list_of(): p = ListOf(InstanceOf(str)) @@ -512,21 +510,17 @@ def test_list_of(): assert p.describe(plural=True) == "lists of strs" -def test_tuple_of(): - p = TupleOf((InstanceOf(str), InstanceOf(int), InstanceOf(float))) - assert p.match(("foo", 1, 1.0), context={}) == ("foo", 1, 1.0) - assert p.match(["foo", 1, 1.0], context={}) == ("foo", 1, 1.0) +def test_pattern_sequence(): + p = PatternList((InstanceOf(str), InstanceOf(int), InstanceOf(float))) + assert p.match(("foo", 1, 1.0), context={}) == ["foo", 1, 1.0] + assert p.match(["foo", 1, 1.0], context={}) == ["foo", 1, 1.0] assert p.match(1, context={}) is NoMatch assert p.describe() == "a tuple of (a str, an int, a float)" assert p.describe(plural=True) == "tuples of (a str, an int, a float)" - p = TupleOf(InstanceOf(str)) - assert p == SequenceOf(InstanceOf(str), tuple) - assert p.match(("foo", "bar"), context={}) == ("foo", "bar") - assert p.match(["foo"], context={}) == ("foo",) - assert p.match(1, context={}) is NoMatch - assert p.describe() == "a tuple of strs" - assert p.describe(plural=True) == "tuples of strs" + p = PatternList((InstanceOf(str),)) + assert p.match(("foo",), context={}) == ["foo"] + assert p.match(("foo", "bar"), context={}) is NoMatch def test_mapping_of(): @@ -676,21 +670,59 @@ def g(a: int, b: str, c: str = "0"): def test_pattern_list(): - p = PatternSequence([1, 2, InstanceOf(int), ...]) + p = PatternList([1, 2, InstanceOf(int), Some(...)]) assert p.match([1, 2, 3, 4, 5], context={}) == [1, 2, 3, 4, 5] assert p.match([1, 2, 3, 4, 5, 6], context={}) == [1, 2, 3, 4, 5, 6] assert p.match([1, 2, 3, 4], context={}) == [1, 2, 3, 4] assert p.match([1, 2, "3", 4], context={}) is NoMatch # subpattern is a simple pattern - p = PatternSequence([1, 2, CoercedTo(int), ...]) + p = PatternList([1, 2, CoercedTo(int), Some(...)]) assert p.match([1, 2, 3.0, 4.0, 5.0], context={}) == [1, 2, 3, 4.0, 5.0] # subpattern is a sequence - p = PatternSequence([1, 2, 3, SequenceOf(CoercedTo(int), at_least=1)]) + p = PatternList([1, 2, 3, Some(CoercedTo(int), at_least=1)]) assert p.match([1, 2, 3, 4.0, 5.0], context={}) == [1, 2, 3, 4, 5] +def test_pattern_list_from_tuple_typehint(): + p = Pattern.from_typehint(tuple[str, int, float]) + assert p == PatternList( + [InstanceOf(str), InstanceOf(int), InstanceOf(float)], type=tuple + ) + assert p.match(["foo", 1, 2.0], context={}) == ("foo", 1, 2.0) + assert p.match(("foo", 1, 2.0), context={}) == ("foo", 1, 2.0) + assert p.match(["foo", 1], context={}) is NoMatch + assert p.match(["foo", 1, 2.0, 3.0], context={}) is NoMatch + + class MyTuple(tuple): + pass + + p = Pattern.from_typehint(MyTuple[int, bool]) + assert p == PatternList([InstanceOf(int), InstanceOf(bool)], type=MyTuple) + assert p.match([1, True], context={}) == MyTuple([1, True]) + assert p.match(MyTuple([1, True]), context={}) == MyTuple([1, True]) + assert p.match([1, 2], context={}) is NoMatch + + +def test_pattern_list_unpack(): + integer = pattern(int) + floating = pattern(float) + + assert match([1, 2, *floating], [1, 2, 3]) is NoMatch + assert match([1, 2, *floating], [1, 2, 3.0]) == [1, 2, 3.0] + assert match([1, 2, *floating], [1, 2, 3.0, 4.0]) == [1, 2, 3.0, 4.0] + assert match([1, *floating, *integer], [1, 2.0, 3.0, 4]) == [1, 2.0, 3.0, 4] + assert match([1, *floating, *integer], [1, 2.0, 3.0, 4, 5]) == [ + 1, + 2.0, + 3.0, + 4, + 5, + ] + assert match([1, *floating, *integer], [1, 2.0, 3, 4.0]) is NoMatch + + def test_matching(): assert match("foo", "foo") == "foo" assert match("foo", "bar") is NoMatch @@ -728,14 +760,14 @@ def __eq__(self, other): def test_replace_in_nested_object_pattern(): # simple example using reference to replace a value b = Variable("b") - p = Object(Foo, 1, b=Replace(..., b)) + p = Object(Foo, 1, b=Replace(Any(), b)) f = p.match(Foo(1, 2), {"b": 3}) assert f.a == 1 assert f.b == 3 # nested example using reference to replace a value d = Variable("d") - p = Object(Foo, 1, b=Object(Bar, 2, d=Replace(..., d))) + p = Object(Foo, 1, b=Object(Bar, 2, d=Replace(Any(), d))) g = p.match(Foo(1, Bar(2, 3)), {"d": 4}) assert g.b.c == 2 assert g.b.d == 4 @@ -795,148 +827,131 @@ def test_matching_sequence_pattern(): assert match([], []) == [] assert match([], [1]) is NoMatch - assert match([1, 2, 3, 4, ...], list(range(1, 9))) == list(range(1, 9)) - assert match([1, 2, 3, 4, ...], list(range(1, 3))) is NoMatch - assert match([1, 2, 3, 4, ...], list(range(1, 5))) == list(range(1, 5)) - assert match([1, 2, 3, 4, ...], list(range(1, 6))) == list(range(1, 6)) + assert match([1, 2, 3, 4, Some(...)], list(range(1, 9))) == list(range(1, 9)) + assert match([1, 2, 3, 4, Some(...)], list(range(1, 3))) is NoMatch + assert match([1, 2, 3, 4, Some(...)], list(range(1, 5))) == list(range(1, 5)) + assert match([1, 2, 3, 4, Some(...)], list(range(1, 6))) == list(range(1, 6)) - assert match([..., 3, 4], list(range(5))) == list(range(5)) - assert match([..., 3, 4], list(range(3))) is NoMatch + assert match([Some(...), 3, 4], list(range(5))) == list(range(5)) + assert match([Some(...), 3, 4], list(range(3))) is NoMatch - assert match([0, 1, ..., 4], list(range(5))) == list(range(5)) - assert match([0, 1, ..., 4], list(range(4))) is NoMatch + assert match([0, 1, Some(...), 4], list(range(5))) == list(range(5)) + assert match([0, 1, Some(...), 4], list(range(4))) is NoMatch - assert match([...], list(range(5))) == list(range(5)) - assert match([..., 2, 3, 4, ...], list(range(8))) == list(range(8)) + assert match([Some(...)], list(range(5))) == list(range(5)) + assert match([Some(...), 2, 3, 4, Some(...)], list(range(8))) == list(range(8)) def test_matching_sequence_with_captures(): - assert match([1, 2, 3, 4, SequenceOf(...)], v := list(range(1, 9))) == v - assert ( - match([1, 2, 3, 4, "rest" @ SequenceOf(...)], v := list(range(1, 9)), ctx := {}) - == v - ) - assert ctx == {"rest": (5, 6, 7, 8)} + v = list(range(1, 9)) + assert match([1, 2, 3, 4, Some(...)], v) == v + assert match([1, 2, 3, 4, "rest" @ Some(...)], v, ctx := {}) == v + assert ctx == {"rest": [5, 6, 7, 8]} v = list(range(5)) - assert match([0, 1, x @ SequenceOf(...), 4], v, ctx := {}) == v - assert ctx == {"x": (2, 3)} - assert match([0, 1, "var" @ SequenceOf(...), 4], v, ctx := {}) == v - assert ctx == {"var": (2, 3)} + assert match([0, 1, x @ Some(...), 4], v, ctx := {}) == v + assert ctx == {"x": [2, 3]} + assert match([0, 1, "var" @ Some(...), 4], v, ctx := {}) == v + assert ctx == {"var": [2, 3]} p = [ 0, 1, - "ints" @ SequenceOf(InstanceOf(int)), - "floats" @ SequenceOf(InstanceOf(float)), + "ints" @ Some(int), + Some("last_float" @ InstanceOf(float)), 6, ] v = [0, 1, 2, 3, 4.0, 5.0, 6] assert match(p, v, ctx := {}) == v - assert ctx == {"ints": (2, 3), "floats": (4.0, 5.0)} + assert ctx == {"ints": [2, 3], "last_float": 5.0} def test_matching_sequence_remaining(): - Seq = SequenceOf - IsInt = InstanceOf(int) - three = [1, 2, 3] four = [1, 2, 3, 4] five = [1, 2, 3, 4, 5] - assert match([1, 2, 3, Seq(IsInt, at_least=1)], four) == four - assert match([1, 2, 3, Seq(IsInt, at_least=1)], three) is NoMatch - assert match([1, 2, 3, Seq(IsInt)], three) == three - assert match([1, 2, 3, Seq(IsInt, at_most=1)], three) == three - assert match([1, 2, 3, Seq(IsInt & Between(0, 10))], five) == five - assert match([1, 2, 3, Seq(IsInt & Between(0, 4))], five) is NoMatch - assert match([1, 2, 3, Seq(IsInt, at_least=2)], four) is NoMatch - assert match([1, 2, 3, "res" @ Seq(IsInt, at_least=2)], five, ctx := {}) == five - assert ctx == {"res": (4, 5)} + assert match([1, 2, 3, Some(int, at_least=1)], four) == four + assert match([1, 2, 3, Some(int, at_least=1)], three) is NoMatch + assert match([1, 2, 3, Some(int)], three) == three + assert match([1, 2, 3, Some(int, at_most=1)], three) == three + assert match([1, 2, 3, Some(InstanceOf(int) & Between(0, 10))], five) == five + assert match([1, 2, 3, Some(InstanceOf(int) & Between(0, 4))], five) is NoMatch + assert match([1, 2, 3, Some(int, at_least=2)], four) is NoMatch + assert match([1, 2, 3, "res" @ Some(int, at_least=2)], five, ctx := {}) == five + assert ctx == {"res": [4, 5]} def test_matching_sequence_complicated(): - pattern = [ + pat = [ 1, - "a" @ ListOf(InstanceOf(int) & Check(lambda x: x < 10)), + "a" @ Some(InstanceOf(int) & Check(lambda x: x < 10)), 4, - "b" @ SequenceOf(...), + "b" @ Some(...), 8, 9, ] expected = { "a": [2, 3], - "b": (5, 6, 7), + "b": [5, 6, 7], } - assert match(pattern, range(1, 10), ctx := {}) == list(range(1, 10)) - assert ctx == expected - - pattern = [0, "pairs" @ PatternSequence([-1, -2]), 3] - expected = {"pairs": [-1, -2]} - assert match(pattern, [0, -1, -2, 3], ctx := {}) == [0, -1, -2, 3] - assert ctx == expected - - pattern = [ - 0, - "first" @ PatternSequence([1, 2]), - "second" @ PatternSequence([4, 5]), - 3, - ] - expected = {"first": [1, 2], "second": [4, 5]} - assert match(pattern, [0, 1, 2, 4, 5, 3], ctx := {}) == [0, 1, 2, 4, 5, 3] + assert match(pat, range(1, 10), ctx := {}) == list(range(1, 10)) assert ctx == expected - pattern = [1, 2, "remaining" @ SequenceOf(...)] - expected = {"remaining": (3, 4, 5, 6, 7, 8, 9)} - assert match(pattern, range(1, 10), ctx := {}) == list(range(1, 10)) + pat = [1, 2, Capture("remaining", Some(...))] + expected = {"remaining": [3, 4, 5, 6, 7, 8, 9]} + assert match(pat, range(1, 10), ctx := {}) == list(range(1, 10)) assert ctx == expected - assert match([0, SequenceOf([1, 2]), 3], v := [0, [1, 2], [1, 2], 3]) == v - - -def test_pattern_map(): - assert PatternMapping({}).match({}, context={}) == {} - assert PatternMapping({}).match({1: 2}, context={}) is NoMatch - - -def test_matching_mapping(): - assert match({}, {}) == {} - assert match({}, {1: 2}) is NoMatch + v = [0, [1, 2, "3"], [1, 2, "4"], 3] + assert match([0, Some([1, 2, str]), 3], v) == v - assert match({1: 2}, {1: 2}) == {1: 2} - assert match({1: 2}, {1: 3}) is NoMatch - - assert match({}, 3) is NoMatch - ctx = {} - assert match({"a": "capture" @ InstanceOf(int)}, {"a": 1}, ctx) == {"a": 1} - assert ctx == {"capture": 1} - - p = { - "a": "capture" @ InstanceOf(int), - "b": InstanceOf(float), - ...: InstanceOf(str), - } - ctx = {} - assert match(p, {"a": 1, "b": 2.0, "c": "foo"}, ctx) == { - "a": 1, - "b": 2.0, - "c": "foo", - } - assert ctx == {"capture": 1} - assert match(p, {"a": 1, "b": 2.0, "c": 3}) is NoMatch - p = { - "a": "capture" @ InstanceOf(int), - "b": InstanceOf(float), - "rest" @ SequenceOf(...): InstanceOf(str), - } +def test_pattern_sequence_with_nested_some(): ctx = {} - assert match(p, {"a": 1, "b": 2.0, "c": "foo"}, ctx) == { - "a": 1, - "b": 2.0, - "c": "foo", - } - assert ctx == {"capture": 1, "rest": ("c",)} + res = match([0, "subseq" @ Some(1, 2), 3], [0, 1, 2, 1, 2, 3], ctx) + assert res == [0, 1, 2, 1, 2, 3] + assert ctx == {"subseq": [1, 2, 1, 2]} + + assert match([0, Some(1), 2, 3], [0, 2, 3]) == [0, 2, 3] + assert match([0, Some(1, at_least=1), 2, 3], [0, 2, 3]) is NoMatch + assert match([0, Some(1, at_least=1), 2, 3], [0, 1, 2, 3]) == [0, 1, 2, 3] + assert match([0, Some(1, at_least=2), 2, 3], [0, 1, 2, 3]) is NoMatch + assert match([0, Some(1, at_least=2), 2, 3], [0, 1, 1, 2, 3]) == [0, 1, 1, 2, 3] + assert match([0, Some(1, at_most=2), 2, 3], [0, 1, 1, 2, 3]) == [0, 1, 1, 2, 3] + assert match([0, Some(1, at_most=1), 2, 3], [0, 1, 1, 2, 3]) is NoMatch + assert match([0, Some(1, exactly=1), 2, 3], [0, 2, 3]) is NoMatch + assert match([0, Some(1, exactly=1), 2, 3], [0, 1, 2, 3]) == [0, 1, 2, 3] + assert match([0, Some(1, exactly=0), 2, 3], [0, 2, 3]) == [0, 2, 3] + assert match([0, Some(1, exactly=0), 2, 3], [0, 1, 2, 3]) is NoMatch + + assert match([0, Some(1, Some(2)), 3], [0, 3]) == [0, 3] + assert match([0, Some(1, Some(2)), 3], [0, 1, 3]) == [0, 1, 3] + assert match([0, Some(1, Some(2)), 3], [0, 1, 2, 3]) == [0, 1, 2, 3] + assert match([0, Some(1, Some(2)), 3], [0, 1, 2, 2, 3]) == [0, 1, 2, 2, 3] + assert match([0, Some(1, Some(2)), 3], [0, 1, 2, 2, 2, 3]) == [0, 1, 2, 2, 2, 3] + assert match([0, Some(1, Some(2)), 3], [0, 1, 2, 1, 2, 2, 3]) == [ + 0, + 1, + 2, + 1, + 2, + 2, + 3, + ] + assert match([0, Some(1, Some(2), at_least=1), 3], [0, 1, 2, 3]) == [0, 1, 2, 3] + assert match([0, Some(1, Some(2), at_least=1), 3], [0, 1, 3]) == [0, 1, 3] + assert match([0, Some(1, Some(2, at_least=2), at_least=1), 3], [0, 1, 3]) is NoMatch + assert ( + match([0, Some(1, Some(2, at_least=2), at_least=1), 3], [0, 1, 2, 3]) is NoMatch + ) + assert match([0, Some(1, Some(2, at_least=2), at_least=1), 3], [0, 1, 2, 2, 3]) == [ + 0, + 1, + 2, + 2, + 3, + ] @pytest.mark.parametrize( @@ -953,7 +968,7 @@ def test_matching_mapping(): (IsIn(("a", "b")), "b", "b"), (IsIn({"a", "b", "c"}), "c", "c"), (TupleOf(InstanceOf(int)), (1, 2, 3), (1, 2, 3)), - (TupleOf((InstanceOf(int), InstanceOf(str))), (1, "a"), (1, "a")), + (PatternList((InstanceOf(int), InstanceOf(str))), (1, "a"), [1, "a"]), (ListOf(InstanceOf(str)), ["a", "b"], ["a", "b"]), (AnyOf(InstanceOf(str), InstanceOf(int)), "foo", "foo"), (AnyOf(InstanceOf(str), InstanceOf(int)), 7, 7), @@ -1022,7 +1037,9 @@ def test_pattern_decorator(): (list[int], SequenceOf(InstanceOf(int), list)), ( tuple[int, float, str], - TupleOf((InstanceOf(int), InstanceOf(float), InstanceOf(str))), + PatternList( + (InstanceOf(int), InstanceOf(float), InstanceOf(str)), type=tuple + ), ), (tuple[int, ...], TupleOf(InstanceOf(int))), ( @@ -1141,7 +1158,7 @@ def test_pattern_coercible_sequence_type(): assert s.match([1, 2, 3], context={}) == (PlusOne(2), PlusOne(3), PlusOne(4)) s = Pattern.from_typehint(DoubledList[PlusOne]) - assert s == SequenceOf(CoercedTo(PlusOne), type=DoubledList) + assert s == GenericSequenceOf(CoercedTo(PlusOne), type=DoubledList) assert s.match([1, 2, 3], context={}) == DoubledList( [PlusOne(2), PlusOne(3), PlusOne(4), PlusOne(2), PlusOne(3), PlusOne(4)] ) @@ -1178,7 +1195,7 @@ def f(x): assert pattern(List[int]) == ListOf(InstanceOf(int)) # noqa: UP006 # spelled out sequences construct a more advanced pattern sequence - assert pattern([int, str, 1]) == PatternSequence( + assert pattern([int, str, 1]) == PatternList( [InstanceOf(int), InstanceOf(str), EqualTo(1)] )