Skip to content

Commit

Permalink
[SPARK-44836][PYTHON] Refactor Arrow Python UDTF
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Refactors Arrow Python UDTF.

### Why are the changes needed?

Arrow Python UDTF is not need to be redefined when creating it. It can be handled in `worker.py`.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

The existing tests.

Closes apache#42520 from ueshin/issues/SPARK-44836/refactor_arrow_udtf.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
ueshin authored and ragnarok56 committed Mar 2, 2024
1 parent 05911e1 commit bf51587
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 126 deletions.
40 changes: 16 additions & 24 deletions python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,34 +83,26 @@ def _create_py_udtf(
else:
raise e

# Create a regular Python UDTF and check for invalid handler class.
regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic)
eval_type: int = PythonEvalType.SQL_TABLE_UDF

if not arrow_enabled:
return regular_udtf

from pyspark.sql.pandas.utils import (
require_minimum_pandas_version,
require_minimum_pyarrow_version,
)

try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
if arrow_enabled:
from pyspark.sql.pandas.utils import (
require_minimum_pandas_version,
require_minimum_pyarrow_version,
)
return regular_udtf

from pyspark.sql.udtf import _vectorize_udtf
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
eval_type = PythonEvalType.SQL_ARROW_TABLE_UDF
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
)

vectorized_udtf = _vectorize_udtf(cls)
return _create_udtf(
vectorized_udtf, returnType, name, PythonEvalType.SQL_ARROW_TABLE_UDF, deterministic
)
return _create_udtf(cls, returnType, name, eval_type, deterministic)


class UserDefinedTableFunction:
Expand Down
114 changes: 18 additions & 96 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
"""
import pickle
from dataclasses import dataclass
from functools import wraps
import inspect
import sys
import warnings
from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, Union, Callable
from typing import Any, Type, TYPE_CHECKING, Optional, Union

from py4j.java_gateway import JavaObject

Expand Down Expand Up @@ -112,107 +111,30 @@ def _create_py_udtf(
if isinstance(value, str) and value.lower() == "true":
arrow_enabled = True

# Create a regular Python UDTF and check for invalid handler class.
regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic)

if not arrow_enabled:
return regular_udtf

# Return the regular UDTF if the required dependencies are not satisfied.
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
)
return regular_udtf
eval_type: int = PythonEvalType.SQL_TABLE_UDF

if arrow_enabled:
# Return the regular UDTF if the required dependencies are not satisfied.
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
eval_type = PythonEvalType.SQL_ARROW_TABLE_UDF
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
)

# Return the vectorized UDTF.
vectorized_udtf = _vectorize_udtf(cls)
return _create_udtf(
cls=vectorized_udtf,
cls=cls,
returnType=returnType,
name=name,
evalType=PythonEvalType.SQL_ARROW_TABLE_UDF,
deterministic=regular_udtf.deterministic,
evalType=eval_type,
deterministic=deterministic,
)


def _vectorize_udtf(cls: Type) -> Type:
"""Vectorize a Python UDTF handler class."""
import pandas as pd

# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def evaluate(*a: Any, **kw: Any) -> Any:
try:
return f(*a, **kw)
except Exception as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
message_parameters={"method_name": f.__name__, "error": str(e)},
)

return evaluate

class VectorizedUDTF:
def __init__(self) -> None:
self.func = cls()

if hasattr(cls, "analyze") and isinstance(
inspect.getattr_static(cls, "analyze"), staticmethod
):

@staticmethod
def analyze(*args: AnalyzeArgument, **kwargs: AnalyzeArgument) -> AnalyzeResult:
return cls.analyze(*args, **kwargs)

def eval(self, *args: pd.Series, **kwargs: pd.Series) -> Iterator[pd.DataFrame]:
if len(args) == 0 and len(kwargs) == 0:
yield pd.DataFrame(wrap_func(self.func.eval)())
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
keys = list(kwargs.keys())
len_args = len(args)
row_tuples = zip(*args, *[kwargs[key] for key in keys])
for row in row_tuples:
res = wrap_func(self.func.eval)(
*row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)}
)
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
},
)
yield pd.DataFrame(res)

if hasattr(cls, "terminate"):

def terminate(self) -> Iterator[pd.DataFrame]:
yield pd.DataFrame(wrap_func(self.func.terminate)())

vectorized_udtf = VectorizedUDTF
vectorized_udtf.__name__ = cls.__name__
vectorized_udtf.__module__ = cls.__module__
vectorized_udtf.__doc__ = cls.__doc__
vectorized_udtf.__init__.__doc__ = cls.__init__.__doc__
vectorized_udtf.eval.__doc__ = getattr(cls, "eval").__doc__
if hasattr(cls, "terminate"):
getattr(vectorized_udtf, "terminate").__doc__ = getattr(cls, "terminate").__doc__

if hasattr(vectorized_udtf, "analyze"):
getattr(vectorized_udtf, "analyze").__doc__ = getattr(cls, "analyze").__doc__

return vectorized_udtf


def _validate_udtf_handler(cls: Any, returnType: Optional[Union[StructType, str]]) -> None:
"""Validate the handler class of a UDTF."""

Expand Down
43 changes: 37 additions & 6 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import time
from inspect import getfullargspec
import json
from typing import Iterable, Iterator
from typing import Any, Iterable, Iterator

import traceback
import faulthandler
Expand Down Expand Up @@ -593,12 +593,12 @@ def read_udtf(pickleSer, infile, eval_type):
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:

def wrap_arrow_udtf(f, return_type):
import pandas as pd

arrow_return_type = to_arrow_type(return_type)
return_type_size = len(return_type)

def verify_result(result):
import pandas as pd

if not isinstance(result, pd.DataFrame):
raise PySparkTypeError(
error_class="INVALID_ARROW_UDTF_RETURN_TYPE",
Expand Down Expand Up @@ -628,9 +628,40 @@ def verify_result(result):
)
return result

return lambda *a, **kw: map(
lambda res: (res, arrow_return_type), map(verify_result, f(*a, **kw))
)
# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
def func(*a: Any, **kw: Any) -> Any:
try:
return f(*a, **kw)
except Exception as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
message_parameters={"method_name": f.__name__, "error": str(e)},
)

def evaluate(*args: pd.Series, **kwargs: pd.Series):
if len(args) == 0 and len(kwargs) == 0:
yield verify_result(pd.DataFrame(func())), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
keys = list(kwargs.keys())
len_args = len(args)
row_tuples = zip(*args, *[kwargs[key] for key in keys])
for row in row_tuples:
res = func(
*row[:len_args],
**{key: row[len_args + i] for i, key in enumerate(keys)},
)
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
},
)
yield verify_result(pd.DataFrame(res)), arrow_return_type

return evaluate

eval = wrap_arrow_udtf(getattr(udtf, "eval"), return_type)

Expand Down

0 comments on commit bf51587

Please sign in to comment.