Skip to content

Commit

Permalink
feat(polars): implement support for builtin aggregate udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 14, 2023
1 parent 0367069 commit c383f62
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
9 changes: 9 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,3 +1202,12 @@ def execute_scalar_udf(op, **kw):
raise NotImplementedError(
f"UDF input type {input_type} not supported for Polars"
)


@translate.register(ops.AggUDF)
def execute_agg_udf(op, **kw):
args = (arg for name, arg in zip(op.argnames, op.args) if name != "where")
first, *rest = map(partial(translate, **kw), args)
if (where := op.where) is not None:
first = first.filter(translate(where, **kw))
return getattr(first, op.__func_name__)(*rest)
17 changes: 16 additions & 1 deletion ibis/backends/polars/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_multiple_argument_udf(alltypes):
@pytest.mark.parametrize(
("value", "expected"), [(8, 2), (27, 3), (7, 7 ** (1.0 / 3.0))]
)
def test_builtin(con, value, expected):
def test_builtin_scalar_udf(con, value, expected):
@udf.scalar.builtin
def cbrt(a: float) -> float:
...
Expand Down Expand Up @@ -105,3 +105,18 @@ def test_multiple_argument_scalar_udf(alltypes, func):
expected = (df.smallint_col + df.int_col).astype("int64")

tm.assert_series_equal(result, expected.rename("tmp"))


def test_builtin_agg_udf(con):
@udf.agg.builtin
def approx_n_unique(a, where: bool = True) -> int:
...

ft = con.tables.functional_alltypes
expr = approx_n_unique(ft.string_col)
result = con.execute(expr)
assert result == 10

expr = approx_n_unique(ft.string_col, where=ft.string_col == "1")
result = con.execute(expr)
assert result == 1

0 comments on commit c383f62

Please sign in to comment.