Skip to content

Commit

Permalink
feat(pyspark): builtin udf support
Browse files Browse the repository at this point in the history
  • Loading branch information
祖弼 committed May 14, 2024
1 parent 49c6ce3 commit 7591e29
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
16 changes: 9 additions & 7 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,17 @@ def wrapper(*args):
def _register_udfs(self, expr: ir.Expr) -> None:
node = expr.op()
for udf in node.find(ops.ScalarUDF):
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
udf_return = PySparkType.from_ibis(udf.dtype)
if udf.__input_type__ != InputType.PANDAS:
if udf.__input_type__ not in (InputType.PANDAS, InputType.BUILTIN):

Check warning on line 278 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L278

Added line #L278 was not covered by tests
raise NotImplementedError(
"Only Pandas UDFs are support in the PySpark backend"
"Only Builtin UDFs and Pandas UDFs are support in the PySpark backend"
)
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)
# register pandas UDFs
if udf.__input_type__ == InputType.PANDAS:
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
udf_return = PySparkType.from_ibis(udf.dtype)
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)

Check warning on line 288 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L284-L288

Added lines #L284 - L288 were not covered by tests

for udf in node.find(ops.ElementWiseVectorizedUDF):
udf_name = self.compiler.__sql_name__(udf)
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.common.patterns import replace
from ibis.config import options
from ibis.expr.operations.udf import InputType

Check warning on line 21 in ibis/backends/pyspark/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/compiler.py#L21

Added line #L21 was not covered by tests
from ibis.util import gen_name


Expand Down Expand Up @@ -327,6 +328,10 @@ def __sql_name__(self, op) -> str:
else:
raise TypeError(f"Cannot get SQL name for {type(op).__name__}")

# builtin functions will not modify the name
if hasattr(op, "__input_type__") and op.__input_type__ == InputType.BUILTIN:
return name

Check warning on line 333 in ibis/backends/pyspark/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/compiler.py#L333

Added line #L333 was not covered by tests

if not name.isidentifier():
# replace invalid characters with underscores
name = re.sub("[^0-9a-zA-Z_]", "", name)
Expand Down
39 changes: 39 additions & 0 deletions ibis/backends/pyspark/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

Check warning on line 1 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L1

Added line #L1 was not covered by tests

import pandas.testing as tm
import pytest

Check warning on line 4 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L3-L4

Added lines #L3 - L4 were not covered by tests

import ibis

Check warning on line 6 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L6

Added line #L6 was not covered by tests

pytest.importorskip("pyspark")

Check warning on line 8 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L8

Added line #L8 was not covered by tests


@pytest.fixture

Check warning on line 11 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L11

Added line #L11 was not covered by tests
def t(con):
return con.table("basic_table")

Check warning on line 13 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L13

Added line #L13 was not covered by tests


@pytest.fixture

Check warning on line 16 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L16

Added line #L16 was not covered by tests
def df(con):
return con._session.table("basic_table").toPandas()

Check warning on line 18 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L18

Added line #L18 was not covered by tests


def test_builtin_udf(t, df):
@ibis.udf.scalar.builtin

Check warning on line 22 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L21-L22

Added lines #L21 - L22 were not covered by tests
def repeat(x, n) -> str: ...

result = t.mutate(repeated=repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)

Check warning on line 27 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L25-L27

Added lines #L25 - L27 were not covered by tests


def test_illegal_udf_type(t):

Check warning on line 30 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L30

Added line #L30 was not covered by tests
with pytest.raises(
NotImplementedError,
match="Only Builtin UDFs and Pandas UDFs are support in the PySpark backend",
):

@ibis.udf.scalar.python()

Check warning on line 36 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L36

Added line #L36 was not covered by tests
def repeat(x, n) -> str: ...

t.mutate(repeated=repeat(t.str_col, 2)).execute()

Check warning on line 39 in ibis/backends/pyspark/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_udf.py#L39

Added line #L39 was not covered by tests

0 comments on commit 7591e29

Please sign in to comment.