From b91ecf0a19d53e123b9faeb5706d18848fb43fad Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 06:35:54 -0400 Subject: [PATCH] refactor(array-apply): adjust array map and array filter representation for easier non-recursive compilation --- ibis/expr/operations/arrays.py | 41 ++++------ ibis/expr/types/arrays.py | 134 ++++++++++++++++----------------- 2 files changed, 79 insertions(+), 96 deletions(-) diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 1868cf88eb8b..8110903d4429 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Optional +from typing import Optional from public import public @@ -9,7 +9,7 @@ import ibis.expr.rules as rlz from ibis.common.annotations import attribute from ibis.common.typing import VarTuple # noqa: TCH001 -from ibis.expr.operations.core import Argument, Unary, Value +from ibis.expr.operations.core import Unary, Value @public @@ -75,41 +75,26 @@ class ArrayRepeat(Value): shape = rlz.shape_like("args") -class ArrayApply(Value): +@public +class ArrayMap(Value): arg: Value[dt.Array] + body: Value + param: str - @attribute - def parameter(self): - name = next(iter(self.func.__signature__.parameters.keys())) - return name - - @attribute - def result(self): - arg = Argument( - name=self.parameter, - shape=self.arg.shape, - dtype=self.arg.dtype.value_type, - ) - return self.func(arg) - - @attribute - def shape(self): - return self.arg.shape - - -@public -class ArrayMap(ArrayApply): - func: Callable[[Value], Value] + shape = rlz.shape_like("arg") @attribute def dtype(self) -> dt.DataType: - return dt.Array(self.result.dtype) + return dt.Array(self.body.dtype) @public -class ArrayFilter(ArrayApply): - func: Callable[[Value], Value[dt.Boolean]] +class ArrayFilter(Value): + arg: Value[dt.Array] + body: Value[dt.Boolean] + param: str + shape = rlz.shape_like("arg") dtype = rlz.dtype_like("arg") diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 835030aebec2..7fd28fb6879c 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable, Iterable -import functools +import inspect from public import public import ibis.expr.operations as ops @@ -384,15 +384,15 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: │ [] │ └──────────────────────┘ >>> t.a.map(lambda x: (x + 100).cast("float")) - ┏━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayMap(a) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├───────────────────────┤ - │ [101.0, None, ... +1] │ - │ [104.0] │ - │ [] │ - └───────────────────────┘ + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayMap(a, Cast(Add(x, 100), float64)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├─────────────────────────────────────────┤ + │ [101.0, None, ... +1] │ + │ [104.0] │ + │ [] │ + └─────────────────────────────────────────┘ `.map()` also supports more complex callables like `functools.partial` and lambdas with closures @@ -403,33 +403,32 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: ... >>> add2 = partial(add, y=2) >>> t.a.map(add2) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayMap(a) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [3, None, ... +1] │ - │ [6] │ - │ [] │ - └──────────────────────┘ + ┏━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayMap(a, Add(x, 2)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├────────────────────────┤ + │ [3, None, ... +1] │ + │ [6] │ + │ [] │ + └────────────────────────┘ >>> y = 2 >>> t.a.map(lambda x: x + y) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayMap(a) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [3, None, ... +1] │ - │ [6] │ - │ [] │ - └──────────────────────┘ + ┏━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayMap(a, Add(x, 2)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├────────────────────────┤ + │ [3, None, ... +1] │ + │ [6] │ + │ [] │ + └────────────────────────┘ """ - - @functools.wraps(func) - def wrapped(x, **kwargs): - return func(x.to_expr(), **kwargs) - - return ops.ArrayMap(self, func=wrapped).to_expr() + param = next(iter(inspect.signature(func).parameters.keys())) + parameter = ops.Argument( + name=param, shape=self.op().shape, dtype=self.type().value_type + ).to_expr() + return ops.ArrayMap(self, param=param, body=func(parameter)).to_expr() def filter( self, predicate: Callable[[ir.Value], bool | ir.BooleanValue] @@ -462,15 +461,15 @@ def filter( │ [] │ └──────────────────────┘ >>> t.a.filter(lambda x: x > 1) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayFilter(a) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [2] │ - │ [4] │ - │ [] │ - └──────────────────────┘ + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayFilter(a, Greater(x, 1)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├───────────────────────────────┤ + │ [2] │ + │ [4] │ + │ [] │ + └───────────────────────────────┘ `.filter()` also supports more complex callables like `functools.partial` and lambdas with closures @@ -481,33 +480,32 @@ def filter( ... >>> gt1 = partial(gt, y=1) >>> t.a.filter(gt1) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayFilter(a) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [2] │ - │ [4] │ - │ [] │ - └──────────────────────┘ + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayFilter(a, Greater(x, 1)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├───────────────────────────────┤ + │ [2] │ + │ [4] │ + │ [] │ + └───────────────────────────────┘ >>> y = 1 >>> t.a.filter(lambda x: x > y) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayFilter(a) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [2] │ - │ [4] │ - │ [] │ - └──────────────────────┘ + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayFilter(a, Greater(x, 1)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├───────────────────────────────┤ + │ [2] │ + │ [4] │ + │ [] │ + └───────────────────────────────┘ """ - - @functools.wraps(predicate) - def wrapped(x, **kwargs): - return predicate(x.to_expr(), **kwargs) - - return ops.ArrayFilter(self, func=wrapped).to_expr() + param = next(iter(inspect.signature(predicate).parameters.keys())) + parameter = ops.Argument( + name=param, shape=self.op().shape, dtype=self.type().value_type + ).to_expr() + return ops.ArrayFilter(self, param=param, body=predicate(parameter)).to_expr() def contains(self, other: ir.Value) -> ir.BooleanValue: """Return whether the array contains `other`. @@ -983,7 +981,7 @@ def array(values: Iterable[V], type: str | dt.DataType | None = None) -> ArrayVa │ [3, 42] │ └──────────────────────┘ """ - if any(isinstance(value, Column) for value in values): + if any(isinstance(value, Value) for value in values): return ops.ArrayColumn(values).to_expr() else: try: