Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): support Deferreds in Array.map and .filter #8267

Merged
merged 4 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,23 +448,25 @@ def test_array_slice(backend, start, stop):
param({"a": [[1, 2], [4]]}, {"a": [[2, 3], [5]]}, id="no_nulls"),
],
)
@pytest.mark.parametrize(
"func",
[
lambda x: x + 1,
functools.partial(lambda x, y: x + y, y=1),
ibis._ + 1,
],
)
@pytest.mark.broken(
["risingwave"],
raises=AssertionError,
reason="TODO(Kexiang): seems a bug",
)
def test_array_map(con, input, output):
def test_array_map(con, input, output, func):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
expected = pd.Series(output["a"])

expr = t.select(a=t.a.map(lambda x: x + 1))
result = con.execute(expr.a)
assert frozenset(map(tuple, result.values)) == frozenset(
map(tuple, expected.values)
)

expr = t.select(a=t.a.map(functools.partial(lambda x, y: x + y, y=1)))
expr = t.select(a=t.a.map(func))
result = con.execute(expr.a)
assert frozenset(map(tuple, result.values)) == frozenset(
map(tuple, expected.values)
Expand Down Expand Up @@ -512,17 +514,19 @@ def test_array_map(con, input, output):
param({"a": [[1, 2], [4]]}, {"a": [[2], [4]]}, id="no_nulls"),
],
)
def test_array_filter(con, input, output):
@pytest.mark.parametrize(
"predicate",
[
lambda x: x > 1,
functools.partial(lambda x, y: x > y, y=1),
ibis._ > 1,
],
)
def test_array_filter(con, input, output, predicate):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
expected = pd.Series(output["a"])

expr = t.select(a=t.a.filter(lambda x: x > 1))
result = con.execute(expr.a)
assert frozenset(map(tuple, result.values)) == frozenset(
map(tuple, expected.values)
)

expr = t.select(a=t.a.filter(functools.partial(lambda x, y: x > y, y=1)))
expr = t.select(a=t.a.filter(predicate))
result = con.execute(expr.a)
assert frozenset(map(tuple, result.values)) == frozenset(
map(tuple, expected.values)
Expand Down
74 changes: 59 additions & 15 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from public import public

import ibis.expr.operations as ops
from ibis.common.deferred import deferrable
from ibis.common.deferred import Deferred, deferrable
from ibis.expr.types.generic import Column, Scalar, Value

if TYPE_CHECKING:
Expand Down Expand Up @@ -358,13 +358,13 @@ def join(self, sep: str | ir.StringValue) -> ir.StringValue:
"""
return ops.ArrayStringJoin(sep, self).to_expr()

def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
"""Apply a callable `func` to each element of this array expression.
def map(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
"""Apply a `func` or `Deferred` to each element of this array expression.

Parameters
----------
func
Function to apply to each element of this array
Function or `Deferred` to apply to each element of this array.

Returns
-------
Expand All @@ -374,6 +374,7 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
Examples
--------
>>> import ibis
>>> from ibis import _
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"a": [[1, None, 2], [4], []]})
>>> t
Expand All @@ -386,6 +387,22 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
│ [4] │
│ [] │
└──────────────────────┘

NickCrews marked this conversation as resolved.
Show resolved Hide resolved
The most succinct way to use `map` is with `Deferred` expressions:

>>> t.a.map((_ + 100).cast("float"))
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a, Cast(Add(_, 100), float64)) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<float64> │
├─────────────────────────────────────────┤
│ [101.0, None, ... +1] │
│ [104.0] │
│ [] │
└─────────────────────────────────────────┘

You can also use `map` with a lambda function:

>>> t.a.map(lambda x: (x + 100).cast("float"))
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a, Cast(Add(x, 100), float64)) ┃
Expand Down Expand Up @@ -426,23 +443,28 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
│ [] │
└────────────────────────┘
"""
name = next(iter(inspect.signature(func).parameters.keys()))
if isinstance(func, Deferred):
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
name = "_"
NickCrews marked this conversation as resolved.
Show resolved Hide resolved
else:
name = next(iter(inspect.signature(func).parameters.keys()))
parameter = ops.Argument(
name=name, shape=self.op().shape, dtype=self.type().value_type
)
return ops.ArrayMap(
self, param=parameter.param, body=func(parameter.to_expr())
).to_expr()
if isinstance(func, Deferred):
body = func.resolve(parameter.to_expr())
else:
body = func(parameter.to_expr())
return ops.ArrayMap(self, param=parameter.param, body=body).to_expr()

def filter(
self, predicate: Callable[[ir.Value], bool | ir.BooleanValue]
self, predicate: Deferred | Callable[[ir.Value], bool | ir.BooleanValue]
) -> ir.ArrayValue:
"""Filter array elements using `predicate`.
"""Filter array elements using `predicate` function or `Deferred`.

Parameters
----------
predicate
Function to use to filter array elements
Function or `Deferred` to use to filter array elements

Returns
-------
Expand All @@ -452,6 +474,7 @@ def filter(
Examples
--------
>>> import ibis
>>> from ibis import _
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"a": [[1, None, 2], [4], []]})
>>> t
Expand All @@ -464,6 +487,22 @@ def filter(
│ [4] │
│ [] │
└──────────────────────┘

The most succinct way to use `filter` is with `Deferred` expressions:

>>> t.a.filter(_ > 1)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a, Greater(_, 1)) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├───────────────────────────────┤
│ [2] │
│ [4] │
│ [] │
└───────────────────────────────┘

You can also use `map` with a lambda function:

>>> t.a.filter(lambda x: x > 1)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a, Greater(x, 1)) ┃
Expand Down Expand Up @@ -504,15 +543,20 @@ def filter(
│ [] │
└───────────────────────────────┘
"""
name = next(iter(inspect.signature(predicate).parameters.keys()))
if isinstance(predicate, Deferred):
name = "_"
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
else:
name = next(iter(inspect.signature(predicate).parameters.keys()))
parameter = ops.Argument(
name=name,
shape=self.op().shape,
dtype=self.type().value_type,
)
return ops.ArrayFilter(
self, param=parameter.param, body=predicate(parameter.to_expr())
).to_expr()
if isinstance(predicate, Deferred):
body = predicate.resolve(parameter.to_expr())
else:
body = predicate(parameter.to_expr())
return ops.ArrayFilter(self, param=parameter.param, body=body).to_expr()

def contains(self, other: ir.Value) -> ir.BooleanValue:
"""Return whether the array contains `other`.
Expand Down
Loading