Skip to content

Commit

Permalink
feat(pyspark): add support for pyarrow and python UDFs (#9753)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
jstammers and cpcloud authored Aug 2, 2024
1 parent f8bea7e commit 02a1d48
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
28 changes: 18 additions & 10 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ibis.expr.api import Watermark

PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4")

PYSPARK_LT_35 = vparse(pyspark.__version__) < vparse("3.5")
ConnectionMode = Literal["streaming", "batch"]


Expand Down Expand Up @@ -359,18 +359,26 @@ def wrapper(*args):
def _register_udfs(self, expr: ir.Expr) -> None:
node = expr.op()
for udf in node.find(ops.ScalarUDF):
if udf.__input_type__ not in (InputType.PANDAS, InputType.BUILTIN):
raise NotImplementedError(
"Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend"
)
# register pandas UDFs
udf_name = self.compiler.__sql_name__(udf)
udf_return = PySparkType.from_ibis(udf.dtype)
if udf.__input_type__ == InputType.PANDAS:
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
udf_return = PySparkType.from_ibis(udf.dtype)
spark_udf = F.pandas_udf(udf_func, udf_return, F.PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)

elif udf.__input_type__ == InputType.PYTHON:
udf_func = udf.__func__
spark_udf = F.udf(udf_func, udf_return)
elif udf.__input_type__ == InputType.PYARROW:
# raise not implemented error if running on pyspark < 3.5
if PYSPARK_LT_35:
raise NotImplementedError(
"pyarrow UDFs are only supported in pyspark >= 3.5"
)
udf_func = udf.__func__
spark_udf = F.udf(udf_func, udf_return, useArrow=True)
else:
# Builtin functions don't need to be registered
continue
self._session.udf.register(udf_name, spark_udf)
for udf in node.find(ops.ElementWiseVectorizedUDF):
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.func, udf.return_type)
Expand Down
27 changes: 26 additions & 1 deletion ibis/backends/pyspark/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import ibis
from ibis.backends.pyspark import PYSPARK_LT_35

pytest.importorskip("pyspark")

Expand All @@ -22,12 +23,36 @@ def df(con):
def repeat(x, n) -> str: ...


@ibis.udf.scalar.python
def py_repeat(x: str, n: int) -> str:
return x * n


@ibis.udf.scalar.pyarrow
def pyarrow_repeat(x: str, n: int) -> str:
return x * n


def test_builtin_udf(t, df):
result = t.mutate(repeated=repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)


def test_python_udf(t, df):
result = t.mutate(repeated=py_repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+")
def test_pyarrow_udf(t, df):
result = t.mutate(repeated=pyarrow_repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(not PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+")
def test_illegal_udf_type(t):
@ibis.udf.scalar.pyarrow
def my_add_one(x) -> str:
Expand All @@ -39,6 +64,6 @@ def my_add_one(x) -> str:

with pytest.raises(
NotImplementedError,
match="Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend",
match="pyarrow UDFs are only supported in pyspark >= 3.5",
):
expr.execute()
8 changes: 3 additions & 5 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

@no_python_udfs
@cloudpickle_version_mismatch
@mark.notimpl(["pyspark"])
@mark.notyet(["datafusion"], raises=NotImplementedError)
def test_udf(batting):
@udf.scalar.python
Expand All @@ -59,7 +58,6 @@ def num_vowels(s: str, include_y: bool = False) -> int:

@no_python_udfs
@cloudpickle_version_mismatch
@mark.notimpl(["pyspark"])
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
Expand Down Expand Up @@ -89,7 +87,6 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:

@no_python_udfs
@cloudpickle_version_mismatch
@mark.notimpl(["pyspark"])
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
Expand Down Expand Up @@ -174,10 +171,11 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pyarrow,
marks=[
mark.notyet(
["snowflake", "sqlite", "pyspark", "flink"],
["snowflake", "sqlite", "flink"],
raises=NotImplementedError,
reason="backend doesn't support pyarrow UDFs",
)
),
mark.xfail_version(pyspark=["pyspark<3.5"]),
],
),
],
Expand Down

0 comments on commit 02a1d48

Please sign in to comment.