diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index b3e5e99cbaaf..f67661e54396 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -448,23 +448,25 @@ def test_array_slice(backend, start, stop): param({"a": [[1, 2], [4]]}, {"a": [[2, 3], [5]]}, id="no_nulls"), ], ) +@pytest.mark.parametrize( + "func", + [ + lambda x: x + 1, + functools.partial(lambda x, y: x + y, y=1), + ibis._ + 1, + ], +) @pytest.mark.broken( ["risingwave"], raises=AssertionError, reason="TODO(Kexiang): seems a bug", ) -def test_array_map(con, input, output): +def test_array_map(con, input, output, func): t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) expected = pd.Series(output["a"]) - expr = t.select(a=t.a.map(lambda x: x + 1)) - result = con.execute(expr.a) - assert frozenset(map(tuple, result.values)) == frozenset( - map(tuple, expected.values) - ) - - expr = t.select(a=t.a.map(functools.partial(lambda x, y: x + y, y=1))) + expr = t.select(a=t.a.map(func)) result = con.execute(expr.a) assert frozenset(map(tuple, result.values)) == frozenset( map(tuple, expected.values) @@ -512,17 +514,19 @@ def test_array_map(con, input, output): param({"a": [[1, 2], [4]]}, {"a": [[2], [4]]}, id="no_nulls"), ], ) -def test_array_filter(con, input, output): +@pytest.mark.parametrize( + "predicate", + [ + lambda x: x > 1, + functools.partial(lambda x, y: x > y, y=1), + ibis._ > 1, + ], +) +def test_array_filter(con, input, output, predicate): t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) expected = pd.Series(output["a"]) - expr = t.select(a=t.a.filter(lambda x: x > 1)) - result = con.execute(expr.a) - assert frozenset(map(tuple, result.values)) == frozenset( - map(tuple, expected.values) - ) - - expr = t.select(a=t.a.filter(functools.partial(lambda x, y: x > y, y=1))) + expr = t.select(a=t.a.filter(predicate)) result = con.execute(expr.a) assert frozenset(map(tuple, result.values)) == frozenset( map(tuple, expected.values) diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 17a1c84e6f3a..add931bfeb2d 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -6,7 +6,7 @@ from public import public import ibis.expr.operations as ops -from ibis.common.deferred import deferrable +from ibis.common.deferred import Deferred, deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: @@ -358,13 +358,13 @@ def join(self, sep: str | ir.StringValue) -> ir.StringValue: """ return ops.ArrayStringJoin(sep, self).to_expr() - def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: - """Apply a callable `func` to each element of this array expression. + def map(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: + """Apply a `func` or `Deferred` to each element of this array expression. Parameters ---------- func - Function to apply to each element of this array + Function or `Deferred` to apply to each element of this array. Returns ------- @@ -374,6 +374,7 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: Examples -------- >>> import ibis + >>> from ibis import _ >>> ibis.options.interactive = True >>> t = ibis.memtable({"a": [[1, None, 2], [4], []]}) >>> t @@ -386,6 +387,22 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: │ [4] │ │ [] │ └──────────────────────┘ + + The most succinct way to use `map` is with `Deferred` expressions: + + >>> t.a.map((_ + 100).cast("float")) + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayMap(a, Cast(Add(_, 100), float64)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├─────────────────────────────────────────┤ + │ [101.0, None, ... +1] │ + │ [104.0] │ + │ [] │ + └─────────────────────────────────────────┘ + + You can also use `map` with a lambda function: + >>> t.a.map(lambda x: (x + 100).cast("float")) ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ ArrayMap(a, Cast(Add(x, 100), float64)) ┃ @@ -426,23 +443,28 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: │ [] │ └────────────────────────┘ """ - name = next(iter(inspect.signature(func).parameters.keys())) + if isinstance(func, Deferred): + name = "_" + else: + name = next(iter(inspect.signature(func).parameters.keys())) parameter = ops.Argument( name=name, shape=self.op().shape, dtype=self.type().value_type ) - return ops.ArrayMap( - self, param=parameter.param, body=func(parameter.to_expr()) - ).to_expr() + if isinstance(func, Deferred): + body = func.resolve(parameter.to_expr()) + else: + body = func(parameter.to_expr()) + return ops.ArrayMap(self, param=parameter.param, body=body).to_expr() def filter( - self, predicate: Callable[[ir.Value], bool | ir.BooleanValue] + self, predicate: Deferred | Callable[[ir.Value], bool | ir.BooleanValue] ) -> ir.ArrayValue: - """Filter array elements using `predicate`. + """Filter array elements using `predicate` function or `Deferred`. Parameters ---------- predicate - Function to use to filter array elements + Function or `Deferred` to use to filter array elements Returns ------- @@ -452,6 +474,7 @@ def filter( Examples -------- >>> import ibis + >>> from ibis import _ >>> ibis.options.interactive = True >>> t = ibis.memtable({"a": [[1, None, 2], [4], []]}) >>> t @@ -464,6 +487,22 @@ def filter( │ [4] │ │ [] │ └──────────────────────┘ + + The most succinct way to use `filter` is with `Deferred` expressions: + + >>> t.a.filter(_ > 1) + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ ArrayFilter(a, Greater(_, 1)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├───────────────────────────────┤ + │ [2] │ + │ [4] │ + │ [] │ + └───────────────────────────────┘ + + You can also use `map` with a lambda function: + >>> t.a.filter(lambda x: x > 1) ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ ArrayFilter(a, Greater(x, 1)) ┃ @@ -504,15 +543,20 @@ def filter( │ [] │ └───────────────────────────────┘ """ - name = next(iter(inspect.signature(predicate).parameters.keys())) + if isinstance(predicate, Deferred): + name = "_" + else: + name = next(iter(inspect.signature(predicate).parameters.keys())) parameter = ops.Argument( name=name, shape=self.op().shape, dtype=self.type().value_type, ) - return ops.ArrayFilter( - self, param=parameter.param, body=predicate(parameter.to_expr()) - ).to_expr() + if isinstance(predicate, Deferred): + body = predicate.resolve(parameter.to_expr()) + else: + body = predicate(parameter.to_expr()) + return ops.ArrayFilter(self, param=parameter.param, body=body).to_expr() def contains(self, other: ir.Value) -> ir.BooleanValue: """Return whether the array contains `other`.