From 02a1d48bb3284303c56fde6e5c5b5a16fef2dc9c Mon Sep 17 00:00:00 2001 From: Jimmy Stammers Date: Fri, 2 Aug 2024 16:39:30 +0100 Subject: [PATCH] feat(pyspark): add support for pyarrow and python UDFs (#9753) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- ibis/backends/pyspark/__init__.py | 28 ++++++++++++++++--------- ibis/backends/pyspark/tests/test_udf.py | 27 +++++++++++++++++++++++- ibis/backends/tests/test_udf.py | 8 +++---- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index fc20b2bb9a39..1226d1110f74 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -46,7 +46,7 @@ from ibis.expr.api import Watermark PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4") - +PYSPARK_LT_35 = vparse(pyspark.__version__) < vparse("3.5") ConnectionMode = Literal["streaming", "batch"] @@ -359,18 +359,26 @@ def wrapper(*args): def _register_udfs(self, expr: ir.Expr) -> None: node = expr.op() for udf in node.find(ops.ScalarUDF): - if udf.__input_type__ not in (InputType.PANDAS, InputType.BUILTIN): - raise NotImplementedError( - "Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend" - ) - # register pandas UDFs + udf_name = self.compiler.__sql_name__(udf) + udf_return = PySparkType.from_ibis(udf.dtype) if udf.__input_type__ == InputType.PANDAS: - udf_name = self.compiler.__sql_name__(udf) udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype) - udf_return = PySparkType.from_ibis(udf.dtype) spark_udf = F.pandas_udf(udf_func, udf_return, F.PandasUDFType.SCALAR) - self._session.udf.register(udf_name, spark_udf) - + elif udf.__input_type__ == InputType.PYTHON: + udf_func = udf.__func__ + spark_udf = F.udf(udf_func, udf_return) + elif udf.__input_type__ == InputType.PYARROW: + # raise not implemented error if running on pyspark < 3.5 + if PYSPARK_LT_35: + raise NotImplementedError( + "pyarrow UDFs are only supported in pyspark >= 3.5" + ) + udf_func = udf.__func__ + spark_udf = F.udf(udf_func, udf_return, useArrow=True) + else: + # Builtin functions don't need to be registered + continue + self._session.udf.register(udf_name, spark_udf) for udf in node.find(ops.ElementWiseVectorizedUDF): udf_name = self.compiler.__sql_name__(udf) udf_func = self._wrap_udf_to_return_pandas(udf.func, udf.return_type) diff --git a/ibis/backends/pyspark/tests/test_udf.py b/ibis/backends/pyspark/tests/test_udf.py index d0f92d836a5c..d5c80ac27c35 100644 --- a/ibis/backends/pyspark/tests/test_udf.py +++ b/ibis/backends/pyspark/tests/test_udf.py @@ -4,6 +4,7 @@ import pytest import ibis +from ibis.backends.pyspark import PYSPARK_LT_35 pytest.importorskip("pyspark") @@ -22,12 +23,36 @@ def df(con): def repeat(x, n) -> str: ... +@ibis.udf.scalar.python +def py_repeat(x: str, n: int) -> str: + return x * n + + +@ibis.udf.scalar.pyarrow +def pyarrow_repeat(x: str, n: int) -> str: + return x * n + + def test_builtin_udf(t, df): result = t.mutate(repeated=repeat(t.str_col, 2)).execute() expected = df.assign(repeated=df.str_col * 2) tm.assert_frame_equal(result, expected) +def test_python_udf(t, df): + result = t.mutate(repeated=py_repeat(t.str_col, 2)).execute() + expected = df.assign(repeated=df.str_col * 2) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.xfail(PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+") +def test_pyarrow_udf(t, df): + result = t.mutate(repeated=pyarrow_repeat(t.str_col, 2)).execute() + expected = df.assign(repeated=df.str_col * 2) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.xfail(not PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+") def test_illegal_udf_type(t): @ibis.udf.scalar.pyarrow def my_add_one(x) -> str: @@ -39,6 +64,6 @@ def my_add_one(x) -> str: with pytest.raises( NotImplementedError, - match="Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend", + match="pyarrow UDFs are only supported in pyspark >= 3.5", ): expr.execute() diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index fa5d7f34e7f0..4fc2e8898cfe 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -34,7 +34,6 @@ @no_python_udfs @cloudpickle_version_mismatch -@mark.notimpl(["pyspark"]) @mark.notyet(["datafusion"], raises=NotImplementedError) def test_udf(batting): @udf.scalar.python @@ -59,7 +58,6 @@ def num_vowels(s: str, include_y: bool = False) -> int: @no_python_udfs @cloudpickle_version_mismatch -@mark.notimpl(["pyspark"]) @mark.notyet( ["postgres"], raises=TypeError, reason="postgres only supports map" ) @@ -89,7 +87,6 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]: @no_python_udfs @cloudpickle_version_mismatch -@mark.notimpl(["pyspark"]) @mark.notyet( ["postgres"], raises=TypeError, reason="postgres only supports map" ) @@ -174,10 +171,11 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type add_one_pyarrow, marks=[ mark.notyet( - ["snowflake", "sqlite", "pyspark", "flink"], + ["snowflake", "sqlite", "flink"], raises=NotImplementedError, reason="backend doesn't support pyarrow UDFs", - ) + ), + mark.xfail_version(pyspark=["pyspark<3.5"]), ], ), ],