Skip to content

Commit

Permalink
Refactor ResolveFunctions analyzer rule to delay making lateral join …
Browse files Browse the repository at this point in the history
…when table arguments are used
  • Loading branch information
ueshin committed Sep 22, 2023
1 parent fdedec1 commit 6051a5e
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 124 deletions.
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3440,6 +3440,11 @@
"message" : [
"IN/EXISTS predicate subqueries can only be used in filters, joins, aggregations, window functions, projections, and UPDATE/MERGE/DELETE commands<treeNode>."
]
},
"UNSUPPORTED_TABLE_ARGUMENT" : {
"message" : [
"Table arguments are used in a function where they are not supported<treeNode>."
]
}
},
"sqlState" : "0A000"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ Correlated scalar subqueries can only be used in filters, aggregations, projecti

IN/EXISTS predicate subqueries can only be used in filters, joins, aggregations, window functions, projections, and UPDATE/MERGE/DELETE commands`<treeNode>`.

## UNSUPPORTED_TABLE_ARGUMENT

Table arguments are used in a function where they are not supported`<treeNode>`.


Original file line number Diff line number Diff line change
Expand Up @@ -2080,7 +2080,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
withPosition(u) {
try {
val resolvedTvf = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse {
val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse {
val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name)
if (CatalogV2Util.isSessionCatalog(catalog)) {
v1SessionCatalog.resolvePersistentTableFunction(
Expand All @@ -2090,93 +2090,19 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
catalog, "table-valued functions")
}
}
// Resolve Python UDTF calls if needed.
val resolvedFunc = resolvedTvf match {
case g @ Generate(u: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) =>
val analyzeResult: PythonUDTFAnalyzeResult =
u.resolveElementMetadata(u.func, u.children)
g.copy(generator =
PythonUDTF(u.name, u.func, analyzeResult.schema, u.children,
u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes,
analyzeResult = Some(analyzeResult)))
case other =>
other
}
val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan]
val functionTableSubqueryArgs =
mutable.ArrayBuffer.empty[FunctionTableSubqueryArgumentExpression]
val tvf = resolvedFunc.transformAllExpressionsWithPruning(
_.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), ruleId) {
resolvedFunc.transformAllExpressionsWithPruning(
_.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) {
case t: FunctionTableSubqueryArgumentExpression =>
val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
val (
pythonUDTFName: String,
pythonUDTFAnalyzeResult: Option[PythonUDTFAnalyzeResult]) =
resolvedFunc match {
case Generate(p: PythonUDTF, _, _, _, _, _) =>
(p.name,
p.analyzeResult)
case _ =>
assert(!t.hasRepartitioning,
"Cannot evaluate the table-valued function call because it included the " +
"PARTITION BY clause, but only Python table functions support this " +
"clause")
("", None)
}
// Check if this is a call to a Python user-defined table function whose polymorphic
// 'analyze' method returned metadata indicated requested partitioning and/or
// ordering properties of the input relation. In that event, make sure that the UDTF
// call did not include any explicit PARTITION BY and/or ORDER BY clauses for the
// corresponding TABLE argument, and then update the TABLE argument representation
// to apply the requested partitioning and/or ordering.
pythonUDTFAnalyzeResult.map { analyzeResult =>
val newTableArgument: FunctionTableSubqueryArgumentExpression =
analyzeResult.applyToTableArgument(pythonUDTFName, t)
tableArgs.append(SubqueryAlias(alias, newTableArgument.evaluable))
functionTableSubqueryArgs.append(newTableArgument)
}.getOrElse {
tableArgs.append(SubqueryAlias(alias, t.evaluable))
functionTableSubqueryArgs.append(t)
resolvedFunc match {
case Generate(_: PythonUDTF, _, _, _, _, _) =>
case Generate(_: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) =>
case _ =>
assert(!t.hasRepartitioning,
"Cannot evaluate the table-valued function call because it included the " +
"PARTITION BY clause, but only Python table functions support this " +
"clause")
}
UnresolvedAttribute(Seq(alias, "c"))
}
if (tableArgs.nonEmpty) {
if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) {
throw QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError(
tableArgs.size)
}
val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
// Propagate the column indexes for TABLE arguments to the PythonUDTF instance.
def assignUDTFPartitionColumnIndexes(
fn: PythonUDTFPartitionColumnIndexes => LogicalPlan): Option[LogicalPlan] = {
val indexes: Seq[Int] = functionTableSubqueryArgs.headOption
.map(_.partitioningExpressionIndexes).getOrElse(Seq.empty)
if (indexes.nonEmpty) {
Some(fn(PythonUDTFPartitionColumnIndexes(indexes)))
} else {
None
}
}
val tvfWithTableColumnIndexes: LogicalPlan = tvf match {
case g@Generate(p: PythonUDTF, _, _, _, _, _) =>
assignUDTFPartitionColumnIndexes(
i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
.getOrElse(g)
case g@Generate(p: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) =>
assignUDTFPartitionColumnIndexes(
i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
.getOrElse(g)
case _ =>
tvf
}
Project(
Seq(UnresolvedStar(Some(Seq(alias)))),
LateralJoin(
tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None)
)
} else {
tvf
t
}
} catch {
case _: NoSuchFunctionException =>
Expand Down Expand Up @@ -2204,6 +2130,46 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
Project(aliases, u.child)

case p: LogicalPlan
if p.resolved && p.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) =>
withPosition(p) {
val tableArgs =
mutable.ArrayBuffer.empty[(FunctionTableSubqueryArgumentExpression, LogicalPlan)]

val tvf = p.transformExpressionsWithPruning(
_.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) {
case t: FunctionTableSubqueryArgumentExpression =>
val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
tableArgs.append((t, SubqueryAlias(alias, t.evaluable)))
UnresolvedAttribute(Seq(alias, "c"))
}

assert(tableArgs.nonEmpty)
if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) {
throw QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError(
tableArgs.size)
}
val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")

// Propagate the column indexes for TABLE arguments to the PythonUDTF instance.
val tvfWithTableColumnIndexes = tvf match {
case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _)
if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty =>
val partitionColumnIndexes =
PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes)
g.copy(generator = pyudtf.copy(
pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes)))
case _ => tvf
}

Project(
Seq(UnresolvedStar(Some(Seq(alias)))),
LateralJoin(
tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None)
)
}

case q: LogicalPlan =>
q.transformExpressionsWithPruning(
_.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR),
Expand Down Expand Up @@ -2249,9 +2215,20 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}

case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) {
val elementSchema = u.resolveElementMetadata(u.func, u.children).schema
PythonUDTF(u.name, u.func, elementSchema, u.children,
u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes)
// Check if this is a call to a Python user-defined table function whose polymorphic
// 'analyze' method returned metadata indicated requested partitioning and/or
// ordering properties of the input relation. In that event, make sure that the UDTF
// call did not include any explicit PARTITION BY and/or ORDER BY clauses for the
// corresponding TABLE argument, and then update the TABLE argument representation
// to apply the requested partitioning and/or ordering.
val analyzeResult = u.resolveElementMetadata(u.func, u.children)
val newChildren = u.children.map {
case t: FunctionTableSubqueryArgumentExpression =>
analyzeResult.applyToTableArgument(u.name, t)
case c => c
}
PythonUDTF(u.name, u.func, analyzeResult.schema, newChildren,
u.evalType, u.udfDeterministic, u.resultId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isLateral = true)

case _: FunctionTableSubqueryArgumentExpression =>
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
messageParameters = Map("treeNode" -> planToString(plan)))

case inSubqueryOrExistsSubquery =>
plan match {
case _: Filter | _: SupportsSubquery | _: Join |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ case class PythonUDTF(
evalType: Int,
udfDeterministic: Boolean,
resultId: ExprId = NamedExpression.newExprId,
pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None,
analyzeResult: Option[PythonUDTFAnalyzeResult] = None)
pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None)
extends UnevaluableGenerator with PythonFuncExpression {

override lazy val canonicalized: Expression = {
Expand Down Expand Up @@ -210,8 +209,7 @@ case class UnresolvedPolymorphicPythonUDTF(
evalType: Int,
udfDeterministic: Boolean,
resolveElementMetadata: (PythonFunction, Seq[Expression]) => PythonUDTFAnalyzeResult,
resultId: ExprId = NamedExpression.newExprId,
pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None)
resultId: ExprId = NamedExpression.newExprId)
extends UnevaluableGenerator with PythonFuncExpression {

override lazy val resolved = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,17 @@ SELECT * FROM explode(collection => TABLE(v))
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
"sqlState" : "0A000",
"messageParameters" : {
"inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"",
"inputType" : "\"STRUCT<id: BIGINT>\"",
"paramIndex" : "1",
"requiredType" : "(\"ARRAY\" or \"MAP\")",
"sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\""
"treeNode" : "'Generate explode(table-argument#x []), false\n: +- SubqueryAlias v\n: +- View (`v`, [id#xL])\n: +- Project [cast(id#xL as bigint) AS id#xL]\n: +- Project [id#xL]\n: +- Range (0, 8, step=1, splits=None)\n+- OneRowRelation\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
"stopIndex" : 45,
"fragment" : "explode(collection => TABLE(v))"
"startIndex" : 37,
"stopIndex" : 44,
"fragment" : "TABLE(v)"
} ]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,17 @@ struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
"sqlState" : "42K09",
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
"sqlState" : "0A000",
"messageParameters" : {
"inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"",
"inputType" : "\"STRUCT<id: BIGINT>\"",
"paramIndex" : "1",
"requiredType" : "(\"ARRAY\" or \"MAP\")",
"sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\""
"treeNode" : "'Generate explode(table-argument#x []), false\n: +- SubqueryAlias v\n: +- View (`v`, [id#xL])\n: +- Project [cast(id#xL as bigint) AS id#xL]\n: +- Project [id#xL]\n: +- Range (0, 8, step=1, splits=None)\n+- OneRowRelation\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
"stopIndex" : 45,
"fragment" : "explode(collection => TABLE(v))"
"startIndex" : 37,
"stopIndex" : 44,
"fragment" : "TABLE(v)"
} ]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
def failure(plan: LogicalPlan): Unit = {
fail(s"Unexpected plan: $plan")
}

spark.udtf.registerPython("testUDTF", pythonUDTF)
sql(
"""
|SELECT * FROM testUDTF(
Expand Down Expand Up @@ -187,19 +189,15 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
withTable("t") {
sql("create table t(col array<int>) using parquet")
val query = "select * from explode(table(t))"
checkError(
checkErrorMatchPVals(
exception = intercept[AnalysisException](sql(query)),
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"explode(outer(__auto_generated_subquery_name_0.c))\"",
"paramIndex" -> "1",
"inputSql" -> "\"outer(__auto_generated_subquery_name_0.c)\"",
"inputType" -> "\"STRUCT<col: ARRAY<INT>>\"",
"requiredType" -> "(\"ARRAY\" or \"MAP\")"),
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
sqlState = None,
parameters = Map("treeNode" -> "(?s).*"),
context = ExpectedContext(
fragment = "explode(table(t))",
start = 14,
stop = 30))
fragment = "table(t)",
start = 22,
stop = 29))
}

spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast)
Expand Down

0 comments on commit 6051a5e

Please sign in to comment.