Skip to content

Commit

Permalink
[SPARK-43082][CONNECT][PYTHON] Arrow-optimized Python UDFs in Spark C…
Browse files Browse the repository at this point in the history
…onnect

### What changes were proposed in this pull request?
Implement Arrow-optimized Python UDFs in Spark Connect.

Please see #39384 for motivation and  performance improvements of Arrow-optimized Python UDFs.

### Why are the changes needed?
Parity with vanilla PySpark.

### Does this PR introduce _any_ user-facing change?
Yes. In Spark Connect Python Client, users can:

1. Set `useArrow` parameter True to enable Arrow optimization for a specific Python UDF.

```sh
>>> df = spark.range(2)
>>> df.select(udf(lambda x : x + 1, useArrow=True)('id')).show()
+------------+
|<lambda>(id)|
+------------+
|           1|
|           2|
+------------+

# ArrowEvalPython indicates Arrow optimization
>>> df.select(udf(lambda x : x + 1, useArrow=True)('id')).explain()
== Physical Plan ==
*(2) Project [pythonUDF0#18 AS <lambda>(id)#16]
+- ArrowEvalPython [<lambda>(id#14L)#15], [pythonUDF0#18], 200
   +- *(1) Range (0, 2, step=1, splits=1)
```

2. Enable `spark.sql.execution.pythonUDF.arrow.enabled` Spark Conf to make all Python UDFs Arrow-optimized.

```sh
>>> spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True)
>>> df.select(udf(lambda x : x + 1)('id')).show()
+------------+
|<lambda>(id)|
+------------+
|           1|
|           2|
+------------+

# ArrowEvalPython indicates Arrow optimization
>>> df.select(udf(lambda x : x + 1)('id')).explain()
== Physical Plan ==
*(2) Project [pythonUDF0#30 AS <lambda>(id)#28]
+- ArrowEvalPython [<lambda>(id#26L)#27], [pythonUDF0#30], 200
   +- *(1) Range (0, 2, step=1, splits=1)

```

### How was this patch tested?
Parity unit tests.

Closes #40725 from xinrong-meng/connect_arrow_py_udf.

Authored-by: Xinrong Meng <xinrong@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
xinrong-meng authored and HyukjinKwon committed Apr 22, 2023
1 parent fece7ed commit f29502a
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 90 deletions.
12 changes: 9 additions & 3 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
LambdaFunction,
UnresolvedNamedLambdaVariable,
)
from pyspark.sql.connect.udf import _create_udf
from pyspark.sql.connect.udf import _create_py_udf
from pyspark.sql import functions as pysparkfuncs
from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType

Expand Down Expand Up @@ -2461,6 +2461,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column:
def udf(
f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
returnType: "DataTypeOrString" = StringType(),
useArrow: Optional[bool] = None,
) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]:
from pyspark.rdd import PythonEvalType

Expand All @@ -2469,10 +2470,15 @@ def udf(
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(
_create_udf, returnType=return_type, evalType=PythonEvalType.SQL_BATCHED_UDF
_create_py_udf,
returnType=return_type,
evalType=PythonEvalType.SQL_BATCHED_UDF,
useArrow=useArrow,
)
else:
return _create_udf(f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF)
return _create_py_udf(
f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow
)


udf.__doc__ = pysparkfuncs.udf.__doc__
Expand Down
46 changes: 45 additions & 1 deletion python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import sys
import functools
import warnings
from inspect import getfullargspec
from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union

from pyspark.rdd import PythonEvalType
Expand All @@ -33,7 +35,7 @@
)
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.types import UnparsedDataType
from pyspark.sql.types import DataType, StringType
from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType
from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration


Expand All @@ -47,6 +49,48 @@
from pyspark.sql.types import StringType


def _create_py_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
evalType: int,
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
from pyspark.sql.udf import _create_arrow_py_udf
from pyspark.sql.connect.session import _active_spark_session

if _active_spark_session is None:
is_arrow_enabled = False
else:
is_arrow_enabled = (
_active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") == "true"
if useArrow is None
else useArrow
)

regular_udf = _create_udf(f, returnType, evalType)
return_type = regular_udf.returnType
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
is_output_atomic_type = (
not isinstance(return_type, StructType)
and not isinstance(return_type, MapType)
and not isinstance(return_type, ArrayType)
)
if is_arrow_enabled:
if is_output_atomic_type and is_func_with_args:
return _create_arrow_py_udf(regular_udf)
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf


def _create_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
Expand Down
48 changes: 48 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin


class ArrowPythonUDFParityTests(UDFParityTests, PythonUDFArrowTestsMixin):
@classmethod
def setUpClass(cls):
super(ArrowPythonUDFParityTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(ArrowPythonUDFParityTests, cls).tearDownClass()


if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_arrow_python_udf import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
21 changes: 15 additions & 6 deletions python/pyspark/sql/tests/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@
@unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
)
class PythonUDFArrowTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(PythonUDFArrowTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_broadcast_in_udf(self):
super(PythonUDFArrowTests, self).test_broadcast_in_udf()
Expand Down Expand Up @@ -118,6 +113,20 @@ def test_use_arrow(self):
self.assertEquals(row_false[0], "[1, 2, 3]")


class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(PythonUDFArrowTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(PythonUDFArrowTests, cls).tearDownClass()


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401

Expand Down
41 changes: 0 additions & 41 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,47 +838,6 @@ def setUpClass(cls):
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false")


def test_use_arrow(self):
# useArrow=True
row_true = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=True)("array"),
)
.first()
)
# The input is a NumPy array when the Arrow optimization is on.
self.assertEquals(row_true[0], "[1 2 3]")

# useArrow=None
row_none = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=None)("array"),
)
.first()
)

# useArrow=False
row_false = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=False)("array"),
)
.first()
)
self.assertEquals(row_false[0], row_none[0]) # "[1, 2, 3]"


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
Expand Down
93 changes: 54 additions & 39 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _create_udf(
name: Optional[str] = None,
deterministic: bool = True,
) -> "UserDefinedFunctionLike":
"""Create a regular(non-Arrow-optimized) Python UDF."""
# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
Expand All @@ -88,6 +89,7 @@ def _create_py_udf(
evalType: int,
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
"""Create a regular/Arrow-optimized Python UDF."""
# The following table shows the results when the type coercion in Arrow is needed, that is,
# when the user-specified return type(SQL Type) of the UDF and the actual instance(Python
# Value(Type)) that the UDF returns are different.
Expand Down Expand Up @@ -138,49 +140,62 @@ def _create_py_udf(
and not isinstance(return_type, MapType)
and not isinstance(return_type, ArrayType)
)
if is_arrow_enabled and is_output_atomic_type and is_func_with_args:
require_minimum_pandas_version()
require_minimum_pyarrow_version()

import pandas as pd
from pyspark.sql.pandas.functions import _create_pandas_udf # type: ignore[attr-defined]

# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def vectorized_udf(*args: pd.Series) -> pd.Series:
if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
raise NotImplementedError(
"Struct input type are not supported with Arrow optimization "
"enabled in Python UDFs. Disable "
"'spark.sql.execution.pythonUDF.arrow.enabled' to workaround."
)
return pd.Series(result_func(f(*a)) for a in zip(*args))

# Regular UDFs can take callable instances too.
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
vectorized_udf.__module__ = (
f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
)
vectorized_udf.__doc__ = f.__doc__
pudf = _create_pandas_udf(vectorized_udf, returnType, None)
# Keep the attributes as if this is a regular Python UDF.
pudf.func = f
pudf.returnType = return_type
pudf.evalType = regular_udf.evalType
return pudf
if is_arrow_enabled:
if is_output_atomic_type and is_func_with_args:
return _create_arrow_py_udf(regular_udf)
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf


def _create_arrow_py_udf(regular_udf): # type: ignore
"""Create an Arrow-optimized Python UDF out of a regular Python UDF."""
require_minimum_pandas_version()
require_minimum_pyarrow_version()

import pandas as pd
from pyspark.sql.pandas.functions import _create_pandas_udf

f = regular_udf.func
return_type = regular_udf.returnType

# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def vectorized_udf(*args: pd.Series) -> pd.Series:
if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
raise NotImplementedError(
"Struct input type are not supported with Arrow optimization "
"enabled in Python UDFs. Disable "
"'spark.sql.execution.pythonUDF.arrow.enabled' to workaround."
)
return pd.Series(result_func(f(*a)) for a in zip(*args))

# Regular UDFs can take callable instances too.
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
vectorized_udf.__doc__ = f.__doc__
pudf = _create_pandas_udf(vectorized_udf, return_type, None)
# Keep the attributes as if this is a regular Python UDF.
pudf.func = f
pudf.returnType = return_type
pudf.evalType = regular_udf.evalType
return pudf


class UserDefinedFunction:
"""
User defined function in Python
Expand Down

0 comments on commit f29502a

Please sign in to comment.