Skip to content

Commit

Permalink
refactor(array-apply): adjust array map and array filter representati…
Browse files Browse the repository at this point in the history
…on for easier non-recursive compilation
  • Loading branch information
cpcloud authored and kszucs committed Sep 26, 2023
1 parent 054ebae commit b91ecf0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 96 deletions.
41 changes: 13 additions & 28 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, Optional
from typing import Optional

from public import public

Expand All @@ -9,7 +9,7 @@
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Argument, Unary, Value
from ibis.expr.operations.core import Unary, Value


@public
Expand Down Expand Up @@ -75,41 +75,26 @@ class ArrayRepeat(Value):
shape = rlz.shape_like("args")


class ArrayApply(Value):
@public
class ArrayMap(Value):
arg: Value[dt.Array]
body: Value
param: str

@attribute
def parameter(self):
name = next(iter(self.func.__signature__.parameters.keys()))
return name

@attribute
def result(self):
arg = Argument(
name=self.parameter,
shape=self.arg.shape,
dtype=self.arg.dtype.value_type,
)
return self.func(arg)

@attribute
def shape(self):
return self.arg.shape


@public
class ArrayMap(ArrayApply):
func: Callable[[Value], Value]
shape = rlz.shape_like("arg")

@attribute
def dtype(self) -> dt.DataType:
return dt.Array(self.result.dtype)
return dt.Array(self.body.dtype)


@public
class ArrayFilter(ArrayApply):
func: Callable[[Value], Value[dt.Boolean]]
class ArrayFilter(Value):
arg: Value[dt.Array]
body: Value[dt.Boolean]
param: str

shape = rlz.shape_like("arg")
dtype = rlz.dtype_like("arg")


Expand Down
134 changes: 66 additions & 68 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Iterable
import functools
import inspect
from public import public

import ibis.expr.operations as ops
Expand Down Expand Up @@ -384,15 +384,15 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
│ [] │
└──────────────────────┘
>>> t.a.map(lambda x: (x + 100).cast("float"))
┏━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a)
┡━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<float64> │
├───────────────────────┤
│ [101.0, None, ... +1] │
│ [104.0] │
│ [] │
└───────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┃ ArrayMap(a, Cast(Add(x, 100), float64))
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
│ array<float64>
├─────────────────────────────────────────
│ [101.0, None, ... +1]
│ [104.0]
│ []
└─────────────────────────────────────────
`.map()` also supports more complex callables like `functools.partial`
and lambdas with closures
Expand All @@ -403,33 +403,32 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
...
>>> add2 = partial(add, y=2)
>>> t.a.map(add2)
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a)
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────┤
│ [3, None, ... +1] │
│ [6] │
│ [] │
└──────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━
┃ ArrayMap(a, Add(x, 2))
┡━━━━━━━━━━━━━━━━━━━━━━━━
│ array<int64>
├────────────────────────
│ [3, None, ... +1]
│ [6]
│ []
└────────────────────────
>>> y = 2
>>> t.a.map(lambda x: x + y)
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a)
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────┤
│ [3, None, ... +1] │
│ [6] │
│ [] │
└──────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━
┃ ArrayMap(a, Add(x, 2))
┡━━━━━━━━━━━━━━━━━━━━━━━━
│ array<int64>
├────────────────────────
│ [3, None, ... +1]
│ [6]
│ []
└────────────────────────
"""

@functools.wraps(func)
def wrapped(x, **kwargs):
return func(x.to_expr(), **kwargs)

return ops.ArrayMap(self, func=wrapped).to_expr()
param = next(iter(inspect.signature(func).parameters.keys()))
parameter = ops.Argument(
name=param, shape=self.op().shape, dtype=self.type().value_type
).to_expr()
return ops.ArrayMap(self, param=param, body=func(parameter)).to_expr()

def filter(
self, predicate: Callable[[ir.Value], bool | ir.BooleanValue]
Expand Down Expand Up @@ -462,15 +461,15 @@ def filter(
│ [] │
└──────────────────────┘
>>> t.a.filter(lambda x: x > 1)
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a)
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────┤
│ [2] │
│ [4] │
│ [] │
└──────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┃ ArrayFilter(a, Greater(x, 1))
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
│ array<int64>
├───────────────────────────────
│ [2]
│ [4]
│ []
└───────────────────────────────
`.filter()` also supports more complex callables like `functools.partial`
and lambdas with closures
Expand All @@ -481,33 +480,32 @@ def filter(
...
>>> gt1 = partial(gt, y=1)
>>> t.a.filter(gt1)
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a)
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────┤
│ [2] │
│ [4] │
│ [] │
└──────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┃ ArrayFilter(a, Greater(x, 1))
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
│ array<int64>
├───────────────────────────────
│ [2]
│ [4]
│ []
└───────────────────────────────
>>> y = 1
>>> t.a.filter(lambda x: x > y)
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a)
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────┤
│ [2] │
│ [4] │
│ [] │
└──────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┃ ArrayFilter(a, Greater(x, 1))
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
│ array<int64>
├───────────────────────────────
│ [2]
│ [4]
│ []
└───────────────────────────────
"""

@functools.wraps(predicate)
def wrapped(x, **kwargs):
return predicate(x.to_expr(), **kwargs)

return ops.ArrayFilter(self, func=wrapped).to_expr()
param = next(iter(inspect.signature(predicate).parameters.keys()))
parameter = ops.Argument(
name=param, shape=self.op().shape, dtype=self.type().value_type
).to_expr()
return ops.ArrayFilter(self, param=param, body=predicate(parameter)).to_expr()

def contains(self, other: ir.Value) -> ir.BooleanValue:
"""Return whether the array contains `other`.
Expand Down Expand Up @@ -983,7 +981,7 @@ def array(values: Iterable[V], type: str | dt.DataType | None = None) -> ArrayVa
│ [3, 42] │
└──────────────────────┘
"""
if any(isinstance(value, Column) for value in values):
if any(isinstance(value, Value) for value in values):
return ops.ArrayColumn(values).to_expr()
else:
try:
Expand Down

0 comments on commit b91ecf0

Please sign in to comment.