Skip to content

Commit

Permalink
feat(patterns): support calling methods on builders like a variable
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Sep 27, 2023
1 parent 3610e52 commit 58b2d0e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
51 changes: 38 additions & 13 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import math
import numbers
import operator
import sys
from abc import abstractmethod
from collections.abc import Callable, Mapping, Sequence
Expand Down Expand Up @@ -336,6 +335,15 @@ def build(self, context: dict):
The constructed object.
"""

def __getattr__(self, name):
return Getattr(self, name)

def __getitem__(self, name):
return Getitem(self, name)

def __call__(self, *args, **kwargs):
return Call(self, *args, **kwargs)


class Variable(Slotted, Builder):
"""Retrieve a value from the context.
Expand All @@ -355,12 +363,6 @@ def __init__(self, name):
def build(self, context):
return context[self]

def __getattr__(self, name):
return Call(operator.attrgetter(name), self)

def __getitem__(self, name):
return Call(operator.itemgetter(name), self)


class Just(Slotted, Builder):
"""Construct exactly the given value.
Expand Down Expand Up @@ -415,6 +417,28 @@ def build(self, context):
return self.func(value, context)


class Getattr(Slotted, Builder):
__slots__ = ("instance", "name")

def __init__(self, instance, name):
super().__init__(instance=builder(instance), name=name)

def build(self, context):
instance = self.instance.build(context)
return getattr(instance, self.name)


class Getitem(Slotted, Builder):
__slots__ = ("instance", "name")

def __init__(self, instance, name):
super().__init__(instance=builder(instance), name=name)

def build(self, context):
instance = self.instance.build(context)
return instance[self.name]


class Call(Slotted, Builder):
"""Pattern that calls a function with the given arguments.
Expand All @@ -431,20 +455,21 @@ class Call(Slotted, Builder):
"""

__slots__ = ("func", "args", "kwargs")
func: Callable
args: tuple
kwargs: FrozenDict
func: Builder
args: tuple[Builder, ...]
kwargs: FrozenDict[str, Builder]

def __init__(self, func, *args, **kwargs):
assert callable(func)
func = func if isinstance(func, Builder) else Just(func)
args = tuple(map(builder, args))
kwargs = frozendict({k: builder(v) for k, v in kwargs.items()})
super().__init__(func=func, args=args, kwargs=kwargs)

def build(self, context):
func = self.func.build(context)
args = tuple(arg.build(context) for arg in self.args)
kwargs = {k: v.build(context) for k, v in self.kwargs.items()}
return self.func(*args, **kwargs)
return func(*args, **kwargs)

def __call__(self, *args, **kwargs):
if self.args or self.kwargs:
Expand All @@ -469,7 +494,7 @@ def namespace(cls, module) -> Namespace:
>>> x = Variable("x")
>>> pattern = c.Negate(x)
>>> pattern
Call(func=<class 'ibis.expr.operations.numeric.Negate'>, args=(Variable(name='x'),), kwargs=FrozenDict({}))
Call(func=Just(value=<class 'ibis.expr.operations.numeric.Negate'>), args=(Variable(name='x'),), kwargs=FrozenDict({}))
>>> pattern.build({x: 5})
<ibis.expr.operations.numeric.Negate object at 0x...>
"""
Expand Down
25 changes: 25 additions & 0 deletions ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
Function,
GenericInstanceOf,
GenericSequenceOf,
Getattr,
Getitem,
Innermost,
InstanceOf,
IsIn,
Expand Down Expand Up @@ -156,6 +158,17 @@ def test_variable():
assert p.build(context) == 10


def test_variable_getattr():
v = Variable("v")
p = v.copy
assert p == Getattr(v, "copy")
assert p.build({v: [1, 2, 3]})() == [1, 2, 3]

p = v.copy()
assert p == Call(Getattr(v, "copy"))
assert p.build({v: [1, 2, 3]}) == [1, 2, 3]


def test_pattern_factory_wraps_variable_with_capture():
v = Variable("other")
p = pattern(v)
Expand Down Expand Up @@ -1251,6 +1264,18 @@ def fn(a, b, c=1):
assert c.build({}) == {"a": 1, "b": 2}


def test_getattr():
v = Variable("v")
b = Getattr(v, "b")
assert b.build({v: Foo(1, 2)}) == 2


def test_getitem():
v = Variable("v")
b = Getitem(v, 1)
assert b.build({v: [1, 2, 3]}) == 2


def test_builder():
def fn(x, ctx):
return x + 1
Expand Down

0 comments on commit 58b2d0e

Please sign in to comment.