Skip to content

Commit

Permalink
feat(pyspark): builtin udf support (#9191)
Browse files Browse the repository at this point in the history
Co-authored-by: 祖弼 <ning.ln@alipay.com>
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
3 people authored May 14, 2024
1 parent 49ecf8d commit 142c105
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
16 changes: 9 additions & 7 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,17 @@ def wrapper(*args):
def _register_udfs(self, expr: ir.Expr) -> None:
node = expr.op()
for udf in node.find(ops.ScalarUDF):
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
udf_return = PySparkType.from_ibis(udf.dtype)
if udf.__input_type__ != InputType.PANDAS:
if udf.__input_type__ not in (InputType.PANDAS, InputType.BUILTIN):
raise NotImplementedError(
"Only Pandas UDFs are support in the PySpark backend"
"Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend"
)
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)
# register pandas UDFs
if udf.__input_type__ == InputType.PANDAS:
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
udf_return = PySparkType.from_ibis(udf.dtype)
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)

for udf in node.find(ops.ElementWiseVectorizedUDF):
udf_name = self.compiler.__sql_name__(udf)
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.common.patterns import replace
from ibis.config import options
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name


Expand Down Expand Up @@ -327,6 +328,10 @@ def __sql_name__(self, op) -> str:
else:
raise TypeError(f"Cannot get SQL name for {type(op).__name__}")

# builtin functions will not modify the name
if getattr(op, "__input_type__", None) == InputType.BUILTIN:
return name

if not name.isidentifier():
# replace invalid characters with underscores
name = re.sub("[^0-9a-zA-Z_]", "", name)
Expand Down
44 changes: 44 additions & 0 deletions ibis/backends/pyspark/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

import pandas.testing as tm
import pytest

import ibis

pytest.importorskip("pyspark")


@pytest.fixture
def t(con):
return con.table("basic_table")


@pytest.fixture
def df(con):
return con._session.table("basic_table").toPandas()


@ibis.udf.scalar.builtin
def repeat(x, n) -> str: ...


def test_builtin_udf(t, df):
result = t.mutate(repeated=repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)


def test_illegal_udf_type(t):
@ibis.udf.scalar.pyarrow
def my_add_one(x) -> str:
import pyarrow.compute as pac

return pac.add(pac.binary_length(x), 1)

expr = t.select(repeated=my_add_one(t.str_col))

with pytest.raises(
NotImplementedError,
match="Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend",
):
expr.execute()

0 comments on commit 142c105

Please sign in to comment.