Skip to content

Commit

Permalink
feat(polars): implement new UDF API
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 14, 2023
1 parent 169d889 commit becbf41
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
31 changes: 29 additions & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,10 +1167,37 @@ def execute_count_distinct_star(op, **kw):
return arg.n_unique()


_UDF_INVOKERS = {
# Convert polars series into a list
# -> map the function element by element
# -> convert back to a polars series
InputType.PYTHON: lambda func, dtype, args: pl.Series(
map(func, *(arg.to_list() for arg in args)),
dtype=dtype_to_polars(dtype),
),
# Convert polars series into a pyarrow array
# -> invoke the function on the pyarrow array
# -> cast the result to match the ibis dtype
# -> convert back to a polars series
InputType.PYARROW: lambda func, dtype, args: pl.from_arrow(
func(*(arg.to_arrow() for arg in args)).cast(dtype.to_pyarrow()),
),
}


@translate.register(ops.ScalarUDF)
def execute_scalar_udf(op, **kw):
if op.__input_type__ == InputType.BUILTIN:
if (input_type := op.__input_type__) in (InputType.PYARROW, InputType.PYTHON):
dtype = op.dtype
return pl.map_batches(
exprs=[translate(arg, **kw) for arg in op.args],
function=partial(_UDF_INVOKERS[input_type], op.__func__, dtype),
return_dtype=dtype_to_polars(dtype),
)
elif input_type == InputType.BUILTIN:
first, *rest = map(translate, op.args)
return getattr(first, op.__func_name__)(*rest)
else:
raise NotImplementedError("Only builtin scalar UDFs are supported for polars")
raise NotImplementedError(
f"UDF input type {input_type} not supported for Polars"
)
43 changes: 43 additions & 0 deletions ibis/backends/polars/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,46 @@ def cbrt(a: float) -> float:
expr = cbrt(value)
result = con.execute(expr)
assert pytest.approx(result) == expected


@udf.scalar.pyarrow
def string_length(x: str) -> int:
return pc.cast(pc.multiply(pc.utf8_length(x), 2), target_type="int64")


@udf.scalar.python
def string_length_python(x: str) -> int:
return len(x) * 2


@udf.scalar.pyarrow
def add(x: int, y: int) -> int:
return pc.add(x, y)


@udf.scalar.python
def add_python(x: int, y: int) -> int:
return x + y


@pytest.mark.parametrize("func", [string_length, string_length_python])
def test_scalar_udf(alltypes, func):
data_string_col = alltypes.date_string_col.execute()
expected = data_string_col.str.len() * 2

expr = func(alltypes.date_string_col)
assert isinstance(expr, ir.Column)

result = expr.execute()
tm.assert_series_equal(result, expected, check_names=False)


@pytest.mark.parametrize("func", [add, add_python])
def test_multiple_argument_scalar_udf(alltypes, func):
expr = func(alltypes.smallint_col, alltypes.int_col).name("tmp")
result = expr.execute()

df = alltypes[["smallint_col", "int_col"]].execute()
expected = (df.smallint_col + df.int_col).astype("int64")

tm.assert_series_equal(result, expected.rename("tmp"))
5 changes: 3 additions & 2 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"mysql",
"oracle",
"pandas",
"polars",
"pyspark",
"trino",
]
Expand Down Expand Up @@ -51,6 +50,7 @@ def num_vowels(s: str, include_y: bool = False) -> int:
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(
["sqlite"],
Expand Down Expand Up @@ -80,6 +80,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types")
def test_map_merge_udf(batting):
Expand Down Expand Up @@ -149,7 +150,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pandas,
marks=[
mark.notyet(
["duckdb", "datafusion", "sqlite"],
["duckdb", "datafusion", "polars", "sqlite"],
raises=NotImplementedError,
reason="backend doesn't support pandas UDFs",
),
Expand Down

0 comments on commit becbf41

Please sign in to comment.