Skip to content

Commit

Permalink
feat(common): add Topmost and Innermost pattern matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed May 2, 2023
1 parent 2c0fa9e commit 90b48fc
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
41 changes: 39 additions & 2 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,7 @@ class Object(Matcher):
__slots__ = ("type", "field_patterns")

def __init__(self, type, *args, **kwargs):
match_args = getattr(type, "__match_args__", tuple())
kwargs.update(dict(zip(args, match_args)))
kwargs.update(dict(zip(type.__match_args__, args)))
super().__init__(type, frozendict(kwargs))

def match(self, value, context):
Expand Down Expand Up @@ -1198,3 +1197,41 @@ def match(pat: Pattern, value: AnyType, context: dict[str, AnyType] = None):
return NoMatch

return context


class Topmost(Matcher):
"""Traverse the value tree topmost first and match the first value that matches."""

__slots__ = ("searcher", "filter")

def __init__(self, searcher, filter=None):
super().__init__(pattern(searcher), filter)

def match(self, value, context):
result = self.searcher.match(value, context)
if result is not NoMatch:
return result

for child in value.__children__(self.filter):
result = self.match(child, context)
if result is not NoMatch:
return result

return NoMatch


class Innermost(Matcher):
"""Traverse the value tree innermost first and match the first value that matches."""

__slots__ = ("searcher", "filter")

def __init__(self, searcher, filter=None):
super().__init__(pattern(searcher), filter)

def match(self, value, context):
for child in value.__children__(self.filter):
result = self.match(child, context)
if result is not NoMatch:
return result

return self.searcher.match(value, context)
73 changes: 72 additions & 1 deletion ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing_extensions import Annotated

from ibis.common.collections import FrozenDict
from ibis.common.graph import Node
from ibis.common.patterns import (
AllOf,
Any,
Expand All @@ -39,6 +40,7 @@
EqualTo,
FrozenDictOf,
Function,
Innermost,
InstanceOf,
IsIn,
LazyInstanceOf,
Expand All @@ -57,6 +59,7 @@
Reference,
SequenceOf,
SubclassOf,
Topmost,
TupleOf,
TypeOf,
ValidationError,
Expand Down Expand Up @@ -387,11 +390,14 @@ def test_mapping_of():

def test_object_pattern():
class Foo:
__match_args__ = ("a", "b")

def __init__(self, a, b):
self.a = a
self.b = b

assert match(Object(Foo, 1, b=2), Foo(1, 2)) == {}
p = Object(Foo, 1, b=2)
assert match(p, Foo(1, 2)) == {}


def test_callable_with():
Expand Down Expand Up @@ -791,3 +797,68 @@ def f(x):
return x > 0

assert pattern(f) == Function(f)


class Term(Node):
def __eq__(self, other):
return type(self) is type(other) and self.__args__ == other.__args__

def __hash__(self):
return hash((self.__class__, self.__args__))


class Lit(Term):
__argnames__ = ("value",)
__match_args__ = ("value",)

def __init__(self, value):
self.value = value

@property
def __args__(self):
return (self.value,)


class Binary(Term):
__argnames__ = ("left", "right")
__match_args__ = ("left", "right")

def __init__(self, left, right):
self.left = left
self.right = right

@property
def __args__(self):
return (self.left, self.right)


class Add(Binary):
pass


class Mul(Binary):
pass


one = Lit(1)
two = Mul(Lit(2), one)

three = Add(one, two)
six = Mul(two, three)
seven = Add(one, six)


def test_topmost_innermost():
inner = Object(Mul, Capture(Any(), "a"), Capture(Any(), "b"))
assert inner.match(six, {}) is six

context = {}
p = Topmost(inner)
m = p.match(seven, context)
assert m is six
assert context == {"a": two, "b": three}

p = Innermost(inner)
m = p.match(seven, context)
assert m is two
assert context == {"a": Lit(2), "b": one}

0 comments on commit 90b48fc

Please sign in to comment.