Skip to content

Commit

Permalink
feat(pyspark): enable the new scalar UDF API
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and jcrist committed Sep 14, 2023
1 parent dc4dbe8 commit f29a8e7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 25 deletions.
25 changes: 23 additions & 2 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pyspark
import pyspark.sql.functions as F
import pyspark.sql.types as pt
import toolz
from packaging.version import parse as vparse
from pyspark.sql import Window
from pyspark.sql.functions import PandasUDFType, pandas_udf
Expand All @@ -28,6 +29,7 @@
)
from ibis.common.collections import frozendict
from ibis.config import options
from ibis.expr.operations.udf import InputType
from ibis.util import any_of, guid


Expand Down Expand Up @@ -71,8 +73,12 @@ def translate(self, op, *, scope, timecontext, **kwargs):
pyspark.sql.DataFrame
translated PySpark DataFrame or Column object
"""

if (
# TODO(cpcloud): remove the udf instance checking when going to sqlglot
if isinstance(op, ops.ScalarUDF):
formatter = compile_scalar_udf
result = formatter(self, op, scope=scope, timecontext=timecontext, **kwargs)
return result
elif (
not isinstance(op, ops.ScalarParameter)
and (result := scope.get_value(op, timecontext)) is not None
):
Expand Down Expand Up @@ -1800,6 +1806,21 @@ def compile_reduction_udf(t, op, *, aggcontext=None, **kwargs):
return src_table.agg(col)


# NB: this is intentionally not using @compiles because @compiles doesn't
# handle subclasses of operations
def compile_scalar_udf(t, op, **kwargs):
if op.__input_type__ != InputType.PANDAS:
raise NotImplementedError("Only Pandas UDFs are support in the PySpark backend")

import pandas as pd

make_series = partial(pd.Series, dtype=op.dtype.to_pandas())
func = toolz.compose(make_series, op.__func__)
spark_dtype = PySparkType.from_ibis(op.dtype)
spark_udf = pandas_udf(func, spark_dtype, PandasUDFType.SCALAR)
return spark_udf(*map(partial(t.translate, **kwargs), op.args))


@compiles(ops.SearchedCase)
def compile_searched_case(t, op, **kwargs):
existing_when = None
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"mysql",
"oracle",
"pandas",
"pyspark",
"trino",
]
)


@no_python_udfs
@mark.notimpl(["pyspark"])
@mark.notyet(["datafusion"], raises=NotImplementedError)
def test_udf(batting):
@udf.scalar.python
Expand All @@ -47,6 +47,7 @@ def num_vowels(s: str, include_y: bool = False) -> int:


@no_python_udfs
@mark.notimpl(["pyspark"])
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
Expand Down Expand Up @@ -77,6 +78,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:


@no_python_udfs
@mark.notimpl(["pyspark"])
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
Expand Down Expand Up @@ -160,7 +162,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pyarrow,
marks=[
mark.notyet(
["snowflake", "sqlite"],
["snowflake", "sqlite", "pyspark"],
raises=NotImplementedError,
reason="backend doesn't support pyarrow UDFs",
)
Expand Down
65 changes: 44 additions & 21 deletions ibis/backends/tests/test_vectorized_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,27 @@ def add_one(s):
return s + 1


def create_add_one_udf(result_formatter):
return elementwise(input_type=[dt.double], output_type=dt.double)(
_format_udf_return_type(add_one, result_formatter)
def create_add_one_udf(result_formatter, id):
@elementwise(input_type=[dt.double], output_type=dt.double)
def add_one_legacy(s):
return result_formatter(add_one(s))

@ibis.udf.scalar.pandas
def add_one_udf(s: float) -> float:
return result_formatter(add_one(s))

yield param(add_one_legacy, id=f"add_one_legacy_{id}")
yield param(
add_one_udf,
marks=[pytest.mark.notimpl(["pandas", "dask"])],
id=f"add_one_modern_{id}",
)


add_one_udfs = [
create_add_one_udf(result_formatter=lambda v: v), # pd.Series,
create_add_one_udf(result_formatter=lambda v: np.array(v)), # np.array,
create_add_one_udf(result_formatter=lambda v: list(v)), # list,
*create_add_one_udf(result_formatter=lambda v: v, id="series"),
*create_add_one_udf(result_formatter=lambda v: np.array(v), id="array"),
*create_add_one_udf(result_formatter=lambda v: list(v), id="list"),
]


Expand Down Expand Up @@ -239,20 +250,34 @@ def quantiles(series, *, quantiles):
return series.quantile(quantiles)


def test_elementwise_udf(udf_backend, udf_alltypes, udf_df):
add_one_udf = create_add_one_udf(result_formatter=lambda v: v)
result = add_one_udf(udf_alltypes["double_col"]).execute()
expected = add_one_udf.func(udf_df["double_col"])
@pytest.mark.parametrize(
"udf", create_add_one_udf(result_formatter=lambda v: v, id="series")
)
def test_elementwise_udf(udf_backend, udf_alltypes, udf_df, udf):
expr = udf(udf_alltypes["double_col"])
result = expr.execute()

expected_func = getattr(expr.op(), "__func__", getattr(udf, "func", None))
assert (
expected_func is not None
), f"neither __func__ nor func attributes found on {udf} or expr object"

expected = expected_func(udf_df["double_col"])
udf_backend.assert_series_equal(result, expected, check_names=False)


@pytest.mark.parametrize("udf", add_one_udfs)
def test_elementwise_udf_mutate(udf_backend, udf_alltypes, udf_df, udf):
expr = udf_alltypes.mutate(incremented=udf(udf_alltypes["double_col"]))
udf_expr = udf(udf_alltypes["double_col"])
expr = udf_alltypes.mutate(incremented=udf_expr)
result = expr.execute()

expected = udf_df.assign(incremented=udf.func(udf_df["double_col"]))
expected_func = getattr(udf_expr.op(), "__func__", getattr(udf, "func", None))
assert (
expected_func is not None
), f"neither __func__ nor func attributes found on {udf} or expr object"

expected = udf_df.assign(incremented=expected_func(udf_df["double_col"]))
udf_backend.assert_series_equal(result["incremented"], expected["incremented"])


Expand All @@ -275,7 +300,7 @@ def test_analytic_udf_mutate(udf_backend, udf_alltypes, udf_df, udf):
udf_backend.assert_series_equal(result["zscore"], expected["zscore"])


def test_reduction_udf(udf_backend, udf_alltypes, udf_df):
def test_reduction_udf(udf_alltypes, udf_df):
result = calc_mean(udf_alltypes["double_col"]).execute()
expected = udf_df["double_col"].mean()
assert result == expected
Expand Down Expand Up @@ -306,7 +331,7 @@ def test_reduction_udf_on_empty_data(udf_backend, udf_alltypes):
udf_backend.assert_frame_equal(result, expected, check_dtype=False)


def test_output_type_in_list_invalid(udf_backend, udf_alltypes, udf_df):
def test_output_type_in_list_invalid():
# Test that an error is raised if UDF output type is wrapped in a list

with pytest.raises(
Expand All @@ -315,7 +340,7 @@ def test_output_type_in_list_invalid(udf_backend, udf_alltypes, udf_df):
):

@elementwise(input_type=[dt.double], output_type=[dt.double])
def add_one(s):
def _(s):
return s + 1


Expand Down Expand Up @@ -423,14 +448,14 @@ def foo4(v, *args, **kwargs):
udf_backend.assert_frame_equal(result, expected)


def test_invalid_kwargs(udf_backend, udf_alltypes):
def test_invalid_kwargs():
# Test that defining a UDF with a non-column argument that is not a
# keyword argument raises an error

with pytest.raises(TypeError, match=".*must be defined as keyword only.*"):

@elementwise(input_type=[dt.double], output_type=dt.double)
def foo1(v, amount):
def _(v, _):
return v + 1


Expand Down Expand Up @@ -494,9 +519,7 @@ def test_elementwise_udf_overwrite_destruct_and_assign(udf_backend, udf_alltypes
@pytest.mark.xfail_version(pyspark=["pyspark<3.1"])
@pytest.mark.parametrize("method", ["destructure", "unpack"])
@pytest.mark.skip("dask")
def test_elementwise_udf_destructure_exact_once(
udf_backend, udf_alltypes, method, tmp_path
):
def test_elementwise_udf_destructure_exact_once(udf_alltypes, method, tmp_path):
@elementwise(
input_type=[dt.double],
output_type=dt.Struct({"col1": dt.double, "col2": dt.double}),
Expand Down Expand Up @@ -541,7 +564,7 @@ def test_elementwise_udf_multiple_overwrite_destruct(udf_backend, udf_alltypes):
udf_backend.assert_frame_equal(result, expected, check_like=True)


def test_elementwise_udf_named_destruct(udf_backend, udf_alltypes):
def test_elementwise_udf_named_destruct(udf_alltypes):
"""Test error when assigning name to a destruct column."""

add_one_struct_udf = create_add_one_struct_udf(
Expand Down

0 comments on commit f29a8e7

Please sign in to comment.