diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 6b3b3b767354..2f599ee9df57 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -920,8 +920,7 @@ class Object(Matcher): __slots__ = ("type", "field_patterns") def __init__(self, type, *args, **kwargs): - match_args = getattr(type, "__match_args__", tuple()) - kwargs.update(dict(zip(args, match_args))) + kwargs.update(dict(zip(type.__match_args__, args))) super().__init__(type, frozendict(kwargs)) def match(self, value, context): @@ -1198,3 +1197,41 @@ def match(pat: Pattern, value: AnyType, context: dict[str, AnyType] = None): return NoMatch return context + + +class Topmost(Matcher): + """Traverse the value tree topmost first and match the first value that matches.""" + + __slots__ = ("searcher", "filter") + + def __init__(self, searcher, filter=None): + super().__init__(pattern(searcher), filter) + + def match(self, value, context): + result = self.searcher.match(value, context) + if result is not NoMatch: + return result + + for child in value.__children__(self.filter): + result = self.match(child, context) + if result is not NoMatch: + return result + + return NoMatch + + +class Innermost(Matcher): + """Traverse the value tree innermost first and match the first value that matches.""" + + __slots__ = ("searcher", "filter") + + def __init__(self, searcher, filter=None): + super().__init__(pattern(searcher), filter) + + def match(self, value, context): + for child in value.__children__(self.filter): + result = self.match(child, context) + if result is not NoMatch: + return result + + return self.searcher.match(value, context) diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index f220a30f0ce1..49085fbd401b 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -24,6 +24,7 @@ from typing_extensions import Annotated from ibis.common.collections import FrozenDict +from ibis.common.graph import Node from ibis.common.patterns import ( AllOf, Any, @@ -39,6 +40,7 @@ EqualTo, FrozenDictOf, Function, + Innermost, InstanceOf, IsIn, LazyInstanceOf, @@ -57,6 +59,7 @@ Reference, SequenceOf, SubclassOf, + Topmost, TupleOf, TypeOf, ValidationError, @@ -387,11 +390,14 @@ def test_mapping_of(): def test_object_pattern(): class Foo: + __match_args__ = ("a", "b") + def __init__(self, a, b): self.a = a self.b = b - assert match(Object(Foo, 1, b=2), Foo(1, 2)) == {} + p = Object(Foo, 1, b=2) + assert match(p, Foo(1, 2)) == {} def test_callable_with(): @@ -791,3 +797,68 @@ def f(x): return x > 0 assert pattern(f) == Function(f) + + +class Term(Node): + def __eq__(self, other): + return type(self) is type(other) and self.__args__ == other.__args__ + + def __hash__(self): + return hash((self.__class__, self.__args__)) + + +class Lit(Term): + __argnames__ = ("value",) + __match_args__ = ("value",) + + def __init__(self, value): + self.value = value + + @property + def __args__(self): + return (self.value,) + + +class Binary(Term): + __argnames__ = ("left", "right") + __match_args__ = ("left", "right") + + def __init__(self, left, right): + self.left = left + self.right = right + + @property + def __args__(self): + return (self.left, self.right) + + +class Add(Binary): + pass + + +class Mul(Binary): + pass + + +one = Lit(1) +two = Mul(Lit(2), one) + +three = Add(one, two) +six = Mul(two, three) +seven = Add(one, six) + + +def test_topmost_innermost(): + inner = Object(Mul, Capture(Any(), "a"), Capture(Any(), "b")) + assert inner.match(six, {}) is six + + context = {} + p = Topmost(inner) + m = p.match(seven, context) + assert m is six + assert context == {"a": two, "b": three} + + p = Innermost(inner) + m = p.match(seven, context) + assert m is two + assert context == {"a": Lit(2), "b": one}