Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41631][SQL] Support implicit lateral column alias resolution on Aggregate #39040

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
04959c2
refactor analyzer adding a new object
anchovYu Nov 23, 2022
6f44c85
lca code
anchovYu Nov 23, 2022
725e5ac
add tests, refine logic
anchovYu Nov 28, 2022
660e1d2
move lca rule to a new file
anchovYu Nov 28, 2022
fd06094
rename conf
anchovYu Nov 28, 2022
7d4f80f
test failure
anchovYu Nov 29, 2022
b9704d5
small fix
anchovYu Nov 29, 2022
777f13a
temp commit, still in implementation
anchovYu Nov 29, 2022
09480ea
a temporary solution, but still fail certain cases
anchovYu Nov 30, 2022
c972738
working solution, needs some refinement
anchovYu Dec 1, 2022
97ee293
Merge remote-tracking branch 'apache/master' into SPARK-27561-refactor
anchovYu Dec 1, 2022
5785943
make changes to accomodate the recent refactor
anchovYu Dec 2, 2022
757cffb
introduce leaf exp in Project as well
anchovYu Dec 5, 2022
29de892
handle a corner case
anchovYu Dec 5, 2022
72991c6
add more tests; add check rule
anchovYu Dec 6, 2022
d45fe31
uplift the necessity to resolve expression in second phase; add more …
anchovYu Dec 8, 2022
1f55f73
address comments to add tests for LCA off
anchovYu Dec 8, 2022
f753529
revert the refactor, split LCA into two rules
anchovYu Dec 9, 2022
b9f706f
better refactor
anchovYu Dec 9, 2022
94d5c9e
address comments
anchovYu Dec 9, 2022
d2e75fd
Merge branch 'SPARK-27561-refactor' into SPARK-27561-agg
anchovYu Dec 9, 2022
edde37c
basic version passing all tests
anchovYu Dec 9, 2022
fb7b18c
update the logic, add and refactor tests
anchovYu Dec 12, 2022
3698cff
update comments
anchovYu Dec 13, 2022
e700d6a
add a corner case comment
anchovYu Dec 13, 2022
8d20986
address comments
anchovYu Dec 13, 2022
d952aa7
Merge branch 'SPARK-27561-refactor' into SPARK-27561-agg
anchovYu Dec 13, 2022
44d5a3d
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 13, 2022
ccebc1c
revert some changes
anchovYu Dec 13, 2022
5540b70
fix few todos
anchovYu Dec 13, 2022
338ba11
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 16, 2022
136a930
fix the failing test
anchovYu Dec 16, 2022
5076ad2
fix the missing_aggregate issue, turn on conf to see failed tests
anchovYu Dec 19, 2022
2f2dee5
remove few todos
anchovYu Dec 19, 2022
3a5509a
better fix to maintain aggregate error: only lift up in certain cases
anchovYu Dec 20, 2022
a23debb
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 20, 2022
b200da0
typo
anchovYu Dec 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,11 @@
"The target JDBC server does not support transactions and can only support ALTER TABLE with a single action."
]
},
"LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : {
"message" : [
"Referencing a lateral column alias <lca> in the aggregate function <aggFunc>."
]
},
"LATERAL_JOIN_USING" : {
"message" : [
"JOIN USING with LATERAL correlation."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Random, Success, Try}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
Expand Down Expand Up @@ -182,6 +183,157 @@ object AnalysisContext {
}
}

object Analyzer extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is even more change than reverting #39054 ...

// support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
(CurrentUser().prettyName, () => CurrentUser(), toPrettySQL),
("user", () => CurrentUser(), toPrettySQL),
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
)

/**
* Literal functions do not require the user to specify braces when calling them
* When an attributes is not resolvable, we try to resolve it as a literal function.
*/
private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = {
if (nameParts.length != 1) return None
val name = nameParts.head
literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map {
case (_, getFuncExpr, getAliasName) =>
val funcExpr = getFuncExpr()
Alias(funcExpr, getAliasName(funcExpr))()
}
}

/**
* Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by
* traversing the input expression in top-down manner. It must be top-down because we need to
* skip over unbound lambda function expression. The lambda expressions are resolved in a
* different place [[ResolveLambdaVariables]].
*
* Example :
* SELECT transform(array(1, 2, 3), (x, i) -> x + i)"
*
* In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]].
*/
private def resolveExpression(
expr: Expression,
resolveColumnByName: Seq[String] => Option[Expression],
getAttrCandidates: () => Seq[Attribute],
resolver: Resolver,
throws: Boolean): Expression = {
def innerResolve(e: Expression, isTopLevel: Boolean): Expression = {
if (e.resolved) return e
e match {
case f: LambdaFunction if !f.bound => f

case GetColumnByOrdinal(ordinal, _) =>
val attrCandidates = getAttrCandidates()
assert(ordinal >= 0 && ordinal < attrCandidates.length)
attrCandidates(ordinal)

case GetViewColumnByNameAndOrdinal(
viewName, colName, ordinal, expectedNumCandidates, viewDDL) =>
val attrCandidates = getAttrCandidates()
val matched = attrCandidates.filter(a => resolver(a.name, colName))
if (matched.length != expectedNumCandidates) {
throw QueryCompilationErrors.incompatibleViewSchemaChange(
viewName, colName, expectedNumCandidates, matched, viewDDL)
}
matched(ordinal)

case u @ UnresolvedAttribute(nameParts) =>
val result = withPosition(u) {
resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map {
// We trim unnecessary alias here. Note that, we cannot trim the alias at top-level,
// as we should resolve `UnresolvedAttribute` to a named expression. The caller side
// can trim the top-level alias if it's safe to do so. Since we will call
// CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe.
case Alias(child, _) if !isTopLevel => child
case other => other
}.getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result

case u @ UnresolvedExtractValue(child, fieldName) =>
val newChild = innerResolve(child, isTopLevel = false)
if (newChild.resolved) {
withOrigin(u.origin) {
ExtractValue(newChild, fieldName, resolver)
}
} else {
u.copy(child = newChild)
}

case _ => e.mapChildren(innerResolve(_, isTopLevel = false))
}
}

try {
innerResolve(expr, isTopLevel = true)
} catch {
case ae: AnalysisException if !throws =>
logDebug(ae.getMessage)
expr
}
}

/**
* Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the
* input plan's output attributes. In order to resolve the nested fields correctly, this function
* makes use of `throws` parameter to control when to raise an AnalysisException.
*
* Example :
* SELECT * FROM t ORDER BY a.b
*
* In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b`
* if there is no such nested field named "b". We should not fail and wait for other rules to
* resolve it if possible.
*/
def resolveExpressionByPlanOutput(
expr: Expression,
plan: LogicalPlan,
resolver: Resolver,
throws: Boolean = false): Expression = {
resolveExpression(
expr,
resolveColumnByName = nameParts => {
plan.resolve(nameParts, resolver)
},
getAttrCandidates = () => plan.output,
resolver = resolver,
throws = throws)
}

/**
* Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the
* input plan's children output attributes.
*
* @param e The expression need to be resolved.
* @param q The LogicalPlan whose children are used to resolve expression's attribute.
* @return resolved Expression.
*/
def resolveExpressionByPlanChildren(
e: Expression,
q: LogicalPlan,
resolver: Resolver): Expression = {
resolveExpression(
e,
resolveColumnByName = nameParts => {
q.resolveChildren(nameParts, resolver)
},
getAttrCandidates = () => {
assert(q.children.length == 1)
q.children.head.output
},
resolver = resolver,
throws = true)
}
}

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
Expand Down Expand Up @@ -230,6 +382,18 @@ class Analyzer(override val catalogManager: CatalogManager)

def resolver: Resolver = conf.resolver

private def resolveExpressionByPlanOutput(
expr: Expression,
plan: LogicalPlan,
throws: Boolean = false): Expression = {
Analyzer.resolveExpressionByPlanOutput(expr, plan, resolver, throws)
}
private def resolveExpressionByPlanChildren(
e: Expression,
q: LogicalPlan): Expression = {
Analyzer.resolveExpressionByPlanChildren(e, q, resolver)
}

/**
* If the plan cannot be resolved within maxIterations, analyzer will throw exception to inform
* user to increase the value of SQLConf.ANALYZER_MAX_ITERATIONS.
Expand Down Expand Up @@ -1767,150 +1931,6 @@ class Analyzer(override val catalogManager: CatalogManager)
exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer]))
}

// support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
(CurrentUser().prettyName, () => CurrentUser(), toPrettySQL),
("user", () => CurrentUser(), toPrettySQL),
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
)

/**
* Literal functions do not require the user to specify braces when calling them
* When an attributes is not resolvable, we try to resolve it as a literal function.
*/
private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = {
if (nameParts.length != 1) return None
val name = nameParts.head
literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map {
case (_, getFuncExpr, getAliasName) =>
val funcExpr = getFuncExpr()
Alias(funcExpr, getAliasName(funcExpr))()
}
}

/**
* Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by
* traversing the input expression in top-down manner. It must be top-down because we need to
* skip over unbound lambda function expression. The lambda expressions are resolved in a
* different place [[ResolveLambdaVariables]].
*
* Example :
* SELECT transform(array(1, 2, 3), (x, i) -> x + i)"
*
* In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]].
*/
private def resolveExpression(
expr: Expression,
resolveColumnByName: Seq[String] => Option[Expression],
getAttrCandidates: () => Seq[Attribute],
throws: Boolean): Expression = {
def innerResolve(e: Expression, isTopLevel: Boolean): Expression = {
if (e.resolved) return e
e match {
case f: LambdaFunction if !f.bound => f

case GetColumnByOrdinal(ordinal, _) =>
val attrCandidates = getAttrCandidates()
assert(ordinal >= 0 && ordinal < attrCandidates.length)
attrCandidates(ordinal)

case GetViewColumnByNameAndOrdinal(
viewName, colName, ordinal, expectedNumCandidates, viewDDL) =>
val attrCandidates = getAttrCandidates()
val matched = attrCandidates.filter(a => resolver(a.name, colName))
if (matched.length != expectedNumCandidates) {
throw QueryCompilationErrors.incompatibleViewSchemaChange(
viewName, colName, expectedNumCandidates, matched, viewDDL)
}
matched(ordinal)

case u @ UnresolvedAttribute(nameParts) =>
val result = withPosition(u) {
resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map {
// We trim unnecessary alias here. Note that, we cannot trim the alias at top-level,
// as we should resolve `UnresolvedAttribute` to a named expression. The caller side
// can trim the top-level alias if it's safe to do so. Since we will call
// CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe.
case Alias(child, _) if !isTopLevel => child
case other => other
}.getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result

case u @ UnresolvedExtractValue(child, fieldName) =>
val newChild = innerResolve(child, isTopLevel = false)
if (newChild.resolved) {
withOrigin(u.origin) {
ExtractValue(newChild, fieldName, resolver)
}
} else {
u.copy(child = newChild)
}

case _ => e.mapChildren(innerResolve(_, isTopLevel = false))
}
}

try {
innerResolve(expr, isTopLevel = true)
} catch {
case ae: AnalysisException if !throws =>
logDebug(ae.getMessage)
expr
}
}

/**
* Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the
* input plan's output attributes. In order to resolve the nested fields correctly, this function
* makes use of `throws` parameter to control when to raise an AnalysisException.
*
* Example :
* SELECT * FROM t ORDER BY a.b
*
* In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b`
* if there is no such nested field named "b". We should not fail and wait for other rules to
* resolve it if possible.
*/
def resolveExpressionByPlanOutput(
expr: Expression,
plan: LogicalPlan,
throws: Boolean = false): Expression = {
resolveExpression(
expr,
resolveColumnByName = nameParts => {
plan.resolve(nameParts, resolver)
},
getAttrCandidates = () => plan.output,
throws = throws)
}

/**
* Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the
* input plan's children output attributes.
*
* @param e The expression need to be resolved.
* @param q The LogicalPlan whose children are used to resolve expression's attribute.
* @return resolved Expression.
*/
def resolveExpressionByPlanChildren(
e: Expression,
q: LogicalPlan): Expression = {
resolveExpression(
e,
resolveColumnByName = nameParts => {
q.resolveChildren(nameParts, resolver)
},
getAttrCandidates = () => {
assert(q.children.length == 1)
q.children.head.output
},
throws = true)
}

/**
* In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
* clauses. This rule is to convert ordinal positions to the corresponding expressions in the
Expand Down Expand Up @@ -4137,3 +4157,4 @@ object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
}
}

Loading