From 098206357cfb134088ee02b7cf04e4ee66eb0bde Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 16 Aug 2023 12:41:37 -0700 Subject: [PATCH 1/2] change det --- .../docs/source/reference/pyspark.sql/udtf.rst | 2 +- python/pyspark/sql/connect/udtf.py | 6 +++--- python/pyspark/sql/functions.py | 11 +++++------ python/pyspark/sql/tests/test_udtf.py | 13 ++++++++++++- python/pyspark/sql/udtf.py | 10 +++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 16 ++++++++++++---- .../analysis/PullOutNondeterministic.scala | 3 +++ .../spark/sql/IntegratedUDFTestUtils.scala | 2 +- .../sql/execution/python/PythonUDTFSuite.scala | 13 +++++++++++++ 9 files changed, 55 insertions(+), 21 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/udtf.rst b/python/docs/source/reference/pyspark.sql/udtf.rst index c251e101bbb3b..64400a257ee0e 100644 --- a/python/docs/source/reference/pyspark.sql/udtf.rst +++ b/python/docs/source/reference/pyspark.sql/udtf.rst @@ -25,6 +25,6 @@ UDTF .. autosummary:: :toctree: api/ - udtf.UserDefinedTableFunction.asNondeterministic + udtf.UserDefinedTableFunction.asDeterministic udtf.UserDefinedTableFunction.returnType UDTFRegistration.register diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index c8495626292c5..c8ce98441bad7 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -129,7 +129,7 @@ def __init__( returnType: Optional[Union[StructType, str]], name: Optional[str] = None, evalType: int = PythonEvalType.SQL_TABLE_UDF, - deterministic: bool = True, + deterministic: bool = False, ) -> None: _validate_udtf_handler(func, returnType) @@ -175,8 +175,8 @@ def __call__(self, *cols: "ColumnOrName") -> "DataFrame": plan = self._build_common_inline_user_defined_table_function(*cols) return DataFrame.withPlan(plan, session) - def asNondeterministic(self) -> "UserDefinedTableFunction": - self.deterministic = False + def asDeterministic(self) -> "UserDefinedTableFunction": + self.deterministic = True return self diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b8a946e02e48..6f9897cb7a2b5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -15627,14 +15627,13 @@ def udtf( Notes ----- - User-defined table functions (UDTFs) are considered deterministic by default. - Use `asNondeterministic()` to mark a function as non-deterministic. E.g.: + User-defined table functions (UDTFs) are considered non-deterministic by default. + Use `asDeterministic()` to mark a function as deterministic. E.g.: - >>> import random - >>> class RandomUDTF: + >>> class PlusOne: ... def eval(self, a: int): - ... yield a * int(random.random() * 100), - >>> random_udtf = udtf(RandomUDTF, returnType="r: int").asNondeterministic() + ... yield a + 1, + >>> plus_one = udtf(PlusOne, returnType="r: int").asDeterministic() Use "yield" to produce one row for the UDTF result relation as many times as needed. In the context of a lateral join, each such result row will be diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 300067716e9de..c7382d987a3aa 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -456,6 +456,17 @@ def terminate(self): with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() + def test_udtf_determinism(self): + class TestUDTF: + def eval(self, a: int): + yield a, + + func = udtf(TestUDTF, returnType="x: int") + # The UDTF is marked as non-deterministic by default. + self.assertFalse(func.deterministic) + func = func.asDeterministic() + self.assertTrue(func.deterministic) + def test_nondeterministic_udtf(self): import random @@ -463,7 +474,7 @@ class RandomUDTF: def eval(self, a: int): yield a + int(random.random()), - random_udtf = udtf(RandomUDTF, returnType="x: int").asNondeterministic() + random_udtf = udtf(RandomUDTF, returnType="x: int") assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)]) self.spark.udtf.register("random_udtf", random_udtf) assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), [Row(x=1)]) diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 027a2646a4657..6c24d7ca12538 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -95,7 +95,7 @@ def _create_py_udtf( cls: Type, returnType: Optional[Union[StructType, str]], name: Optional[str] = None, - deterministic: bool = True, + deterministic: bool = False, useArrow: Optional[bool] = None, ) -> "UserDefinedTableFunction": """Create a regular or an Arrow-optimized Python UDTF.""" @@ -257,7 +257,7 @@ def __init__( returnType: Optional[Union[StructType, str]], name: Optional[str] = None, evalType: int = PythonEvalType.SQL_TABLE_UDF, - deterministic: bool = True, + deterministic: bool = False, ): _validate_udtf_handler(func, returnType) @@ -349,13 +349,13 @@ def __call__(self, *cols: "ColumnOrName") -> "DataFrame": jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column)) return DataFrame(jPythonUDTF, spark) - def asNondeterministic(self) -> "UserDefinedTableFunction": + def asDeterministic(self) -> "UserDefinedTableFunction": """ - Updates UserDefinedTableFunction to nondeterministic. + Updates UserDefinedTableFunction to deterministic. """ # Explicitly clean the cache to create a JVM UDTF instance. self._judtf_placeholder = None - self.deterministic = False + self.deterministic = True return self diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 0b953fc2b61f5..e7d8e455b4530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1038,10 +1038,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // A lateral join with a multi-row outer query and a non-deterministic lateral subquery // cannot be decorrelated. Otherwise it may produce incorrect results. if (!expr.deterministic && !join.left.maxRows.exists(_ <= 1)) { - expr.failAnalysis( - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "NON_DETERMINISTIC_LATERAL_SUBQUERIES", - messageParameters = Map("treeNode" -> planToString(plan))) + cleanQueryInScalarSubquery(join.right.plan) match { + // Python UDTFs are by default non-deterministic. They are constructed as a + // OneRowRelation subquery and can be rewritten by the optimizer without + // any decorrelation. + case Generate(_: PythonUDTF, _, _, _, _, _: OneRowRelation) + if SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) => // Ok + case _ => + expr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "NON_DETERMINISTIC_LATERAL_SUBQUERIES", + messageParameters = Map("treeNode" -> planToString(plan))) + } } // Check if the lateral join's join condition is deterministic. if (join.condition.exists(!_.deterministic)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala index 3431c9327f1d5..3955142166831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala @@ -44,6 +44,9 @@ object PullOutNondeterministic extends Rule[LogicalPlan] { // and we want to retain them inside the aggregate functions. case m: CollectMetrics => m + // Skip PythonUDTF as it will be planned as its own dedicated logical and physical node. + case g @ Generate(_: PythonUDTF, _, _, _, _, _) => g + // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 10a821c1d1ed7..62bf1a676626c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -393,7 +393,7 @@ object IntegratedUDFTestUtils extends SQLHelper { pythonScript: String, returnType: StructType, evalType: Int = PythonEvalType.SQL_TABLE_UDF, - deterministic: Boolean = true): UserDefinedPythonTableFunction = { + deterministic: Boolean = false): UserDefinedPythonTableFunction = { UserDefinedPythonTableFunction( name = name, func = SimplePythonFunction( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index 43f61a7c61e8a..8abcb0a6ce15e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTes import org.apache.spark.sql.catalyst.expressions.{Add, Alias, FunctionTableSubqueryArgumentExpression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias} import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -110,6 +111,18 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { } } + test("non-deterministic UDTF should pass check analysis") { + assume(shouldTestPythonUDFs) + withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "true") { + spark.udtf.registerPython("testUDTF", pythonUDTF) + withTempView("t") { + Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t") + val df = sql("SELECT f.* FROM t, LATERAL testUDTF(a, b) f") + df.queryExecution.assertAnalyzed() + } + } + } + test("SPARK-44503: Specify PARTITION BY and ORDER BY for TABLE arguments") { // Positive tests assume(shouldTestPythonUDFs) From 6427b01e5787a397b4baf2956030bdec18e02f4a Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 16 Aug 2023 17:14:21 -0700 Subject: [PATCH 2/2] fix --- python/pyspark/sql/connect/udtf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index c8ce98441bad7..fd3867509967d 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -50,7 +50,7 @@ def _create_udtf( returnType: Optional[Union[StructType, str]], name: Optional[str] = None, evalType: int = PythonEvalType.SQL_TABLE_UDF, - deterministic: bool = True, + deterministic: bool = False, ) -> "UserDefinedTableFunction": udtf_obj = UserDefinedTableFunction( cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic @@ -62,7 +62,7 @@ def _create_py_udtf( cls: Type, returnType: Optional[Union[StructType, str]], name: Optional[str] = None, - deterministic: bool = True, + deterministic: bool = False, useArrow: Optional[bool] = None, ) -> "UserDefinedTableFunction": if useArrow is not None: