Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44836][PYTHON] Refactor Arrow Python UDTF #42520

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,34 +83,26 @@ def _create_py_udtf(
else:
raise e

# Create a regular Python UDTF and check for invalid handler class.
regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic)
eval_type: int = PythonEvalType.SQL_TABLE_UDF

if not arrow_enabled:
return regular_udtf

from pyspark.sql.pandas.utils import (
require_minimum_pandas_version,
require_minimum_pyarrow_version,
)

try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
if arrow_enabled:
from pyspark.sql.pandas.utils import (
require_minimum_pandas_version,
require_minimum_pyarrow_version,
)
return regular_udtf

from pyspark.sql.udtf import _vectorize_udtf
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
eval_type = PythonEvalType.SQL_ARROW_TABLE_UDF
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
)

vectorized_udtf = _vectorize_udtf(cls)
return _create_udtf(
vectorized_udtf, returnType, name, PythonEvalType.SQL_ARROW_TABLE_UDF, deterministic
)
return _create_udtf(cls, returnType, name, eval_type, deterministic)


class UserDefinedTableFunction:
Expand Down
114 changes: 18 additions & 96 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
"""
import pickle
from dataclasses import dataclass
from functools import wraps
import inspect
import sys
import warnings
from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, Union, Callable
from typing import Any, Type, TYPE_CHECKING, Optional, Union

from py4j.java_gateway import JavaObject

Expand Down Expand Up @@ -112,107 +111,30 @@ def _create_py_udtf(
if isinstance(value, str) and value.lower() == "true":
arrow_enabled = True

# Create a regular Python UDTF and check for invalid handler class.
regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic)

if not arrow_enabled:
return regular_udtf

# Return the regular UDTF if the required dependencies are not satisfied.
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
)
return regular_udtf
eval_type: int = PythonEvalType.SQL_TABLE_UDF

if arrow_enabled:
# Return the regular UDTF if the required dependencies are not satisfied.
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
eval_type = PythonEvalType.SQL_ARROW_TABLE_UDF
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
)

# Return the vectorized UDTF.
vectorized_udtf = _vectorize_udtf(cls)
return _create_udtf(
cls=vectorized_udtf,
cls=cls,
returnType=returnType,
name=name,
evalType=PythonEvalType.SQL_ARROW_TABLE_UDF,
deterministic=regular_udtf.deterministic,
evalType=eval_type,
deterministic=deterministic,
)


def _vectorize_udtf(cls: Type) -> Type:
"""Vectorize a Python UDTF handler class."""
import pandas as pd

# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def evaluate(*a: Any, **kw: Any) -> Any:
try:
return f(*a, **kw)
except Exception as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
message_parameters={"method_name": f.__name__, "error": str(e)},
)

return evaluate

class VectorizedUDTF:
def __init__(self) -> None:
self.func = cls()

if hasattr(cls, "analyze") and isinstance(
inspect.getattr_static(cls, "analyze"), staticmethod
):

@staticmethod
def analyze(*args: AnalyzeArgument, **kwargs: AnalyzeArgument) -> AnalyzeResult:
return cls.analyze(*args, **kwargs)

def eval(self, *args: pd.Series, **kwargs: pd.Series) -> Iterator[pd.DataFrame]:
if len(args) == 0 and len(kwargs) == 0:
yield pd.DataFrame(wrap_func(self.func.eval)())
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
keys = list(kwargs.keys())
len_args = len(args)
row_tuples = zip(*args, *[kwargs[key] for key in keys])
for row in row_tuples:
res = wrap_func(self.func.eval)(
*row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)}
)
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
},
)
yield pd.DataFrame(res)

if hasattr(cls, "terminate"):

def terminate(self) -> Iterator[pd.DataFrame]:
yield pd.DataFrame(wrap_func(self.func.terminate)())

vectorized_udtf = VectorizedUDTF
vectorized_udtf.__name__ = cls.__name__
vectorized_udtf.__module__ = cls.__module__
vectorized_udtf.__doc__ = cls.__doc__
vectorized_udtf.__init__.__doc__ = cls.__init__.__doc__
vectorized_udtf.eval.__doc__ = getattr(cls, "eval").__doc__
if hasattr(cls, "terminate"):
getattr(vectorized_udtf, "terminate").__doc__ = getattr(cls, "terminate").__doc__

if hasattr(vectorized_udtf, "analyze"):
getattr(vectorized_udtf, "analyze").__doc__ = getattr(cls, "analyze").__doc__

return vectorized_udtf


def _validate_udtf_handler(cls: Any, returnType: Optional[Union[StructType, str]]) -> None:
"""Validate the handler class of a UDTF."""

Expand Down
43 changes: 37 additions & 6 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import time
from inspect import getfullargspec
import json
from typing import Iterable, Iterator
from typing import Any, Iterable, Iterator

import traceback
import faulthandler
Expand Down Expand Up @@ -593,12 +593,12 @@ def read_udtf(pickleSer, infile, eval_type):
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:

def wrap_arrow_udtf(f, return_type):
import pandas as pd

arrow_return_type = to_arrow_type(return_type)
return_type_size = len(return_type)

def verify_result(result):
import pandas as pd

if not isinstance(result, pd.DataFrame):
raise PySparkTypeError(
error_class="INVALID_ARROW_UDTF_RETURN_TYPE",
Expand Down Expand Up @@ -628,9 +628,40 @@ def verify_result(result):
)
return result

return lambda *a, **kw: map(
lambda res: (res, arrow_return_type), map(verify_result, f(*a, **kw))
)
# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
def func(*a: Any, **kw: Any) -> Any:
try:
return f(*a, **kw)
except Exception as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
message_parameters={"method_name": f.__name__, "error": str(e)},
)

def evaluate(*args: pd.Series, **kwargs: pd.Series):
if len(args) == 0 and len(kwargs) == 0:
yield verify_result(pd.DataFrame(func())), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
keys = list(kwargs.keys())
len_args = len(args)
row_tuples = zip(*args, *[kwargs[key] for key in keys])
for row in row_tuples:
res = func(
*row[:len_args],
**{key: row[len_args + i] for i, key in enumerate(keys)},
)
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
},
)
yield verify_result(pd.DataFrame(res)), arrow_return_type

return evaluate

eval = wrap_arrow_udtf(getattr(udtf, "eval"), return_type)

Expand Down