Skip to content

Commit

Permalink
[SPARK-45746][PYTHON] Return specific error messages if UDTF 'analyze…
Browse files Browse the repository at this point in the history
…' or 'eval' method accepts or returns wrong values

### What changes were proposed in this pull request?

This PR adds checks to return specific error messages if any Python UDTF `analyze` or `eval` method accepts or returns wrong values.

Error messages improved include:
* If the `__init__` method takes more arguments than `self` and `analyze_result`.
* If the UDTF call passes more or fewer arguments than `analyze` or `eval` expects (not using `*args` or `**kwargs`).
* If the `analyze` method returns an object besides a `StructType` in the `AnalyzeResult`  `schema` field.
* If there are extra optional `AnalyzeResult` fields relating to input table arguments (e.g. `with_single_partition`) but the `analyze` method received no input table argument.
* If the `analyze` method tries to return a list of strings for the `partition_by` optional field of the `AnalyzeResult` instead of a list of `PartitioningColumn` objects.
* If the `AnalyzeResult` is missing the `schema` argument entirely.
* If we use keyword arguments in the TVF call but the `analyze` or `eval` method does not accept arguments with those keyword(s) (or `**kwargs`).

### Why are the changes needed?

This helps users understand how to easily fix their user-defined table functions if they are malformed.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR adds test coverage.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#43611 from dtenedor/fix-more-udtf-errors.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
dtenedor authored and ueshin committed Nov 29, 2023
1 parent 8bd9c8c commit f5e4e84
Show file tree
Hide file tree
Showing 12 changed files with 939 additions and 309 deletions.
15 changes: 15 additions & 0 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,21 @@
"Cannot convert the output value of the column '<col_name>' with type '<col_type>' to the specified return type of the column: '<arrow_type>'. Please check if the data types match and try again."
]
},
"UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD" : {
"message" : [
"Failed to evaluate the user-defined table function '<name>' because its constructor is invalid: the function implements the 'analyze' method, but its constructor has more than two arguments (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, or one 'self' argument plus another argument for the result of the 'analyze' method, and try the query again."
]
},
"UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD" : {
"message" : [
"Failed to evaluate the user-defined table function '<name>' because its constructor is invalid: the function does not implement the 'analyze' method, and its constructor has more than one argument (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, and try the query again."
]
},
"UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE" : {
"message" : [
"Failed to evaluate the user-defined table function '<name>' because the function arguments did not match the expected signature of the 'eval' method (<reason>). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its 'eval' method accepts the provided arguments, and then try the query again."
]
},
"UDTF_EXEC_ERROR" : {
"message" : [
"User defined table function encountered an error in the '<method_name>' method: <error>"
Expand Down
49 changes: 18 additions & 31 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,20 +274,21 @@ def eval(self, a: int):
df = self.spark.sql("SELECT * FROM testUDTF(null)")
self.assertEqual(df.collect(), [Row(a=None)])

# These are expected error message substrings to be used in test cases below.
tooManyPositionalArguments = "too many positional arguments"
missingARequiredArgument = "missing a required argument"
multipleValuesForArgument = "multiple values for argument"

def test_udtf_with_wrong_num_input(self):
@udtf(returnType="a: int, b: int")
class TestUDTF:
def eval(self, a: int):
yield a, a + 1

with self.assertRaisesRegex(
PythonException, r"eval\(\) missing 1 required positional argument: 'a'"
):
with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.missingARequiredArgument):
TestUDTF().collect()

with self.assertRaisesRegex(
PythonException, r"eval\(\) takes 2 positional arguments but 3 were given"
):
with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.tooManyPositionalArguments):
TestUDTF(lit(1), lit(2)).collect()

def test_udtf_init_with_additional_args(self):
Expand All @@ -299,9 +300,7 @@ def __init__(self, a: int):
def eval(self, a: int):
yield a,

with self.assertRaisesRegex(
PythonException, r"__init__\(\) missing 1 required positional argument: 'a'"
):
with self.assertRaisesRegex(PythonException, r".*constructor has more than one argument.*"):
TestUDTF(lit(1)).show()

def test_udtf_terminate_with_additional_args(self):
Expand Down Expand Up @@ -1582,8 +1581,9 @@ def eval(self):

with self.assertRaisesRegex(
AnalysisException,
"Output of `analyze` static method of Python UDTFs expects "
"a pyspark.sql.udtf.AnalyzeResult but got: <class 'pyspark.sql.types.StringType'>",
"'analyze' method expects a result of type pyspark.sql.udtf.AnalyzeResult, "
"but instead this method returned a value of type: "
"<class 'pyspark.sql.types.StringType'>",
):
func().collect()

Expand Down Expand Up @@ -1622,26 +1622,17 @@ class TestUDTF:
def analyze(a: AnalyzeArgument, b: AnalyzeArgument) -> AnalyzeResult:
return AnalyzeResult(StructType().add("a", a.dataType).add("b", b.dataType))

def eval(self, a):
def eval(self, a, b):
yield a, a + 1

func = udtf(TestUDTF)

with self.assertRaisesRegex(
AnalysisException, r"analyze\(\) missing 1 required positional argument: 'b'"
):
with self.assertRaisesRegex(AnalysisException, r"arguments"):
func(lit(1)).collect()

with self.assertRaisesRegex(
AnalysisException, r"analyze\(\) takes 2 positional arguments but 3 were given"
):
with self.assertRaisesRegex(AnalysisException, r"arguments"):
func(lit(1), lit(2), lit(3)).collect()

with self.assertRaisesRegex(
PythonException, r"eval\(\) takes 2 positional arguments but 3 were given"
):
func(lit(1), lit(2)).collect()

def test_udtf_with_analyze_taking_keyword_arguments(self):
@udtf
class TestUDTF:
Expand All @@ -1660,12 +1651,12 @@ def eval(self, **kwargs):
assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected)

with self.assertRaisesRegex(
AnalysisException, r"analyze\(\) takes 0 positional arguments but 1 was given"
AnalysisException, BaseUDTFTestsMixin.tooManyPositionalArguments
):
TestUDTF(lit(1)).collect()

with self.assertRaisesRegex(
AnalysisException, r"analyze\(\) takes 0 positional arguments but 2 were given"
AnalysisException, BaseUDTFTestsMixin.tooManyPositionalArguments
):
self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect()

Expand Down Expand Up @@ -1924,14 +1915,10 @@ def eval(self, a, b):
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show()

with self.assertRaisesRegex(
PythonException, r"eval\(\) got an unexpected keyword argument 'c'"
):
with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.missingARequiredArgument):
self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show()

with self.assertRaisesRegex(
PythonException, r"eval\(\) got multiple values for argument 'a'"
):
with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.multipleValuesForArgument):
self.spark.sql("SELECT * FROM test_udtf(10, a => 100)").show()

def test_udtf_with_kwargs(self):
Expand Down
86 changes: 83 additions & 3 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import os
import sys
from textwrap import dedent
from typing import Dict, List, IO, Tuple

from pyspark.accumulators import _accumulatorRegistry
Expand All @@ -30,7 +31,8 @@
write_with_length,
SpecialLengths,
)
from pyspark.sql.types import _parse_datatype_json_string
from pyspark.sql.functions import PartitioningColumn
from pyspark.sql.types import _parse_datatype_json_string, StructType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
from pyspark.util import handle_worker_exception
from pyspark.worker_util import (
Expand Down Expand Up @@ -113,15 +115,93 @@ def main(infile: IO, outfile: IO) -> None:

_accumulatorRegistry.clear()

udtf_name = utf8_deserializer.loads(infile)
handler = read_udtf(infile)
args, kwargs = read_arguments(infile)

error_prefix = f"Failed to evaluate the user-defined table function '{udtf_name}'"

def format_error(msg: str) -> str:
return dedent(msg).replace("\n", " ")

# Check that the arguments provided to the UDTF call match the expected parameters defined
# in the static 'analyze' method signature.
try:
inspect.signature(handler.analyze).bind(*args, **kwargs) # type: ignore[attr-defined]
except TypeError as e:
# The UDTF call's arguments did not match the expected signature.
raise PySparkValueError(
format_error(
f"""
{error_prefix} because the function arguments did not match the expected
signature of the static 'analyze' method ({e}). Please update the query so that
this table function call provides arguments matching the expected signature, or
else update the table function so that its static 'analyze' method accepts the
provided arguments, and then try the query again."""
)
)

# Invoke the UDTF's 'analyze' method.
result = handler.analyze(*args, **kwargs) # type: ignore[attr-defined]

# Check invariants about the 'analyze' method after running it.
if not isinstance(result, AnalyzeResult):
raise PySparkValueError(
"Output of `analyze` static method of Python UDTFs expects "
f"a pyspark.sql.udtf.AnalyzeResult but got: {type(result)}"
format_error(
f"""
{error_prefix} because the static 'analyze' method expects a result of type
pyspark.sql.udtf.AnalyzeResult, but instead this method returned a value of
type: {type(result)}"""
)
)
elif not isinstance(result.schema, StructType):
raise PySparkValueError(
format_error(
f"""
{error_prefix} because the static 'analyze' method expects a result of type
pyspark.sql.udtf.AnalyzeResult with a 'schema' field comprising a StructType,
but the 'schema' field had the wrong type: {type(result.schema)}"""
)
)
has_table_arg = any(arg.isTable for arg in args) or any(
arg.isTable for arg in kwargs.values()
)
if not has_table_arg and result.withSinglePartition:
raise PySparkValueError(
format_error(
f"""
{error_prefix} because the static 'analyze' method returned an
'AnalyzeResult' object with the 'withSinglePartition' field set to 'true', but
the function call did not provide any table argument. Please update the query so
that it provides a table argument, or else update the table function so that its
'analyze' method returns an 'AnalyzeResult' object with the
'withSinglePartition' field set to 'false', and then try the query again."""
)
)
elif not has_table_arg and len(result.partitionBy) > 0:
raise PySparkValueError(
format_error(
f"""
{error_prefix} because the static 'analyze' method returned an
'AnalyzeResult' object with the 'partitionBy' list set to non-empty, but the
function call did not provide any table argument. Please update the query so
that it provides a table argument, or else update the table function so that its
'analyze' method returns an 'AnalyzeResult' object with the 'partitionBy' list
set to empty, and then try the query again."""
)
)
elif isinstance(result.partitionBy, (list, tuple)) and (
len(result.partitionBy) > 0
and not all([isinstance(val, PartitioningColumn) for val in result.partitionBy])
):
raise PySparkValueError(
format_error(
f"""
{error_prefix} because the static 'analyze' method returned an
'AnalyzeResult' object with the 'partitionBy' field set to a value besides a
list or tuple of 'PartitioningColumn' objects. Please update the table function
and then try the query again."""
)
)

# Return the analyzed schema.
Expand Down
50 changes: 31 additions & 19 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import dataclasses
import time
from inspect import getfullargspec
import inspect
import json
from typing import Any, Callable, Iterable, Iterator, Optional
import faulthandler
Expand Down Expand Up @@ -616,12 +616,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
return args_offsets, wrap_arrow_batch_iter_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when wrapping it
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when wrapping it
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return wrap_grouped_agg_pandas_udf(func, args_offsets, kwargs_offsets, return_type)
Expand Down Expand Up @@ -691,7 +691,6 @@ def read_udtf(pickleSer, infile, eval_type):
pickled_analyze_result = None
# Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult,
# although we may set this to false later if we find otherwise.
udtf_init_method_accepts_analyze_result = True
handler = read_command(pickleSer, infile)
if not isinstance(handler, type):
raise PySparkRuntimeError(
Expand All @@ -704,29 +703,32 @@ def read_udtf(pickleSer, infile, eval_type):
raise PySparkRuntimeError(
f"The return type of a UDTF must be a struct type, but got {type(return_type)}."
)
udtf_name = utf8_deserializer.loads(infile)

# Update the handler that creates a new UDTF instance to first try calling the UDTF constructor
# with one argument containing the previous AnalyzeResult. If that fails, then try a constructor
# with no arguments. In this way each UDTF class instance can decide if it wants to inspect the
# AnalyzeResult.
udtf_init_args = inspect.getfullargspec(handler)
if has_pickled_analyze_result:
prev_handler = handler
if len(udtf_init_args.args) > 2:
raise PySparkRuntimeError(
error_class="UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD",
message_parameters={"name": udtf_name},
)
elif len(udtf_init_args.args) == 2:
prev_handler = handler

def construct_udtf():
nonlocal udtf_init_method_accepts_analyze_result
if not udtf_init_method_accepts_analyze_result:
return prev_handler()
else:
try:
# Here we pass the AnalyzeResult to the UDTF's __init__ method.
return prev_handler(dataclasses.replace(pickled_analyze_result))
except TypeError:
# This means that the UDTF handler does not accept an AnalyzeResult object in
# its __init__ method.
udtf_init_method_accepts_analyze_result = False
return prev_handler()
def construct_udtf():
# Here we pass the AnalyzeResult to the UDTF's __init__ method.
return prev_handler(dataclasses.replace(pickled_analyze_result))

handler = construct_udtf
handler = construct_udtf
elif len(udtf_init_args.args) > 1:
raise PySparkRuntimeError(
error_class="UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD",
message_parameters={"name": udtf_name},
)

class UDTFWithPartitions:
"""
Expand Down Expand Up @@ -854,6 +856,16 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
"the query again."
)

# Check that the arguments provided to the UDTF call match the expected parameters defined
# in the 'eval' method signature.
try:
inspect.signature(udtf.eval).bind(*args_offsets, **kwargs_offsets)
except TypeError as e:
raise PySparkRuntimeError(
error_class="UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE",
message_parameters={"name": udtf_name, "reason": str(e)},
) from None

def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]:
def raise_(result_column_index):
raise PySparkRuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,7 @@ object PythonUDTFRunner {
PythonWorkerUtils.writePythonFunction(udtf.func, dataOut)
// Write the result schema of the UDTF call.
PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut)
// Write the UDTF name.
PythonWorkerUtils.writeUTF(udtf.name, dataOut)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ case class UserDefinedPythonTableFunction(
case _ => false
}
val runAnalyzeInPython = (func: PythonFunction, exprs: Seq[Expression]) => {
val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(func, exprs, tableArgs)
val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(name, func, exprs, tableArgs)
runner.runInPython()
}
UnresolvedPolymorphicPythonUDTF(
Expand Down Expand Up @@ -184,6 +184,7 @@ case class UserDefinedPythonTableFunction(
* will be thrown when an exception is raised in Python.
*/
class UserDefinedPythonTableFunctionAnalyzeRunner(
name: String,
func: PythonFunction,
exprs: Seq[Expression],
tableArgs: Seq[Boolean]) extends PythonPlannerRunner[PythonUDTFAnalyzeResult](func) {
Expand All @@ -192,6 +193,7 @@ class UserDefinedPythonTableFunctionAnalyzeRunner(

override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
// Send Python UDTF
PythonWorkerUtils.writeUTF(name, dataOut)
PythonWorkerUtils.writePythonFunction(func, dataOut)

// Send arguments
Expand Down Expand Up @@ -226,6 +228,9 @@ class UserDefinedPythonTableFunctionAnalyzeRunner(
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
// Remove the leading traceback stack trace from the error message string, if any, since it
// usually only includes the "analyze_udtf.py" filename and a line number.
.split("PySparkValueError:").last.strip()
throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
}

Expand Down
Loading

0 comments on commit f5e4e84

Please sign in to comment.