diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 0b175647d3ef..8ef4ed7af09e 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -275,15 +275,17 @@ def wrapper(*args): def _register_udfs(self, expr: ir.Expr) -> None: node = expr.op() for udf in node.find(ops.ScalarUDF): - udf_name = self.compiler.__sql_name__(udf) - udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype) - udf_return = PySparkType.from_ibis(udf.dtype) - if udf.__input_type__ != InputType.PANDAS: + if udf.__input_type__ not in (InputType.PANDAS, InputType.BUILTIN): raise NotImplementedError( - "Only Pandas UDFs are support in the PySpark backend" + "Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend" ) - spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR) - self._session.udf.register(udf_name, spark_udf) + # register pandas UDFs + if udf.__input_type__ == InputType.PANDAS: + udf_name = self.compiler.__sql_name__(udf) + udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype) + udf_return = PySparkType.from_ibis(udf.dtype) + spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR) + self._session.udf.register(udf_name, spark_udf) for udf in node.find(ops.ElementWiseVectorizedUDF): udf_name = self.compiler.__sql_name__(udf) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 27525f169296..b7b3ca347471 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -18,6 +18,7 @@ from ibis.backends.sql.rewrites import FirstValue, LastValue, p from ibis.common.patterns import replace from ibis.config import options +from ibis.expr.operations.udf import InputType from ibis.util import gen_name @@ -327,6 +328,10 @@ def __sql_name__(self, op) -> str: else: raise TypeError(f"Cannot get SQL name for {type(op).__name__}") + # builtin functions will not modify the name + if getattr(op, "__input_type__", None) == InputType.BUILTIN: + return name + if not name.isidentifier(): # replace invalid characters with underscores name = re.sub("[^0-9a-zA-Z_]", "", name) diff --git a/ibis/backends/pyspark/tests/test_udf.py b/ibis/backends/pyspark/tests/test_udf.py new file mode 100644 index 000000000000..d0f92d836a5c --- /dev/null +++ b/ibis/backends/pyspark/tests/test_udf.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pandas.testing as tm +import pytest + +import ibis + +pytest.importorskip("pyspark") + + +@pytest.fixture +def t(con): + return con.table("basic_table") + + +@pytest.fixture +def df(con): + return con._session.table("basic_table").toPandas() + + +@ibis.udf.scalar.builtin +def repeat(x, n) -> str: ... + + +def test_builtin_udf(t, df): + result = t.mutate(repeated=repeat(t.str_col, 2)).execute() + expected = df.assign(repeated=df.str_col * 2) + tm.assert_frame_equal(result, expected) + + +def test_illegal_udf_type(t): + @ibis.udf.scalar.pyarrow + def my_add_one(x) -> str: + import pyarrow.compute as pac + + return pac.add(pac.binary_length(x), 1) + + expr = t.select(repeated=my_add_one(t.str_col)) + + with pytest.raises( + NotImplementedError, + match="Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend", + ): + expr.execute()