Skip to content

Commit

Permalink
feat(duckdb): implement ops.ArrayMap
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Feb 16, 2023
1 parent 49e5f7a commit 063602d
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
23 changes: 22 additions & 1 deletion ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import operator
from functools import partial
from typing import Any, Mapping

import numpy as np
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.sql.functions import FunctionElement, GenericFunction

import ibis.expr.operations as ops
from ibis.backends.base.sql import alchemy
Expand Down Expand Up @@ -224,6 +225,24 @@ def _translate_case(t, op, *, value):
)


class list_apply(FunctionElement):
pass


@compiles(list_apply, "duckdb")
def compiles_list_apply(element, compiler, **kw):
*args, signature, result = map(partial(compiler.process, **kw), element.clauses)
return f"list_apply({', '.join(args)}, {signature} -> {result})"


def _array_map(t, op):
return list_apply(
t.translate(op.arg),
sa.literal_column(f"({', '.join(op.signature)})"),
t.translate(op.result),
)


operation_registry.update(
{
ops.ArrayColumn: (
Expand Down Expand Up @@ -309,6 +328,8 @@ def _translate_case(t, op, *, value):
ops.SimpleCase: _simple_case,
ops.StartsWith: fixed_arity(sa.func.prefix, 2),
ops.EndsWith: fixed_arity(sa.func.suffix, 2),
ops.ArrayMap: _array_map,
ops.Argument: lambda _, op: sa.literal_column(op.name),
}
)

Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,11 @@ def test_array_slice(con, start, stop):
{'sliced': array_types.y.execute().map(lambda x: x[start:stop])}
)
tm.assert_frame_equal(result, expected)


def test_array_map(backend):
t = ibis.memtable({"a": [[1, None, 2], None, [4]]})
expr = t.select(a=t.a.map(lambda x: x + 1))
result = expr.execute()
expected = pd.DataFrame({"a": [[2, None, 3], None, [5]]})
backend.assert_frame_equal(result, expected)
17 changes: 13 additions & 4 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from typing import Callable

from public import public
Expand Down Expand Up @@ -101,12 +102,20 @@ class ArrayMap(Value):

@attribute.default
def result(self):
shape = rlz.Shape.COLUMNAR
dtype = self.arg.output_dtype.value_type
arg = Argument(shape=shape, dtype=dtype).to_expr()
expr = self.func(arg)
arg = self.arg
shape = arg.output_shape
dtype = arg.output_dtype.value_type
args = [
Argument(name=name, shape=shape, dtype=dtype).to_expr()
for name in self.signature
]
expr = self.func(*args)
return expr.op()

@property
def signature(self):
return list(inspect.signature(self.func).parameters.keys())

@attribute.default
def output_dtype(self):
return dt.Array(self.result.output_dtype)
Expand Down
1 change: 1 addition & 0 deletions ibis/expr/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def to_expr(self):

@public
class Argument(Value):
name = rlz.instance_of(str)
shape = rlz.instance_of(rlz.Shape)
dtype = rlz.datatype

Expand Down

0 comments on commit 063602d

Please sign in to comment.