From becbf411872336675eb642f9bd50529c4ed6aaa8 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 14 Sep 2023 09:49:45 -0400 Subject: [PATCH] feat(polars): implement new UDF API --- ibis/backends/polars/compiler.py | 31 +++++++++++++++++-- ibis/backends/polars/tests/test_udf.py | 43 ++++++++++++++++++++++++++ ibis/backends/tests/test_udf.py | 5 +-- 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 31d7a6ca12c7..fcff20e46959 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1167,10 +1167,37 @@ def execute_count_distinct_star(op, **kw): return arg.n_unique() +_UDF_INVOKERS = { + # Convert polars series into a list + # -> map the function element by element + # -> convert back to a polars series + InputType.PYTHON: lambda func, dtype, args: pl.Series( + map(func, *(arg.to_list() for arg in args)), + dtype=dtype_to_polars(dtype), + ), + # Convert polars series into a pyarrow array + # -> invoke the function on the pyarrow array + # -> cast the result to match the ibis dtype + # -> convert back to a polars series + InputType.PYARROW: lambda func, dtype, args: pl.from_arrow( + func(*(arg.to_arrow() for arg in args)).cast(dtype.to_pyarrow()), + ), +} + + @translate.register(ops.ScalarUDF) def execute_scalar_udf(op, **kw): - if op.__input_type__ == InputType.BUILTIN: + if (input_type := op.__input_type__) in (InputType.PYARROW, InputType.PYTHON): + dtype = op.dtype + return pl.map_batches( + exprs=[translate(arg, **kw) for arg in op.args], + function=partial(_UDF_INVOKERS[input_type], op.__func__, dtype), + return_dtype=dtype_to_polars(dtype), + ) + elif input_type == InputType.BUILTIN: first, *rest = map(translate, op.args) return getattr(first, op.__func_name__)(*rest) else: - raise NotImplementedError("Only builtin scalar UDFs are supported for polars") + raise NotImplementedError( + f"UDF input type {input_type} not supported for Polars" + ) diff --git a/ibis/backends/polars/tests/test_udf.py b/ibis/backends/polars/tests/test_udf.py index 8693c2a14797..fe73b5168b79 100644 --- a/ibis/backends/polars/tests/test_udf.py +++ b/ibis/backends/polars/tests/test_udf.py @@ -62,3 +62,46 @@ def cbrt(a: float) -> float: expr = cbrt(value) result = con.execute(expr) assert pytest.approx(result) == expected + + +@udf.scalar.pyarrow +def string_length(x: str) -> int: + return pc.cast(pc.multiply(pc.utf8_length(x), 2), target_type="int64") + + +@udf.scalar.python +def string_length_python(x: str) -> int: + return len(x) * 2 + + +@udf.scalar.pyarrow +def add(x: int, y: int) -> int: + return pc.add(x, y) + + +@udf.scalar.python +def add_python(x: int, y: int) -> int: + return x + y + + +@pytest.mark.parametrize("func", [string_length, string_length_python]) +def test_scalar_udf(alltypes, func): + data_string_col = alltypes.date_string_col.execute() + expected = data_string_col.str.len() * 2 + + expr = func(alltypes.date_string_col) + assert isinstance(expr, ir.Column) + + result = expr.execute() + tm.assert_series_equal(result, expected, check_names=False) + + +@pytest.mark.parametrize("func", [add, add_python]) +def test_multiple_argument_scalar_udf(alltypes, func): + expr = func(alltypes.smallint_col, alltypes.int_col).name("tmp") + result = expr.execute() + + df = alltypes[["smallint_col", "int_col"]].execute() + expected = (df.smallint_col + df.int_col).astype("int64") + + tm.assert_series_equal(result, expected.rename("tmp")) diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index d09a37f3eaa0..bb12383d2a81 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -17,7 +17,6 @@ "mysql", "oracle", "pandas", - "polars", "pyspark", "trino", ] @@ -51,6 +50,7 @@ def num_vowels(s: str, include_y: bool = False) -> int: @mark.notyet( ["postgres"], raises=TypeError, reason="postgres only supports map" ) +@mark.notimpl(["polars"]) @mark.notyet(["datafusion"], raises=NotImplementedError) @mark.notyet( ["sqlite"], @@ -80,6 +80,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]: @mark.notyet( ["postgres"], raises=TypeError, reason="postgres only supports map" ) +@mark.notimpl(["polars"]) @mark.notyet(["datafusion"], raises=NotImplementedError) @mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types") def test_map_merge_udf(batting): @@ -149,7 +150,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type add_one_pandas, marks=[ mark.notyet( - ["duckdb", "datafusion", "sqlite"], + ["duckdb", "datafusion", "polars", "sqlite"], raises=NotImplementedError, reason="backend doesn't support pandas UDFs", ),