diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 6662efa8ca54b..289b16c9b6066 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -798,6 +798,21 @@ "Cannot convert the output value of the column '' with type '' to the specified return type of the column: ''. Please check if the data types match and try again." ] }, + "UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD" : { + "message" : [ + "Failed to evaluate the user-defined table function '' because its constructor is invalid: the function implements the 'analyze' method, but its constructor has more than two arguments (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, or one 'self' argument plus another argument for the result of the 'analyze' method, and try the query again." + ] + }, + "UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD" : { + "message" : [ + "Failed to evaluate the user-defined table function '' because its constructor is invalid: the function does not implement the 'analyze' method, and its constructor has more than one argument (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, and try the query again." + ] + }, + "UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE" : { + "message" : [ + "Failed to evaluate the user-defined table function '' because the function arguments did not match the expected signature of the 'eval' method (). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its 'eval' method accepts the provided arguments, and then try the query again." + ] + }, "UDTF_EXEC_ERROR" : { "message" : [ "User defined table function encountered an error in the '' method: " diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 2794b51eb7048..41321f556ac66 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -274,20 +274,21 @@ def eval(self, a: int): df = self.spark.sql("SELECT * FROM testUDTF(null)") self.assertEqual(df.collect(), [Row(a=None)]) + # These are expected error message substrings to be used in test cases below. + tooManyPositionalArguments = "too many positional arguments" + missingARequiredArgument = "missing a required argument" + multipleValuesForArgument = "multiple values for argument" + def test_udtf_with_wrong_num_input(self): @udtf(returnType="a: int, b: int") class TestUDTF: def eval(self, a: int): yield a, a + 1 - with self.assertRaisesRegex( - PythonException, r"eval\(\) missing 1 required positional argument: 'a'" - ): + with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.missingARequiredArgument): TestUDTF().collect() - with self.assertRaisesRegex( - PythonException, r"eval\(\) takes 2 positional arguments but 3 were given" - ): + with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.tooManyPositionalArguments): TestUDTF(lit(1), lit(2)).collect() def test_udtf_init_with_additional_args(self): @@ -299,9 +300,7 @@ def __init__(self, a: int): def eval(self, a: int): yield a, - with self.assertRaisesRegex( - PythonException, r"__init__\(\) missing 1 required positional argument: 'a'" - ): + with self.assertRaisesRegex(PythonException, r".*constructor has more than one argument.*"): TestUDTF(lit(1)).show() def test_udtf_terminate_with_additional_args(self): @@ -1582,8 +1581,9 @@ def eval(self): with self.assertRaisesRegex( AnalysisException, - "Output of `analyze` static method of Python UDTFs expects " - "a pyspark.sql.udtf.AnalyzeResult but got: ", + "'analyze' method expects a result of type pyspark.sql.udtf.AnalyzeResult, " + "but instead this method returned a value of type: " + "", ): func().collect() @@ -1622,26 +1622,17 @@ class TestUDTF: def analyze(a: AnalyzeArgument, b: AnalyzeArgument) -> AnalyzeResult: return AnalyzeResult(StructType().add("a", a.dataType).add("b", b.dataType)) - def eval(self, a): + def eval(self, a, b): yield a, a + 1 func = udtf(TestUDTF) - with self.assertRaisesRegex( - AnalysisException, r"analyze\(\) missing 1 required positional argument: 'b'" - ): + with self.assertRaisesRegex(AnalysisException, r"arguments"): func(lit(1)).collect() - with self.assertRaisesRegex( - AnalysisException, r"analyze\(\) takes 2 positional arguments but 3 were given" - ): + with self.assertRaisesRegex(AnalysisException, r"arguments"): func(lit(1), lit(2), lit(3)).collect() - with self.assertRaisesRegex( - PythonException, r"eval\(\) takes 2 positional arguments but 3 were given" - ): - func(lit(1), lit(2)).collect() - def test_udtf_with_analyze_taking_keyword_arguments(self): @udtf class TestUDTF: @@ -1660,12 +1651,12 @@ def eval(self, **kwargs): assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected) with self.assertRaisesRegex( - AnalysisException, r"analyze\(\) takes 0 positional arguments but 1 was given" + AnalysisException, BaseUDTFTestsMixin.tooManyPositionalArguments ): TestUDTF(lit(1)).collect() with self.assertRaisesRegex( - AnalysisException, r"analyze\(\) takes 0 positional arguments but 2 were given" + AnalysisException, BaseUDTFTestsMixin.tooManyPositionalArguments ): self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect() @@ -1924,14 +1915,10 @@ def eval(self, a, b): with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show() - with self.assertRaisesRegex( - PythonException, r"eval\(\) got an unexpected keyword argument 'c'" - ): + with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.missingARequiredArgument): self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show() - with self.assertRaisesRegex( - PythonException, r"eval\(\) got multiple values for argument 'a'" - ): + with self.assertRaisesRegex(PythonException, BaseUDTFTestsMixin.multipleValuesForArgument): self.spark.sql("SELECT * FROM test_udtf(10, a => 100)").show() def test_udtf_with_kwargs(self): diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index de484c9cf941c..ce21e4859770a 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -18,6 +18,7 @@ import inspect import os import sys +from textwrap import dedent from typing import Dict, List, IO, Tuple from pyspark.accumulators import _accumulatorRegistry @@ -30,7 +31,8 @@ write_with_length, SpecialLengths, ) -from pyspark.sql.types import _parse_datatype_json_string +from pyspark.sql.functions import PartitioningColumn +from pyspark.sql.types import _parse_datatype_json_string, StructType from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult from pyspark.util import handle_worker_exception from pyspark.worker_util import ( @@ -113,15 +115,93 @@ def main(infile: IO, outfile: IO) -> None: _accumulatorRegistry.clear() + udtf_name = utf8_deserializer.loads(infile) handler = read_udtf(infile) args, kwargs = read_arguments(infile) + error_prefix = f"Failed to evaluate the user-defined table function '{udtf_name}'" + + def format_error(msg: str) -> str: + return dedent(msg).replace("\n", " ") + + # Check that the arguments provided to the UDTF call match the expected parameters defined + # in the static 'analyze' method signature. + try: + inspect.signature(handler.analyze).bind(*args, **kwargs) # type: ignore[attr-defined] + except TypeError as e: + # The UDTF call's arguments did not match the expected signature. + raise PySparkValueError( + format_error( + f""" + {error_prefix} because the function arguments did not match the expected + signature of the static 'analyze' method ({e}). Please update the query so that + this table function call provides arguments matching the expected signature, or + else update the table function so that its static 'analyze' method accepts the + provided arguments, and then try the query again.""" + ) + ) + + # Invoke the UDTF's 'analyze' method. result = handler.analyze(*args, **kwargs) # type: ignore[attr-defined] + # Check invariants about the 'analyze' method after running it. if not isinstance(result, AnalyzeResult): raise PySparkValueError( - "Output of `analyze` static method of Python UDTFs expects " - f"a pyspark.sql.udtf.AnalyzeResult but got: {type(result)}" + format_error( + f""" + {error_prefix} because the static 'analyze' method expects a result of type + pyspark.sql.udtf.AnalyzeResult, but instead this method returned a value of + type: {type(result)}""" + ) + ) + elif not isinstance(result.schema, StructType): + raise PySparkValueError( + format_error( + f""" + {error_prefix} because the static 'analyze' method expects a result of type + pyspark.sql.udtf.AnalyzeResult with a 'schema' field comprising a StructType, + but the 'schema' field had the wrong type: {type(result.schema)}""" + ) + ) + has_table_arg = any(arg.isTable for arg in args) or any( + arg.isTable for arg in kwargs.values() + ) + if not has_table_arg and result.withSinglePartition: + raise PySparkValueError( + format_error( + f""" + {error_prefix} because the static 'analyze' method returned an + 'AnalyzeResult' object with the 'withSinglePartition' field set to 'true', but + the function call did not provide any table argument. Please update the query so + that it provides a table argument, or else update the table function so that its + 'analyze' method returns an 'AnalyzeResult' object with the + 'withSinglePartition' field set to 'false', and then try the query again.""" + ) + ) + elif not has_table_arg and len(result.partitionBy) > 0: + raise PySparkValueError( + format_error( + f""" + {error_prefix} because the static 'analyze' method returned an + 'AnalyzeResult' object with the 'partitionBy' list set to non-empty, but the + function call did not provide any table argument. Please update the query so + that it provides a table argument, or else update the table function so that its + 'analyze' method returns an 'AnalyzeResult' object with the 'partitionBy' list + set to empty, and then try the query again.""" + ) + ) + elif isinstance(result.partitionBy, (list, tuple)) and ( + len(result.partitionBy) > 0 + and not all([isinstance(val, PartitioningColumn) for val in result.partitionBy]) + ): + raise PySparkValueError( + format_error( + f""" + {error_prefix} because the static 'analyze' method returned an + 'AnalyzeResult' object with the 'partitionBy' field set to a value besides a + list or tuple of 'PartitioningColumn' objects. Please update the table function + and then try the query again.""" + ) ) # Return the analyzed schema. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 195c989c41066..060594292ad58 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -22,7 +22,7 @@ import sys import dataclasses import time -from inspect import getfullargspec +import inspect import json from typing import Any, Callable, Iterable, Iterator, Optional import faulthandler @@ -616,12 +616,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: return args_offsets, wrap_arrow_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - argspec = getfullargspec(chained_func) # signature was lost when wrapping it + argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - argspec = getfullargspec(chained_func) # signature was lost when wrapping it + argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return wrap_grouped_agg_pandas_udf(func, args_offsets, kwargs_offsets, return_type) @@ -691,7 +691,6 @@ def read_udtf(pickleSer, infile, eval_type): pickled_analyze_result = None # Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult, # although we may set this to false later if we find otherwise. - udtf_init_method_accepts_analyze_result = True handler = read_command(pickleSer, infile) if not isinstance(handler, type): raise PySparkRuntimeError( @@ -704,29 +703,32 @@ def read_udtf(pickleSer, infile, eval_type): raise PySparkRuntimeError( f"The return type of a UDTF must be a struct type, but got {type(return_type)}." ) + udtf_name = utf8_deserializer.loads(infile) # Update the handler that creates a new UDTF instance to first try calling the UDTF constructor # with one argument containing the previous AnalyzeResult. If that fails, then try a constructor # with no arguments. In this way each UDTF class instance can decide if it wants to inspect the # AnalyzeResult. + udtf_init_args = inspect.getfullargspec(handler) if has_pickled_analyze_result: - prev_handler = handler + if len(udtf_init_args.args) > 2: + raise PySparkRuntimeError( + error_class="UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD", + message_parameters={"name": udtf_name}, + ) + elif len(udtf_init_args.args) == 2: + prev_handler = handler - def construct_udtf(): - nonlocal udtf_init_method_accepts_analyze_result - if not udtf_init_method_accepts_analyze_result: - return prev_handler() - else: - try: - # Here we pass the AnalyzeResult to the UDTF's __init__ method. - return prev_handler(dataclasses.replace(pickled_analyze_result)) - except TypeError: - # This means that the UDTF handler does not accept an AnalyzeResult object in - # its __init__ method. - udtf_init_method_accepts_analyze_result = False - return prev_handler() + def construct_udtf(): + # Here we pass the AnalyzeResult to the UDTF's __init__ method. + return prev_handler(dataclasses.replace(pickled_analyze_result)) - handler = construct_udtf + handler = construct_udtf + elif len(udtf_init_args.args) > 1: + raise PySparkRuntimeError( + error_class="UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD", + message_parameters={"name": udtf_name}, + ) class UDTFWithPartitions: """ @@ -854,6 +856,16 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: "the query again." ) + # Check that the arguments provided to the UDTF call match the expected parameters defined + # in the 'eval' method signature. + try: + inspect.signature(udtf.eval).bind(*args_offsets, **kwargs_offsets) + except TypeError as e: + raise PySparkRuntimeError( + error_class="UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE", + message_parameters={"name": udtf_name, "reason": str(e)}, + ) from None + def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]: def raise_(result_column_index): raise PySparkRuntimeError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index 40993f96e7a0c..e6b19910296e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -142,5 +142,7 @@ object PythonUDTFRunner { PythonWorkerUtils.writePythonFunction(udtf.func, dataOut) // Write the result schema of the UDTF call. PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut) + // Write the UDTF name. + PythonWorkerUtils.writeUTF(udtf.name, dataOut) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index dd9a869bf06ac..202159907af82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -133,7 +133,7 @@ case class UserDefinedPythonTableFunction( case _ => false } val runAnalyzeInPython = (func: PythonFunction, exprs: Seq[Expression]) => { - val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(func, exprs, tableArgs) + val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(name, func, exprs, tableArgs) runner.runInPython() } UnresolvedPolymorphicPythonUDTF( @@ -184,6 +184,7 @@ case class UserDefinedPythonTableFunction( * will be thrown when an exception is raised in Python. */ class UserDefinedPythonTableFunctionAnalyzeRunner( + name: String, func: PythonFunction, exprs: Seq[Expression], tableArgs: Seq[Boolean]) extends PythonPlannerRunner[PythonUDTFAnalyzeResult](func) { @@ -192,6 +193,7 @@ class UserDefinedPythonTableFunctionAnalyzeRunner( override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = { // Send Python UDTF + PythonWorkerUtils.writeUTF(name, dataOut) PythonWorkerUtils.writePythonFunction(func, dataOut) // Send arguments @@ -226,6 +228,9 @@ class UserDefinedPythonTableFunctionAnalyzeRunner( val length = dataIn.readInt() if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { val msg = PythonWorkerUtils.readUTF(dataIn) + // Remove the leading traceback stack trace from the error message string, if any, since it + // usually only includes the "analyze_udtf.py" filename and a line number. + .split("PySparkValueError:").last.strip() throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out index 078bd790a8453..3a9dfc26bcc92 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out @@ -512,6 +512,252 @@ SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnMapType(TABLE(t2)) [Analyzer test output redacted due to nondeterminism] +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs() +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(1, 2) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFForwardStateFromAnalyzeWithKwargs' because the function arguments did not match the expected signature of the static 'analyze' method (too many positional arguments). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 57, + "fragment" : "UDTFForwardStateFromAnalyzeWithKwargs(1, 2)" + } ] +} + + +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(invalid => 2) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(argument => 1, argument => 2) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`UDTFForwardStateFromAnalyzeWithKwargs`", + "parameterName" : "`argument`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 81, + "fragment" : "UDTFForwardStateFromAnalyzeWithKwargs(argument => 1, argument => 2)" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodWithSinglePartitionNoInputTable(argument => 1) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodWithSinglePartitionNoInputTable' because the static 'analyze' method returned an 'AnalyzeResult' object with the 'withSinglePartition' field set to 'true', but the function call did not provide any table argument. Please update the query so that it provides a table argument, or else update the table function so that its 'analyze' method returns an 'AnalyzeResult' object with the 'withSinglePartition' field set to 'false', and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 80, + "fragment" : "InvalidAnalyzeMethodWithSinglePartitionNoInputTable(argument => 1)" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodWithPartitionByNoInputTable(argument => 1) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodWithPartitionByNoInputTable' because the static 'analyze' method returned an 'AnalyzeResult' object with the 'partitionBy' list set to non-empty, but the function call did not provide any table argument. Please update the query so that it provides a table argument, or else update the table function so that its 'analyze' method returns an 'AnalyzeResult' object with the 'partitionBy' list set to empty, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 76, + "fragment" : "InvalidAnalyzeMethodWithPartitionByNoInputTable(argument => 1)" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodReturnsNonStructTypeSchema(TABLE(t2)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodReturnsNonStructTypeSchema' because the static 'analyze' method expects a result of type pyspark.sql.udtf.AnalyzeResult with a 'schema' field comprising a StructType, but the 'schema' field had the wrong type: " + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 71, + "fragment" : "InvalidAnalyzeMethodReturnsNonStructTypeSchema(TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodWithPartitionByListOfStrings(argument => TABLE(t2)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodWithPartitionByListOfStrings' because the static 'analyze' method returned an 'AnalyzeResult' object with the 'partitionBy' field set to a value besides a list or tuple of 'PartitioningColumn' objects. Please update the table function and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 85, + "fragment" : "InvalidAnalyzeMethodWithPartitionByListOfStrings(argument => TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM InvalidForwardStateFromAnalyzeTooManyInitArgs(TABLE(t2)) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM InvalidNotForwardStateFromAnalyzeTooManyInitArgs(TABLE(t2)) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (missing a required argument: 'input_table'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 40, + "fragment" : "UDTFWithSinglePartition(1)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1, 2, 3) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (too many positional arguments). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 46, + "fragment" : "UDTFWithSinglePartition(1, 2, 3)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1, invalid_arg_name => 2) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (missing a required argument: 'input_table'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 63, + "fragment" : "UDTFWithSinglePartition(1, invalid_arg_name => 2)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1, initial_count => 2) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (multiple values for argument 'initial_count'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 60, + "fragment" : "UDTFWithSinglePartition(1, initial_count => 2)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(initial_count => 1, initial_count => 2) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`UDTFWithSinglePartition`", + "parameterName" : "`initial_count`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 77, + "fragment" : "UDTFWithSinglePartition(initial_count => 1, initial_count => 2)" + } ] +} + + -- !query DROP VIEW t1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql index 3d7d0bb325169..68885923e9f77 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql @@ -116,6 +116,23 @@ SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnArrayType(TABLE(t2)) SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType(TABLE(t2)); SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnStructType(TABLE(t2)); SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnMapType(TABLE(t2)); +-- The following UDTF calls exercise various invalid function definitions and calls to show the +-- error messages. +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(); +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(1, 2); +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(invalid => 2); +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(argument => 1, argument => 2); +SELECT * FROM InvalidAnalyzeMethodWithSinglePartitionNoInputTable(argument => 1); +SELECT * FROM InvalidAnalyzeMethodWithPartitionByNoInputTable(argument => 1); +SELECT * FROM InvalidAnalyzeMethodReturnsNonStructTypeSchema(TABLE(t2)); +SELECT * FROM InvalidAnalyzeMethodWithPartitionByListOfStrings(argument => TABLE(t2)); +SELECT * FROM InvalidForwardStateFromAnalyzeTooManyInitArgs(TABLE(t2)); +SELECT * FROM InvalidNotForwardStateFromAnalyzeTooManyInitArgs(TABLE(t2)); +SELECT * FROM UDTFWithSinglePartition(1); +SELECT * FROM UDTFWithSinglePartition(1, 2, 3); +SELECT * FROM UDTFWithSinglePartition(1, invalid_arg_name => 2); +SELECT * FROM UDTFWithSinglePartition(1, initial_count => 2); +SELECT * FROM UDTFWithSinglePartition(initial_count => 1, initial_count => 2); -- cleanup DROP VIEW t1; diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out index 3317f5fada763..1ed7726dde8e8 100644 --- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out @@ -615,6 +615,286 @@ org.apache.spark.api.python.PythonException pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User defined table function encountered an error in the 'eval' or 'terminate' method: Column 0 within a returned row had a value of None, either directly or within array/struct/map subfields, but the corresponding column type was declared as non-nullable; please update the UDTF to return a non-None value at this location or otherwise declare the column type as nullable. +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs() +-- !query schema +struct<> +-- !query output +org.apache.spark.api.python.PythonException +pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE] Failed to evaluate the user-defined table function 'UDTFForwardStateFromAnalyzeWithKwargs' because the function arguments did not match the expected signature of the 'eval' method (missing a required argument: 'argument'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its 'eval' method accepts the provided arguments, and then try the query again. + + +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(1, 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFForwardStateFromAnalyzeWithKwargs' because the function arguments did not match the expected signature of the static 'analyze' method (too many positional arguments). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 57, + "fragment" : "UDTFForwardStateFromAnalyzeWithKwargs(1, 2)" + } ] +} + + +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(invalid => 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.api.python.PythonException +pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE] Failed to evaluate the user-defined table function 'UDTFForwardStateFromAnalyzeWithKwargs' because the function arguments did not match the expected signature of the 'eval' method (missing a required argument: 'argument'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its 'eval' method accepts the provided arguments, and then try the query again. + + +-- !query +SELECT * FROM UDTFForwardStateFromAnalyzeWithKwargs(argument => 1, argument => 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`UDTFForwardStateFromAnalyzeWithKwargs`", + "parameterName" : "`argument`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 81, + "fragment" : "UDTFForwardStateFromAnalyzeWithKwargs(argument => 1, argument => 2)" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodWithSinglePartitionNoInputTable(argument => 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodWithSinglePartitionNoInputTable' because the static 'analyze' method returned an 'AnalyzeResult' object with the 'withSinglePartition' field set to 'true', but the function call did not provide any table argument. Please update the query so that it provides a table argument, or else update the table function so that its 'analyze' method returns an 'AnalyzeResult' object with the 'withSinglePartition' field set to 'false', and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 80, + "fragment" : "InvalidAnalyzeMethodWithSinglePartitionNoInputTable(argument => 1)" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodWithPartitionByNoInputTable(argument => 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodWithPartitionByNoInputTable' because the static 'analyze' method returned an 'AnalyzeResult' object with the 'partitionBy' list set to non-empty, but the function call did not provide any table argument. Please update the query so that it provides a table argument, or else update the table function so that its 'analyze' method returns an 'AnalyzeResult' object with the 'partitionBy' list set to empty, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 76, + "fragment" : "InvalidAnalyzeMethodWithPartitionByNoInputTable(argument => 1)" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodReturnsNonStructTypeSchema(TABLE(t2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodReturnsNonStructTypeSchema' because the static 'analyze' method expects a result of type pyspark.sql.udtf.AnalyzeResult with a 'schema' field comprising a StructType, but the 'schema' field had the wrong type: " + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 71, + "fragment" : "InvalidAnalyzeMethodReturnsNonStructTypeSchema(TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM InvalidAnalyzeMethodWithPartitionByListOfStrings(argument => TABLE(t2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'InvalidAnalyzeMethodWithPartitionByListOfStrings' because the static 'analyze' method returned an 'AnalyzeResult' object with the 'partitionBy' field set to a value besides a list or tuple of 'PartitioningColumn' objects. Please update the table function and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 85, + "fragment" : "InvalidAnalyzeMethodWithPartitionByListOfStrings(argument => TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM InvalidForwardStateFromAnalyzeTooManyInitArgs(TABLE(t2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.api.python.PythonException +pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD] Failed to evaluate the user-defined table function 'InvalidForwardStateFromAnalyzeTooManyInitArgs' because its constructor is invalid: the function implements the 'analyze' method, but its constructor has more than two arguments (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, or one 'self' argument plus another argument for the result of the 'analyze' method, and try the query again. + + +-- !query +SELECT * FROM InvalidNotForwardStateFromAnalyzeTooManyInitArgs(TABLE(t2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.api.python.PythonException +pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD] Failed to evaluate the user-defined table function 'InvalidNotForwardStateFromAnalyzeTooManyInitArgs' because its constructor is invalid: the function does not implement the 'analyze' method, and its constructor has more than one argument (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, and try the query again. + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (missing a required argument: 'input_table'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 40, + "fragment" : "UDTFWithSinglePartition(1)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1, 2, 3) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (too many positional arguments). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 46, + "fragment" : "UDTFWithSinglePartition(1, 2, 3)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1, invalid_arg_name => 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (missing a required argument: 'input_table'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 63, + "fragment" : "UDTFWithSinglePartition(1, invalid_arg_name => 2)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(1, initial_count => 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON", + "sqlState" : "38000", + "messageParameters" : { + "msg" : "Failed to evaluate the user-defined table function 'UDTFWithSinglePartition' because the function arguments did not match the expected signature of the static 'analyze' method (multiple values for argument 'initial_count'). Please update the query so that this table function call provides arguments matching the expected signature, or else update the table function so that its static 'analyze' method accepts the provided arguments, and then try the query again." + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 60, + "fragment" : "UDTFWithSinglePartition(1, initial_count => 2)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(initial_count => 1, initial_count => 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`UDTFWithSinglePartition`", + "parameterName" : "`initial_count`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 77, + "fragment" : "UDTFWithSinglePartition(initial_count => 1, initial_count => 2)" + } ] +} + + -- !query DROP VIEW t1 -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 20158bc5cc620..99045ffd86371 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -345,11 +345,15 @@ object IntegratedUDFTestUtils extends SQLHelper { } sealed trait TestUDTF { - def apply(session: SparkSession, exprs: Column*): DataFrame + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) - val name: String - val prettyName: String - val udtf: UserDefinedPythonTableFunction + val name: String = getClass.getSimpleName.stripSuffix("$") + val pythonScript: String + lazy val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) } class PythonUDFWithoutId( @@ -459,7 +463,7 @@ object IntegratedUDFTestUtils extends SQLHelper { udfDeterministic = deterministic) } - case class TestPythonUDTF(name: String) extends TestUDTF { + case class TestPythonUDTF(override val name: String) extends TestUDTF { val pythonScript: String = """ |class TestUDTF: @@ -473,20 +477,14 @@ object IntegratedUDFTestUtils extends SQLHelper { | ... |""".stripMargin - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + override lazy val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( name = "TestUDTF", pythonScript = pythonScript, returnType = Some(StructType.fromDDL("x int, y int")) ) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = "Regular Python UDTF" } - object TestPythonUDTFCountSumLast extends TestUDTF { - val name: String = "UDTFCountSumLast" + object UDTFCountSumLast extends TestUDTF { val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn @@ -498,7 +496,7 @@ object IntegratedUDFTestUtils extends SQLHelper { | self._last = None | | @staticmethod - | def analyze(self): + | def analyze(row: Row): | return AnalyzeResult( | schema=StructType() | .add("count", IntegerType()) @@ -513,21 +511,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield self._count, self._sum, self._last |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Python UDTF finding the count, sum, and last value from the input rows" } - object TestPythonUDTFLastString extends TestUDTF { - val name: String = "UDTFLastString" + object UDTFLastString extends TestUDTF { val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -537,7 +523,7 @@ object IntegratedUDFTestUtils extends SQLHelper { | self._last = "" | | @staticmethod - | def analyze(self): + | def analyze(row: Row): | return AnalyzeResult( | schema=StructType() | .add("last", StringType())) @@ -548,21 +534,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield self._last, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = "Python UDTF returning the last string provided in the input table" } - - object TestPythonUDTFWithSinglePartition extends TestUDTF { - val name: String = "UDTFWithSinglePartition" + object UDTFWithSinglePartition extends TestUDTF { val pythonScript: String = s""" |import json @@ -607,20 +581,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield self._count, self._sum, self._last |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = "Python UDTF exporting single-partition requirement from 'analyze'" } - object TestPythonUDTFPartitionBy extends TestUDTF { - val name: String = "UDTFPartitionByOrderBy" + object UDTFPartitionByOrderBy extends TestUDTF { val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn @@ -633,7 +596,7 @@ object IntegratedUDFTestUtils extends SQLHelper { | self._last = None | | @staticmethod - | def analyze(self): + | def analyze(row: Row): | return AnalyzeResult( | schema=StructType() | .add("partition_col", IntegerType()) @@ -656,21 +619,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield self._partition_col, self._count, self._sum, self._last |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Python UDTF exporting input table partitioning and ordering requirement from 'analyze'" } - object InvalidPartitionByAndWithSinglePartition extends TestUDTF { - val name: String = "UDTFInvalidPartitionByAndWithSinglePartition" + object UDTFInvalidPartitionByAndWithSinglePartition extends TestUDTF { val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn @@ -680,7 +631,7 @@ object IntegratedUDFTestUtils extends SQLHelper { | self._last = None | | @staticmethod - | def analyze(self): + | def analyze(row: Row): | return AnalyzeResult( | schema=StructType() | .add("last", IntegerType()), @@ -695,22 +646,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield self._last, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Python UDTF exporting invalid input table partitioning requirement from 'analyze' " + - "because the 'withSinglePartition' property is also exported to true" } - object InvalidOrderByWithoutPartitionBy extends TestUDTF { - val name: String = "UDTFInvalidOrderByWithoutPartitionBy" + object UDTFInvalidOrderByWithoutPartitionBy extends TestUDTF { val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn @@ -720,7 +658,7 @@ object IntegratedUDFTestUtils extends SQLHelper { | self._last = None | | @staticmethod - | def analyze(self): + | def analyze(row: Row): | return AnalyzeResult( | schema=StructType() | .add("last", IntegerType()), @@ -734,22 +672,170 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield self._last, |""".stripMargin + } - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) + object UDTFForwardStateFromAnalyze extends TestUDTF { + val pythonScript: String = + s""" + |from dataclasses import dataclass + |from pyspark.sql.functions import AnalyzeResult + |from pyspark.sql.types import StringType, StructType + | + |@dataclass + |class AnalyzeResultWithBuffer(AnalyzeResult): + | buffer: str = "" + | + |class $name: + | def __init__(self, analyze_result): + | self._analyze_result = analyze_result + | + | @staticmethod + | def analyze(argument): + | assert(argument.dataType == StringType()) + | return AnalyzeResultWithBuffer( + | schema=StructType() + | .add("result", StringType()), + | buffer=argument.value) + | + | def eval(self, argument): + | pass + | + | def terminate(self): + | yield self._analyze_result.buffer, + |""".stripMargin + } - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) + object UDTFForwardStateFromAnalyzeWithKwargs extends TestUDTF { + val pythonScript: String = + s""" + |from dataclasses import dataclass + |from pyspark.sql.functions import AnalyzeResult + |from pyspark.sql.types import StringType, StructType + | + |@dataclass + |class AnalyzeResultWithBuffer(AnalyzeResult): + | buffer: str = "" + | + |class $name: + | def __init__(self, analyze_result): + | self._analyze_result = analyze_result + | + | @staticmethod + | def analyze(**kwargs): + | argument = kwargs.get("argument") + | if argument is not None: + | assert(argument.dataType == StringType()) + | argument_value = argument.value + | else: + | argument_value = None + | return AnalyzeResultWithBuffer( + | schema=StructType() + | .add("result", StringType()), + | buffer=argument_value) + | + | def eval(self, argument: str): + | pass + | + | def terminate(self): + | yield self._analyze_result.buffer, + |""".stripMargin + } - val prettyName: String = - "Python UDTF exporting invalid input table ordering requirement from 'analyze' " + - "without a corresponding partitioning table requirement" + object InvalidAnalyzeMethodReturnsNonStructTypeSchema extends TestUDTF { + val pythonScript: String = + s""" + |from dataclasses import dataclass + |from pyspark.sql.functions import AnalyzeResult + |from pyspark.sql.types import StringType, StructType + | + |class $name: + | @staticmethod + | def analyze(argument): + | return AnalyzeResult( + | schema=42) + | + | def eval(self, argument): + | pass + | + | def terminate(self): + | yield 42, + |""".stripMargin } - object TestPythonUDTFForwardStateFromAnalyze extends TestUDTF { - val name: String = "TestPythonUDTFForwardStateFromAnalyze" + object InvalidAnalyzeMethodWithSinglePartitionNoInputTable extends TestUDTF { + val pythonScript: String = + s""" + |from dataclasses import dataclass + |from pyspark.sql.functions import AnalyzeResult + |from pyspark.sql.types import StringType, StructType + | + |class $name: + | @staticmethod + | def analyze(**wkargs): + | return AnalyzeResult( + | schema=StructType() + | .add("result", StringType()), + | withSinglePartition=True) + | + | def eval(self, argument): + | pass + | + | def terminate(self): + | yield 42, + |""".stripMargin + } + + object InvalidAnalyzeMethodWithPartitionByNoInputTable extends TestUDTF { + val pythonScript: String = + s""" + |from dataclasses import dataclass + |from pyspark.sql.functions import AnalyzeResult, PartitioningColumn + |from pyspark.sql.types import StringType, StructType + | + |class $name: + | @staticmethod + | def analyze(**wkargs): + | return AnalyzeResult( + | schema=StructType() + | .add("result", StringType()), + | partitionBy=[ + | PartitioningColumn("partition_col") + | ]) + | + | def eval(self, argument): + | pass + | + | def terminate(self): + | yield 42, + |""".stripMargin + } + + object InvalidAnalyzeMethodWithPartitionByListOfStrings extends TestUDTF { + val pythonScript: String = + s""" + |from dataclasses import dataclass + |from pyspark.sql.functions import AnalyzeResult, PartitioningColumn + |from pyspark.sql.types import StringType, StructType + | + |class $name: + | @staticmethod + | def analyze(**wkargs): + | return AnalyzeResult( + | schema=StructType() + | .add("result", StringType()), + | partitionBy=[ + | "partition_col" + | ]) + | + | def eval(self, argument): + | pass + | + | def terminate(self): + | yield 42, + |""".stripMargin + } + + object InvalidForwardStateFromAnalyzeTooManyInitArgs extends TestUDTF { val pythonScript: String = s""" |from dataclasses import dataclass @@ -761,12 +847,11 @@ object IntegratedUDFTestUtils extends SQLHelper { | buffer: str = "" | |class $name: - | def __init__(self, analyze_result): + | def __init__(self, analyze_result, other_argument): | self._analyze_result = analyze_result | | @staticmethod | def analyze(argument): - | assert(argument.dataType == StringType()) | return AnalyzeResultWithBuffer( | schema=StructType() | .add("result", StringType()), @@ -778,20 +863,30 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield self._analyze_result.buffer, |""".stripMargin + } + + object InvalidNotForwardStateFromAnalyzeTooManyInitArgs extends TestUDTF { + val pythonScript: String = + s""" + |class $name: + | def __init__(self, other_argument): + | pass + | + | def eval(self, argument): + | pass + | + | def terminate(self): + | yield 'abc', + |""".stripMargin - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + override lazy val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( name = name, pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = "Python UDTF whose 'analyze' method sets state and reads it later" + returnType = Some(StructType.fromDDL("result string")) + ) } object InvalidEvalReturnsNoneToNonNullableColumnScalarType extends TestUDTF { - val name: String = "InvalidEvalReturnsNoneToNonNullableColumnScalarType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -811,21 +906,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def eval(self, *args): | yield None, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'eval' method returns None to a non-nullable scalar column" } object InvalidEvalReturnsNoneToNonNullableColumnArrayType extends TestUDTF { - val name: String = "InvalidEvalReturnsNoneToNonNullableColumnArrayType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -845,21 +928,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def eval(self, *args): | yield None, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'eval' method returns None to a non-nullable array column" } object InvalidEvalReturnsNoneToNonNullableColumnArrayElementType extends TestUDTF { - val name: String = "InvalidEvalReturnsNoneToNonNullableColumnArrayElementType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -879,21 +950,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def eval(self, *args): | yield [1, 2, None, 3], |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'eval' method returns None to a non-nullable array element" } object InvalidEvalReturnsNoneToNonNullableColumnStructType extends TestUDTF { - val name: String = "InvalidEvalReturnsNoneToNonNullableColumnStructType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -913,21 +972,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def eval(self, *args): | yield Row(field=None), |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'eval' method returns None to a non-nullable struct column" } object InvalidEvalReturnsNoneToNonNullableColumnMapType extends TestUDTF { - val name: String = "InvalidEvalReturnsNoneToNonNullableColumnMapType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -947,21 +994,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def eval(self, *args): | yield {42: None}, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'eval' method returns None to a non-nullable map column" } object InvalidTerminateReturnsNoneToNonNullableColumnScalarType extends TestUDTF { - val name: String = "InvalidTerminateReturnsNoneToNonNullableColumnScalarType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -984,21 +1019,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield None, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'terminate' method returns None to a non-nullable column" } object InvalidTerminateReturnsNoneToNonNullableColumnArrayType extends TestUDTF { - val name: String = "InvalidTerminateReturnsNoneToNonNullableColumnArrayType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -1021,21 +1044,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield None, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'terminate' method returns None to a non-nullable array column" } object InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType extends TestUDTF { - val name: String = "InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -1058,21 +1069,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield [1, 2, None, 3], |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'terminate' method returns None to a non-nullable array element" } object InvalidTerminateReturnsNoneToNonNullableColumnStructType extends TestUDTF { - val name: String = "InvalidTerminateReturnsNoneToNonNullableColumnStructType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -1095,21 +1094,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield Row(field=None), |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'terminate' method returns None to a non-nullable struct column" } object InvalidTerminateReturnsNoneToNonNullableColumnMapType extends TestUDTF { - val name: String = "InvalidTerminateReturnsNoneToNonNullableColumnMapType" val pythonScript: String = s""" |from pyspark.sql.functions import AnalyzeResult @@ -1132,19 +1119,36 @@ object IntegratedUDFTestUtils extends SQLHelper { | def terminate(self): | yield {42: None}, |""".stripMargin - - val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - name = name, - pythonScript = pythonScript, - returnType = None) - - def apply(session: SparkSession, exprs: Column*): DataFrame = - udtf.apply(session, exprs: _*) - - val prettyName: String = - "Invalid Python UDTF whose 'terminate' method returns None to a non-nullable map column" } + def AllTestUDTFs: Seq[TestUDTF] = Seq( + TestPythonUDTF("udtf"), + UDTFCountSumLast, + UDTFLastString, + UDTFWithSinglePartition, + UDTFPartitionByOrderBy, + UDTFInvalidPartitionByAndWithSinglePartition, + UDTFInvalidOrderByWithoutPartitionBy, + UDTFForwardStateFromAnalyze, + UDTFForwardStateFromAnalyzeWithKwargs, + InvalidAnalyzeMethodReturnsNonStructTypeSchema, + InvalidAnalyzeMethodWithSinglePartitionNoInputTable, + InvalidAnalyzeMethodWithPartitionByNoInputTable, + InvalidAnalyzeMethodWithPartitionByListOfStrings, + InvalidForwardStateFromAnalyzeTooManyInitArgs, + InvalidNotForwardStateFromAnalyzeTooManyInitArgs, + InvalidEvalReturnsNoneToNonNullableColumnScalarType, + InvalidEvalReturnsNoneToNonNullableColumnArrayType, + InvalidEvalReturnsNoneToNonNullableColumnArrayElementType, + InvalidEvalReturnsNoneToNonNullableColumnStructType, + InvalidEvalReturnsNoneToNonNullableColumnMapType, + InvalidTerminateReturnsNoneToNonNullableColumnScalarType, + InvalidTerminateReturnsNoneToNonNullableColumnArrayType, + InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType, + InvalidTerminateReturnsNoneToNonNullableColumnStructType, + InvalidTerminateReturnsNoneToNonNullableColumnMapType + ) + /** * A Scalar Pandas UDF that takes one column, casts into string, executes the * Python native function, and casts back to the type of input column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 36b572c7e8b5d..d9f4c685163c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -645,25 +645,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf) } } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) { - Seq(TestUDTFSet(Seq( - TestPythonUDTF("udtf"), - TestPythonUDTFCountSumLast, - TestPythonUDTFLastString, - TestPythonUDTFWithSinglePartition, - TestPythonUDTFPartitionBy, - InvalidPartitionByAndWithSinglePartition, - InvalidOrderByWithoutPartitionBy, - InvalidEvalReturnsNoneToNonNullableColumnScalarType, - InvalidEvalReturnsNoneToNonNullableColumnArrayType, - InvalidEvalReturnsNoneToNonNullableColumnArrayElementType, - InvalidEvalReturnsNoneToNonNullableColumnStructType, - InvalidEvalReturnsNoneToNonNullableColumnMapType, - InvalidTerminateReturnsNoneToNonNullableColumnScalarType, - InvalidTerminateReturnsNoneToNonNullableColumnArrayType, - InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType, - InvalidTerminateReturnsNoneToNonNullableColumnStructType, - InvalidTerminateReturnsNoneToNonNullableColumnMapType - ))).map { udtfSet => + Seq(TestUDTFSet(AllTestUDTFs)).map { udtfSet => UDTFSetTestCase( s"$testCaseName - Python UDTFs", absPath, resultFile, udtfSet) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index efab685236de3..989597ae041db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -48,15 +48,15 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { private val pythonUDTFCountSumLast: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - TestPythonUDTFCountSumLast.name, TestPythonUDTFCountSumLast.pythonScript, None) + UDTFCountSumLast.name, UDTFCountSumLast.pythonScript, None) private val pythonUDTFWithSinglePartition: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - TestPythonUDTFWithSinglePartition.name, TestPythonUDTFWithSinglePartition.pythonScript, None) + UDTFWithSinglePartition.name, UDTFWithSinglePartition.pythonScript, None) private val pythonUDTFPartitionByOrderBy: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - TestPythonUDTFPartitionBy.name, TestPythonUDTFPartitionBy.pythonScript, None) + UDTFPartitionByOrderBy.name, UDTFPartitionByOrderBy.pythonScript, None) private val arrowPythonUDTF: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( @@ -67,8 +67,8 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { private val pythonUDTFForwardStateFromAnalyze: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - TestPythonUDTFForwardStateFromAnalyze.name, - TestPythonUDTFForwardStateFromAnalyze.pythonScript, None) + UDTFForwardStateFromAnalyze.name, + UDTFForwardStateFromAnalyze.pythonScript, None) test("Simple PythonUDTF") { assume(shouldTestPythonUDFs) @@ -205,14 +205,14 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { stop = 29)) } - spark.udtf.registerPython(TestPythonUDTFCountSumLast.name, pythonUDTFCountSumLast) + spark.udtf.registerPython(UDTFCountSumLast.name, pythonUDTFCountSumLast) var plan = sql( s""" |WITH t AS ( | VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input) |) |SELECT count, total, last - |FROM ${TestPythonUDTFCountSumLast.name}(TABLE(t) WITH SINGLE PARTITION) + |FROM ${UDTFCountSumLast.name}(TABLE(t) WITH SINGLE PARTITION) |ORDER BY 1, 2 |""".stripMargin).queryExecution.analyzed plan.collectFirst { case r: Repartition => r } match { @@ -221,7 +221,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { failure(plan) } - spark.udtf.registerPython(TestPythonUDTFWithSinglePartition.name, pythonUDTFWithSinglePartition) + spark.udtf.registerPython(UDTFWithSinglePartition.name, pythonUDTFWithSinglePartition) plan = sql( s""" |WITH t AS ( @@ -230,7 +230,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { | SELECT id AS partition_col, 2 AS input FROM range(1, 21) |) |SELECT count, total, last - |FROM ${TestPythonUDTFWithSinglePartition.name}(0, TABLE(t)) + |FROM ${UDTFWithSinglePartition.name}(0, TABLE(t)) |ORDER BY 1, 2 |""".stripMargin).queryExecution.analyzed plan.collectFirst { case r: Repartition => r } match { @@ -239,7 +239,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { failure(plan) } - spark.udtf.registerPython(TestPythonUDTFPartitionBy.name, pythonUDTFPartitionByOrderBy) + spark.udtf.registerPython(UDTFPartitionByOrderBy.name, pythonUDTFPartitionByOrderBy) plan = sql( s""" |WITH t AS ( @@ -248,7 +248,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { | SELECT id AS partition_col, 2 AS input FROM range(1, 21) |) |SELECT partition_col, count, total, last - |FROM ${TestPythonUDTFPartitionBy.name}(TABLE(t)) + |FROM ${UDTFPartitionByOrderBy.name}(TABLE(t)) |ORDER BY 1, 2 |""".stripMargin).queryExecution.analyzed plan.collectFirst { case r: RepartitionByExpression => r } match { @@ -353,11 +353,11 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { test("SPARK-45402: Add UDTF API for 'analyze' to return a buffer to consume on class creation") { spark.udtf.registerPython( - TestPythonUDTFForwardStateFromAnalyze.name, + UDTFForwardStateFromAnalyze.name, pythonUDTFForwardStateFromAnalyze) withTable("t") { sql("create table t(col array) using parquet") - val query = s"select * from ${TestPythonUDTFForwardStateFromAnalyze.name}('abc')" + val query = s"select * from ${UDTFForwardStateFromAnalyze.name}('abc')" checkAnswer( sql(query), Row("abc"))