From c383f62cc699a548640099160ef9093aace5d203 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 13 Sep 2023 06:44:41 -0400 Subject: [PATCH] feat(polars): implement support for builtin aggregate udfs --- ibis/backends/polars/compiler.py | 9 +++++++++ ibis/backends/polars/tests/test_udf.py | 17 ++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index bc69f1cdfc90..37d03f41377e 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1202,3 +1202,12 @@ def execute_scalar_udf(op, **kw): raise NotImplementedError( f"UDF input type {input_type} not supported for Polars" ) + + +@translate.register(ops.AggUDF) +def execute_agg_udf(op, **kw): + args = (arg for name, arg in zip(op.argnames, op.args) if name != "where") + first, *rest = map(partial(translate, **kw), args) + if (where := op.where) is not None: + first = first.filter(translate(where, **kw)) + return getattr(first, op.__func_name__)(*rest) diff --git a/ibis/backends/polars/tests/test_udf.py b/ibis/backends/polars/tests/test_udf.py index fe73b5168b79..9d9aa696ff78 100644 --- a/ibis/backends/polars/tests/test_udf.py +++ b/ibis/backends/polars/tests/test_udf.py @@ -54,7 +54,7 @@ def test_multiple_argument_udf(alltypes): @pytest.mark.parametrize( ("value", "expected"), [(8, 2), (27, 3), (7, 7 ** (1.0 / 3.0))] ) -def test_builtin(con, value, expected): +def test_builtin_scalar_udf(con, value, expected): @udf.scalar.builtin def cbrt(a: float) -> float: ... @@ -105,3 +105,18 @@ def test_multiple_argument_scalar_udf(alltypes, func): expected = (df.smallint_col + df.int_col).astype("int64") tm.assert_series_equal(result, expected.rename("tmp")) + + +def test_builtin_agg_udf(con): + @udf.agg.builtin + def approx_n_unique(a, where: bool = True) -> int: + ... + + ft = con.tables.functional_alltypes + expr = approx_n_unique(ft.string_col) + result = con.execute(expr) + assert result == 10 + + expr = approx_n_unique(ft.string_col, where=ft.string_col == "1") + result = con.execute(expr) + assert result == 1