diff --git a/python/docs/source/reference/pyspark.sql/udtf.rst b/python/docs/source/reference/pyspark.sql/udtf.rst index c251e101bbb3b..64400a257ee0e 100644 --- a/python/docs/source/reference/pyspark.sql/udtf.rst +++ b/python/docs/source/reference/pyspark.sql/udtf.rst @@ -25,6 +25,6 @@ UDTF .. autosummary:: :toctree: api/ - udtf.UserDefinedTableFunction.asNondeterministic + udtf.UserDefinedTableFunction.asDeterministic udtf.UserDefinedTableFunction.returnType UDTFRegistration.register diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index d5ce91803581d..16f9b990760dc 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -50,7 +50,7 @@ def _create_udtf( returnType: Optional[Union[StructType, str]], name: Optional[str] = None, evalType: int = PythonEvalType.SQL_TABLE_UDF, - deterministic: bool = True, + deterministic: bool = False, ) -> "UserDefinedTableFunction": udtf_obj = UserDefinedTableFunction( cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic @@ -62,7 +62,7 @@ def _create_py_udtf( cls: Type, returnType: Optional[Union[StructType, str]], name: Optional[str] = None, - deterministic: bool = True, + deterministic: bool = False, useArrow: Optional[bool] = None, ) -> "UserDefinedTableFunction": if useArrow is not None: @@ -121,7 +121,7 @@ def __init__( returnType: Optional[Union[StructType, str]], name: Optional[str] = None, evalType: int = PythonEvalType.SQL_TABLE_UDF, - deterministic: bool = True, + deterministic: bool = False, ) -> None: _validate_udtf_handler(func, returnType) @@ -169,8 +169,8 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFram plan = self._build_common_inline_user_defined_table_function(*args, **kwargs) return DataFrame.withPlan(plan, session) - def asNondeterministic(self) -> "UserDefinedTableFunction": - self.deterministic = False + def asDeterministic(self) -> "UserDefinedTableFunction": + self.deterministic = True return self diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cf88932d1ec1b..d6d8c1322f1a5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -15677,14 +15677,13 @@ def udtf( Notes ----- - User-defined table functions (UDTFs) are considered deterministic by default. - Use `asNondeterministic()` to mark a function as non-deterministic. E.g.: + User-defined table functions (UDTFs) are considered non-deterministic by default. + Use `asDeterministic()` to mark a function as deterministic. E.g.: - >>> import random - >>> class RandomUDTF: + >>> class PlusOne: ... def eval(self, a: int): - ... yield a * int(random.random() * 100), - >>> random_udtf = udtf(RandomUDTF, returnType="r: int").asNondeterministic() + ... yield a + 1, + >>> plus_one = udtf(PlusOne, returnType="r: int").asDeterministic() Use "yield" to produce one row for the UDTF result relation as many times as needed. In the context of a lateral join, each such result row will be diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index aa7df815e81be..59a482ad41182 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -456,6 +456,17 @@ def terminate(self): with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() + def test_udtf_determinism(self): + class TestUDTF: + def eval(self, a: int): + yield a, + + func = udtf(TestUDTF, returnType="x: int") + # The UDTF is marked as non-deterministic by default. + self.assertFalse(func.deterministic) + func = func.asDeterministic() + self.assertTrue(func.deterministic) + def test_nondeterministic_udtf(self): import random @@ -463,7 +474,7 @@ class RandomUDTF: def eval(self, a: int): yield a + int(random.random()), - random_udtf = udtf(RandomUDTF, returnType="x: int").asNondeterministic() + random_udtf = udtf(RandomUDTF, returnType="x: int") assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)]) self.spark.udtf.register("random_udtf", random_udtf) assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), [Row(x=1)]) diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index fa0bf548bd00a..833a2b9289485 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -94,7 +94,7 @@ def _create_py_udtf( cls: Type, returnType: Optional[Union[StructType, str]], name: Optional[str] = None, - deterministic: bool = True, + deterministic: bool = False, useArrow: Optional[bool] = None, ) -> "UserDefinedTableFunction": """Create a regular or an Arrow-optimized Python UDTF.""" @@ -183,7 +183,7 @@ def __init__( returnType: Optional[Union[StructType, str]], name: Optional[str] = None, evalType: int = PythonEvalType.SQL_TABLE_UDF, - deterministic: bool = True, + deterministic: bool = False, ): _validate_udtf_handler(func, returnType) @@ -285,13 +285,13 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFram jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, jcols)) return DataFrame(jPythonUDTF, spark) - def asNondeterministic(self) -> "UserDefinedTableFunction": + def asDeterministic(self) -> "UserDefinedTableFunction": """ - Updates UserDefinedTableFunction to nondeterministic. + Updates UserDefinedTableFunction to deterministic. """ # Explicitly clean the cache to create a JVM UDTF instance. self._judtf_placeholder = None - self.deterministic = False + self.deterministic = True return self diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c7346809f3fd7..a86a60527086e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1038,10 +1038,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // A lateral join with a multi-row outer query and a non-deterministic lateral subquery // cannot be decorrelated. Otherwise it may produce incorrect results. if (!expr.deterministic && !join.left.maxRows.exists(_ <= 1)) { - expr.failAnalysis( - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "NON_DETERMINISTIC_LATERAL_SUBQUERIES", - messageParameters = Map("treeNode" -> planToString(plan))) + cleanQueryInScalarSubquery(join.right.plan) match { + // Python UDTFs are by default non-deterministic. They are constructed as a + // OneRowRelation subquery and can be rewritten by the optimizer without + // any decorrelation. + case Generate(_: PythonUDTF, _, _, _, _, _: OneRowRelation) + if SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) => // Ok + case _ => + expr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "NON_DETERMINISTIC_LATERAL_SUBQUERIES", + messageParameters = Map("treeNode" -> planToString(plan))) + } } // Check if the lateral join's join condition is deterministic. if (join.condition.exists(!_.deterministic)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala index 3431c9327f1d5..3955142166831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala @@ -44,6 +44,9 @@ object PullOutNondeterministic extends Rule[LogicalPlan] { // and we want to retain them inside the aggregate functions. case m: CollectMetrics => m + // Skip PythonUDTF as it will be planned as its own dedicated logical and physical node. + case g @ Generate(_: PythonUDTF, _, _, _, _, _) => g + // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 10a821c1d1ed7..62bf1a676626c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -393,7 +393,7 @@ object IntegratedUDFTestUtils extends SQLHelper { pythonScript: String, returnType: StructType, evalType: Int = PythonEvalType.SQL_TABLE_UDF, - deterministic: Boolean = true): UserDefinedPythonTableFunction = { + deterministic: Boolean = false): UserDefinedPythonTableFunction = { UserDefinedPythonTableFunction( name = name, func = SimplePythonFunction( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index 43f61a7c61e8a..8abcb0a6ce15e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTes import org.apache.spark.sql.catalyst.expressions.{Add, Alias, FunctionTableSubqueryArgumentExpression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias} import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -110,6 +111,18 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { } } + test("non-deterministic UDTF should pass check analysis") { + assume(shouldTestPythonUDFs) + withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "true") { + spark.udtf.registerPython("testUDTF", pythonUDTF) + withTempView("t") { + Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t") + val df = sql("SELECT f.* FROM t, LATERAL testUDTF(a, b) f") + df.queryExecution.assertAnalyzed() + } + } + } + test("SPARK-44503: Specify PARTITION BY and ORDER BY for TABLE arguments") { // Positive tests assume(shouldTestPythonUDFs)