diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index af2f58fecc65..a7eb47b37aea 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -2,7 +2,6 @@ import math import numbers -import operator import sys from abc import abstractmethod from collections.abc import Callable, Mapping, Sequence @@ -336,6 +335,15 @@ def build(self, context: dict): The constructed object. """ + def __getattr__(self, name): + return Getattr(self, name) + + def __getitem__(self, name): + return Getitem(self, name) + + def __call__(self, *args, **kwargs): + return Call(self, *args, **kwargs) + class Variable(Slotted, Builder): """Retrieve a value from the context. @@ -355,12 +363,6 @@ def __init__(self, name): def build(self, context): return context[self] - def __getattr__(self, name): - return Call(operator.attrgetter(name), self) - - def __getitem__(self, name): - return Call(operator.itemgetter(name), self) - class Just(Slotted, Builder): """Construct exactly the given value. @@ -415,6 +417,28 @@ def build(self, context): return self.func(value, context) +class Getattr(Slotted, Builder): + __slots__ = ("instance", "name") + + def __init__(self, instance, name): + super().__init__(instance=builder(instance), name=name) + + def build(self, context): + instance = self.instance.build(context) + return getattr(instance, self.name) + + +class Getitem(Slotted, Builder): + __slots__ = ("instance", "name") + + def __init__(self, instance, name): + super().__init__(instance=builder(instance), name=name) + + def build(self, context): + instance = self.instance.build(context) + return instance[self.name] + + class Call(Slotted, Builder): """Pattern that calls a function with the given arguments. @@ -431,20 +455,21 @@ class Call(Slotted, Builder): """ __slots__ = ("func", "args", "kwargs") - func: Callable - args: tuple - kwargs: FrozenDict + func: Builder + args: tuple[Builder, ...] + kwargs: FrozenDict[str, Builder] def __init__(self, func, *args, **kwargs): - assert callable(func) + func = func if isinstance(func, Builder) else Just(func) args = tuple(map(builder, args)) kwargs = frozendict({k: builder(v) for k, v in kwargs.items()}) super().__init__(func=func, args=args, kwargs=kwargs) def build(self, context): + func = self.func.build(context) args = tuple(arg.build(context) for arg in self.args) kwargs = {k: v.build(context) for k, v in self.kwargs.items()} - return self.func(*args, **kwargs) + return func(*args, **kwargs) def __call__(self, *args, **kwargs): if self.args or self.kwargs: @@ -469,7 +494,7 @@ def namespace(cls, module) -> Namespace: >>> x = Variable("x") >>> pattern = c.Negate(x) >>> pattern - Call(func=, args=(Variable(name='x'),), kwargs=FrozenDict({})) + Call(func=Just(value=), args=(Variable(name='x'),), kwargs=FrozenDict({})) >>> pattern.build({x: 5}) """ diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 5f509bba19c0..8591acd50ece 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -44,6 +44,8 @@ Function, GenericInstanceOf, GenericSequenceOf, + Getattr, + Getitem, Innermost, InstanceOf, IsIn, @@ -156,6 +158,17 @@ def test_variable(): assert p.build(context) == 10 +def test_variable_getattr(): + v = Variable("v") + p = v.copy + assert p == Getattr(v, "copy") + assert p.build({v: [1, 2, 3]})() == [1, 2, 3] + + p = v.copy() + assert p == Call(Getattr(v, "copy")) + assert p.build({v: [1, 2, 3]}) == [1, 2, 3] + + def test_pattern_factory_wraps_variable_with_capture(): v = Variable("other") p = pattern(v) @@ -1251,6 +1264,18 @@ def fn(a, b, c=1): assert c.build({}) == {"a": 1, "b": 2} +def test_getattr(): + v = Variable("v") + b = Getattr(v, "b") + assert b.build({v: Foo(1, 2)}) == 2 + + +def test_getitem(): + v = Variable("v") + b = Getitem(v, 1) + assert b.build({v: [1, 2, 3]}) == 2 + + def test_builder(): def fn(x, ctx): return x + 1