Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41631][SQL] Support implicit lateral column alias resolution on Aggregate #39040

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
04959c2
refactor analyzer adding a new object
anchovYu Nov 23, 2022
6f44c85
lca code
anchovYu Nov 23, 2022
725e5ac
add tests, refine logic
anchovYu Nov 28, 2022
660e1d2
move lca rule to a new file
anchovYu Nov 28, 2022
fd06094
rename conf
anchovYu Nov 28, 2022
7d4f80f
test failure
anchovYu Nov 29, 2022
b9704d5
small fix
anchovYu Nov 29, 2022
777f13a
temp commit, still in implementation
anchovYu Nov 29, 2022
09480ea
a temporary solution, but still fail certain cases
anchovYu Nov 30, 2022
c972738
working solution, needs some refinement
anchovYu Dec 1, 2022
97ee293
Merge remote-tracking branch 'apache/master' into SPARK-27561-refactor
anchovYu Dec 1, 2022
5785943
make changes to accomodate the recent refactor
anchovYu Dec 2, 2022
757cffb
introduce leaf exp in Project as well
anchovYu Dec 5, 2022
29de892
handle a corner case
anchovYu Dec 5, 2022
72991c6
add more tests; add check rule
anchovYu Dec 6, 2022
d45fe31
uplift the necessity to resolve expression in second phase; add more …
anchovYu Dec 8, 2022
1f55f73
address comments to add tests for LCA off
anchovYu Dec 8, 2022
f753529
revert the refactor, split LCA into two rules
anchovYu Dec 9, 2022
b9f706f
better refactor
anchovYu Dec 9, 2022
94d5c9e
address comments
anchovYu Dec 9, 2022
d2e75fd
Merge branch 'SPARK-27561-refactor' into SPARK-27561-agg
anchovYu Dec 9, 2022
edde37c
basic version passing all tests
anchovYu Dec 9, 2022
fb7b18c
update the logic, add and refactor tests
anchovYu Dec 12, 2022
3698cff
update comments
anchovYu Dec 13, 2022
e700d6a
add a corner case comment
anchovYu Dec 13, 2022
8d20986
address comments
anchovYu Dec 13, 2022
d952aa7
Merge branch 'SPARK-27561-refactor' into SPARK-27561-agg
anchovYu Dec 13, 2022
44d5a3d
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 13, 2022
ccebc1c
revert some changes
anchovYu Dec 13, 2022
5540b70
fix few todos
anchovYu Dec 13, 2022
338ba11
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 16, 2022
136a930
fix the failing test
anchovYu Dec 16, 2022
5076ad2
fix the missing_aggregate issue, turn on conf to see failed tests
anchovYu Dec 19, 2022
2f2dee5
remove few todos
anchovYu Dec 19, 2022
3a5509a
better fix to maintain aggregate error: only lift up in certain cases
anchovYu Dec 20, 2022
a23debb
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 20, 2022
b200da0
typo
anchovYu Dec 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,11 @@
"The target JDBC server does not support transactions and can only support ALTER TABLE with a single action."
]
},
"LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : {
"message" : [
"Referencing a lateral column alias <lca> in the aggregate function <aggFunc>."
]
},
"LATERAL_JOIN_USING" : {
"message" : [
"JOIN USING with LATERAL correlation."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1818,7 +1818,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val aliases = aliasMap.get(u.nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n)
throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n)
case n if n == 1 && aliases.head.alias.resolved =>
// Only resolved alias can be the lateral column alias
// The lateral alias can be a struct and have nested field, need to construct
Expand All @@ -1838,7 +1838,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val aliases = aliasMap.get(nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n)
throw QueryCompilationErrors.ambiguousLateralColumnAliasError(nameParts, n)
case n if n == 1 && aliases.head.alias.resolved =>
resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o)
case _ => o
Expand Down Expand Up @@ -1869,6 +1869,30 @@ class Analyzer(override val catalogManager: CatalogManager)
wrapLCARef(e, p, aliasMap)
}
p.copy(projectList = newProjectList)

// Implementation notes:
// In Aggregate, introducing and wrapping this resolved leaf expression
// LateralColumnAliasReference is especially needed because it needs an accurate condition
// to trigger adding a Project above and extracting and pushing down aggregate functions
// or grouping expressions. Such operation can only be done once. With this
// LateralColumnAliasReference, that condition can simply be when the whole Aggregate is
// resolved. Otherwise, it can't tell if all aggregate functions are created and
// resolved so that it can start the extraction, because the lateral alias reference is
// unresolved and can be the argument to functions, blocking the resolution of functions.
case agg @ Aggregate(_, aggExprs, _) if agg.childrenResolved
&& !ResolveReferences.containsStar(aggExprs)
&& aggExprs.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) =>

var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]())
val newAggExprs = aggExprs.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaWrapped = wrapLCARef(a, agg, aliasMap).asInstanceOf[Alias]
aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap)
lcaWrapped
case (e, _) =>
wrapLCARef(e, agg, aliasMap)
}
agg.copy(aggregateExpressions = newAggExprs)
}
}
}
Expand Down Expand Up @@ -4248,3 +4272,4 @@ object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf

/**
Expand All @@ -31,30 +34,54 @@ import org.apache.spark.sql.internal.SQLConf
* Plan-wise, it handles two types of operators: Project and Aggregate.
* - in Project, pushing down the referenced lateral alias into a newly created Project, resolve
* the attributes referencing these aliases
* - in Aggregate TODO.
* - in Aggregate, inserting the Project node above and falling back to the resolution of Project.
*
* The whole process is generally divided into two phases:
* 1) recognize resolved lateral alias, wrap the attributes referencing them with
* [[LateralColumnAliasReference]]
* 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]].
* For Project, it further resolves the attributes and push down the referenced lateral aliases.
* For Aggregate, TODO
* 2) when the whole operator is resolved,
* For Project, it unwrap [[LateralColumnAliasReference]], further resolves the attributes and
* push down the referenced lateral aliases.
* For Aggregate, it goes through the whole aggregation list, extracts the aggregation
* expressions and grouping expressions to keep them in this Aggregate node, and add a Project
* above with the original output. It doesn't do anything on [[LateralColumnAliasReference]], but
* completely leave it to the Project in the future turns of this rule.
*
* Example for Project:
* ** Example for Project:
* Before rewrite:
* Project [age AS a, 'a + 1]
* +- Child
*
* After phase 1:
* Project [age AS a, lateralalias(a) + 1]
* Project [age AS a, lca(a) + 1]
* +- Child
*
* After phase 2:
* Project [a, a + 1]
* +- Project [child output, age AS a]
* +- Child
*
* Example for Aggregate TODO
* ** Example for Aggregate:
* Before rewrite:
* Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
* After phase 1:
* Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
* After phase 2:
* Project [dept#14 AS a#12, lca(a) + 1, avg(salary)#26 AS b#13, lca(b) + avg(bonus)#27]
* +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,dept#14]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
* Now the problem falls back to the lateral alias resolution in Project.
* After future rounds of this rule:
* Project [a#12, a#12 + 1, b#13, b#13 + avg(bonus)#27]
* +- Project [dept#14 AS a#12, avg(salary)#26 AS b#13]
* +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,
* dept#14]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
*
* The name resolution priority:
Expand All @@ -75,6 +102,13 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
*/
val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr")

private def assignAlias(expr: Expression): NamedExpression = {
expr match {
case ne: NamedExpression => ne
case e => Alias(e, toPrettySQL(e))()
}
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
Expand Down Expand Up @@ -129,6 +163,45 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
child = Project(innerProjectList.toSeq, child)
)
}

case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) if agg.resolved
&& aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>

val newAggExprs = collection.mutable.Set.empty[NamedExpression]
val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression]
val projectExprs = aggregateExpressions.map { exp =>
exp.transformDown {
case aggExpr: AggregateExpression =>
// Doesn't support referencing a lateral alias in aggregate function
if (aggExpr.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
aggExpr.collectFirst {
case lcaRef: LateralColumnAliasReference =>
throw QueryCompilationErrors.lateralColumnAliasInAggFuncUnsupportedError(
lcaRef.nameParts, aggExpr)
}
}
val ne = expressionMap.getOrElseUpdate(aggExpr.canonicalized, assignAlias(aggExpr))
newAggExprs += ne
gengliangwang marked this conversation as resolved.
Show resolved Hide resolved
ne.toAttribute
case e if groupingExpressions.exists(_.semanticEquals(e)) =>
// TODO one concern here, is condition here be able to match all grouping
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I surprisingly found out that this existing query can't analyze:

select 1 + dept + 10 from $testTable group by dept + 10
-- error: [MISSING_AGGREGATION] The non-aggregating expression "dept" is based on columns which are not participating in the GROUP BY clause

Seems in our checkAnalysis, we don't canonicalize to compare the expressions. It is structured as (1 + dept) + 10, and can't match the grouping expression (dept + 10).

// expressions? For example, Agg [age + 10] [1 + age + 10], when transforming down,
// is it possible that (1 + age) + 10, so that it won't be able to match (age + 10)
// add a test.
val ne = expressionMap.getOrElseUpdate(e.canonicalized, assignAlias(e))
newAggExprs += ne
ne.toAttribute
}.asInstanceOf[NamedExpression]
}
if (newAggExprs.isEmpty) {
agg
} else {
Project(
projectList = projectExprs,
child = agg.copy(aggregateExpressions = newAggExprs.toSeq)
)
}
// TODO withOrigin?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need. All the results from resolveOperatorsUpWithPruning will have withOrigin around it.

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3400,7 +3400,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
}
}

def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = {
def ambiguousLateralColumnAliasError(name: String, numOfMatches: Int): Throwable = {
new AnalysisException(
errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS",
messageParameters = Map(
Expand All @@ -3409,7 +3409,8 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
)
)
}
def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = {

def ambiguousLateralColumnAliasError(nameParts: Seq[String], numOfMatches: Int): Throwable = {
new AnalysisException(
errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS",
messageParameters = Map(
Expand All @@ -3418,4 +3419,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
)
)
}

def lateralColumnAliasInAggFuncUnsupportedError(
lcaNameParts: Seq[String], aggExpr: Expression): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC",
messageParameters = Map(
"lca" -> toSQLId(lcaNameParts),
"aggFunc" -> toSQLExpr(aggExpr)
)
)
}
}
Loading