Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44822][PYTHON][SQL] Make Python UDTFs by default non-deterministic #42519

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/docs/source/reference/pyspark.sql/udtf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ UDTF
.. autosummary::
:toctree: api/

udtf.UserDefinedTableFunction.asNondeterministic
udtf.UserDefinedTableFunction.asDeterministic
udtf.UserDefinedTableFunction.returnType
UDTFRegistration.register
10 changes: 5 additions & 5 deletions python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _create_udtf(
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = True,
deterministic: bool = False,
) -> "UserDefinedTableFunction":
udtf_obj = UserDefinedTableFunction(
cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
Expand All @@ -62,7 +62,7 @@ def _create_py_udtf(
cls: Type,
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
deterministic: bool = True,
deterministic: bool = False,
useArrow: Optional[bool] = None,
) -> "UserDefinedTableFunction":
if useArrow is not None:
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = True,
deterministic: bool = False,
) -> None:
_validate_udtf_handler(func, returnType)

Expand Down Expand Up @@ -175,8 +175,8 @@ def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
plan = self._build_common_inline_user_defined_table_function(*cols)
return DataFrame.withPlan(plan, session)

def asNondeterministic(self) -> "UserDefinedTableFunction":
self.deterministic = False
def asDeterministic(self) -> "UserDefinedTableFunction":
self.deterministic = True
return self


Expand Down
11 changes: 5 additions & 6 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15627,14 +15627,13 @@ def udtf(

Notes
-----
User-defined table functions (UDTFs) are considered deterministic by default.
Use `asNondeterministic()` to mark a function as non-deterministic. E.g.:
User-defined table functions (UDTFs) are considered non-deterministic by default.
Use `asDeterministic()` to mark a function as deterministic. E.g.:

>>> import random
>>> class RandomUDTF:
>>> class PlusOne:
... def eval(self, a: int):
... yield a * int(random.random() * 100),
>>> random_udtf = udtf(RandomUDTF, returnType="r: int").asNondeterministic()
... yield a + 1,
>>> plus_one = udtf(PlusOne, returnType="r: int").asDeterministic()

Use "yield" to produce one row for the UDTF result relation as many times
as needed. In the context of a lateral join, each such result row will be
Expand Down
13 changes: 12 additions & 1 deletion python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,25 @@ def terminate(self):
with self.assertRaisesRegex(PythonException, err_msg):
TestUDTF(lit(1)).show()

def test_udtf_determinism(self):
class TestUDTF:
def eval(self, a: int):
yield a,

func = udtf(TestUDTF, returnType="x: int")
# The UDTF is marked as non-deterministic by default.
self.assertFalse(func.deterministic)
func = func.asDeterministic()
self.assertTrue(func.deterministic)

def test_nondeterministic_udtf(self):
import random

class RandomUDTF:
def eval(self, a: int):
yield a + int(random.random()),

random_udtf = udtf(RandomUDTF, returnType="x: int").asNondeterministic()
random_udtf = udtf(RandomUDTF, returnType="x: int")
assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)])
self.spark.udtf.register("random_udtf", random_udtf)
assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), [Row(x=1)])
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _create_py_udtf(
cls: Type,
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
deterministic: bool = True,
deterministic: bool = False,
useArrow: Optional[bool] = None,
) -> "UserDefinedTableFunction":
"""Create a regular or an Arrow-optimized Python UDTF."""
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
returnType: Optional[Union[StructType, str]],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = True,
deterministic: bool = False,
):
_validate_udtf_handler(func, returnType)

Expand Down Expand Up @@ -349,13 +349,13 @@ def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column))
return DataFrame(jPythonUDTF, spark)

def asNondeterministic(self) -> "UserDefinedTableFunction":
def asDeterministic(self) -> "UserDefinedTableFunction":
"""
Updates UserDefinedTableFunction to nondeterministic.
Updates UserDefinedTableFunction to deterministic.
"""
# Explicitly clean the cache to create a JVM UDTF instance.
self._judtf_placeholder = None
self.deterministic = False
self.deterministic = True
return self


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1038,10 +1038,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
// A lateral join with a multi-row outer query and a non-deterministic lateral subquery
// cannot be decorrelated. Otherwise it may produce incorrect results.
if (!expr.deterministic && !join.left.maxRows.exists(_ <= 1)) {
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"NON_DETERMINISTIC_LATERAL_SUBQUERIES",
messageParameters = Map("treeNode" -> planToString(plan)))
cleanQueryInScalarSubquery(join.right.plan) match {
// Python UDTFs are by default non-deterministic. They are constructed as a
// OneRowRelation subquery and can be rewritten by the optimizer without
// any decorrelation.
case Generate(_: PythonUDTF, _, _, _, _, _: OneRowRelation)
if SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) => // Ok
case _ =>
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"NON_DETERMINISTIC_LATERAL_SUBQUERIES",
messageParameters = Map("treeNode" -> planToString(plan)))
}
}
// Check if the lateral join's join condition is deterministic.
if (join.condition.exists(!_.deterministic)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ object PullOutNondeterministic extends Rule[LogicalPlan] {
// and we want to retain them inside the aggregate functions.
case m: CollectMetrics => m

// Skip PythonUDTF as it will be planned as its own dedicated logical and physical node.
case g @ Generate(_: PythonUDTF, _, _, _, _, _) => g

// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
pythonScript: String,
returnType: StructType,
evalType: Int = PythonEvalType.SQL_TABLE_UDF,
deterministic: Boolean = true): UserDefinedPythonTableFunction = {
deterministic: Boolean = false): UserDefinedPythonTableFunction = {
UserDefinedPythonTableFunction(
name = name,
func = SimplePythonFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTes
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, FunctionTableSubqueryArgumentExpression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -110,6 +111,18 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
}
}

test("non-deterministic UDTF should pass check analysis") {
assume(shouldTestPythonUDFs)
withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "true") {
spark.udtf.registerPython("testUDTF", pythonUDTF)
withTempView("t") {
Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
val df = sql("SELECT f.* FROM t, LATERAL testUDTF(a, b) f")
df.queryExecution.assertAnalyzed()
}
}
}

test("SPARK-44503: Specify PARTITION BY and ORDER BY for TABLE arguments") {
// Positive tests
assume(shouldTestPythonUDFs)
Expand Down