Skip to content

Commit

Permalink
[SPARK-44822][PYTHON][SQL] Make Python UDTFs by default non-determini…
Browse files Browse the repository at this point in the history
…stic

### What changes were proposed in this pull request?

This PR changes the default determinism of Python UDTFs from `True` to `False`.

### Why are the changes needed?

To prevent potential regressions as many Python UDTFs are often used as non-deterministic UDTFs. Users can always mark them as deterministic.

### Does this PR introduce _any_ user-facing change?

No. Python UDTF is a new feature that is not yet released.

### How was this patch tested?

Existing and new tests

Closes apache#42519 from allisonwang-db/spark-44822-non-det-by-default.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Aug 17, 2023
1 parent 06959e2 commit fce83d4
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 23 deletions.
2 changes: 1 addition & 1 deletion python/docs/source/reference/pyspark.sql/udtf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ UDTF
.. autosummary::
:toctree: api/

udtf.UserDefinedTableFunction.asNondeterministic
udtf.UserDefinedTableFunction.asDeterministic
udtf.UserDefinedTableFunction.returnType
UDTFRegistration.register
10 changes: 5 additions & 5 deletions python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _create_udtf(
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = True,
deterministic: bool = False,
) -> "UserDefinedTableFunction":
udtf_obj = UserDefinedTableFunction(
cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
Expand All @@ -62,7 +62,7 @@ def _create_py_udtf(
cls: Type,
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
deterministic: bool = True,
deterministic: bool = False,
useArrow: Optional[bool] = None,
) -> "UserDefinedTableFunction":
if useArrow is not None:
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = True,
deterministic: bool = False,
) -> None:
_validate_udtf_handler(func, returnType)

Expand Down Expand Up @@ -169,8 +169,8 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFram
plan = self._build_common_inline_user_defined_table_function(*args, **kwargs)
return DataFrame.withPlan(plan, session)

def asNondeterministic(self) -> "UserDefinedTableFunction":
self.deterministic = False
def asDeterministic(self) -> "UserDefinedTableFunction":
self.deterministic = True
return self


Expand Down
11 changes: 5 additions & 6 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15677,14 +15677,13 @@ def udtf(
Notes
-----
User-defined table functions (UDTFs) are considered deterministic by default.
Use `asNondeterministic()` to mark a function as non-deterministic. E.g.:
User-defined table functions (UDTFs) are considered non-deterministic by default.
Use `asDeterministic()` to mark a function as deterministic. E.g.:
>>> import random
>>> class RandomUDTF:
>>> class PlusOne:
... def eval(self, a: int):
... yield a * int(random.random() * 100),
>>> random_udtf = udtf(RandomUDTF, returnType="r: int").asNondeterministic()
... yield a + 1,
>>> plus_one = udtf(PlusOne, returnType="r: int").asDeterministic()
Use "yield" to produce one row for the UDTF result relation as many times
as needed. In the context of a lateral join, each such result row will be
Expand Down
13 changes: 12 additions & 1 deletion python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,25 @@ def terminate(self):
with self.assertRaisesRegex(PythonException, err_msg):
TestUDTF(lit(1)).show()

def test_udtf_determinism(self):
class TestUDTF:
def eval(self, a: int):
yield a,

func = udtf(TestUDTF, returnType="x: int")
# The UDTF is marked as non-deterministic by default.
self.assertFalse(func.deterministic)
func = func.asDeterministic()
self.assertTrue(func.deterministic)

def test_nondeterministic_udtf(self):
import random

class RandomUDTF:
def eval(self, a: int):
yield a + int(random.random()),

random_udtf = udtf(RandomUDTF, returnType="x: int").asNondeterministic()
random_udtf = udtf(RandomUDTF, returnType="x: int")
assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)])
self.spark.udtf.register("random_udtf", random_udtf)
assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), [Row(x=1)])
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _create_py_udtf(
cls: Type,
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
deterministic: bool = True,
deterministic: bool = False,
useArrow: Optional[bool] = None,
) -> "UserDefinedTableFunction":
"""Create a regular or an Arrow-optimized Python UDTF."""
Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = True,
deterministic: bool = False,
):
_validate_udtf_handler(func, returnType)

Expand Down Expand Up @@ -285,13 +285,13 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFram
jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, jcols))
return DataFrame(jPythonUDTF, spark)

def asNondeterministic(self) -> "UserDefinedTableFunction":
def asDeterministic(self) -> "UserDefinedTableFunction":
"""
Updates UserDefinedTableFunction to nondeterministic.
Updates UserDefinedTableFunction to deterministic.
"""
# Explicitly clean the cache to create a JVM UDTF instance.
self._judtf_placeholder = None
self.deterministic = False
self.deterministic = True
return self


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1038,10 +1038,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
// A lateral join with a multi-row outer query and a non-deterministic lateral subquery
// cannot be decorrelated. Otherwise it may produce incorrect results.
if (!expr.deterministic && !join.left.maxRows.exists(_ <= 1)) {
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"NON_DETERMINISTIC_LATERAL_SUBQUERIES",
messageParameters = Map("treeNode" -> planToString(plan)))
cleanQueryInScalarSubquery(join.right.plan) match {
// Python UDTFs are by default non-deterministic. They are constructed as a
// OneRowRelation subquery and can be rewritten by the optimizer without
// any decorrelation.
case Generate(_: PythonUDTF, _, _, _, _, _: OneRowRelation)
if SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) => // Ok
case _ =>
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"NON_DETERMINISTIC_LATERAL_SUBQUERIES",
messageParameters = Map("treeNode" -> planToString(plan)))
}
}
// Check if the lateral join's join condition is deterministic.
if (join.condition.exists(!_.deterministic)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ object PullOutNondeterministic extends Rule[LogicalPlan] {
// and we want to retain them inside the aggregate functions.
case m: CollectMetrics => m

// Skip PythonUDTF as it will be planned as its own dedicated logical and physical node.
case g @ Generate(_: PythonUDTF, _, _, _, _, _) => g

// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
pythonScript: String,
returnType: StructType,
evalType: Int = PythonEvalType.SQL_TABLE_UDF,
deterministic: Boolean = true): UserDefinedPythonTableFunction = {
deterministic: Boolean = false): UserDefinedPythonTableFunction = {
UserDefinedPythonTableFunction(
name = name,
func = SimplePythonFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTes
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, FunctionTableSubqueryArgumentExpression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -110,6 +111,18 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
}
}

test("non-deterministic UDTF should pass check analysis") {
assume(shouldTestPythonUDFs)
withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "true") {
spark.udtf.registerPython("testUDTF", pythonUDTF)
withTempView("t") {
Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
val df = sql("SELECT f.* FROM t, LATERAL testUDTF(a, b) f")
df.queryExecution.assertAnalyzed()
}
}
}

test("SPARK-44503: Specify PARTITION BY and ORDER BY for TABLE arguments") {
// Positive tests
assume(shouldTestPythonUDFs)
Expand Down

0 comments on commit fce83d4

Please sign in to comment.