From 04959c2ac394a2e70b8c61d7abba5354469320da Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Wed, 23 Nov 2022 10:40:07 -0800 Subject: [PATCH 01/31] refactor analyzer adding a new object --- .../sql/catalyst/analysis/Analyzer.scala | 321 ++++++++++-------- 1 file changed, 171 insertions(+), 150 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1daa8ea36bf35..fc149308578c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -182,6 +183,164 @@ object AnalysisContext { } } +object Analyzer extends Logging { + // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id + private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( + (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), + (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), + (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), + ("user", () => CurrentUser(), toPrettySQL), + (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) + ) + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { + if (nameParts.length != 1) return None + val name = nameParts.head + literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { + case (_, getFuncExpr, getAliasName) => + val funcExpr = getFuncExpr() + Alias(funcExpr, getAliasName(funcExpr))() + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. + */ + private def resolveExpression( + expr: Expression, + resolveColumnByName: Seq[String] => Option[Expression], + getAttrCandidates: () => Seq[Attribute], + resolver: Resolver, + throws: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + + case GetColumnByOrdinal(ordinal, _) => + val attrCandidates = getAttrCandidates() + assert(ordinal >= 0 && ordinal < attrCandidates.length) + attrCandidates(ordinal) + + case GetViewColumnByNameAndOrdinal( + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + val attrCandidates = getAttrCandidates() + val matched = attrCandidates.filter(a => resolver(a.name, colName)) + if (matched.length != expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChange( + viewName, colName, expectedNumCandidates, matched, viewDDL) + } + matched(ordinal) + + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + withOrigin(u.origin) { + ExtractValue(newChild, fieldName, resolver) + } + } else { + u.copy(child = newChild) + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + } + + try { + innerResolve(expr, isTopLevel = true) + } catch { + case ae: AnalysisException if !throws => + logDebug(ae.getMessage) + expr + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + resolver: Resolver, + throws: Boolean = false): Expression = { + resolveExpression( + expr, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, resolver) + }, + getAttrCandidates = () => plan.output, + resolver = resolver, + throws = throws) + } + + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan, + resolver: Resolver): Expression = { + resolveExpression( + e, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + resolver = resolver, + throws = true) + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) +} + /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -258,6 +417,17 @@ class Analyzer(override val catalogManager: CatalogManager) TypeCoercion.typeCoercionRules } + private def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false): Expression = { + Analyzer.resolveExpressionByPlanOutput(expr, plan, resolver, throws) + } + + private def resolveExpressionByPlanChildren(e: Expression, q: LogicalPlan): Expression = { + Analyzer.resolveExpressionByPlanChildren(e, q, resolver) + } + override def batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -1386,6 +1556,7 @@ class Analyzer(override val catalogManager: CatalogManager) * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.analysis.Analyzer.containsStar /** Return true if there're conflicting attributes among children's outputs of a plan */ def hasConflictingAttrs(p: LogicalPlan): Boolean = { @@ -1698,12 +1869,6 @@ class Analyzer(override val catalogManager: CatalogManager) }.map(_.asInstanceOf[NamedExpression]) } - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) - private def extractStar(exprs: Seq[Expression]): Seq[Star] = exprs.flatMap(_.collect { case s: Star => s }) @@ -1764,150 +1929,6 @@ class Analyzer(override val catalogManager: CatalogManager) exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } - // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id - private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( - (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), - (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), - (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), - ("user", () => CurrentUser(), toPrettySQL), - (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) - ) - - /** - * Literal functions do not require the user to specify braces when calling them - * When an attributes is not resolvable, we try to resolve it as a literal function. - */ - private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { - if (nameParts.length != 1) return None - val name = nameParts.head - literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { - case (_, getFuncExpr, getAliasName) => - val funcExpr = getFuncExpr() - Alias(funcExpr, getAliasName(funcExpr))() - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by - * traversing the input expression in top-down manner. It must be top-down because we need to - * skip over unbound lambda function expression. The lambda expressions are resolved in a - * different place [[ResolveLambdaVariables]]. - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. - */ - private def resolveExpression( - expr: Expression, - resolveColumnByName: Seq[String] => Option[Expression], - getAttrCandidates: () => Seq[Attribute], - throws: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { - if (e.resolved) return e - e match { - case f: LambdaFunction if !f.bound => f - - case GetColumnByOrdinal(ordinal, _) => - val attrCandidates = getAttrCandidates() - assert(ordinal >= 0 && ordinal < attrCandidates.length) - attrCandidates(ordinal) - - case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => - val attrCandidates = getAttrCandidates() - val matched = attrCandidates.filter(a => resolver(a.name, colName)) - if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChange( - viewName, colName, expectedNumCandidates, matched, viewDDL) - } - matched(ordinal) - - case u @ UnresolvedAttribute(nameParts) => - val result = withPosition(u) { - resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { - // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, - // as we should resolve `UnresolvedAttribute` to a named expression. The caller side - // can trim the top-level alias if it's safe to do so. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. - case Alias(child, _) if !isTopLevel => child - case other => other - }.getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - - case u @ UnresolvedExtractValue(child, fieldName) => - val newChild = innerResolve(child, isTopLevel = false) - if (newChild.resolved) { - withOrigin(u.origin) { - ExtractValue(newChild, fieldName, resolver) - } - } else { - u.copy(child = newChild) - } - - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - } - - try { - innerResolve(expr, isTopLevel = true) - } catch { - case ae: AnalysisException if !throws => - logDebug(ae.getMessage) - expr - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's output attributes. In order to resolve the nested fields correctly, this function - * makes use of `throws` parameter to control when to raise an AnalysisException. - * - * Example : - * SELECT * FROM t ORDER BY a.b - * - * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` - * if there is no such nested field named "b". We should not fail and wait for other rules to - * resolve it if possible. - */ - def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - throws: Boolean = false): Expression = { - resolveExpression( - expr, - resolveColumnByName = nameParts => { - plan.resolve(nameParts, resolver) - }, - getAttrCandidates = () => plan.output, - throws = throws) - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's children output attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @return resolved Expression. - */ - def resolveExpressionByPlanChildren( - e: Expression, - q: LogicalPlan): Expression = { - resolveExpression( - e, - resolveColumnByName = nameParts => { - q.resolveChildren(nameParts, resolver) - }, - getAttrCandidates = () => { - assert(q.children.length == 1) - q.children.head.output - }, - throws = true) - } - /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the From 6f44c8500850b1d122510c66dc7e9b27e6adaf2d Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Wed, 23 Nov 2022 13:25:41 -0800 Subject: [PATCH 02/31] lca code (cherry picked from commit 94adb3f98d701e4c4f19189eb11134949b61bc45) --- .../main/resources/error/error-classes.json | 6 ++ .../sql/catalyst/analysis/Analyzer.scala | 100 +++++++++++++++++- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 10 ++ .../apache/spark/sql/internal/SQLConf.scala | 8 ++ 5 files changed, 124 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 77d155bfc21e4..e279ffc87d21e 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -5,6 +5,12 @@ ], "sqlState" : "42000" }, + "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { + "message" : [ + "Lateral column alias is ambiguous and has matches." + ], + "sqlState" : "42000" + }, "AMBIGUOUS_REFERENCE" : { "message" : [ "Reference is ambiguous, could be: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fc149308578c5..5f595e81d56e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -458,6 +458,7 @@ class Analyzer(override val catalogManager: CatalogManager) AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: + ResolveLateralColumnAlias :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: ResolveNewInstance :: @@ -1551,6 +1552,103 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * Resolve lateral column alias, which references the alias defined previously in the SELECT list, + * - in Project inserting a new Project node with the referenced alias so that it can be + * resolved by other rules + * - in Aggregate TODO. + * + * For Project, it rewrites the Project plan by inserting a newly created Project plan between + * the original Project and its child, and updating the project list of the original Project plan. + * The project list of the new Project plan is the lateral column aliases that are referenced + * in the original project list. These aliases in the original project list are updated to + * attribute references. + * + * Before rewrite: + * Project [age AS a, a + 1] + * +- Child + * + * After rewrite: + * Project [a, a + 1] + * +- Project [age AS a] + * +- Child + */ + object ResolveLateralColumnAlias extends Rule[LogicalPlan] { + private case class AliasEntry(alias: Alias, index: Int) + + private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { + case p @ Project(projectList, child) if p.childrenResolved + && !ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + // TODO: delta + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + var referencedAliases = Seq[AliasEntry]() + def updateAliasMap(a: Alias, idx: Int): Unit = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def searchMatchedLCA(e: Expression): Unit = { + e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case _ => + val referencedAlias = aliases.head + // Only resolved alias can be the lateral column alias + if (referencedAlias.alias.resolved) { + referencedAliases :+= referencedAlias + } + } + u + } + } + projectList.zipWithIndex.foreach { + case (a: Alias, idx) => + // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed + // down. Unresolved alias is added to the map to perform the ambiguous name check. + // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, + // because only resolved alias can be LCA, in the first round the rule application, + // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are + // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, + // it is pushed down. + searchMatchedLCA(a) + updateAliasMap(a, idx) + case (e, _) => + searchMatchedLCA(e) + } + + referencedAliases = referencedAliases.sortBy(_.index) + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = projectList.to[collection.mutable.Seq] + val innerProjectList = + child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { + plan + } else { + rewriteLateralColumnAlias(plan) + } + } + } + /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index f6bef88ab868e..8493895218f79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -66,6 +66,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolvePivot" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveLateralColumnAlias" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 63c912c15a156..eeb1dfc213d94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3402,4 +3402,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { cause = Option(other)) } } + + def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(name), + "n" -> numOfMatches.toString + ) + ) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84d78f365acbc..a5b84660e0581 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4027,6 +4027,14 @@ object SQLConf { .checkValues(ErrorMessageFormat.values.map(_.toString)) .createWithDefault(ErrorMessageFormat.PRETTY.toString) + val LATERAL_COLUMN_ALIAS_ENABLED = + buildConf("spark.sql.lateralColumnAlias.enabled") + .internal() + .doc("Enable lateral column alias in analyzer") + .version("3.4.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * From 725e5ac9df65438f87f2c260ea5507aaf1a1bd2b Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 10:38:29 -0800 Subject: [PATCH 03/31] add tests, refine logic (cherry picked from commit 313b2c98e9513e50d2764b28c447c3a7cd281ebb) --- .../sql/catalyst/analysis/Analyzer.scala | 40 ++++--- .../spark/sql/LateralColumnAliasSuite.scala | 109 ++++++++++++++++++ 2 files changed, 130 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5f595e81d56e2..165bb2ecf4ec2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1558,20 +1558,20 @@ class Analyzer(override val catalogManager: CatalogManager) * resolved by other rules * - in Aggregate TODO. * - * For Project, it rewrites the Project plan by inserting a newly created Project plan between - * the original Project and its child, and updating the project list of the original Project plan. - * The project list of the new Project plan is the lateral column aliases that are referenced - * in the original project list. These aliases in the original project list are updated to - * attribute references. + * For Project, it rewrites by inserting a newly created Project plan between the original Project + * and its child, pushing the referenced lateral column aliases to this new Project, and updating + * the project list of the original Project. * * Before rewrite: - * Project [age AS a, a + 1] + * Project [age AS a, 'a + 1] * +- Child * * After rewrite: - * Project [a, a + 1] - * +- Project [age AS a] + * Project [a, 'a + 1] + * +- Project [child output, age AS a] * +- Child + * + * For Aggregate TODO. */ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { private case class AliasEntry(alias: Alias, index: Int) @@ -1581,14 +1581,14 @@ class Analyzer(override val catalogManager: CatalogManager) case p @ Project(projectList, child) if p.childrenResolved && !ResolveReferences.containsStar(projectList) && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => - // TODO: delta + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - var referencedAliases = Seq[AliasEntry]() - def updateAliasMap(a: Alias, idx: Int): Unit = { + def insertIntoAliasMap(a: Alias, idx: Int): Unit = { val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def searchMatchedLCA(e: Expression): Unit = { + def lookUpLCA(e: Expression): Option[AliasEntry] = { + var matchedLCA: Option[AliasEntry] = None e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => @@ -1600,13 +1600,15 @@ class Analyzer(override val catalogManager: CatalogManager) val referencedAlias = aliases.head // Only resolved alias can be the lateral column alias if (referencedAlias.alias.resolved) { - referencedAliases :+= referencedAlias + matchedLCA = Some(referencedAlias) } } u } + matchedLCA } - projectList.zipWithIndex.foreach { + + val referencedAliases = projectList.zipWithIndex.flatMap { case (a: Alias, idx) => // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed // down. Unresolved alias is added to the map to perform the ambiguous name check. @@ -1615,13 +1617,13 @@ class Analyzer(override val catalogManager: CatalogManager) // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, // it is pushed down. - searchMatchedLCA(a) - updateAliasMap(a, idx) + val matchedLCA = lookUpLCA(a) + insertIntoAliasMap(a, idx) + matchedLCA case (e, _) => - searchMatchedLCA(e) - } + lookUpLCA(e) + }.toSet - referencedAliases = referencedAliases.sortBy(_.index) if (referencedAliases.isEmpty) { p } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala new file mode 100644 index 0000000000000..daf750c39bb1e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { + protected val testTable: String = "employee" + + override def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (dept INTEGER, name String, salary INTEGER, bonus INTEGER) " + + s"using orc") + sql( + s""" + |INSERT INTO $testTable VALUES + | (1, 'amy', 10000, 1000), + | (2, 'alex', 12000, 1200), + | (1, 'cathy', 9000, 1200), + | (2, 'david', 10000, 1300), + | (6, 'jen', 12000, 1200) + |""".stripMargin) + } + + override def afterAll(): Unit = { + try { + sql(s"DROP TABLE IF EXISTS $testTable") + } finally { + super.afterAll() + } + } + + val lcaEnabled: Boolean = true + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED.key -> lcaEnabled.toString) { + testFun + } + } + } + + test("Lateral alias in project") { + checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), + Row(1, 2)) + + checkAnswer( + sql( + s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'"), + Row(20000, 21000)) + checkAnswer( + sql( + s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + + s" where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 23000)) + + // When the lateral alias conflicts with the table column, it should resolved as the table + // column + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'"), + Row(20000, 22000, 11000, 22000)) + + // Corner cases for resolution order + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + } +} From 660e1d231b641c65c979199ed37d57f52db2a3ea Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 10:54:58 -0800 Subject: [PATCH 04/31] move lca rule to a new file --- .../sql/catalyst/analysis/Analyzer.scala | 101 +------------- .../analysis/ResolveLateralColumnAlias.scala | 127 ++++++++++++++++++ .../sql/catalyst/rules/RuleIdCollection.scala | 2 +- 3 files changed, 129 insertions(+), 101 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 165bb2ecf4ec2..95101c0d8130b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -1552,105 +1552,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list, - * - in Project inserting a new Project node with the referenced alias so that it can be - * resolved by other rules - * - in Aggregate TODO. - * - * For Project, it rewrites by inserting a newly created Project plan between the original Project - * and its child, pushing the referenced lateral column aliases to this new Project, and updating - * the project list of the original Project. - * - * Before rewrite: - * Project [age AS a, 'a + 1] - * +- Child - * - * After rewrite: - * Project [a, 'a + 1] - * +- Project [child output, age AS a] - * +- Child - * - * For Aggregate TODO. - */ - object ResolveLateralColumnAlias extends Rule[LogicalPlan] { - private case class AliasEntry(alias: Alias, index: Int) - - private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { - case p @ Project(projectList, child) if p.childrenResolved - && !ResolveReferences.containsStar(projectList) - && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => - - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def insertIntoAliasMap(a: Alias, idx: Int): Unit = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - def lookUpLCA(e: Expression): Option[AliasEntry] = { - var matchedLCA: Option[AliasEntry] = None - e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case _ => - val referencedAlias = aliases.head - // Only resolved alias can be the lateral column alias - if (referencedAlias.alias.resolved) { - matchedLCA = Some(referencedAlias) - } - } - u - } - matchedLCA - } - - val referencedAliases = projectList.zipWithIndex.flatMap { - case (a: Alias, idx) => - // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed - // down. Unresolved alias is added to the map to perform the ambiguous name check. - // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, - // because only resolved alias can be LCA, in the first round the rule application, - // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are - // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, - // it is pushed down. - val matchedLCA = lookUpLCA(a) - insertIntoAliasMap(a, idx) - matchedLCA - case (e, _) => - lookUpLCA(e) - }.toSet - - if (referencedAliases.isEmpty) { - p - } else { - val outerProjectList = projectList.to[collection.mutable.Seq] - val innerProjectList = - child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] - referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => - outerProjectList.update(idx, alias.toAttribute) - innerProjectList += alias - } - p.copy( - projectList = outerProjectList.toSeq, - child = Project(innerProjectList.toSeq, child) - ) - } - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { - plan - } else { - rewriteLateralColumnAlias(plan) - } - } - } - /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala new file mode 100644 index 0000000000000..ea2648ccde553 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf + +/** + * Resolve lateral column alias, which references the alias defined previously in the SELECT list, + * - in Project inserting a new Project node with the referenced alias so that it can be + * resolved by other rules + * - in Aggregate TODO. + * + * For Project, it rewrites by inserting a newly created Project plan between the original Project + * and its child, pushing the referenced lateral column aliases to this new Project, and updating + * the project list of the original Project. + * + * Before rewrite: + * Project [age AS a, 'a + 1] + * +- Child + * + * After rewrite: + * Project [a, 'a + 1] + * +- Project [child output, age AS a] + * +- Child + * + * For Aggregate TODO. + */ +object ResolveLateralColumnAlias extends Rule[LogicalPlan] { + private case class AliasEntry(alias: Alias, index: Int) + def resolver: Resolver = conf.resolver + + private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { + case p @ Project(projectList, child) if p.childrenResolved + && !Analyzer.containsStar(projectList) + && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + def insertIntoAliasMap(a: Alias, idx: Int): Unit = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def lookUpLCA(e: Expression): Option[AliasEntry] = { + var matchedLCA: Option[AliasEntry] = None + e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + Analyzer.resolveExpressionByPlanChildren(u, p, resolver) + .isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case _ => + val referencedAlias = aliases.head + // Only resolved alias can be the lateral column alias + if (referencedAlias.alias.resolved) { + matchedLCA = Some(referencedAlias) + } + } + u + } + matchedLCA + } + + val referencedAliases = projectList.zipWithIndex.flatMap { + case (a: Alias, idx) => + // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed + // down. Unresolved alias is added to the map to perform the ambiguous name check. + // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, + // because only resolved alias can be LCA, in the first round the rule application, + // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are + // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, + // it is pushed down. + val matchedLCA = lookUpLCA(a) + insertIntoAliasMap(a, idx) + matchedLCA + case (e, _) => + lookUpLCA(e) + }.toSet + + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = projectList.to[collection.mutable.Seq] + val innerProjectList = + child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { + plan + } else { + rewriteLateralColumnAlias(plan) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 8493895218f79..032b0e7a08fcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -66,7 +66,6 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolvePivot" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences" :: - "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveLateralColumnAlias" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases" :: @@ -89,6 +88,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" :: "org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" :: "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: + "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAlias" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: From fd0609438d99643d147afad854e8206624437278 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 11:44:35 -0800 Subject: [PATCH 05/31] rename conf --- .../catalyst/analysis/ResolveLateralColumnAlias.scala | 2 +- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 9 ++++++--- .../org/apache/spark/sql/LateralColumnAliasSuite.scala | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index ea2648ccde553..2b435f1c460a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -118,7 +118,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan } else { rewriteLateralColumnAlias(plan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a5b84660e0581..575775a0f5519 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4027,10 +4027,13 @@ object SQLConf { .checkValues(ErrorMessageFormat.values.map(_.toString)) .createWithDefault(ErrorMessageFormat.PRETTY.toString) - val LATERAL_COLUMN_ALIAS_ENABLED = - buildConf("spark.sql.lateralColumnAlias.enabled") + val LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED = + buildConf("spark.sql.lateralColumnAlias.enableImplicitResolution") .internal() - .doc("Enable lateral column alias in analyzer") + .doc("Enable resolving implicit lateral column alias defined in the same SELECT list. For " + + "example, with this conf turned on, for query `SELECT 1 AS a, a + 1` the `a` in `a + 1` " + + "can be resolved as the previously defined `1 AS a`. But note that table column has " + + "higher resolution priority than the lateral column alias.") .version("3.4.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index daf750c39bb1e..f6b2c919b794c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -53,7 +53,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { override protected def test(testName: String, testTags: Tag*)(testFun: => Any) (implicit pos: Position): Unit = { super.test(testName, testTags: _*) { - withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED.key -> lcaEnabled.toString) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { testFun } } From 7d4f80f4c74a77dfceb77b4d86d36cd83d63d9c5 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 16:07:21 -0800 Subject: [PATCH 06/31] test failure --- .../sql/catalyst/analysis/ResolveLateralColumnAlias.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 2b435f1c460a0..e1372664b791e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -102,9 +102,9 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { if (referencedAliases.isEmpty) { p } else { - val outerProjectList = projectList.to[collection.mutable.Seq] + val outerProjectList = collection.mutable.Seq(projectList: _*) val innerProjectList = - child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => outerProjectList.update(idx, alias.toAttribute) innerProjectList += alias From b9704d5428fa2f25de9b6da076c972168ee0477d Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 29 Nov 2022 10:27:43 -0800 Subject: [PATCH 07/31] small fix --- .../sql/catalyst/analysis/ResolveLateralColumnAlias.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index e1372664b791e..5462cee65fd09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -61,8 +61,8 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def lookUpLCA(e: Expression): Option[AliasEntry] = { - var matchedLCA: Option[AliasEntry] = None + def lookUpLCA(e: Expression): Seq[AliasEntry] = { + var matchedLCA: Seq[AliasEntry] = Seq.empty[AliasEntry] e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) @@ -75,7 +75,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { val referencedAlias = aliases.head // Only resolved alias can be the lateral column alias if (referencedAlias.alias.resolved) { - matchedLCA = Some(referencedAlias) + matchedLCA :+= referencedAlias } } u From 777f13a05f58342028585e4b94f4a90743865181 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 29 Nov 2022 10:20:28 -0800 Subject: [PATCH 08/31] temp commit, still in implementation --- .../analysis/ResolveLateralColumnAlias.scala | 73 ++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 5462cee65fd09..2b6637d42e4ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, UNRESOLVED_ATTRIBUTE} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -114,6 +114,75 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { child = Project(innerProjectList.toSeq, child) ) } + + /** + * Implementation notes: + * SELECT dept AS a, count(id) AS b, a, b, a + avg(age), b + avg(age), a + b, b + dept + * GROUP BY dept + * + * Project [a, b, a, b, a_plus_avg_age, b + avg_age, a + b, b + dept] + * +- Aggregate [dept] + * [a, count(id) AS b, a, a + avg(age) AS a_plus_avg_age, avg(age) AS avg_age, dept] + * +- Project [child output, dept AS a] + * + * Careful: Doesn't need to create duplicate avg(age), or grouping dept in the Aggregate + * + * Push down: non-aggregate referenced LCA + * Add to Aggregate: Grouping expression (dept) and aggregate expressions (avg(age)), + * if it is used in push-ups. + * Remove from Aggregate: push-up expressions. + * Push up: reference an aggregate LCA + */ + case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) + if agg.childrenResolved + && groupingExpressions.forall(_.resolved) + && !Analyzer.containsStar(aggregateExpressions) + && aggregateExpressions.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + def insertIntoAliasMap(a: Alias, idx: Int): Unit = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def lookUpLCA(e: Expression): Option[AliasEntry] = { + var matchedLCA: Option[AliasEntry] = None + e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + // First resolve using child output + Analyzer.resolveExpressionByPlanChildren(u, agg, resolver) + .isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case _ => + val referencedAlias = aliases.head + // Only resolved alias can be the lateral column alias + if (referencedAlias.alias.resolved) { + matchedLCA = Some(referencedAlias) + } + } + u + } + matchedLCA + } + val downExps = + collection.mutable.Set(agg.child.output.map(_.asInstanceOf[NamedExpression]): _*) + val upExps = collection.mutable.Seq() + aggregateExpressions.zipWithIndex.foreach { + case (a: Alias, idx) => + val matchedLCA = lookUpLCA(a) + insertIntoAliasMap(a, idx) + if (matchedLCA.isDefined) { + val alias = matchedLCA.get.alias + if (!alias.containsPattern(AGGREGATE_EXPRESSION)) { + downExps += alias + } else { + + } + } + } + agg + } } From 09480ea1069a84ce00b6a7abe8a67aa384827797 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Wed, 30 Nov 2022 10:10:54 -0800 Subject: [PATCH 09/31] a temporary solution, but still fail certain cases --- .../analysis/ResolveLateralColumnAlias.scala | 114 +++++++--------- .../spark/sql/LateralColumnAliasSuite.scala | 127 ++++++++++++++++++ 2 files changed, 177 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 2b6637d42e4ab..f72614c5111cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, UNRESOLVED_ATTRIBUTE} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.MetadataBuilder /** * Resolve lateral column alias, which references the alias defined previously in the SELECT list, @@ -115,74 +117,58 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { ) } - /** - * Implementation notes: - * SELECT dept AS a, count(id) AS b, a, b, a + avg(age), b + avg(age), a + b, b + dept - * GROUP BY dept - * - * Project [a, b, a, b, a_plus_avg_age, b + avg_age, a + b, b + dept] - * +- Aggregate [dept] - * [a, count(id) AS b, a, a + avg(age) AS a_plus_avg_age, avg(age) AS avg_age, dept] - * +- Project [child output, dept AS a] - * - * Careful: Doesn't need to create duplicate avg(age), or grouping dept in the Aggregate - * - * Push down: non-aggregate referenced LCA - * Add to Aggregate: Grouping expression (dept) and aggregate expressions (avg(age)), - * if it is used in push-ups. - * Remove from Aggregate: push-up expressions. - * Push up: reference an aggregate LCA - */ - case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) + case agg @ Aggregate(groupingExpressions, aggregateExpressions, child) if agg.childrenResolved && groupingExpressions.forall(_.resolved) && !Analyzer.containsStar(aggregateExpressions) && aggregateExpressions.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def insertIntoAliasMap(a: Alias, idx: Int): Unit = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - def lookUpLCA(e: Expression): Option[AliasEntry] = { - var matchedLCA: Option[AliasEntry] = None - e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - // First resolve using child output - Analyzer.resolveExpressionByPlanChildren(u, agg, resolver) - .isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case _ => - val referencedAlias = aliases.head - // Only resolved alias can be the lateral column alias - if (referencedAlias.alias.resolved) { - matchedLCA = Some(referencedAlias) - } - } - u - } - matchedLCA - } - val downExps = - collection.mutable.Set(agg.child.output.map(_.asInstanceOf[NamedExpression]): _*) - val upExps = collection.mutable.Seq() - aggregateExpressions.zipWithIndex.foreach { - case (a: Alias, idx) => - val matchedLCA = lookUpLCA(a) - insertIntoAliasMap(a, idx) - if (matchedLCA.isDefined) { - val alias = matchedLCA.get.alias - if (!alias.containsPattern(AGGREGATE_EXPRESSION)) { - downExps += alias - } else { + val newAggNode = agg - } - } + // make sure all aggregate expressions are constructed and resolved, then push them down + val aggFuncCandidates = newAggNode.aggregateExpressions.flatMap { exp => + exp.map { + // TODO: This is problematic. All functions operate on the lca won't be resolved, e.g. + // concat(string(dept_salary_sum), ': dept', string(dept)) + // But without this condition, it may miss certain complex cases like + // SELECT count(bonus), count(salary * 1.5 + 10000 + bonus * 1.0) AS a, a + // when the second count is not resolved to aggregate expression, this rule incorrectly + // applies + case unresolvedFunc: UnresolvedFunction => Some(unresolvedFunc) + case aggExp: AggregateExpression => Some(aggExp) + case _ => None + }.flatten } - agg - + val newAggExprs = collection.mutable.Set.empty[NamedExpression] + if (!aggFuncCandidates.isEmpty && aggFuncCandidates.forall(_.resolved)) { + val upExprs = newAggNode.aggregateExpressions.map { exp => + exp.transformDown { + // TODO: dedup these aggregate expressions + case aggExp: AggregateExpression => + val alias = Alias(aggExp, toPrettySQL(aggExp))( + explicitMetadata = Some(new MetadataBuilder() + .putString("__autoGeneratedAlias", "true") + .build())) + newAggExprs += alias + alias.toAttribute + case e if e.resolved && groupingExpressions.exists(_.semanticEquals(e)) => + // TODO: dedup these grouping expressions + val alias = Alias(e, toPrettySQL(e))( + explicitMetadata = Some(new MetadataBuilder() + .putString("__autoGeneratedAlias", "true") + .build())) + newAggExprs += alias + alias.toAttribute + }.asInstanceOf[NamedExpression] + } + Project( + projectList = upExprs, + child = newAggNode.copy( + aggregateExpressions = newAggExprs.toSeq + ) + ) + } else { + newAggNode + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index f6b2c919b794c..f679cf5f18c97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -106,4 +106,131 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { Row(18000, 18000, 10000) ) } + + test("temp test") { + sql(s"SELECT count(name) AS b, b FROM $testTable GROUP BY dept") + sql(s"SELECT dept AS a, count(name) AS b, a, b FROM $testTable GROUP BY dept") + sql(s"SELECT avg(salary) AS a, count(name) AS b, a, b, a + b FROM $testTable GROUP BY dept") + sql(s"SELECT dept, count(name) AS b, dept + b FROM $testTable GROUP BY dept") + sql(s"SELECT count(bonus), count(salary * 1.5 + 10000 + bonus * 1.0) AS a, a " + + s"FROM $testTable GROUP BY dept") + } + + test("Lateral alias in aggregation") { + // literal as LCA, used in various cases of expressions +// checkAnswer( +// sql( +// s""" +// |SELECT +// | 10000 AS baseline_salary, +// | baseline_salary * 1.5, +// | baseline_salary + dept * 10000, +// | baseline_salary + avg(bonus), +// | avg(baseline_salary * 1.5), +// | avg(baseline_salary * 1.5) + dept * 10000, +// | avg(baseline_salary * 1.5) + avg(bonus) +// |FROM $testTable +// |GROUP BY dept +// |ORDER BY dept +// |""".stripMargin +// ), +// Row(10000, 15000.0, 20000, 11100.0, 15000.0, 25000.0, 16100.0) :: +// Row(10000, 15000.0, 30000, 11250.0, 15000.0, 35000.0, 16250.0) :: +// Row(10000, 15000.0, 70000, 11200.0, 15000.0, 75000.0, 16200.0) :: Nil +// ) + + // grouping attribute as LCA, used in various cases of expressions +// checkAnswer( +// sql( +// s""" +// |SELECT +// | salary + 1000 AS new_salary, +// | new_salary - 1000 AS prev_salary, +// | new_salary - salary, +// | new_salary - avg(salary), +// | avg(new_salary) - 1000, +// | avg(new_salary) - salary, +// | avg(new_salary) - avg(salary), +// | avg(new_salary) - avg(prev_salary) +// |FROM $testTable +// |GROUP BY salary +// |ORDER BY salary +// |""".stripMargin), +// Row(10000, 9000, 1000, 1000.0, 9000.0, 1000.0, 1000.0, 1000.0) :: +// Row(11000, 10000, 1000, 1000.0, 10000.0, 1000.0, 1000.0, 1000.0) :: +// Row(13000, 12000, 1000, 1000.0, 12000.0, 1000.0, 1000.0, 1000.0) :: +// Nil +// ) + + // aggregate expression as LCA, used in various cases of expressions +// checkAnswer( +// sql( +// s""" +// |SELECT +// | sum(salary) AS dept_salary_sum, +// | sum(bonus) AS dept_bonus_sum, +// | dept_salary_sum * 1.5, +// | concat(string(dept_salary_sum), ': dept', string(dept)), +// | dept_salary_sum + sum(bonus), +// | dept_salary_sum + dept_bonus_sum +// |FROM $testTable +// |GROUP BY dept +// |ORDER BY dept +// |""".stripMargin +// ), +// Row(19000, 2200, 28500.0, "19000: dept1", 21200, 21200) :: +// Row(22000, 2500, 33000.0, "22000: dept2", 24500, 24500) :: +// Row(12000, 1200, 18000.0, "12000: dept6", 13200, 13200) :: +// Nil +// ) + checkAnswer( + sql(s"SELECT sum(salary) AS s, s + sum(bonus) AS total FROM $testTable"), + Row(53000, 58900) + ) + + // Doesn't support nested aggregate expressions + // TODO: add error class and use CheckError + intercept[AnalysisException] { + sql(s"SELECT sum(salary) AS a, avg(a) FROM $testTable") + } + + // chaining + checkAnswer( + sql( + s""" + |SELECT + | dept, + | sum(salary) AS salary_sum, + | salary_sum + sum(bonus) AS salary_total, + | salary_total * 1.5 AS new_total, + | new_total - salary_sum + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin), + Row(1, 19000, 21200, 31800.0, 12800.0) :: + Row(2, 22000, 24500, 36750.0, 14750.0) :: + Row(6, 12000, 13200, 19800.0, 7800.0) :: Nil + ) + + // conflict names with table columns + checkAnswer( + sql( + s""" + |SELECT + | sum(salary) AS salary, + | sum(bonus) AS bonus, + | avg(salary) AS avg_s, + | avg(salary + bonus) AS avg_t, + | avg_s + avg_t + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin), + Row(19000, 2200, 9500.0, 10600.0, 20100.0) :: + Row(22000, 2500, 11000.0, 12250.0, 23250.0) :: + Row(12000, 1200, 12000.0, 13200.0, 25200.0) :: + Nil) + } + } From c97273889614e703b0da6c007837f3d79017c41c Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 1 Dec 2022 13:40:08 -0800 Subject: [PATCH 10/31] working solution, needs some refinement --- .../main/resources/error/error-classes.json | 5 + .../sql/catalyst/analysis/Analyzer.scala | 129 +++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../analysis/ResolveLateralColumnAlias.scala | 169 +++++++--- .../expressions/namedExpressions.scala | 29 ++ .../sql/catalyst/rules/RuleIdCollection.scala | 2 +- .../sql/catalyst/trees/TreePatterns.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 13 +- .../spark/sql/LateralColumnAliasSuite.scala | 316 ++++++++++++++---- .../org/apache/spark/sql/QueryTest.scala | 2 +- 10 files changed, 482 insertions(+), 185 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index e279ffc87d21e..57a69bf72684d 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1227,6 +1227,11 @@ "The target JDBC server does not support transactions and can only support ALTER TABLE with a single action." ] }, + "LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : { + "message" : [ + "Referencing a lateral column alias in the aggregate function ." + ] + }, "LATERAL_JOIN_USING" : { "message" : [ "JOIN USING with LATERAL correlation." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 95101c0d8130b..415a89ae0ed90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -637,70 +637,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * Replaces [[UnresolvedAlias]]s with concrete aliases. - */ - object ResolveAliases extends Rule[LogicalPlan] { - private def assignAliases(exprs: Seq[NamedExpression]) = { - def extractOnly(e: Expression): Boolean = e match { - case _: ExtractValue => e.children.forall(extractOnly) - case _: Literal => true - case _: Attribute => true - case _ => false - } - def metaForAutoGeneratedAlias = { - new MetadataBuilder() - .putString("__autoGeneratedAlias", "true") - .build() - } - exprs.map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) { - case u @ UnresolvedAlias(child, optGenAliasFunc) => - child match { - case ne: NamedExpression => ne - case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) - case e if !e.resolved => u - case g: Generator => MultiAlias(g, Nil) - case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() - case e: ExtractValue => - if (extractOnly(e)) { - Alias(e, toPrettySQL(e))() - } else { - Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) - } - case e if optGenAliasFunc.isDefined => - Alias(child, optGenAliasFunc.get.apply(e))() - case l: Literal => Alias(l, toPrettySQL(l))() - case e => - Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) - } - } - ).asInstanceOf[Seq[NamedExpression]] - } - - private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = - exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) - - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsPattern(UNRESOLVED_ALIAS), ruleId) { - case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => - Aggregate(groups, assignAliases(aggs), child) - - case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) - if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => - Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) - - case up: Unpivot if up.child.resolved && - (up.ids.exists(hasUnresolvedAlias) || up.values.exists(_.exists(hasUnresolvedAlias))) => - up.copy(ids = up.ids.map(assignAliases), values = up.values.map(_.map(assignAliases))) - - case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => - Project(assignAliases(projectList), child) - - case c: CollectMetrics if c.child.resolved && hasUnresolvedAlias(c.metrics) => - c.copy(metrics = assignAliases(c.metrics)) - } - } - object ResolveGroupingAnalytics extends Rule[LogicalPlan] { private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.exists (g => g.isInstanceOf[Grouping] || g.isInstanceOf[GroupingID]) @@ -4147,3 +4083,68 @@ object RemoveTempResolvedColumn extends Rule[LogicalPlan] { CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts)) } } + +/** + * Replaces [[UnresolvedAlias]]s with concrete aliases. + */ +object ResolveAliases extends Rule[LogicalPlan] { + def metaForAutoGeneratedAlias: Metadata = { + new MetadataBuilder() + .putString("__autoGeneratedAlias", "true") + .build() + } + + def assignAliases(exprs: Seq[NamedExpression]): Seq[NamedExpression] = { + def extractOnly(e: Expression): Boolean = e match { + case _: ExtractValue => e.children.forall(extractOnly) + case _: Literal => true + case _: Attribute => true + case _ => false + } + exprs.map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) { + case u @ UnresolvedAlias(child, optGenAliasFunc) => + child match { + case ne: NamedExpression => ne + case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) + case e if !e.resolved => u + case g: Generator => MultiAlias(g, Nil) + case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() + case e: ExtractValue => + if (extractOnly(e)) { + Alias(e, toPrettySQL(e))() + } else { + Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) + } + case e if optGenAliasFunc.isDefined => + Alias(child, optGenAliasFunc.get.apply(e))() + case l: Literal => Alias(l, toPrettySQL(l))() + case e => + Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) + } + } + ).asInstanceOf[Seq[NamedExpression]] + } + + private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = + exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(UNRESOLVED_ALIAS), ruleId) { + case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => + Aggregate(groups, assignAliases(aggs), child) + + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) + + case up: Unpivot if up.child.resolved && + (up.ids.exists(hasUnresolvedAlias) || up.values.exists(_.exists(hasUnresolvedAlias))) => + up.copy(ids = up.ids.map(assignAliases), values = up.values.map(_.map(assignAliases))) + + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => + Project(assignAliases(projectList), child) + + case c: CollectMetrics if c.child.resolved && hasUnresolvedAlias(c.metrics) => + c.copy(metrics = assignAliases(c.metrics)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 86bb48410d275..52257d0c4c7a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -439,6 +439,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB groupingExprs.foreach(checkValidGroupingExprs) aggregateExprs.foreach(checkValidAggregateExpression) + // TODO: if the Aggregate is resolved, it can't contain the LateralColumnAliasReference case CollectMetrics(name, metrics, _) => if (name == null || name.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index f72614c5111cb..8fdc8b1697337 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,21 +17,20 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, LateralColumnAliasReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.MetadataBuilder /** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list, - * - in Project inserting a new Project node with the referenced alias so that it can be + * Resolve lateral column alias, which references the alias defined previously in the SELECT list. + * - in Project, inserting a new Project node below with the referenced alias so that it can be * resolved by other rules - * - in Aggregate TODO. + * - in Aggregate, inserting the Project node above and fall back to the resolution of Project * * For Project, it rewrites by inserting a newly created Project plan between the original Project * and its child, pushing the referenced lateral column aliases to this new Project, and updating @@ -46,14 +45,41 @@ import org.apache.spark.sql.types.MetadataBuilder * +- Project [child output, age AS a] * +- Child * - * For Aggregate TODO. + * For Aggregate, it first wraps the attribute resolved by lateral alias with + * [[LateralColumnAliasReference]]. + * Before wrap (omit some cast or alias): + * Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)] + * +- Child [dept#14,name#15,salary#16,bonus#17] + * + * After wrap: + * Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)] + * +- Child [dept#14,name#15,salary#16,bonus#17] + * + * When the whole Aggregate is resolved, it inserts a [[Project]] above with the aggregation + * expression list, but extracts the [[AggregateExpression]] and grouping expressions in the + * list to the current Aggregate. It restores all the [[LateralColumnAliasReference]] back to + * [[UnresolvedAttribute]]. The problem falls back to the lateral alias resolution in Project. + * + * After restore: + * Project [dept#14 AS a#12, 'a + 1, avg(salary)#26 AS b#13, 'b + avg(bonus)#27] + * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,dept#14] + * +- Child [dept#14,name#15,salary#16,bonus#17] */ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { private case class AliasEntry(alias: Alias, index: Int) def resolver: Resolver = conf.resolver + def unwrapLCAReference(exprs: Seq[NamedExpression]): Seq[NamedExpression] = { + exprs.map { expr => + expr.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case l: LateralColumnAliasReference => + UnresolvedAttribute(l.nameParts) + }.asInstanceOf[NamedExpression] + } + } private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { case p @ Project(projectList, child) if p.childrenResolved && !Analyzer.containsStar(projectList) && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => @@ -72,7 +98,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { val aliases = aliasMap.get(u.nameParts.head).get aliases.size match { case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) case _ => val referencedAlias = aliases.head // Only resolved alias can be the lateral column alias @@ -117,58 +143,91 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { ) } + // wrap LCA + // Implementation notes: + // In Aggregate, introducing and wrapping this resolved leaf expression + // LateralColumnAliasReference is especially needed because it needs an accurate condition to + // trigger adding a Project above and extracting aggregate functions or grouping expressions. + // Such operation can only be done once. With this LateralColumnAliasReference, the condition + // can simply be when the whole Aggregate is resolved. Otherwise, it can't really tell if + // all aggregate functions are created and resolved, because the lateral alias reference + // itself is unresolved. case agg @ Aggregate(groupingExpressions, aggregateExpressions, child) if agg.childrenResolved - && groupingExpressions.forall(_.resolved) && !Analyzer.containsStar(aggregateExpressions) && aggregateExpressions.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => - val newAggNode = agg - // make sure all aggregate expressions are constructed and resolved, then push them down - val aggFuncCandidates = newAggNode.aggregateExpressions.flatMap { exp => - exp.map { - // TODO: This is problematic. All functions operate on the lca won't be resolved, e.g. - // concat(string(dept_salary_sum), ': dept', string(dept)) - // But without this condition, it may miss certain complex cases like - // SELECT count(bonus), count(salary * 1.5 + 10000 + bonus * 1.0) AS a, a - // when the second count is not resolved to aggregate expression, this rule incorrectly - // applies - case unresolvedFunc: UnresolvedFunction => Some(unresolvedFunc) - case aggExp: AggregateExpression => Some(aggExp) - case _ => None - }.flatten + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + def insertIntoAliasMap(a: Alias, idx: Int): Unit = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def wrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + Analyzer.resolveExpressionByPlanChildren(u, agg, resolver) + .isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + LateralColumnAliasReference(aliases.head.alias, u.nameParts) + case _ => + u + } + }.asInstanceOf[NamedExpression] + } + + val newAggExprs = aggregateExpressions.zipWithIndex.map { + case (a: Alias, idx) => + val LCAResolved = wrapLCAReference(a).asInstanceOf[Alias] + // insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions + insertIntoAliasMap(LCAResolved, idx) + LCAResolved + case (e, _) => + wrapLCAReference(e) } + agg.copy(aggregateExpressions = newAggExprs) + + case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) + if agg.resolved + && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + val newAggExprs = collection.mutable.Set.empty[NamedExpression] - if (!aggFuncCandidates.isEmpty && aggFuncCandidates.forall(_.resolved)) { - val upExprs = newAggNode.aggregateExpressions.map { exp => - exp.transformDown { - // TODO: dedup these aggregate expressions - case aggExp: AggregateExpression => - val alias = Alias(aggExp, toPrettySQL(aggExp))( - explicitMetadata = Some(new MetadataBuilder() - .putString("__autoGeneratedAlias", "true") - .build())) - newAggExprs += alias - alias.toAttribute - case e if e.resolved && groupingExpressions.exists(_.semanticEquals(e)) => - // TODO: dedup these grouping expressions - val alias = Alias(e, toPrettySQL(e))( - explicitMetadata = Some(new MetadataBuilder() - .putString("__autoGeneratedAlias", "true") - .build())) - newAggExprs += alias - alias.toAttribute - }.asInstanceOf[NamedExpression] - } - Project( - projectList = upExprs, - child = newAggNode.copy( - aggregateExpressions = newAggExprs.toSeq - ) - ) - } else { - newAggNode - } + val projectExprs = aggregateExpressions.map { exp => + exp.transformDown { + case aggExpr: AggregateExpression => + // TODO (improvement) dedup + // Doesn't support referencing a lateral alias in aggregate function + if (aggExpr.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + aggExpr.collectFirst { + case lcaRef: LateralColumnAliasReference => + throw QueryCompilationErrors.LateralColumnAliasInAggFuncUnsupportedError( + lcaRef.nameParts, aggExpr) + } + } + val ne = ResolveAliases.assignAliases(Seq(UnresolvedAlias(aggExpr))).head + newAggExprs += ne + ne.toAttribute + case e if groupingExpressions.exists(_.semanticEquals(e)) => + // TODO (improvement) dedup + val alias = ResolveAliases.assignAliases(Seq(UnresolvedAlias(e))).head + newAggExprs += alias + alias.toAttribute + }.asInstanceOf[NamedExpression] + } + val unwrappedAggExprs = unwrapLCAReference(newAggExprs.toSeq) + val unwrappedProjectExprs = unwrapLCAReference(projectExprs) + Project( + projectList = unwrappedProjectExprs, + child = agg.copy(aggregateExpressions = unwrappedAggExprs) + ) + // TODO: think about a corner case, when the Alias passed to LateralColumnAliasReference + // contains a LateralColumnAliasReference. Is it safe to do a.toAttribute when resolving + // the LateralColumnAliasReference? + // TODO withOrigin? } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8dd28e9aaae3d..ea465cf7cdb94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -428,6 +428,35 @@ case class OuterReference(e: NamedExpression) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) } +/** + * A placeholder used to hold a referenced that has been temporarily resolved as the reference + * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. + * + * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all + * analysis check, then all [[LateralColumnAliasReference]] should already be removed. + * + * @param a A resolved [[Alias]] that is a lateral column alias referenced by the current attribute + * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to restore back + */ +case class LateralColumnAliasReference(a: Alias, nameParts: Seq[String]) + extends LeafExpression with NamedExpression with Unevaluable { + assert(a.resolved) + override def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = a.exprId + override def qualifier: Seq[String] = a.qualifier + override def toAttribute: Attribute = a.toAttribute + override def newInstance(): NamedExpression = + LateralColumnAliasReference(a.newInstance().asInstanceOf[Alias], nameParts) + + override def nullable: Boolean = a.nullable + override def dataType: DataType = a.dataType + override def prettyName: String = "lateralAliasReference" + override def sql: String = s"$prettyName($name)" + + final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE) +} + object VirtualColumn { // The attribute name used by Hive, which has different result than Spark, deprecated. val hiveGroupingIdName: String = "grouping__id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 032b0e7a08fcd..bc1d3295cbb86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -49,7 +49,6 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$GlobalAggregates" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggAliasInGroupBy" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions" :: - "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveBinaryArithmetic" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" :: @@ -82,6 +81,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: "org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases" :: "org.apache.spark.sql.catalyst.analysis.EliminateUnions" :: + "org.apache.spark.sql.catalyst.analysis.ResolveAliases" :: "org.apache.spark.sql.catalyst.analysis.ResolveDefaultColumns" :: "org.apache.spark.sql.catalyst.analysis.ResolveExpressionsWithNamePlaceholders" :: "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveCoalesceHints" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 8fca9ec60cdff..1a8ad7c7d6213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -58,6 +58,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index eeb1dfc213d94..bfcf9ef697231 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3403,7 +3403,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { } } - def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = { + def ambiguousLateralColumnAliasError(name: String, numOfMatches: Int): Throwable = { new AnalysisException( errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", messageParameters = Map( @@ -3412,4 +3412,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { ) ) } + + def LateralColumnAliasInAggFuncUnsupportedError( + lcaNameParts: Seq[String], aggExpr: Expression): Throwable = { + new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + messageParameters = Map( + "lca" -> toSQLId(lcaNameParts), + "aggFunc" -> toSQLExpr(aggExpr) + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index f679cf5f18c97..bbe593dd7a3d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -50,6 +50,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } val lcaEnabled: Boolean = true + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) (implicit pos: Position): Unit = { super.test(testName, testTags: _*) { @@ -59,7 +60,17 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - test("Lateral alias in project") { + private def checkDuplicatedAliasErrorHelper( + query: String, parameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] {sql(query)}, + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + sqlState = "42000", + parameters = parameters + ) + } + + test("Basic lateral alias in Project") { checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), Row(1, 2)) @@ -107,6 +118,20 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { ) } + test("Conflict names with CTE - Project") { + checkAnswer( + sql( + s""" + |WITH temp_table(x, y) + |AS (SELECT 1, 2) + |SELECT 100 AS x, x + 1 + |FROM temp_table + |""".stripMargin + ), + Row(100, 2) + ) + } + test("temp test") { sql(s"SELECT count(name) AS b, b FROM $testTable GROUP BY dept") sql(s"SELECT dept AS a, count(name) AS b, a, b FROM $testTable GROUP BY dept") @@ -116,73 +141,90 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { s"FROM $testTable GROUP BY dept") } - test("Lateral alias in aggregation") { + test("Basic lateral alias in Aggregate") { + // doesn't support lca used in aggregation functions + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT 10000 AS lca, count(lca) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`lca`", + "aggFunc" -> "\"count(lateralAliasReference(lca))\"" + ) + ) + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT dept AS lca, avg(lca) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`lca`", + "aggFunc" -> "\"avg(lateralAliasReference(lca))\"" + ) + ) + // literal as LCA, used in various cases of expressions -// checkAnswer( -// sql( -// s""" -// |SELECT -// | 10000 AS baseline_salary, -// | baseline_salary * 1.5, -// | baseline_salary + dept * 10000, -// | baseline_salary + avg(bonus), -// | avg(baseline_salary * 1.5), -// | avg(baseline_salary * 1.5) + dept * 10000, -// | avg(baseline_salary * 1.5) + avg(bonus) -// |FROM $testTable -// |GROUP BY dept -// |ORDER BY dept -// |""".stripMargin -// ), -// Row(10000, 15000.0, 20000, 11100.0, 15000.0, 25000.0, 16100.0) :: -// Row(10000, 15000.0, 30000, 11250.0, 15000.0, 35000.0, 16250.0) :: -// Row(10000, 15000.0, 70000, 11200.0, 15000.0, 75000.0, 16200.0) :: Nil -// ) + checkAnswer( + sql( + s""" + |SELECT + | 10000 AS baseline_salary, + | baseline_salary * 1.5, + | baseline_salary + dept * 10000, + | baseline_salary + avg(bonus) + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin + ), + Row(10000, 15000.0, 20000, 11100.0) :: + Row(10000, 15000.0, 30000, 11250.0) :: + Row(10000, 15000.0, 70000, 11200.0) :: Nil + ) // grouping attribute as LCA, used in various cases of expressions -// checkAnswer( -// sql( -// s""" -// |SELECT -// | salary + 1000 AS new_salary, -// | new_salary - 1000 AS prev_salary, -// | new_salary - salary, -// | new_salary - avg(salary), -// | avg(new_salary) - 1000, -// | avg(new_salary) - salary, -// | avg(new_salary) - avg(salary), -// | avg(new_salary) - avg(prev_salary) -// |FROM $testTable -// |GROUP BY salary -// |ORDER BY salary -// |""".stripMargin), -// Row(10000, 9000, 1000, 1000.0, 9000.0, 1000.0, 1000.0, 1000.0) :: -// Row(11000, 10000, 1000, 1000.0, 10000.0, 1000.0, 1000.0, 1000.0) :: -// Row(13000, 12000, 1000, 1000.0, 12000.0, 1000.0, 1000.0, 1000.0) :: -// Nil -// ) + checkAnswer( + sql( + s""" + |SELECT + | salary + 1000 AS new_salary, + | new_salary - 1000 AS prev_salary, + | new_salary - salary, + | new_salary - avg(salary) + |FROM $testTable + |GROUP BY salary + |ORDER BY salary + |""".stripMargin), + Row(10000, 9000, 1000, 1000.0) :: + Row(11000, 10000, 1000, 1000.0) :: + Row(13000, 12000, 1000, 1000.0) :: + Nil + ) // aggregate expression as LCA, used in various cases of expressions -// checkAnswer( -// sql( -// s""" -// |SELECT -// | sum(salary) AS dept_salary_sum, -// | sum(bonus) AS dept_bonus_sum, -// | dept_salary_sum * 1.5, -// | concat(string(dept_salary_sum), ': dept', string(dept)), -// | dept_salary_sum + sum(bonus), -// | dept_salary_sum + dept_bonus_sum -// |FROM $testTable -// |GROUP BY dept -// |ORDER BY dept -// |""".stripMargin -// ), -// Row(19000, 2200, 28500.0, "19000: dept1", 21200, 21200) :: -// Row(22000, 2500, 33000.0, "22000: dept2", 24500, 24500) :: -// Row(12000, 1200, 18000.0, "12000: dept6", 13200, 13200) :: -// Nil -// ) + checkAnswer( + sql( + s""" + |SELECT + | sum(salary) AS dept_salary_sum, + | sum(bonus) AS dept_bonus_sum, + | dept_salary_sum * 1.5, + | concat(string(dept_salary_sum), ': dept', string(dept)), + | dept_salary_sum + sum(bonus), + | dept_salary_sum + dept_bonus_sum + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin + ), + Row(19000, 2200, 28500.0, "19000: dept1", 21200, 21200) :: + Row(22000, 2500, 33000.0, "22000: dept2", 24500, 24500) :: + Row(12000, 1200, 18000.0, "12000: dept6", 13200, 13200) :: + Nil + ) checkAnswer( sql(s"SELECT sum(salary) AS s, s + sum(bonus) AS total FROM $testTable"), Row(53000, 58900) @@ -233,4 +275,152 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { Nil) } + test("non-deterministic expression as LCA is evaluated only once - Project") { + sql(s"SELECT dept, rand(0) AS r, r FROM $testTable").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(1), row(2))) + } + sql(s"SELECT dept + rand(0) AS r, r FROM $testTable").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(0), row(1))) + } + } + + test("non-deterministic expression as LCA is evaluated only once - Aggregate") { + val groupBySnippet = s"FROM $testTable GROUP BY dept" + sql(s"SELECT dept, rand(0) AS r, r $groupBySnippet").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(1), row(2))) + } + sql(s"SELECT dept + rand(0) AS r, r $groupBySnippet").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(0), row(1))) + } + sql(s"SELECT avg(salary) + rand(0) AS r, r $groupBySnippet").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(0), row(1))) + } + } + + test("Case insensitive lateral column alias") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer( + sql(s"SELECT salary AS new_salary, New_Salary + 1 FROM $testTable WHERE name = 'jen'"), + Row(12000, 12001)) + checkAnswer( + sql( + s""" + |SELECT avg(salary) AS AVG_SALARY, avg_salary + avg(bonus) + |FROM $testTable + |GROUP BY dept + |HAVING dept = 1 + |""".stripMargin), + Row(9500, 10600)) + } + } + + test("Duplicated lateral alias names - Project") { + // Has duplicated names but not referenced is fine + checkAnswer( + sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 1200) + ) + checkAnswer( + sql(s"SELECT salary AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 12000, 10000) + ) + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, 10000 AS d, d + 1 FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, salary * 1.5 AS d, d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary AS d, d + 1 AS d, d + 1 AS d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, bonus * 1.5 AS d, d + d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + + checkAnswer( + sql( + s""" + |SELECT salary * 1.5 AS salary, salary, 10000 AS salary, salary + |FROM $testTable + |WHERE name = 'jen' + |""".stripMargin), + Row(18000, 12000, 10000, 12000) + ) + } + + test("Duplicated lateral alias names - Aggregate") { + // Has duplicated names but not referenced is fine + checkAnswer( + sql(s"SELECT dept AS d, name AS d FROM $testTable GROUP BY dept, name ORDER BY dept, name"), + Row(1, "amy") :: Row(1, "cathy") :: Row(2, "alex") :: Row(2, "david") :: Row(6, "jen") :: Nil + ) + checkAnswer( + sql(s"SELECT dept AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), + Row(1, 1, 10) :: Row(2, 2, 10) :: Row(6, 6, 10) :: Nil + ) + checkAnswer( + sql(s"SELECT sum(salary * 1.5) AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), + Row(28500, 28500, 10) :: Row(33000, 33000, 10) :: Row(18000, 18000, 10) :: Nil + ) + checkAnswer( + sql( + s""" + |SELECT sum(salary * 1.5) AS d, d, d + sum(bonus) AS d + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin), + Row(28500, 28500, 30700) :: Row(33000, 33000, 35500) :: Row(18000, 18000, 19200) :: Nil + ) + + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT dept * 2.0 AS d, d, 10000 AS d, d + 1 FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, dept * 2.0 AS d, d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT avg(salary) AS d, d * 1.0, avg(bonus * 1.5) AS d, d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT dept AS d, d + 1 AS d, d + 1 AS d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + + checkAnswer( + sql(s""" + |SELECT avg(salary * 1.5) AS salary, sum(salary), dept AS salary, avg(salary) + |FROM $testTable + |GROUP BY dept + |HAVING dept = 6 + |""".stripMargin), + Row(18000, 12000, 6, 12000) + ) + } + + test("Attribute cannot be resolved by LCA remain unresolved") { + assert(intercept[AnalysisException] { + sql(s"SELECT dept AS d, d AS new_dept, new_dep + 1 AS newer_dept FROM $testTable") + }.getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + + assert(intercept[AnalysisException] { + sql(s"SELECT count(name) AS cnt, cnt + 1, count(unresovled) FROM $testTable GROUP BY dept") + }.getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + + // TODO: subquery + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 0bb5e5230c188..22cc4fd46cbd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -360,7 +360,7 @@ object QueryTest extends Assertions { None } - private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { case (null, null) => true case (null, _) => false case (_, null) => false From 5785943fbb53b525b4434b7566d4f466461ceb61 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 1 Dec 2022 18:36:33 -0800 Subject: [PATCH 11/31] make changes to accomodate the recent refactor --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../analysis/ResolveLateralColumnAlias.scala | 64 ++++++--- .../expressions/namedExpressions.scala | 12 +- .../sql/catalyst/expressions/subquery.scala | 4 +- .../spark/sql/LateralColumnAliasSuite.scala | 131 ++++++++++++++++++ 5 files changed, 188 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d5d87b768770..3bc98c68d8486 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1844,7 +1844,7 @@ class Analyzer(override val catalogManager: CatalogManager) // Only Project and Aggregate can host star expressions. case u @ (_: Project | _: Aggregate) => Try(s.expand(u.children.head, resolver)) match { - case Success(expanded) => expanded.map(wrapOuterReference) + case Success(expanded) => expanded.map(wrapOuterReference(_)) case Failure(_) => throw e } // Do not use the outer plan to resolve the star expression @@ -2165,7 +2165,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { try { AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match { - case Some(resolved) => wrapOuterReference(resolved) + case Some(resolved) => wrapOuterReference(resolved, Some(nameParts)) case None => u } } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 5462cee65fd09..37e93f7c9105b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression, OuterReference} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -31,6 +31,14 @@ import org.apache.spark.sql.internal.SQLConf * resolved by other rules * - in Aggregate TODO. * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and restores it back to + * [[UnresolvedAttribute]] + * * For Project, it rewrites by inserting a newly created Project plan between the original Project * and its child, pushing the referenced lateral column aliases to this new Project, and updating * the project list of the original Project. @@ -51,19 +59,35 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { def resolver: Resolver = conf.resolver private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { case p @ Project(projectList, child) if p.childrenResolved && !Analyzer.containsStar(projectList) - && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) def insertIntoAliasMap(a: Alias, idx: Int): Unit = { val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def lookUpLCA(e: Expression): Seq[AliasEntry] = { - var matchedLCA: Seq[AliasEntry] = Seq.empty[AliasEntry] - e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def resolveLCA(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + val name = o.nameParts.map(_.head).getOrElse(o.name) + val aliases = aliasMap.get(name).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(o.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + referencedAliases += aliases.head + o.nameParts.map(UnresolvedAttribute(_)).getOrElse(UnresolvedAttribute(o.name)) + case _ => + o + } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) .isInstanceOf[UnresolvedAttribute] => @@ -71,19 +95,15 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { aliases.size match { case n if n > 1 => throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case _ => - val referencedAlias = aliases.head + case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - if (referencedAlias.alias.resolved) { - matchedLCA :+= referencedAlias - } + referencedAliases += aliases.head + case _ => } u - } - matchedLCA + }.asInstanceOf[NamedExpression] } - - val referencedAliases = projectList.zipWithIndex.flatMap { + val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed // down. Unresolved alias is added to the map to perform the ambiguous name check. @@ -92,17 +112,17 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, // it is pushed down. - val matchedLCA = lookUpLCA(a) - insertIntoAliasMap(a, idx) - matchedLCA + val lcaResolved = resolveLCA(a).asInstanceOf[Alias] + insertIntoAliasMap(lcaResolved, idx) + lcaResolved case (e, _) => - lookUpLCA(e) - }.toSet + resolveLCA(e) + } if (referencedAliases.isEmpty) { p } else { - val outerProjectList = collection.mutable.Seq(projectList: _*) + val outerProjectList = collection.mutable.Seq(newProjectList: _*) val innerProjectList = collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8dd28e9aaae3d..f83c3aa462614 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -424,8 +424,18 @@ case class OuterReference(e: NamedExpression) override def qualifier: Seq[String] = e.qualifier override def exprId: ExprId = e.exprId override def toAttribute: Attribute = e.toAttribute - override def newInstance(): NamedExpression = OuterReference(e.newInstance()) + override def newInstance(): NamedExpression = + OuterReference(e.newInstance()).setNameParts(nameParts) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) + + // optional field of the original name parts of UnresolvedAttribute before it is resolved to + // OuterReference. Used in rule ResolveLateralColumnAlias to restore OuterReference back to + // UnresolvedAttribute. + var nameParts: Option[Seq[String]] = None + def setNameParts(newNameParts: Option[Seq[String]]): OuterReference = { + nameParts = newNameParts + this + } } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e7384dac2d53e..d249a2b5a6bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -158,8 +158,8 @@ object SubExprUtils extends PredicateHelper { /** * Wrap attributes in the expression with [[OuterReference]]s. */ - def wrapOuterReference[E <: Expression](e: E): E = { - e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E] + def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = { + e.transform { case a: Attribute => OuterReference(a).setNameParts(nameParts) }.asInstanceOf[E] } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index f6b2c919b794c..adf18958a1e92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -59,6 +60,17 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } + private def withLCAOff(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { + f + } + } + private def withLCAOn(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "true") { + f + } + } + test("Lateral alias in project") { checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), Row(1, 2)) @@ -106,4 +118,123 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { Row(18000, 18000, 10000) ) } + + test("Duplicated lateral alias names - Project") { + def checkDuplicatedAliasErrorHelper(query: String, parameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] {sql(query)}, + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + sqlState = "42000", + parameters = parameters + ) + } + + // Has duplicated names but not referenced is fine + checkAnswer( + sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 1200) + ) + checkAnswer( + sql(s"SELECT salary AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 12000, 10000) + ) + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + checkAnswer( + sql(s"SELECT salary + 1000 AS new_salary, new_salary * 1.0 AS new_salary " + + s"FROM $testTable WHERE name = 'jen'"), + Row(13000, 13000.0)) + + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, 10000 AS d, d + 1 FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, salary * 1.5 AS d, d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary AS d, d + 1 AS d, d + 1 AS d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, bonus * 1.5 AS d, d + d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + + checkAnswer( + sql( + s""" + |SELECT salary * 1.5 AS salary, salary, 10000 AS salary, salary + |FROM $testTable + |WHERE name = 'jen' + |""".stripMargin), + Row(18000, 12000, 10000, 12000) + ) + } + + test("Lateral alias conflicts with OuterReference - Project") { + // an attribute can both be resolved as LCA and OuterReference + val query1 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, id + 1 AS id2)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { checkAnswer(sql(query1), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query1), Seq.empty) } + + // an attribute can only be resolved as LCA + val query2 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id1, id1 + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { + assert(intercept[AnalysisException] { sql(query2) } + .getErrorClass == "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") + } + withLCAOn { checkAnswer(sql(query2), Seq.empty) } + + // an attribute should only be resolved as OuterReference + val query3 = + s""" + |SELECT * + |FROM range(1, 7) outer_table + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, outer_table.id + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + + // a bit complex subquery that the id + 1 is first wrapped with OuterReference + // test if lca rule strips the OuterReference and resolves to lateral alias + val query4 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. + withLCAOn { + val analyzedPlan = sql(query4).queryExecution.analyzed + assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) + // but running it triggers exception + // checkAnswer(sql(query4), Range(1, 7).map(Row(_))) + } + } + // TODO: LCA in subquery } From 757cffb4f0adbb512eb4738fe3e38eea943b474a Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 5 Dec 2022 11:48:19 -0800 Subject: [PATCH 12/31] introduce leaf exp in Project as well --- .../analysis/ResolveLateralColumnAlias.scala | 162 ++++++++++++------ .../expressions/namedExpressions.scala | 36 +++- .../sql/catalyst/trees/TreePatterns.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 9 + .../spark/sql/LateralColumnAliasSuite.scala | 89 ++++++---- 5 files changed, 211 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 37e93f7c9105b..a674c8cdbb423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,76 +17,90 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression, OuterReference} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} +import org.apache.spark.sql.catalyst.expressions.{Alias, LateralColumnAliasReference, NamedExpression, OuterReference} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.rules.{Rule, UnknownRuleId} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list, - * - in Project inserting a new Project node with the referenced alias so that it can be - * resolved by other rules + * Resolve lateral column alias, which references the alias defined previously in the SELECT list. + * Plan-wise it handles two types of operators: Project and Aggregate. + * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve + * the attributes referencing these aliases * - in Aggregate TODO. * - * The name resolution priority: - * local table column > local lateral column alias > outer reference - * - * Because lateral column alias has higher resolution priority than outer reference, it will try - * to resolve an [[OuterReference]] using lateral column alias, similar as an - * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and restores it back to - * [[UnresolvedAttribute]] - * - * For Project, it rewrites by inserting a newly created Project plan between the original Project - * and its child, pushing the referenced lateral column aliases to this new Project, and updating - * the project list of the original Project. + * The whole process is generally divided into two phases: + * 1) recognize lateral alias, wrap the attributes referencing them with + * [[LateralColumnAliasReference]] + * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. + * For Project, it further resolves the attributes and push down the referenced lateral aliases. + * For Aggregate, TODO * + * Example for Project: * Before rewrite: * Project [age AS a, 'a + 1] * +- Child * - * After rewrite: - * Project [a, 'a + 1] + * After phase 1: + * Project [age AS a, lateralalias(a) + 1] + * +- Child + * + * After phase 2: + * Project [a, a + 1] * +- Project [child output, age AS a] * +- Child * - * For Aggregate TODO. + * Example for Aggregate TODO + * + * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with + * [[LateralColumnAliasReference]]. */ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { private case class AliasEntry(alias: Alias, index: Int) + private def insertIntoAliasMap( + a: Alias, + idx: Int, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def resolver: Resolver = conf.resolver private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUpWithPruning( + // phase 1: wrap + val rewrittenPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { case p @ Project(projectList, child) if p.childrenResolved && !Analyzer.containsStar(projectList) && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def insertIntoAliasMap(a: Alias, idx: Int): Unit = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - - val referencedAliases = collection.mutable.Set.empty[AliasEntry] - def resolveLCA(e: NamedExpression): NamedExpression = { + def wrapLCAReference(e: NamedExpression): NamedExpression = { e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { case o: OuterReference if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - val name = o.nameParts.map(_.head).getOrElse(o.name) - val aliases = aliasMap.get(name).get + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get aliases.size match { case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(o.name, n) + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - referencedAliases += aliases.head - o.nameParts.map(UnresolvedAttribute(_)).getOrElse(UnresolvedAttribute(o.name)) - case _ => - o + // TODO We need to resolve to the nested field type, e.g. for query + // SELECT named_struct() AS foo, foo.a, we can't say this foo.a is the + // LateralColumnAliasReference(foo, foo.a). Otherwise, the type can be mismatched + LateralColumnAliasReference(aliases.head.alias, nameParts) + case _ => o } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) @@ -97,28 +111,74 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - referencedAliases += aliases.head - case _ => + // TODO similar problem + LateralColumnAliasReference(aliases.head.alias, u.nameParts) + case _ => u } - u }.asInstanceOf[NamedExpression] } val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => - // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed - // down. Unresolved alias is added to the map to perform the ambiguous name check. - // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, - // because only resolved alias can be LCA, in the first round the rule application, - // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are - // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, - // it is pushed down. - val lcaResolved = resolveLCA(a).asInstanceOf[Alias] - insertIntoAliasMap(lcaResolved, idx) - lcaResolved + val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] + // Insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions (chaining). + // Unresolved Alias is also added to the map to perform ambiguous name check, but only + // resolved alias can be LCA + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped case (e, _) => - resolveLCA(e) + wrapLCAReference(e) + } + p.copy(projectList = newProjectList) + } + + // phase 2: unwrap + rewrittenPlan.resolveOperatorsUpWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), UnknownRuleId) { + case p @ Project(projectList, child) if p.resolved + && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + // build the map again in case the project list changes and index goes off + // TODO one risk: is there any rule that strips off the Alias? that the LCA is resolved + // in the beginning, but when it comes to push down, it really can't find the matching one? + // Restore back to UnresolvedAttribute + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.nameParts.head) => + val aliasEntry = aliasMap.get(lcaRef.nameParts.head).get.head + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + // If there is no chaining, push down the alias and resolve the attribute by + // constructing a dummy plan + referencedAliases += aliasEntry + // Implementation notes (to-delete): + // this is a design decision whether to restore the UnresolvedAttribute, or + // directly resolve by constructing a plan and using resolveExpressionByPlanChildren + Analyzer.resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(lcaRef.nameParts), + plan = Project(Seq(aliasEntry.alias), OneRowRelation()), + resolver = resolver, + throws = false + ) + } else { + // If there is chaining, don't resolve and save to future rounds + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.nameParts.head) => + // It shouldn't happen. Restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.name) + }.asInstanceOf[NamedExpression] } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap = insertIntoAliasMap(a, idx, aliasMap) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } if (referencedAliases.isEmpty) { p } else { @@ -134,7 +194,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { child = Project(innerProjectList.toSeq, child) ) } - } + } } override def apply(plan: LogicalPlan): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f83c3aa462614..4a3e5a6487f13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -428,9 +428,9 @@ case class OuterReference(e: NamedExpression) OuterReference(e.newInstance()).setNameParts(nameParts) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) - // optional field of the original name parts of UnresolvedAttribute before it is resolved to - // OuterReference. Used in rule ResolveLateralColumnAlias to restore OuterReference back to - // UnresolvedAttribute. + // optional field, the original name parts of UnresolvedAttribute before it is resolved to + // OuterReference. Used in rule ResolveLateralColumnAlias to convert OuterReference back to + // LateralColumnAliasReference. var nameParts: Option[Seq[String]] = None def setNameParts(newNameParts: Option[Seq[String]]): OuterReference = { nameParts = newNameParts @@ -438,6 +438,36 @@ case class OuterReference(e: NamedExpression) } } +/** + * A placeholder used to hold a referenced that has been temporarily resolved as the reference + * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. + * + * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all + * analysis check, then all [[LateralColumnAliasReference]] should already be removed. + * + * @param a A resolved [[Alias]] that is a lateral column alias referenced by the current attribute + * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to resolve + * the attribute, or restore back. + */ +case class LateralColumnAliasReference(a: Alias, nameParts: Seq[String]) + extends LeafExpression with NamedExpression with Unevaluable { + assert(a.resolved) + override def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = a.exprId + override def qualifier: Seq[String] = a.qualifier + override def toAttribute: Attribute = a.toAttribute + override def newInstance(): NamedExpression = + LateralColumnAliasReference(a.newInstance().asInstanceOf[Alias], nameParts) + + override def nullable: Boolean = a.nullable + override def dataType: DataType = a.dataType + override def prettyName: String = "lateralAliasReference" + override def sql: String = s"$prettyName($name)" + + final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE) +} + object VirtualColumn { // The attribute name used by Hive, which has different result than Spark, deprecated. val hiveGroupingIdName: String = "grouping__id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 8fca9ec60cdff..1a8ad7c7d6213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -58,6 +58,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 25c509732f9b6..209a80fee2ff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3413,4 +3413,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { ) ) } + def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(nameParts), + "n" -> numOfMatches.toString + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index adf18958a1e92..d78a661c5a7e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -71,7 +71,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - test("Lateral alias in project") { + test("Lateral alias basics - Project") { checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), Row(1, 2)) @@ -91,28 +91,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { s"new_income from $testTable where name = 'amy'"), Row(20000, 23000)) - // When the lateral alias conflicts with the table column, it should resolved as the table - // column - checkAnswer( - sql( - "select salary * 2 as salary, salary * 2 + bonus as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 21000)) - - checkAnswer( - sql( - "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 22000)) - - checkAnswer( - sql( - "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + - s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + - " where name = 'amy'"), - Row(20000, 22000, 11000, 22000)) - - // Corner cases for resolution order + // should referring to the previously defined LCA checkAnswer( sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), Row(18000, 18000, 10000) @@ -176,6 +155,27 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { ) } + test("Lateral alias conflicts with table column - Project") { + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'"), + Row(20000, 22000, 11000, 22000)) + } + test("Lateral alias conflicts with OuterReference - Project") { // an attribute can both be resolved as LCA and OuterReference val query1 = @@ -220,14 +220,14 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { // a bit complex subquery that the id + 1 is first wrapped with OuterReference // test if lca rule strips the OuterReference and resolves to lateral alias val query4 = - s""" - |SELECT * - |FROM range(1, 7) - |WHERE ( - | SELECT id2 - | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 - |ORDER BY id - |""".stripMargin + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 + |ORDER BY id + |""".stripMargin withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. withLCAOn { val analyzedPlan = sql(query4).queryExecution.analyzed @@ -236,5 +236,30 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { // checkAnswer(sql(query4), Range(1, 7).map(Row(_))) } } - // TODO: LCA in subquery + // TODO: more tests on LCA in subquery + + test("Lateral alias of a struct - Project") { + // This test fails now +// checkAnswer( +// sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), +// Row(Row(1), 2)) + } + + test("Lateral alias chaining - Project") { + checkAnswer( + sql( + s""" + |SELECT bonus * 1.1 AS new_bonus, salary + new_bonus AS new_base, + | new_base * 1.1 AS new_total, new_total - new_base AS r, + | new_total - r + |FROM $testTable WHERE name = 'cathy' + |""".stripMargin), + Row(1320, 10320, 11352, 1032, 10320) + ) + + checkAnswer( + sql("SELECT 1 AS a, a + 1 AS b, b - 1, b + 1 AS c, c + 1 AS d, d - a AS e, e + 1"), + Row(1, 2, 1, 3, 4, 3, 4) + ) + } } From 29de892ba1c76f11684167fae78a6abb91165750 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 5 Dec 2022 14:10:56 -0800 Subject: [PATCH 13/31] handle a corner case --- .../analysis/ResolveLateralColumnAlias.scala | 47 ++++++++++++------- .../expressions/namedExpressions.scala | 24 +++++----- .../spark/sql/LateralColumnAliasSuite.scala | 8 ++-- 3 files changed, 47 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index a674c8cdbb423..7a9b6d43c8c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.internal.SQLConf * - in Aggregate TODO. * * The whole process is generally divided into two phases: - * 1) recognize lateral alias, wrap the attributes referencing them with + * 1) recognize resolved lateral alias, wrap the attributes referencing them with * [[LateralColumnAliasReference]] * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. * For Project, it further resolves the attributes and push down the referenced lateral aliases. @@ -64,7 +64,10 @@ import org.apache.spark.sql.internal.SQLConf * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with * [[LateralColumnAliasReference]]. */ +// TODO revisit resolving order: top down, or bottom up object ResolveLateralColumnAlias extends Rule[LogicalPlan] { + def resolver: Resolver = conf.resolver + private case class AliasEntry(alias: Alias, index: Int) private def insertIntoAliasMap( a: Alias, @@ -74,7 +77,20 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def resolver: Resolver = conf.resolver + /** + * Use the given the lateral alias candidate to resolve the name parts. + * @return The resolved attribute if succeeds. None if fails to resolve. + */ + private def resolveByLateralAlias( + nameParts: Seq[String], lateralAlias: Alias): Option[NamedExpression] = { + val resolvedAttr = Analyzer.resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(nameParts), + plan = Project(Seq(lateralAlias), OneRowRelation()), + resolver = resolver, + throws = false + ).asInstanceOf[NamedExpression] + if (resolvedAttr.resolved) Some(resolvedAttr) else None + } private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { // phase 1: wrap @@ -96,10 +112,9 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - // TODO We need to resolve to the nested field type, e.g. for query - // SELECT named_struct() AS foo, foo.a, we can't say this foo.a is the - // LateralColumnAliasReference(foo, foo.a). Otherwise, the type can be mismatched - LateralColumnAliasReference(aliases.head.alias, nameParts) + resolveByLateralAlias(nameParts, aliases.head.alias) + .map(LateralColumnAliasReference(_, nameParts)) + .getOrElse(o) case _ => o } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && @@ -111,8 +126,9 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - // TODO similar problem - LateralColumnAliasReference(aliases.head.alias, u.nameParts) + resolveByLateralAlias(u.nameParts, aliases.head.alias) + .map(LateralColumnAliasReference(_, u.nameParts)) + .getOrElse(u) case _ => u } }.asInstanceOf[NamedExpression] @@ -138,9 +154,13 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { case p @ Project(projectList, child) if p.resolved && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => // build the map again in case the project list changes and index goes off - // TODO one risk: is there any rule that strips off the Alias? that the LCA is resolved + // TODO one risk: is there any rule that strips off /add the Alias? that the LCA is resolved // in the beginning, but when it comes to push down, it really can't find the matching one? - // Restore back to UnresolvedAttribute + // Restore back to UnresolvedAttribute. + // Also, when resolving from bottom up should I worry about cases like: + // Project [b AS c, c + 1 AS d] + // +- Project [1 AS a, a AS b] + // b AS c is resolved, even b refers to an alias contains the lateral alias? var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) val referencedAliases = collection.mutable.Set.empty[AliasEntry] def unwrapLCAReference(e: NamedExpression): NamedExpression = { @@ -154,12 +174,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { // Implementation notes (to-delete): // this is a design decision whether to restore the UnresolvedAttribute, or // directly resolve by constructing a plan and using resolveExpressionByPlanChildren - Analyzer.resolveExpressionByPlanOutput( - expr = UnresolvedAttribute(lcaRef.nameParts), - plan = Project(Seq(aliasEntry.alias), OneRowRelation()), - resolver = resolver, - throws = false - ) + resolveByLateralAlias(lcaRef.nameParts, aliasEntry.alias).getOrElse(lcaRef) } else { // If there is chaining, don't resolve and save to future rounds lcaRef diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 4a3e5a6487f13..7972cd1399f2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -439,29 +439,29 @@ case class OuterReference(e: NamedExpression) } /** - * A placeholder used to hold a referenced that has been temporarily resolved as the reference + * A placeholder used to hold a attribute that has been temporarily resolved as the reference * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. * * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all * analysis check, then all [[LateralColumnAliasReference]] should already be removed. * - * @param a A resolved [[Alias]] that is a lateral column alias referenced by the current attribute - * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to resolve - * the attribute, or restore back. + * @param ne the current attribute is resolved to by lateral column alias + * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to later resolve + * the attribute or restore back. */ -case class LateralColumnAliasReference(a: Alias, nameParts: Seq[String]) +case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String]) extends LeafExpression with NamedExpression with Unevaluable { - assert(a.resolved) + assert(ne.resolved) override def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") - override def exprId: ExprId = a.exprId - override def qualifier: Seq[String] = a.qualifier - override def toAttribute: Attribute = a.toAttribute + override def exprId: ExprId = ne.exprId + override def qualifier: Seq[String] = ne.qualifier + override def toAttribute: Attribute = ne.toAttribute override def newInstance(): NamedExpression = - LateralColumnAliasReference(a.newInstance().asInstanceOf[Alias], nameParts) + LateralColumnAliasReference(ne.newInstance(), nameParts) - override def nullable: Boolean = a.nullable - override def dataType: DataType = a.dataType + override def nullable: Boolean = ne.nullable + override def dataType: DataType = ne.dataType override def prettyName: String = "lateralAliasReference" override def sql: String = s"$prettyName($name)" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index d78a661c5a7e4..17b7f5750697b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -239,10 +239,10 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { // TODO: more tests on LCA in subquery test("Lateral alias of a struct - Project") { - // This test fails now -// checkAnswer( -// sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), -// Row(Row(1), 2)) + checkAnswer( + sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), + Row(Row(1), 2)) + // TODO: more tests } test("Lateral alias chaining - Project") { From 72991c6210b34ef3a0af3e7b2c075f73812f89cb Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 6 Dec 2022 11:46:44 -0800 Subject: [PATCH 14/31] add more tests; add check rule --- .../sql/catalyst/analysis/CheckAnalysis.scala | 21 +++++++++- .../spark/sql/LateralColumnAliasSuite.scala | 42 ++++++++++++++----- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 12dac5c632a3b..9937a06de9a98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WINDOW_EXPRESSION +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} @@ -638,6 +638,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case UnresolvedWindowExpression(_, windowSpec) => throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name) }) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + projectList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if p.resolved => + failUnresolvedAttribute( + p, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + }) case j: Join if !j.duplicateResolved => val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) @@ -714,6 +722,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "operator" -> other.nodeName, "invalidExprSqls" -> invalidExprSqls.mkString(", "))) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + case agg @ Aggregate(_, aggList, _) + if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved => + aggList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference => + failUnresolvedAttribute( + agg, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + }) + case _ => // Analysis successful! } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 17b7f5750697b..5c0c120a87d2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -29,16 +29,24 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { override def beforeAll(): Unit = { super.beforeAll() - sql(s"CREATE TABLE $testTable (dept INTEGER, name String, salary INTEGER, bonus INTEGER) " + - s"using orc") + sql( + s""" + |CREATE TABLE $testTable ( + | dept INTEGER, + | name String, + | salary INTEGER, + | bonus INTEGER, + | properties STRUCT) + |USING orc + |""".stripMargin) sql( s""" |INSERT INTO $testTable VALUES - | (1, 'amy', 10000, 1000), - | (2, 'alex', 12000, 1200), - | (1, 'cathy', 9000, 1200), - | (2, 'david', 10000, 1300), - | (6, 'jen', 12000, 1200) + | (1, 'amy', 10000, 1000, named_struct('joinYear', 2019, 'mostRecentEmployer', 'A')), + | (2, 'alex', 12000, 1200, named_struct('joinYear', 2017, 'mostRecentEmployer', 'A')), + | (1, 'cathy', 9000, 1200, named_struct('joinYear', 2020, 'mostRecentEmployer', 'B')), + | (2, 'david', 10000, 1300, named_struct('joinYear', 2019, 'mostRecentEmployer', 'C')), + | (6, 'jen', 12000, 1200, named_struct('joinYear', 2018, 'mostRecentEmployer', 'D')) |""".stripMargin) } @@ -174,6 +182,16 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + " where name = 'amy'"), Row(20000, 22000, 11000, 22000)) + + checkAnswer( + sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row(2022), 2019)) + + checkAnswer( + sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row("someone"), "amy")) } test("Lateral alias conflicts with OuterReference - Project") { @@ -240,9 +258,13 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { test("Lateral alias of a struct - Project") { checkAnswer( - sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), - Row(Row(1), 2)) - // TODO: more tests + sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), + Row(Row(1), 2, 3)) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), + Row(Row(Row(1)), 2) + ) } test("Lateral alias chaining - Project") { From d45fe31f0aec6ddb670a012e53495554f03c05cb Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 14:48:11 -0800 Subject: [PATCH 15/31] uplift the necessity to resolve expression in second phase; add more tests --- .../analysis/ResolveLateralColumnAlias.scala | 91 +++++++++---------- .../expressions/namedExpressions.scala | 17 ++-- .../spark/sql/LateralColumnAliasSuite.scala | 30 +++++- 3 files changed, 82 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 7a9b6d43c8c17..ad6867042ffa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, LateralColumnAliasReference, NamedExpression, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression, OuterReference} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules.{Rule, UnknownRuleId} import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} @@ -64,7 +64,6 @@ import org.apache.spark.sql.internal.SQLConf * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with * [[LateralColumnAliasReference]]. */ -// TODO revisit resolving order: top down, or bottom up object ResolveLateralColumnAlias extends Rule[LogicalPlan] { def resolver: Resolver = conf.resolver @@ -78,18 +77,27 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { } /** - * Use the given the lateral alias candidate to resolve the name parts. - * @return The resolved attribute if succeeds. None if fails to resolve. + * Use the given lateral alias to resolve the unresolved attribute with the name parts. + * + * Construct a dummy plan with the given lateral alias as project list, use the output of the + * plan to resolve. + * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. */ private def resolveByLateralAlias( - nameParts: Seq[String], lateralAlias: Alias): Option[NamedExpression] = { + nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { + // TODO question: everytime it resolves the extract field it generates a new exprId. + // Does it matter? val resolvedAttr = Analyzer.resolveExpressionByPlanOutput( expr = UnresolvedAttribute(nameParts), plan = Project(Seq(lateralAlias), OneRowRelation()), resolver = resolver, throws = false ).asInstanceOf[NamedExpression] - if (resolvedAttr.resolved) Some(resolvedAttr) else None + if (resolvedAttr.resolved) { + Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) + } else { + None + } } private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { @@ -103,20 +111,6 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) def wrapLCAReference(e: NamedExpression): NamedExpression = { e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - val nameParts = o.nameParts.getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - resolveByLateralAlias(nameParts, aliases.head.alias) - .map(LateralColumnAliasReference(_, nameParts)) - .getOrElse(o) - case _ => o - } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) .isInstanceOf[UnresolvedAttribute] => @@ -126,11 +120,23 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - resolveByLateralAlias(u.nameParts, aliases.head.alias) - .map(LateralColumnAliasReference(_, u.nameParts)) - .getOrElse(u) + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) case _ => u } + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } }.asInstanceOf[NamedExpression] } val newProjectList = projectList.zipWithIndex.map { @@ -139,7 +145,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { // Insert the LCA-resolved alias instead of the unresolved one into map. If it is // resolved, it can be referenced as LCA by later expressions (chaining). // Unresolved Alias is also added to the map to perform ambiguous name check, but only - // resolved alias can be LCA + // resolved alias can be LCA. aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) lcaWrapped case (e, _) => @@ -153,47 +159,36 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), UnknownRuleId) { case p @ Project(projectList, child) if p.resolved && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - // build the map again in case the project list changes and index goes off - // TODO one risk: is there any rule that strips off /add the Alias? that the LCA is resolved - // in the beginning, but when it comes to push down, it really can't find the matching one? - // Restore back to UnresolvedAttribute. - // Also, when resolving from bottom up should I worry about cases like: - // Project [b AS c, c + 1 AS d] - // +- Project [1 AS a, a AS b] - // b AS c is resolved, even b refers to an alias contains the lateral alias? - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + var aliasMap = Map[Attribute, AliasEntry]() val referencedAliases = collection.mutable.Set.empty[AliasEntry] def unwrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.nameParts.head) => - val aliasEntry = aliasMap.get(lcaRef.nameParts.head).get.head + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap(lcaRef.a) + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - // If there is no chaining, push down the alias and resolve the attribute by - // constructing a dummy plan referencedAliases += aliasEntry - // Implementation notes (to-delete): - // this is a design decision whether to restore the UnresolvedAttribute, or - // directly resolve by constructing a plan and using resolveExpressionByPlanChildren - resolveByLateralAlias(lcaRef.nameParts, aliasEntry.alias).getOrElse(lcaRef) + lcaRef.ne } else { - // If there is chaining, don't resolve and save to future rounds lcaRef } - case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.nameParts.head) => - // It shouldn't happen. Restore to unresolved attribute to be safe. - UnresolvedAttribute(lcaRef.name) + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) }.asInstanceOf[NamedExpression] } - val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => val lcaResolved = unwrapLCAReference(a) // Insert the original alias instead of rewritten one to detect chained LCA - aliasMap = insertIntoAliasMap(a, idx, aliasMap) + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) lcaResolved case (e, _) => unwrapLCAReference(e) } + if (referencedAliases.isEmpty) { p } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7972cd1399f2d..ff65eecafc48d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -439,17 +439,20 @@ case class OuterReference(e: NamedExpression) } /** - * A placeholder used to hold a attribute that has been temporarily resolved as the reference - * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. + * A placeholder used to hold a [[NamedExpression]] that has been temporarily resolved as the + * reference to a lateral column alias. * + * This is created and removed by Analyzer rule [[ResolveLateralColumnAlias]]. * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all * analysis check, then all [[LateralColumnAliasReference]] should already be removed. * - * @param ne the current attribute is resolved to by lateral column alias - * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to later resolve - * the attribute or restore back. + * @param ne the resolved [[NamedExpression]] by lateral column alias + * @param nameParts the named parts of the original [[UnresolvedAttribute]]. Used to restore back + * to [[UnresolvedAttribute]] when needed + * @param a the attribute of referenced lateral column alias. Used to match alias when unwrapping + * and resolving LateralColumnAliasReference */ -case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String]) +case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String], a: Attribute) extends LeafExpression with NamedExpression with Unevaluable { assert(ne.resolved) override def name: String = @@ -458,7 +461,7 @@ case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[Strin override def qualifier: Seq[String] = ne.qualifier override def toAttribute: Attribute = ne.toAttribute override def newInstance(): NamedExpression = - LateralColumnAliasReference(ne.newInstance(), nameParts) + LateralColumnAliasReference(ne.newInstance(), nameParts, a) override def nullable: Boolean = ne.nullable override def dataType: DataType = ne.dataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 5c0c120a87d2c..3c528e5997e8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -256,7 +256,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } // TODO: more tests on LCA in subquery - test("Lateral alias of a struct - Project") { + test("Lateral alias of a complex type - Project") { checkAnswer( sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), Row(Row(1), 2, 3)) @@ -265,6 +265,34 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), Row(Row(Row(1)), 2) ) + + checkAnswer( + sql("SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1"), + Row(Seq(1, 2, 3), 2, 3) + ) + checkAnswer( + sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar"), + Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101) + ) + checkAnswer( + sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar"), + Row(Seq(Row(1), Row(2)), 2) + ) + + checkAnswer( + sql("SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1"), + Row(Map("a" -> 1, "b" -> 2), 2, 3) + ) + } + + test("Lateral alias reference attribute further be used by upper plan - Project") { + // this is out of the scope of lateral alias project functionality requirements, but naturally + // supported by the current design + checkAnswer( + sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + + s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), + Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil + ) } test("Lateral alias chaining - Project") { From 1f55f7381e728b0feff5fd89e71b8a4fc1c60ccd Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 15:25:07 -0800 Subject: [PATCH 16/31] address comments to add tests for LCA off --- .../spark/sql/LateralColumnAliasSuite.scala | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 3c528e5997e8e..abeb3bb784124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -59,6 +59,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } val lcaEnabled: Boolean = true + // by default the tests in this suites run with LCA on override protected def test(testName: String, testTags: Tag*)(testFun: => Any) (implicit pos: Position): Unit = { super.test(testName, testTags: _*) { @@ -67,6 +68,11 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } } + // mark special testcases test both LCA on and off + protected def testOnAndOff(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*)(testFun) + } private def withLCAOff(f: => Unit): Unit = { withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { @@ -79,29 +85,35 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - test("Lateral alias basics - Project") { - checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), + testOnAndOff("Lateral alias basics - Project") { + def checkAnswerWhenOnAndExceptionWhenOff(query: String, expectedAnswerLCAOn: Row): Unit = { + withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } + withLCAOff { + assert(intercept[AnalysisException]{ sql(query) } + .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + } + + checkAnswerWhenOnAndExceptionWhenOff( + s"select dept as d, d + 1 as e from $testTable where name = 'amy'", Row(1, 2)) - checkAnswer( - sql( - s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'"), + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'", Row(20000, 21000)) - checkAnswer( - sql( - s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + - s" where name = 'amy'"), + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + + s" where name = 'amy'", Row(20000, 22000)) - checkAnswer( - sql( - "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + - s"new_income from $testTable where name = 'amy'"), + checkAnswerWhenOnAndExceptionWhenOff( + "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + + s"new_income from $testTable where name = 'amy'", Row(20000, 23000)) // should referring to the previously defined LCA - checkAnswer( - sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + checkAnswerWhenOnAndExceptionWhenOff( + s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'", Row(18000, 18000, 10000) ) } @@ -194,7 +206,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { Row(Row("someone"), "amy")) } - test("Lateral alias conflicts with OuterReference - Project") { + testOnAndOff("Lateral alias conflicts with OuterReference - Project") { // an attribute can both be resolved as LCA and OuterReference val query1 = s""" From f753529afe5ca21543a2d1915fcdbe6f63f218d4 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 17:55:40 -0800 Subject: [PATCH 17/31] revert the refactor, split LCA into two rules --- .../sql/catalyst/analysis/Analyzer.scala | 424 +++++++++++------- .../analysis/ResolveLateralColumnAlias.scala | 217 --------- .../ResolveLateralColumnAliasReference.scala | 127 ++++++ .../sql/catalyst/rules/RuleIdCollection.scala | 3 +- 4 files changed, 380 insertions(+), 391 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3bc98c68d8486..6bbf2de445418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} -import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -42,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -183,164 +182,6 @@ object AnalysisContext { } } -object Analyzer extends Logging { - // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id - private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( - (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), - (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), - (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), - ("user", () => CurrentUser(), toPrettySQL), - (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) - ) - - /** - * Literal functions do not require the user to specify braces when calling them - * When an attributes is not resolvable, we try to resolve it as a literal function. - */ - private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { - if (nameParts.length != 1) return None - val name = nameParts.head - literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { - case (_, getFuncExpr, getAliasName) => - val funcExpr = getFuncExpr() - Alias(funcExpr, getAliasName(funcExpr))() - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by - * traversing the input expression in top-down manner. It must be top-down because we need to - * skip over unbound lambda function expression. The lambda expressions are resolved in a - * different place [[ResolveLambdaVariables]]. - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. - */ - private def resolveExpression( - expr: Expression, - resolveColumnByName: Seq[String] => Option[Expression], - getAttrCandidates: () => Seq[Attribute], - resolver: Resolver, - throws: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { - if (e.resolved) return e - e match { - case f: LambdaFunction if !f.bound => f - - case GetColumnByOrdinal(ordinal, _) => - val attrCandidates = getAttrCandidates() - assert(ordinal >= 0 && ordinal < attrCandidates.length) - attrCandidates(ordinal) - - case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => - val attrCandidates = getAttrCandidates() - val matched = attrCandidates.filter(a => resolver(a.name, colName)) - if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChange( - viewName, colName, expectedNumCandidates, matched, viewDDL) - } - matched(ordinal) - - case u @ UnresolvedAttribute(nameParts) => - val result = withPosition(u) { - resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { - // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, - // as we should resolve `UnresolvedAttribute` to a named expression. The caller side - // can trim the top-level alias if it's safe to do so. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. - case Alias(child, _) if !isTopLevel => child - case other => other - }.getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - - case u @ UnresolvedExtractValue(child, fieldName) => - val newChild = innerResolve(child, isTopLevel = false) - if (newChild.resolved) { - withOrigin(u.origin) { - ExtractValue(newChild, fieldName, resolver) - } - } else { - u.copy(child = newChild) - } - - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - } - - try { - innerResolve(expr, isTopLevel = true) - } catch { - case ae: AnalysisException if !throws => - logDebug(ae.getMessage) - expr - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's output attributes. In order to resolve the nested fields correctly, this function - * makes use of `throws` parameter to control when to raise an AnalysisException. - * - * Example : - * SELECT * FROM t ORDER BY a.b - * - * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` - * if there is no such nested field named "b". We should not fail and wait for other rules to - * resolve it if possible. - */ - def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - resolver: Resolver, - throws: Boolean = false): Expression = { - resolveExpression( - expr, - resolveColumnByName = nameParts => { - plan.resolve(nameParts, resolver) - }, - getAttrCandidates = () => plan.output, - resolver = resolver, - throws = throws) - } - - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's children output attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @return resolved Expression. - */ - def resolveExpressionByPlanChildren( - e: Expression, - q: LogicalPlan, - resolver: Resolver): Expression = { - resolveExpression( - e, - resolveColumnByName = nameParts => { - q.resolveChildren(nameParts, resolver) - }, - getAttrCandidates = () => { - assert(q.children.length == 1) - q.children.head.output - }, - resolver = resolver, - throws = true) - } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) -} - /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -417,17 +258,6 @@ class Analyzer(override val catalogManager: CatalogManager) TypeCoercion.typeCoercionRules } - private def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - throws: Boolean = false): Expression = { - Analyzer.resolveExpressionByPlanOutput(expr, plan, resolver, throws) - } - - private def resolveExpressionByPlanChildren(e: Expression, q: LogicalPlan): Expression = { - Analyzer.resolveExpressionByPlanChildren(e, q, resolver) - } - override def batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -458,7 +288,8 @@ class Analyzer(override val catalogManager: CatalogManager) AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: - ResolveLateralColumnAlias :: + WrapLateralColumnAliasReference :: + ResolveLateralColumnAliasReference :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: ResolveNewInstance :: @@ -1558,7 +1389,6 @@ class Analyzer(override val catalogManager: CatalogManager) * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { - import org.apache.spark.sql.catalyst.analysis.Analyzer.containsStar /** Return true if there're conflicting attributes among children's outputs of a plan */ def hasConflictingAttrs(p: LogicalPlan): Boolean = { @@ -1871,6 +1701,12 @@ class Analyzer(override val catalogManager: CatalogManager) }.map(_.asInstanceOf[NamedExpression]) } + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) + private def extractStar(exprs: Seq[Expression]): Seq[Star] = exprs.flatMap(_.collect { case s: Star => s }) @@ -1927,10 +1763,252 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * The first phase to resolve lateral column alias. See comments in + * [[ResolveLateralColumnAliasReference]] for more detailed explanation. + */ + object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { + import ResolveLateralColumnAliasReference.AliasEntry + + private def insertIntoAliasMap( + a: Alias, + idx: Int, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + + /** + * Use the given lateral alias to resolve the unresolved attribute with the name parts. + * + * Construct a dummy plan with the given lateral alias as project list, use the output of the + * plan to resolve. + * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. + */ + private def resolveByLateralAlias( + nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { + // TODO question: everytime it resolves the extract field it generates a new exprId. + // Does it matter? + val resolvedAttr = resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(nameParts), + plan = Project(Seq(lateralAlias), OneRowRelation()), + throws = false + ).asInstanceOf[NamedExpression] + if (resolvedAttr.resolved) { + Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) + } else { + None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + // phase 1: wrap + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { + case p @ Project(projectList, child) if p.childrenResolved + && !ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => + + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + def wrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) + case _ => u + } + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] + // Insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions (chaining). + // Unresolved Alias is also added to the map to perform ambiguous name check, but + // only resolved alias can be LCA. + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped + case (e, _) => + wrapLCAReference(e) + } + p.copy(projectList = newProjectList) + } + } + } + } + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } + // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id + private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( + (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), + (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), + (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), + ("user", () => CurrentUser(), toPrettySQL), + (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) + ) + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { + if (nameParts.length != 1) return None + val name = nameParts.head + literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { + case (_, getFuncExpr, getAliasName) => + val funcExpr = getFuncExpr() + Alias(funcExpr, getAliasName(funcExpr))() + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. + */ + private def resolveExpression( + expr: Expression, + resolveColumnByName: Seq[String] => Option[Expression], + getAttrCandidates: () => Seq[Attribute], + throws: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + + case GetColumnByOrdinal(ordinal, _) => + val attrCandidates = getAttrCandidates() + assert(ordinal >= 0 && ordinal < attrCandidates.length) + attrCandidates(ordinal) + + case GetViewColumnByNameAndOrdinal( + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + val attrCandidates = getAttrCandidates() + val matched = attrCandidates.filter(a => resolver(a.name, colName)) + if (matched.length != expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChange( + viewName, colName, expectedNumCandidates, matched, viewDDL) + } + matched(ordinal) + + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + withOrigin(u.origin) { + ExtractValue(newChild, fieldName, resolver) + } + } else { + u.copy(child = newChild) + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + } + + try { + innerResolve(expr, isTopLevel = true) + } catch { + case ae: AnalysisException if !throws => + logDebug(ae.getMessage) + expr + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false): Expression = { + resolveExpression( + expr, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, resolver) + }, + getAttrCandidates = () => plan.output, + throws = throws) + } + + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan): Expression = { + resolveExpression( + e, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + throws = true) + } + /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala deleted file mode 100644 index ad6867042ffa2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression, OuterReference} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} -import org.apache.spark.sql.catalyst.rules.{Rule, UnknownRuleId} -import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf - -/** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list. - * Plan-wise it handles two types of operators: Project and Aggregate. - * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve - * the attributes referencing these aliases - * - in Aggregate TODO. - * - * The whole process is generally divided into two phases: - * 1) recognize resolved lateral alias, wrap the attributes referencing them with - * [[LateralColumnAliasReference]] - * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. - * For Project, it further resolves the attributes and push down the referenced lateral aliases. - * For Aggregate, TODO - * - * Example for Project: - * Before rewrite: - * Project [age AS a, 'a + 1] - * +- Child - * - * After phase 1: - * Project [age AS a, lateralalias(a) + 1] - * +- Child - * - * After phase 2: - * Project [a, a + 1] - * +- Project [child output, age AS a] - * +- Child - * - * Example for Aggregate TODO - * - * - * The name resolution priority: - * local table column > local lateral column alias > outer reference - * - * Because lateral column alias has higher resolution priority than outer reference, it will try - * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an - * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with - * [[LateralColumnAliasReference]]. - */ -object ResolveLateralColumnAlias extends Rule[LogicalPlan] { - def resolver: Resolver = conf.resolver - - private case class AliasEntry(alias: Alias, index: Int) - private def insertIntoAliasMap( - a: Alias, - idx: Int, - aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - - /** - * Use the given lateral alias to resolve the unresolved attribute with the name parts. - * - * Construct a dummy plan with the given lateral alias as project list, use the output of the - * plan to resolve. - * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. - */ - private def resolveByLateralAlias( - nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { - // TODO question: everytime it resolves the extract field it generates a new exprId. - // Does it matter? - val resolvedAttr = Analyzer.resolveExpressionByPlanOutput( - expr = UnresolvedAttribute(nameParts), - plan = Project(Seq(lateralAlias), OneRowRelation()), - resolver = resolver, - throws = false - ).asInstanceOf[NamedExpression] - if (resolvedAttr.resolved) { - Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) - } else { - None - } - } - - private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - // phase 1: wrap - val rewrittenPlan = plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { - case p @ Project(projectList, child) if p.childrenResolved - && !Analyzer.containsStar(projectList) - && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => - - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def wrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - Analyzer.resolveExpressionByPlanChildren(u, p, resolver) - .isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - // The lateral alias can be a struct and have nested field, need to construct - // a dummy plan to resolve the expression - resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) - case _ => u - } - case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o.nameParts.getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) - case _ => o - } - }.asInstanceOf[NamedExpression] - } - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] - // Insert the LCA-resolved alias instead of the unresolved one into map. If it is - // resolved, it can be referenced as LCA by later expressions (chaining). - // Unresolved Alias is also added to the map to perform ambiguous name check, but only - // resolved alias can be LCA. - aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) - lcaWrapped - case (e, _) => - wrapLCAReference(e) - } - p.copy(projectList = newProjectList) - } - - // phase 2: unwrap - rewrittenPlan.resolveOperatorsUpWithPruning( - _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), UnknownRuleId) { - case p @ Project(projectList, child) if p.resolved - && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - var aliasMap = Map[Attribute, AliasEntry]() - val referencedAliases = collection.mutable.Set.empty[AliasEntry] - def unwrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => - val aliasEntry = aliasMap(lcaRef.a) - // If there is no chaining of lateral column alias reference, push down the alias - // and unwrap the LateralColumnAliasReference to the NamedExpression inside - // If there is chaining, don't resolve and save to future rounds - if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - referencedAliases += aliasEntry - lcaRef.ne - } else { - lcaRef - } - case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => - // It shouldn't happen, but restore to unresolved attribute to be safe. - UnresolvedAttribute(lcaRef.nameParts) - }.asInstanceOf[NamedExpression] - } - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaResolved = unwrapLCAReference(a) - // Insert the original alias instead of rewritten one to detect chained LCA - aliasMap += (a.toAttribute -> AliasEntry(a, idx)) - lcaResolved - case (e, _) => - unwrapLCAReference(e) - } - - if (referencedAliases.isEmpty) { - p - } else { - val outerProjectList = collection.mutable.Seq(newProjectList: _*) - val innerProjectList = - collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) - referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => - outerProjectList.update(idx, alias.toAttribute) - innerProjectList += alias - } - p.copy( - projectList = outerProjectList.toSeq, - child = Project(innerProjectList.toSeq, child) - ) - } - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { - plan - } else { - rewriteLateralColumnAlias(plan) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala new file mode 100644 index 0000000000000..cd0e0b86a8e48 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE} +import org.apache.spark.sql.internal.SQLConf + +/** + * This rule is the second phase to resolve lateral column alias. + * + * Resolve lateral column alias, which references the alias defined previously in the SELECT list. + * Plan-wise, it handles two types of operators: Project and Aggregate. + * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve + * the attributes referencing these aliases + * - in Aggregate TODO. + * + * The whole process is generally divided into two phases: + * 1) recognize resolved lateral alias, wrap the attributes referencing them with + * [[LateralColumnAliasReference]] + * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. + * For Project, it further resolves the attributes and push down the referenced lateral aliases. + * For Aggregate, TODO + * + * Example for Project: + * Before rewrite: + * Project [age AS a, 'a + 1] + * +- Child + * + * After phase 1: + * Project [age AS a, lateralalias(a) + 1] + * +- Child + * + * After phase 2: + * Project [a, a + 1] + * +- Project [child output, age AS a] + * +- Child + * + * Example for Aggregate TODO + * + * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with + * [[LateralColumnAliasReference]]. + */ +object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { + case class AliasEntry(alias: Alias, index: Int) + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + // phase 2: unwrap + plan.resolveOperatorsUpWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { + case p @ Project(projectList, child) if p.resolved + && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + var aliasMap = Map[Attribute, AliasEntry]() + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap(lcaRef.a) + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + referencedAliases += aliasEntry + lcaRef.ne + } else { + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } + + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = collection.mutable.Seq(newProjectList: _*) + val innerProjectList = + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 032b0e7a08fcd..efafd3cfbcde8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -77,6 +77,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$WrapLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" :: "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: @@ -88,7 +89,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" :: "org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" :: "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: - "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAlias" :: + "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: From b9f706f9ea23bf80e9248e61136612b1c6ee363b Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 18:13:15 -0800 Subject: [PATCH 18/31] better refactor --- .../sql/catalyst/analysis/Analyzer.scala | 82 +++++++++++-------- 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6bbf2de445418..5d94defc68d2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1801,50 +1801,61 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * Recognize all the attributes in the given expression that reference lateral column aliases + * by looking up the alias map. Resolve these attributes and replace by wrapping with + * [[LateralColumnAliasReference]]. + * + * @param currentPlan Because lateral alias has lower resolution priority than table columns, + * the current plan is needed to first try resolving the attribute by its + * children + */ + private def wrapLCARefHelper( + e: NamedExpression, + currentPlan: LogicalPlan, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) + case _ => u + } + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } + }.asInstanceOf[NamedExpression] + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan } else { - // phase 1: wrap plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { - case p @ Project(projectList, child) if p.childrenResolved + case p @ Project(projectList, _) if p.childrenResolved && !ResolveReferences.containsStar(projectList) && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def wrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - // The lateral alias can be a struct and have nested field, need to construct - // a dummy plan to resolve the expression - resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) - case _ => u - } - case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o.nameParts.getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) - case _ => o - } - }.asInstanceOf[NamedExpression] - } val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => - val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] + val lcaWrapped = wrapLCARefHelper(a, p, aliasMap).asInstanceOf[Alias] // Insert the LCA-resolved alias instead of the unresolved one into map. If it is // resolved, it can be referenced as LCA by later expressions (chaining). // Unresolved Alias is also added to the map to perform ambiguous name check, but @@ -1852,7 +1863,7 @@ class Analyzer(override val catalogManager: CatalogManager) aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) lcaWrapped case (e, _) => - wrapLCAReference(e) + wrapLCARefHelper(e, p, aliasMap) } p.copy(projectList = newProjectList) } @@ -1914,7 +1925,7 @@ class Analyzer(override val catalogManager: CatalogManager) attrCandidates(ordinal) case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => val attrCandidates = getAttrCandidates() val matched = attrCandidates.filter(a => resolver(a.name, colName)) if (matched.length != expectedNumCandidates) { @@ -1985,7 +1996,6 @@ class Analyzer(override val catalogManager: CatalogManager) throws = throws) } - /** * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the * input plan's children output attributes. From 94d5c9ee7c095b40ea5fe676fa50bc7acc5fe885 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Fri, 9 Dec 2022 13:28:07 -0800 Subject: [PATCH 19/31] address comments --- .../spark/sql/catalyst/expressions/AttributeMap.scala | 3 ++- .../spark/sql/catalyst/expressions/AttributeMap.scala | 3 +++ .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 +--- .../analysis/ResolveLateralColumnAliasReference.scala | 8 ++++---- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index c55c542d957de..504b65e3db693 100644 --- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,7 +49,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined - override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 3d5d6471d26d4..ac6149f3acc4d 100644 --- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,6 +49,9 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) + override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] = baseMap.values.toMap + (key -> value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5d94defc68d2a..a56a3d9cb6bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1787,11 +1787,9 @@ class Analyzer(override val catalogManager: CatalogManager) */ private def resolveByLateralAlias( nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { - // TODO question: everytime it resolves the extract field it generates a new exprId. - // Does it matter? val resolvedAttr = resolveExpressionByPlanOutput( expr = UnresolvedAttribute(nameParts), - plan = Project(Seq(lateralAlias), OneRowRelation()), + plan = LocalRelation(Seq(lateralAlias.toAttribute)), throws = false ).asInstanceOf[NamedExpression] if (resolvedAttr.resolved) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index cd0e0b86a8e48..c86d0a6dff0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE} +import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE import org.apache.spark.sql.internal.SQLConf /** @@ -76,12 +76,12 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { case p @ Project(projectList, child) if p.resolved && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - var aliasMap = Map[Attribute, AliasEntry]() + var aliasMap = AttributeMap.empty[AliasEntry] val referencedAliases = collection.mutable.Set.empty[AliasEntry] def unwrapLCAReference(e: NamedExpression): NamedExpression = { e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => - val aliasEntry = aliasMap(lcaRef.a) + val aliasEntry = aliasMap.get(lcaRef.a).get // If there is no chaining of lateral column alias reference, push down the alias // and unwrap the LateralColumnAliasReference to the NamedExpression inside // If there is chaining, don't resolve and save to future rounds From edde37c5a6f4bab92faae65bcd09c4ba7c80fba6 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Fri, 9 Dec 2022 15:02:48 -0800 Subject: [PATCH 20/31] basic version passing all tests --- .../sql/catalyst/analysis/Analyzer.scala | 90 ++++++++++--------- .../ResolveLateralColumnAliasReference.scala | 63 ++++--------- .../expressions/namedExpressions.scala | 29 ------ .../sql/errors/QueryCompilationErrors.scala | 2 +- .../spark/sql/LateralColumnAliasSuite.scala | 44 --------- 5 files changed, 69 insertions(+), 159 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 88845d8558bfa..e0c74f915afb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1754,7 +1754,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aliases = aliasMap.get(u.nameParts.head).get aliases.size match { case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias // The lateral alias can be a struct and have nested field, need to construct @@ -1769,7 +1769,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aliases = aliasMap.get(nameParts.head).get aliases.size match { case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(nameParts, n) case n if n == 1 && aliases.head.alias.resolved => resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) case _ => o @@ -1804,51 +1804,59 @@ class Analyzer(override val catalogManager: CatalogManager) // wrap LCA // Implementation notes: // In Aggregate, introducing and wrapping this resolved leaf expression - // LateralColumnAliasReference is especially needed because it needs an accurate condition to - // trigger adding a Project above and extracting aggregate functions or grouping expressions. - // Such operation can only be done once. With this LateralColumnAliasReference, the condition - // can simply be when the whole Aggregate is resolved. Otherwise, it can't really tell if - // all aggregate functions are created and resolved, because the lateral alias reference - // itself is unresolved. - case agg @ Aggregate(groupingExpressions, aggregateExpressions, child) - if agg.childrenResolved - && !ResolveReferences.containsStar(aggregateExpressions) - && aggregateExpressions.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + // LateralColumnAliasReference is especially needed because it needs an accurate condition + // to trigger adding a Project above and extracting aggregate functions or grouping + // expressions. Such operation can only be done once. With this + // LateralColumnAliasReference, the condition can simply be when the whole Aggregate is + // resolved. Otherwise, it can't really tell if all aggregate functions are created and + // resolved, because the lateral alias reference itself is unresolved. + case agg @ Aggregate(_, aggExprs, _) if agg.childrenResolved + && !ResolveReferences.containsStar(aggExprs) + && aggExprs.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def insertIntoAliasMap(a: Alias, idx: Int): Unit = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - def wrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - resolveExpressionByPlanChildren(u, agg, resolver) - .isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) - case n if n == 1 && aliases.head.alias.resolved => - LateralColumnAliasReference(aliases.head.alias, u.nameParts) - case _ => - u - } - }.asInstanceOf[NamedExpression] - } - - val newAggExprs = aggregateExpressions.zipWithIndex.map { + val newAggExprs = aggExprs.zipWithIndex.map { case (a: Alias, idx) => - val LCAResolved = wrapLCAReference(a).asInstanceOf[Alias] - // insert the LCA-resolved alias instead of the unresolved one into map. If it is - // resolved, it can be referenced as LCA by later expressions - insertIntoAliasMap(LCAResolved, idx) - LCAResolved + val lcaWrapped = wrapLCARefHelper(a, agg, aliasMap).asInstanceOf[Alias] + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped case (e, _) => - wrapLCAReference(e) + wrapLCARefHelper(e, agg, aliasMap) } agg.copy(aggregateExpressions = newAggExprs) - +// def insertIntoAliasMap(a: Alias, idx: Int): Unit = { +// val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) +// aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) +// } +// def wrapLCAReference(e: NamedExpression): NamedExpression = { +// e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { +// case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && +// resolveExpressionByPlanChildren(u, agg) +// .isInstanceOf[UnresolvedAttribute] => +// val aliases = aliasMap.get(u.nameParts.head).get +// aliases.size match { +// case n if n > 1 => +// throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) +// case n if n == 1 && aliases.head.alias.resolved => +// LateralColumnAliasReference( +// aliases.head.alias, u.nameParts, aliases.head.alias.toAttribute) +// case _ => +// u +// } +// }.asInstanceOf[NamedExpression] +// } + +// val newAggExprs = aggExprs.zipWithIndex.map { +// case (a: Alias, idx) => +// val LCAResolved = wrapLCAReference(a).asInstanceOf[Alias] +// // insert the LCA-resolved alias instead of the unresolved one into map. If it is +// // resolved, it can be referenced as LCA by later expressions +// insertIntoAliasMap(LCAResolved, idx) +// LCAResolved +// case (e, _) => +// wrapLCAReference(e) +// } +// agg.copy(aggregateExpressions = newAggExprs) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 278fd254328fa..a9824fda3c118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf * Plan-wise, it handles two types of operators: Project and Aggregate. * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve * the attributes referencing these aliases - * - in Aggregate TODO. + * - in Aggregate TODO inserting the Project node above and fall back to the resolution of Project. * * The whole process is generally divided into two phases: * 1) recognize resolved lateral alias, wrap the attributes referencing them with @@ -57,35 +57,6 @@ import org.apache.spark.sql.internal.SQLConf * * Example for Aggregate TODO * - * - * The name resolution priority: - * local table column > local lateral column alias > outer reference - * - * Because lateral column alias has higher resolution priority than outer reference, it will try - * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an - * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with - * [[LateralColumnAliasReference]]. - */ - -/** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list. - * - in Project, inserting a new Project node below with the referenced alias so that it can be - * resolved by other rules - * - in Aggregate, inserting the Project node above and fall back to the resolution of Project - * - * For Project, it rewrites by inserting a newly created Project plan between the original Project - * and its child, pushing the referenced lateral column aliases to this new Project, and updating - * the project list of the original Project. - * - * Before rewrite: - * Project [age AS a, 'a + 1] - * +- Child - * - * After rewrite: - * Project [a, 'a + 1] - * +- Project [child output, age AS a] - * +- Child - * * For Aggregate, it first wraps the attribute resolved by lateral alias with * [[LateralColumnAliasReference]]. * Before wrap (omit some cast or alias): @@ -105,18 +76,20 @@ import org.apache.spark.sql.internal.SQLConf * Project [dept#14 AS a#12, 'a + 1, avg(salary)#26 AS b#13, 'b + avg(bonus)#27] * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,dept#14] * +- Child [dept#14,name#15,salary#16,bonus#17] + * + * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with + * [[LateralColumnAliasReference]]. */ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { case class AliasEntry(alias: Alias, index: Int) - def unwrapLCAReference(exprs: Seq[NamedExpression]): Seq[NamedExpression] = { - exprs.map { expr => - expr.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - case l: LateralColumnAliasReference => - UnresolvedAttribute(l.nameParts) - }.asInstanceOf[NamedExpression] - } - } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { @@ -172,8 +145,8 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { child = Project(innerProjectList.toSeq, child) ) } - case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) - if agg.resolved + + case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) if agg.resolved && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => val newAggExprs = collection.mutable.Set.empty[NamedExpression] @@ -194,16 +167,18 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { ne.toAttribute case e if groupingExpressions.exists(_.semanticEquals(e)) => // TODO (improvement) dedup + // TODO one concern here, is condition here be able to match all grouping + // expressions? For example, Agg [age + 10] [a + age + 10], when transforming down, + // is it possible that (a + age) + 10, so that it won't be able to match (age + 10) + // add a test. val alias = ResolveAliases.assignAliases(Seq(UnresolvedAlias(e))).head newAggExprs += alias alias.toAttribute }.asInstanceOf[NamedExpression] } - val unwrappedAggExprs = unwrapLCAReference(newAggExprs.toSeq) - val unwrappedProjectExprs = unwrapLCAReference(projectExprs) Project( - projectList = unwrappedProjectExprs, - child = agg.copy(aggregateExpressions = unwrappedAggExprs) + projectList = projectExprs, + child = agg.copy(aggregateExpressions = newAggExprs.toSeq) ) // TODO: think about a corner case, when the Alias passed to LateralColumnAliasReference // contains a LateralColumnAliasReference. Is it safe to do a.toAttribute when resolving diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 24978449424ae..ff65eecafc48d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -471,35 +471,6 @@ case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[Strin final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE) } -/** - * A placeholder used to hold a referenced that has been temporarily resolved as the reference - * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. - * - * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all - * analysis check, then all [[LateralColumnAliasReference]] should already be removed. - * - * @param a A resolved [[Alias]] that is a lateral column alias referenced by the current attribute - * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to restore back - */ -case class LateralColumnAliasReference(a: Alias, nameParts: Seq[String]) - extends LeafExpression with NamedExpression with Unevaluable { - assert(a.resolved) - override def name: String = - nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") - override def exprId: ExprId = a.exprId - override def qualifier: Seq[String] = a.qualifier - override def toAttribute: Attribute = a.toAttribute - override def newInstance(): NamedExpression = - LateralColumnAliasReference(a.newInstance().asInstanceOf[Alias], nameParts) - - override def nullable: Boolean = a.nullable - override def dataType: DataType = a.dataType - override def prettyName: String = "lateralAliasReference" - override def sql: String = s"$prettyName($name)" - - final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE) -} - object VirtualColumn { // The attribute name used by Hive, which has different result than Spark, deprecated. val hiveGroupingIdName: String = "grouping__id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index a9173d3fd5b0d..1718cab406a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3414,7 +3414,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { ) } - def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = { + def ambiguousLateralColumnAliasError(nameParts: Seq[String], numOfMatches: Int): Throwable = { new AnalysisException( errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", messageParameters = Map( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index bdb4bb0ef9dfb..aa6475ff397cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -531,50 +531,6 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - test("Duplicated lateral alias names - Project") { - // Has duplicated names but not referenced is fine - checkAnswer( - sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), - Row(12000, 1200) - ) - checkAnswer( - sql(s"SELECT salary AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), - Row(12000, 12000, 10000) - ) - checkAnswer( - sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), - Row(18000, 18000, 10000) - ) - - // Referencing duplicated names raises error - checkDuplicatedAliasErrorHelper( - s"SELECT salary * 1.5 AS d, d, 10000 AS d, d + 1 FROM $testTable", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - checkDuplicatedAliasErrorHelper( - s"SELECT 10000 AS d, d * 1.0, salary * 1.5 AS d, d FROM $testTable", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - checkDuplicatedAliasErrorHelper( - s"SELECT salary AS d, d + 1 AS d, d + 1 AS d FROM $testTable", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - checkDuplicatedAliasErrorHelper( - s"SELECT salary * 1.5 AS d, d, bonus * 1.5 AS d, d + d FROM $testTable", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - - checkAnswer( - sql( - s""" - |SELECT salary * 1.5 AS salary, salary, 10000 AS salary, salary - |FROM $testTable - |WHERE name = 'jen' - |""".stripMargin), - Row(18000, 12000, 10000, 12000) - ) - } - test("Duplicated lateral alias names - Aggregate") { // Has duplicated names but not referenced is fine checkAnswer( From fb7b18cd480bf9369ea9e6b128dfcb42d0361407 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 12 Dec 2022 15:56:14 -0800 Subject: [PATCH 21/31] update the logic, add and refactor tests --- .../sql/catalyst/analysis/Analyzer.scala | 45 +- .../ResolveLateralColumnAliasReference.scala | 40 +- .../spark/sql/LateralColumnAliasSuite.scala | 698 ++++++++++-------- 3 files changed, 425 insertions(+), 358 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e0c74f915afb2..b3edb5eb9a4b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1801,15 +1801,15 @@ class Analyzer(override val catalogManager: CatalogManager) } p.copy(projectList = newProjectList) - // wrap LCA // Implementation notes: // In Aggregate, introducing and wrapping this resolved leaf expression // LateralColumnAliasReference is especially needed because it needs an accurate condition - // to trigger adding a Project above and extracting aggregate functions or grouping - // expressions. Such operation can only be done once. With this - // LateralColumnAliasReference, the condition can simply be when the whole Aggregate is - // resolved. Otherwise, it can't really tell if all aggregate functions are created and - // resolved, because the lateral alias reference itself is unresolved. + // to trigger adding a Project above and extracting and pushing down aggregate functions + // or grouping expressions. Such operation can only be done once. With this + // LateralColumnAliasReference, that condition can simply be when the whole Aggregate is + // resolved. Otherwise, it can't tell if all aggregate functions are created and + // resolved so that it can start the extraction, because the lateral alias reference is + // unresolved and can be the argument to functions, blocking the resolution of functions. case agg @ Aggregate(_, aggExprs, _) if agg.childrenResolved && !ResolveReferences.containsStar(aggExprs) && aggExprs.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => @@ -1824,39 +1824,6 @@ class Analyzer(override val catalogManager: CatalogManager) wrapLCARefHelper(e, agg, aliasMap) } agg.copy(aggregateExpressions = newAggExprs) -// def insertIntoAliasMap(a: Alias, idx: Int): Unit = { -// val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) -// aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) -// } -// def wrapLCAReference(e: NamedExpression): NamedExpression = { -// e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { -// case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && -// resolveExpressionByPlanChildren(u, agg) -// .isInstanceOf[UnresolvedAttribute] => -// val aliases = aliasMap.get(u.nameParts.head).get -// aliases.size match { -// case n if n > 1 => -// throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) -// case n if n == 1 && aliases.head.alias.resolved => -// LateralColumnAliasReference( -// aliases.head.alias, u.nameParts, aliases.head.alias.toAttribute) -// case _ => -// u -// } -// }.asInstanceOf[NamedExpression] -// } - -// val newAggExprs = aggExprs.zipWithIndex.map { -// case (a: Alias, idx) => -// val LCAResolved = wrapLCAReference(a).asInstanceOf[Alias] -// // insert the LCA-resolved alias instead of the unresolved one into map. If it is -// // resolved, it can be referenced as LCA by later expressions -// insertIntoAliasMap(LCAResolved, idx) -// LCAResolved -// case (e, _) => -// wrapLCAReference(e) -// } -// agg.copy(aggregateExpressions = newAggExprs) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index a9824fda3c118..4ee7749a3b236 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -89,8 +89,6 @@ import org.apache.spark.sql.internal.SQLConf object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { case class AliasEntry(alias: Alias, index: Int) - - override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan @@ -150,10 +148,10 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => val newAggExprs = collection.mutable.Set.empty[NamedExpression] + val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression] val projectExprs = aggregateExpressions.map { exp => exp.transformDown { case aggExpr: AggregateExpression => - // TODO (improvement) dedup // Doesn't support referencing a lateral alias in aggregate function if (aggExpr.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { aggExpr.collectFirst { @@ -162,29 +160,37 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { lcaRef.nameParts, aggExpr) } } - val ne = ResolveAliases.assignAliases(Seq(UnresolvedAlias(aggExpr))).head + val ne = expressionMap.getOrElseUpdate( + aggExpr.canonicalized, + ResolveAliases.assignAliases(Seq(UnresolvedAlias(aggExpr))).map { + // TODO temporarily clear the metadata for an issue found in test + case a: Alias => a.copy(a.child, a.name)( + a.exprId, a.qualifier, None, a.nonInheritableMetadataKeys) + case other => other + }.head) newAggExprs += ne ne.toAttribute case e if groupingExpressions.exists(_.semanticEquals(e)) => - // TODO (improvement) dedup // TODO one concern here, is condition here be able to match all grouping // expressions? For example, Agg [age + 10] [a + age + 10], when transforming down, // is it possible that (a + age) + 10, so that it won't be able to match (age + 10) // add a test. - val alias = ResolveAliases.assignAliases(Seq(UnresolvedAlias(e))).head - newAggExprs += alias - alias.toAttribute + val ne = expressionMap.getOrElseUpdate( + e.canonicalized, + ResolveAliases.assignAliases(Seq(UnresolvedAlias(e))).head) + newAggExprs += ne + ne.toAttribute }.asInstanceOf[NamedExpression] } - Project( - projectList = projectExprs, - child = agg.copy(aggregateExpressions = newAggExprs.toSeq) - ) - // TODO: think about a corner case, when the Alias passed to LateralColumnAliasReference - // contains a LateralColumnAliasReference. Is it safe to do a.toAttribute when resolving - // the LateralColumnAliasReference? + if (newAggExprs.isEmpty) { + agg + } else { + Project( + projectList = projectExprs, + child = agg.copy(aggregateExpressions = newAggExprs.toSeq) + ) + } // TODO withOrigin? - } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index aa6475ff397cb..bba351d0aa31c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -20,11 +20,28 @@ package org.apache.spark.sql import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionSet} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { +/** + * Lateral column alias base suite with LCA off, extended by LateralColumnAliasSuite with LCA on. + * Should test behaviors remaining the same no matter LCA conf is on or off. + */ +class LateralColumnAliasSuiteBase extends QueryTest with SharedSparkSession { + // by default the tests in this suites run with LCA off + val lcaEnabled: Boolean = false + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { + testFun + } + } + } + protected val testTable: String = "employee" override def beforeAll(): Unit = { @@ -58,33 +75,111 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - val lcaEnabled: Boolean = true - // by default the tests in this suites run with LCA on - override protected def test(testName: String, testTags: Tag*)(testFun: => Any) - (implicit pos: Position): Unit = { - super.test(testName, testTags: _*) { - withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { - testFun - } - } - } - // mark special testcases test both LCA on and off - protected def testOnAndOff(testName: String, testTags: Tag*)(testFun: => Any) - (implicit pos: Position): Unit = { - super.test(testName, testTags: _*)(testFun) - } - - private def withLCAOff(f: => Unit): Unit = { + protected def withLCAOff(f: => Unit): Unit = { withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { f } } - private def withLCAOn(f: => Unit): Unit = { + protected def withLCAOn(f: => Unit): Unit = { withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "true") { f } } + test("Lateral alias conflicts with table column - Project") { + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row(2022), 2019)) + + checkAnswer( + sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row("someone"), "amy")) + + // CTE table + checkAnswer( + sql( + s""" + |WITH temp_table(x, y) AS (SELECT 1, 2) + |SELECT 100 AS x, x + 1 + |FROM temp_table + |""".stripMargin + ), + Row(100, 2)) + } + + test("Lateral alias conflicts with table column - Aggregate") { + checkAnswer( + sql( + s""" + |SELECT + | sum(salary) AS salary, + | sum(bonus) AS bonus, + | avg(salary) AS avg_s, + | avg(salary + bonus) AS avg_t + |FROM $testTable GROUP BY dept ORDER BY dept + |""".stripMargin), + Row(19000, 2200, 9500.0, 10600.0) :: + Row(22000, 2500, 11000.0, 12250.0) :: + Row(12000, 1200, 12000.0, 13200.0) :: + Nil) + + // TODO: how does it correctly resolve to the right dept in SORT? + checkAnswer( + sql(s"SELECT avg(bonus) AS dept, dept, avg(salary) " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(1100, 1, 9500.0) :: Row(1250, 2, 11000) :: Row(1200, 6, 12000) :: Nil + ) + + checkAnswer( + sql("SELECT named_struct('joinYear', 2022) AS properties, min(properties.joinYear) " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(Row(2022), 2019) :: Row(Row(2022), 2017) :: Row(Row(2022), 2018) :: Nil) + + checkAnswer( + sql(s"SELECT named_struct('salary', 20000) AS $testTable, avg($testTable.salary) " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(Row(20000), 9500) :: Row(Row(20000), 11000) :: Row(Row(20000), 12000) :: Nil) + + // CTE table + checkAnswer( + sql( + s""" + |WITH temp_table(x, y) AS (SELECT 1, 2) + |SELECT 100 AS x, x + 1 + |FROM temp_table + |GROUP BY x + |""".stripMargin), + Row(100, 2)) + } +} + +/** + * Lateral column alias base with LCA on. + */ +class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { + // by default the tests in this suites run with LCA on + override val lcaEnabled: Boolean = true + + // mark special testcases test both LCA on and off + protected def testOnAndOff(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*)(testFun) + } + private def checkDuplicatedAliasErrorHelper( query: String, parameters: Map[String, String]): Unit = { checkError( @@ -95,49 +190,160 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { ) } - testOnAndOff("Lateral alias basics - Project") { - def checkAnswerWhenOnAndExceptionWhenOff(query: String, expectedAnswerLCAOn: Row): Unit = { - withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } - withLCAOff { - assert(intercept[AnalysisException]{ sql(query) } - .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - } + private def checkAnswerWhenOnAndExceptionWhenOff( + query: String, expectedAnswerLCAOn: Seq[Row]): Unit = { + withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } + withLCAOff { + assert(intercept[AnalysisException]{ sql(query) } + .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") } + } + testOnAndOff("Lateral alias basics - Project") { checkAnswerWhenOnAndExceptionWhenOff( s"select dept as d, d + 1 as e from $testTable where name = 'amy'", - Row(1, 2)) + Row(1, 2) :: Nil) checkAnswerWhenOnAndExceptionWhenOff( s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'", - Row(20000, 21000)) + Row(20000, 21000) :: Nil) checkAnswerWhenOnAndExceptionWhenOff( s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + s" where name = 'amy'", - Row(20000, 22000)) + Row(20000, 22000) :: Nil) checkAnswerWhenOnAndExceptionWhenOff( "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + s"new_income from $testTable where name = 'amy'", - Row(20000, 23000)) + Row(20000, 23000) :: Nil) // should referring to the previously defined LCA checkAnswerWhenOnAndExceptionWhenOff( s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'", - Row(18000, 18000, 10000) - ) + Row(18000, 18000, 10000) :: Nil) + + // LCA and conflicted table column mixed + checkAnswerWhenOnAndExceptionWhenOff( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'", + Row(20000, 22000, 11000, 22000) :: Nil) } - test("Duplicated lateral alias names - Project") { - def checkDuplicatedAliasErrorHelper(query: String, parameters: Map[String, String]): Unit = { + testOnAndOff("Lateral alias basics - Aggregate") { + // doesn't support lca used in aggregation functions + withLCAOn( checkError( - exception = intercept[AnalysisException] {sql(query)}, - errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", - sqlState = "42000", - parameters = parameters - ) - } + exception = intercept[AnalysisException] { + sql(s"SELECT 10000 AS lca, count(lca) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`lca`", + "aggFunc" -> "\"count(lateralAliasReference(lca))\"" + ))) + withLCAOn( + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT dept AS lca, avg(lca) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`lca`", + "aggFunc" -> "\"avg(lateralAliasReference(lca))\"" + ))) + // doesn't support nested aggregate expressions + withLCAOn( + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT sum(salary) AS a, avg(a) FROM $testTable") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`a`", + "aggFunc" -> "\"avg(lateralAliasReference(a))\"" + ))) + + // literal as LCA, used in various cases of expressions + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | 10000 AS baseline_salary, + | baseline_salary * 1.5, + | baseline_salary + dept * 10000, + | baseline_salary + avg(bonus) + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin, + Row(10000, 15000.0, 20000, 11100.0) :: + Row(10000, 15000.0, 30000, 11250.0) :: + Row(10000, 15000.0, 70000, 11200.0) :: Nil + ) + + // grouping attribute as LCA, used in various cases of expressions + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | salary + 1000 AS new_salary, + | new_salary - 1000 AS prev_salary, + | new_salary - salary, + | new_salary - avg(salary) + |FROM $testTable + |GROUP BY salary + |ORDER BY salary + |""".stripMargin, + Row(10000, 9000, 1000, 1000.0) :: + Row(11000, 10000, 1000, 1000.0) :: + Row(13000, 12000, 1000, 1000.0) :: Nil + ) + + // aggregate expression as LCA, used in various cases of expressions + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | sum(salary) AS dept_salary_sum, + | sum(bonus) AS dept_bonus_sum, + | dept_salary_sum * 1.5, + | concat(string(dept_salary_sum), ': dept', string(dept)), + | dept_salary_sum + sum(bonus), + | dept_salary_sum + dept_bonus_sum, + | avg(salary * 1.5 + 10000 + bonus * 1.0) AS avg_total, + | avg_total + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin, + Row(19000, 2200, 28500.0, "19000: dept1", 21200, 21200, 25350, 25350) :: + Row(22000, 2500, 33000.0, "22000: dept2", 24500, 24500, 27750, 27750) :: + Row(12000, 1200, 18000.0, "12000: dept6", 13200, 13200, 29200, 29200) :: + Nil + ) + checkAnswerWhenOnAndExceptionWhenOff( + s"SELECT sum(salary) AS s, s + sum(bonus) AS total FROM $testTable", + Row(53000, 58900) :: Nil + ) + // LCA and conflicted table column mixed + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | sum(salary) AS salary, + | sum(bonus) AS bonus, + | avg(salary) AS avg_s, + | avg(salary + bonus) AS avg_t, + | avg_s + avg_t + |FROM $testTable GROUP BY dept ORDER BY dept + |""".stripMargin, + Row(19000, 2200, 9500.0, 10600.0, 20100.0) :: + Row(22000, 2500, 11000.0, 12250.0, 23250.0) :: + Row(12000, 1200, 12000.0, 13200.0, 25200.0) :: Nil) + } + + test("Duplicated lateral alias names - Project") { // Has duplicated names but not referenced is fine checkAnswer( sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), @@ -185,35 +391,58 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { ) } - test("Lateral alias conflicts with table column - Project") { + test("Duplicated lateral alias names - Aggregate") { + // Has duplicated names but not referenced is fine checkAnswer( - sql( - "select salary * 2 as salary, salary * 2 + bonus as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 21000)) - + sql(s"SELECT dept AS d, name AS d FROM $testTable GROUP BY dept, name ORDER BY dept, name"), + Row(1, "amy") :: Row(1, "cathy") :: Row(2, "alex") :: Row(2, "david") :: Row(6, "jen") :: Nil + ) checkAnswer( - sql( - "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 22000)) - + sql(s"SELECT dept AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), + Row(1, 1, 10) :: Row(2, 2, 10) :: Row(6, 6, 10) :: Nil + ) + checkAnswer( + sql(s"SELECT sum(salary * 1.5) AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), + Row(28500, 28500, 10) :: Row(33000, 33000, 10) :: Row(18000, 18000, 10) :: Nil + ) checkAnswer( sql( - "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + - s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + - " where name = 'amy'"), - Row(20000, 22000, 11000, 22000)) + s""" + |SELECT sum(salary * 1.5) AS d, d, d + sum(bonus) AS d + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin), + Row(28500, 28500, 30700) :: Row(33000, 33000, 35500) :: Row(18000, 18000, 19200) :: Nil + ) - checkAnswer( - sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + - s"FROM $testTable WHERE name = 'amy'"), - Row(Row(2022), 2019)) + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT dept * 2.0 AS d, d, 10000 AS d, d + 1 FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, dept * 2.0 AS d, d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT avg(salary) AS d, d * 1.0, avg(bonus * 1.5) AS d, d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT dept AS d, d + 1 AS d, d + 1 AS d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) checkAnswer( - sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + - s"FROM $testTable WHERE name = 'amy'"), - Row(Row("someone"), "amy")) + sql(s""" + |SELECT avg(salary * 1.5) AS salary, sum(salary), dept AS salary, avg(salary) + |FROM $testTable + |GROUP BY dept + |HAVING dept = 6 + |""".stripMargin), + Row(18000, 12000, 6, 12000) + ) } testOnAndOff("Lateral alias conflicts with OuterReference - Project") { @@ -278,46 +507,87 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } // TODO: more tests on LCA in subquery - test("Lateral alias of a complex type - Project") { - checkAnswer( - sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), - Row(Row(1), 2, 3)) + test("Lateral alias conflicts with OuterReference - Aggregate") { + // test if lca rule strips the OuterReference and resolves to lateral alias + val query = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT avg(salary * 1.0) AS id, id + 1 AS id2 FROM $testTable GROUP BY dept)) > 5 + |""".stripMargin + // TODO: It no longer returns the following failure: + // [UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY] + // Unsupported subquery expression: A GROUP BY clause in a scalar correlated subquery cannot + // contain non-correlated columns + val analyzedPlan = sql(query).queryExecution.analyzed + assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) + } - checkAnswer( - sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), - Row(Row(Row(1)), 2) - ) + test("Lateral alias of a complex type") { + // test both Project and Aggregate + val querySuffixes = Seq("", s"FROM $testTable GROUP BY dept HAVING dept = 6") + querySuffixes.foreach { querySuffix => + checkAnswer( + sql(s"SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1 $querySuffix"), + Row(Row(1), 2, 3)) + checkAnswer( + sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar " + + s"$querySuffix"), + Row(Row(Row(1)), 2)) + checkAnswer( + sql(s"SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1 $querySuffix"), + Row(Seq(1, 2, 3), 2, 3)) checkAnswer( - sql("SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1"), - Row(Seq(1, 2, 3), 2, 3) - ) - checkAnswer( - sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar"), - Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101) - ) + sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar " + + s"$querySuffix"), + Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101)) checkAnswer( - sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar"), - Row(Seq(Row(1), Row(2)), 2) - ) + sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar" + + s" $querySuffix"), + Row(Seq(Row(1), Row(2)), 2)) + + checkAnswer( + sql(s"SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1 $querySuffix"), + Row(Map("a" -> 1, "b" -> 2), 2, 3)) + } checkAnswer( - sql("SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1"), - Row(Map("a" -> 1, "b" -> 2), 2, 3) - ) - } + sql("SELECT named_struct('s', salary * 1.0) AS foo, foo.s + 1 AS bar, bar + 1 " + + s"FROM $testTable WHERE dept = 1 ORDER BY name"), + Row(Row(10000), 10001, 10002) :: Row(Row(9000), 9001, 9002) :: Nil) - test("Lateral alias reference attribute further be used by upper plan - Project") { - // this is out of the scope of lateral alias project functionality requirements, but naturally - // supported by the current design checkAnswer( - sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + - s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), - Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil + sql(s"SELECT properties AS foo, foo.joinYear AS bar, bar + 1 " + + s"FROM $testTable GROUP BY properties HAVING properties.mostRecentEmployer = 'B'"), + Row(Row(2020, "B"), 2020, 2021)) + // TODO fix this case without clearing out the metadata + // After applying rule org.apache.spark.sql.catalyst.optimizer.CollapseProject in batch + // Operator Optimization before Inferring Filters, the structural integrity of the plan + // is broken. + // It is because one output with the same exprId has auto generated alias as metadata, but + // others not. + checkAnswer( + sql(s"SELECT named_struct('avg_salary', avg(salary)) AS foo, foo.avg_salary + 1 AS bar " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(Row(9500), 9501) :: Row(Row(11000), 11001) :: Row(Row(12000), 12001) :: Nil ) } - test("Lateral alias chaining - Project") { +// test("Lateral alias reference attribute further be used by upper plan - Project") { +// // this is out of the scope of lateral alias project functionality requirements, but naturally +// // supported by the current design +// checkAnswer( +// sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + +// s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), +// Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil +// ) +// } + + test("Lateral alias chaining") { + // Project checkAnswer( sql( s""" @@ -333,127 +603,8 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { sql("SELECT 1 AS a, a + 1 AS b, b - 1, b + 1 AS c, c + 1 AS d, d - a AS e, e + 1"), Row(1, 2, 1, 3, 4, 3, 4) ) - } - - test("Conflict names with CTE - Project") { - checkAnswer( - sql( - s""" - |WITH temp_table(x, y) - |AS (SELECT 1, 2) - |SELECT 100 AS x, x + 1 - |FROM temp_table - |""".stripMargin - ), - Row(100, 2) - ) - } - - test("temp test") { - sql(s"SELECT count(name) AS b, b FROM $testTable GROUP BY dept") - sql(s"SELECT dept AS a, count(name) AS b, a, b FROM $testTable GROUP BY dept") - sql(s"SELECT avg(salary) AS a, count(name) AS b, a, b, a + b FROM $testTable GROUP BY dept") - sql(s"SELECT dept, count(name) AS b, dept + b FROM $testTable GROUP BY dept") - sql(s"SELECT count(bonus), count(salary * 1.5 + 10000 + bonus * 1.0) AS a, a " + - s"FROM $testTable GROUP BY dept") - } - - test("Basic lateral alias in Aggregate") { - // doesn't support lca used in aggregation functions - checkError( - exception = intercept[AnalysisException] { - sql(s"SELECT 10000 AS lca, count(lca) FROM $testTable GROUP BY dept") - }, - errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", - sqlState = "0A000", - parameters = Map( - "lca" -> "`lca`", - "aggFunc" -> "\"count(lateralAliasReference(lca))\"" - ) - ) - checkError( - exception = intercept[AnalysisException] { - sql(s"SELECT dept AS lca, avg(lca) FROM $testTable GROUP BY dept") - }, - errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", - sqlState = "0A000", - parameters = Map( - "lca" -> "`lca`", - "aggFunc" -> "\"avg(lateralAliasReference(lca))\"" - ) - ) - - // literal as LCA, used in various cases of expressions - checkAnswer( - sql( - s""" - |SELECT - | 10000 AS baseline_salary, - | baseline_salary * 1.5, - | baseline_salary + dept * 10000, - | baseline_salary + avg(bonus) - |FROM $testTable - |GROUP BY dept - |ORDER BY dept - |""".stripMargin - ), - Row(10000, 15000.0, 20000, 11100.0) :: - Row(10000, 15000.0, 30000, 11250.0) :: - Row(10000, 15000.0, 70000, 11200.0) :: Nil - ) - - // grouping attribute as LCA, used in various cases of expressions - checkAnswer( - sql( - s""" - |SELECT - | salary + 1000 AS new_salary, - | new_salary - 1000 AS prev_salary, - | new_salary - salary, - | new_salary - avg(salary) - |FROM $testTable - |GROUP BY salary - |ORDER BY salary - |""".stripMargin), - Row(10000, 9000, 1000, 1000.0) :: - Row(11000, 10000, 1000, 1000.0) :: - Row(13000, 12000, 1000, 1000.0) :: - Nil - ) - - // aggregate expression as LCA, used in various cases of expressions - checkAnswer( - sql( - s""" - |SELECT - | sum(salary) AS dept_salary_sum, - | sum(bonus) AS dept_bonus_sum, - | dept_salary_sum * 1.5, - | concat(string(dept_salary_sum), ': dept', string(dept)), - | dept_salary_sum + sum(bonus), - | dept_salary_sum + dept_bonus_sum - |FROM $testTable - |GROUP BY dept - |ORDER BY dept - |""".stripMargin - ), - Row(19000, 2200, 28500.0, "19000: dept1", 21200, 21200) :: - Row(22000, 2500, 33000.0, "22000: dept2", 24500, 24500) :: - Row(12000, 1200, 18000.0, "12000: dept6", 13200, 13200) :: - Nil - ) - checkAnswer( - sql(s"SELECT sum(salary) AS s, s + sum(bonus) AS total FROM $testTable"), - Row(53000, 58900) - ) - // Doesn't support nested aggregate expressions - // TODO: add error class and use CheckError - intercept[AnalysisException] { - sql(s"SELECT sum(salary) AS a, avg(a) FROM $testTable") - } - - // chaining + // Aggregate checkAnswer( sql( s""" @@ -471,46 +622,20 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { Row(2, 22000, 24500, 36750.0, 14750.0) :: Row(6, 12000, 13200, 19800.0, 7800.0) :: Nil ) - - // conflict names with table columns - checkAnswer( - sql( - s""" - |SELECT - | sum(salary) AS salary, - | sum(bonus) AS bonus, - | avg(salary) AS avg_s, - | avg(salary + bonus) AS avg_t, - | avg_s + avg_t - |FROM $testTable - |GROUP BY dept - |ORDER BY dept - |""".stripMargin), - Row(19000, 2200, 9500.0, 10600.0, 20100.0) :: - Row(22000, 2500, 11000.0, 12250.0, 23250.0) :: - Row(12000, 1200, 12000.0, 13200.0, 25200.0) :: - Nil) } - test("non-deterministic expression as LCA is evaluated only once - Project") { - sql(s"SELECT dept, rand(0) AS r, r FROM $testTable").collect().toSeq.foreach { row => - assert(QueryTest.compare(row(1), row(2))) - } - sql(s"SELECT dept + rand(0) AS r, r FROM $testTable").collect().toSeq.foreach { row => - assert(QueryTest.compare(row(0), row(1))) - } - } - - test("non-deterministic expression as LCA is evaluated only once - Aggregate") { - val groupBySnippet = s"FROM $testTable GROUP BY dept" - sql(s"SELECT dept, rand(0) AS r, r $groupBySnippet").collect().toSeq.foreach { row => - assert(QueryTest.compare(row(1), row(2))) - } - sql(s"SELECT dept + rand(0) AS r, r $groupBySnippet").collect().toSeq.foreach { row => - assert(QueryTest.compare(row(0), row(1))) + test("non-deterministic expression as LCA is evaluated only once") { + val querySuffixes = Seq(s"FROM $testTable", s"FROM $testTable GROUP BY dept") + querySuffixes.foreach { querySuffix => + sql(s"SELECT dept, rand(0) AS r, r $querySuffix").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(1), row(2))) + } + sql(s"SELECT dept + rand(0) AS r, r $querySuffix").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(0), row(1))) + } } - sql(s"SELECT avg(salary) + rand(0) AS r, r $groupBySnippet").collect().toSeq.foreach { row => - assert(QueryTest.compare(row(0), row(1))) + sql(s"SELECT avg(salary) + rand(0) AS r, r ${querySuffixes(1)}").collect().toSeq.foreach { + row => assert(QueryTest.compare(row(0), row(1))) } } @@ -531,60 +656,6 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - test("Duplicated lateral alias names - Aggregate") { - // Has duplicated names but not referenced is fine - checkAnswer( - sql(s"SELECT dept AS d, name AS d FROM $testTable GROUP BY dept, name ORDER BY dept, name"), - Row(1, "amy") :: Row(1, "cathy") :: Row(2, "alex") :: Row(2, "david") :: Row(6, "jen") :: Nil - ) - checkAnswer( - sql(s"SELECT dept AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), - Row(1, 1, 10) :: Row(2, 2, 10) :: Row(6, 6, 10) :: Nil - ) - checkAnswer( - sql(s"SELECT sum(salary * 1.5) AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), - Row(28500, 28500, 10) :: Row(33000, 33000, 10) :: Row(18000, 18000, 10) :: Nil - ) - checkAnswer( - sql( - s""" - |SELECT sum(salary * 1.5) AS d, d, d + sum(bonus) AS d - |FROM $testTable - |GROUP BY dept - |ORDER BY dept - |""".stripMargin), - Row(28500, 28500, 30700) :: Row(33000, 33000, 35500) :: Row(18000, 18000, 19200) :: Nil - ) - - // Referencing duplicated names raises error - checkDuplicatedAliasErrorHelper( - s"SELECT dept * 2.0 AS d, d, 10000 AS d, d + 1 FROM $testTable GROUP BY dept", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - checkDuplicatedAliasErrorHelper( - s"SELECT 10000 AS d, d * 1.0, dept * 2.0 AS d, d FROM $testTable GROUP BY dept", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - checkDuplicatedAliasErrorHelper( - s"SELECT avg(salary) AS d, d * 1.0, avg(bonus * 1.5) AS d, d FROM $testTable GROUP BY dept", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - checkDuplicatedAliasErrorHelper( - s"SELECT dept AS d, d + 1 AS d, d + 1 AS d FROM $testTable GROUP BY dept", - parameters = Map("name" -> "`d`", "n" -> "2") - ) - - checkAnswer( - sql(s""" - |SELECT avg(salary * 1.5) AS salary, sum(salary), dept AS salary, avg(salary) - |FROM $testTable - |GROUP BY dept - |HAVING dept = 6 - |""".stripMargin), - Row(18000, 12000, 6, 12000) - ) - } - test("Attribute cannot be resolved by LCA remain unresolved") { assert(intercept[AnalysisException] { sql(s"SELECT dept AS d, d AS new_dept, new_dep + 1 AS newer_dept FROM $testTable") @@ -596,4 +667,27 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { // TODO: subquery } + + test("Pushed-down aggregateExpressions should have no duplicates") { + val query = s""" + |SELECT dept, avg(salary) AS a, a + avg(bonus), dept + 1, + | concat(string(dept), string(avg(bonus))), avg(salary) + |FROM $testTable + |GROUP BY dept + |HAVING dept = 2 + |""".stripMargin + val analyzedPlan = sql(query).queryExecution.analyzed + analyzedPlan.collect { + case Aggregate(_, aggregateExpressions, _) => + val extracted = aggregateExpressions.collect { + case Alias(child, _) => child + case a: Attribute => a + } + val expressionSet = ExpressionSet(extracted) + assert( + extracted.size == expressionSet.size, + "The pushed-down aggregateExpressions in Aggregate should have no duplicates " + + s"after extracted from Alias. Current aggregateExpressions: $aggregateExpressions") + } + } } From 3698cfffa865f39f1262be7b407dbae4a9eeee57 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 12 Dec 2022 16:17:31 -0800 Subject: [PATCH 22/31] update comments --- .../ResolveLateralColumnAliasReference.scala | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 4ee7749a3b236..98ee92b147a66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -32,22 +32,26 @@ import org.apache.spark.sql.internal.SQLConf * Plan-wise, it handles two types of operators: Project and Aggregate. * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve * the attributes referencing these aliases - * - in Aggregate TODO inserting the Project node above and fall back to the resolution of Project. + * - in Aggregate, inserting the Project node above and falling back to the resolution of Project. * * The whole process is generally divided into two phases: * 1) recognize resolved lateral alias, wrap the attributes referencing them with * [[LateralColumnAliasReference]] - * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. - * For Project, it further resolves the attributes and push down the referenced lateral aliases. - * For Aggregate, TODO + * 2) when the whole operator is resolved, + * For Project, it unwrap [[LateralColumnAliasReference]], further resolves the attributes and + * push down the referenced lateral aliases. + * For Aggregate, it goes through the whole aggregation list, extracts the aggregation + * expressions and grouping expressions to keep them in this Aggregate node, and add a Project + * above with the original output. It doesn't do anything on [[LateralColumnAliasReference]], but + * completely leave it to the Project in the future turns of this rule. * - * Example for Project: + * ** Example for Project: * Before rewrite: * Project [age AS a, 'a + 1] * +- Child * * After phase 1: - * Project [age AS a, lateralalias(a) + 1] + * Project [age AS a, lca(a) + 1] * +- Child * * After phase 2: @@ -55,28 +59,28 @@ import org.apache.spark.sql.internal.SQLConf * +- Project [child output, age AS a] * +- Child * - * Example for Aggregate TODO - * - * For Aggregate, it first wraps the attribute resolved by lateral alias with - * [[LateralColumnAliasReference]]. - * Before wrap (omit some cast or alias): + * ** Example for Aggregate: + * Before rewrite: * Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)] * +- Child [dept#14,name#15,salary#16,bonus#17] * - * After wrap: + * After phase 1: * Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)] * +- Child [dept#14,name#15,salary#16,bonus#17] * - * When the whole Aggregate is resolved, it inserts a [[Project]] above with the aggregation - * expression list, but extracts the [[AggregateExpression]] and grouping expressions in the - * list to the current Aggregate. It restores all the [[LateralColumnAliasReference]] back to - * [[UnresolvedAttribute]]. The problem falls back to the lateral alias resolution in Project. - * - * After restore: - * Project [dept#14 AS a#12, 'a + 1, avg(salary)#26 AS b#13, 'b + avg(bonus)#27] + * After phase 2: + * Project [dept#14 AS a#12, lca(a) + 1, avg(salary)#26 AS b#13, lca(b) + avg(bonus)#27] * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,dept#14] * +- Child [dept#14,name#15,salary#16,bonus#17] * + * Now the problem falls back to the lateral alias resolution in Project. + * After future rounds of this rule: + * Project [a#12, a#12 + 1, b#13, b#13 + avg(bonus)#27] + * +- Project [dept#14 AS a#12, avg(salary)#26 AS b#13] + * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27, + * dept#14] + * +- Child [dept#14,name#15,salary#16,bonus#17] + * * * The name resolution priority: * local table column > local lateral column alias > outer reference @@ -172,8 +176,8 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { ne.toAttribute case e if groupingExpressions.exists(_.semanticEquals(e)) => // TODO one concern here, is condition here be able to match all grouping - // expressions? For example, Agg [age + 10] [a + age + 10], when transforming down, - // is it possible that (a + age) + 10, so that it won't be able to match (age + 10) + // expressions? For example, Agg [age + 10] [1 + age + 10], when transforming down, + // is it possible that (1 + age) + 10, so that it won't be able to match (age + 10) // add a test. val ne = expressionMap.getOrElseUpdate( e.canonicalized, From e700d6a76079238e80f1bec23e61d4d8a740df00 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 12 Dec 2022 17:02:54 -0800 Subject: [PATCH 23/31] add a corner case comment --- .../catalyst/analysis/ResolveLateralColumnAliasReference.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 98ee92b147a66..94c822a14bec1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -195,6 +195,9 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { ) } // TODO withOrigin? + // TODO potential risk. named_struct('a', avg(bonus)) AS foo, foo.a AS bar.. + // foo.a is resolved to LCA(get struct field .. ). But later avg(bonus) is pushed down. + // it just resolves to the ne in the struct. is it still valid? } } } From 8d20986ee90145b1f3abeafa48c22d463d5a0c99 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 12 Dec 2022 20:03:51 -0800 Subject: [PATCH 24/31] address comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 15 ++++++++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++++++---- .../ResolveLateralColumnAliasReference.scala | 8 ++++++++ .../catalyst/expressions/namedExpressions.scala | 12 +----------- .../spark/sql/catalyst/expressions/subquery.scala | 7 ++++++- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a56a3d9cb6bad..e28a2f5dfda9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1808,7 +1808,7 @@ class Analyzer(override val catalogManager: CatalogManager) * the current plan is needed to first try resolving the attribute by its * children */ - private def wrapLCARefHelper( + private def wrapLCARef( e: NamedExpression, currentPlan: LogicalPlan, aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { @@ -1827,9 +1827,14 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => u } case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + if aliasMap.contains( + o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .map(_.head) + .getOrElse(o.name)) => // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val nameParts = o + .getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .getOrElse(Seq(o.name)) val aliases = aliasMap.get(nameParts.head).get aliases.size match { case n if n > 1 => @@ -1853,7 +1858,7 @@ class Analyzer(override val catalogManager: CatalogManager) var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => - val lcaWrapped = wrapLCARefHelper(a, p, aliasMap).asInstanceOf[Alias] + val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias] // Insert the LCA-resolved alias instead of the unresolved one into map. If it is // resolved, it can be referenced as LCA by later expressions (chaining). // Unresolved Alias is also added to the map to perform ambiguous name check, but @@ -1861,7 +1866,7 @@ class Analyzer(override val catalogManager: CatalogManager) aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) lcaWrapped case (e, _) => - wrapLCARefHelper(e, p, aliasMap) + wrapLCARef(e, p, aliasMap) } p.copy(projectList = newProjectList) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9937a06de9a98..ff8450d524c47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -643,8 +643,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB projectList.foreach(_.transformDownWithPruning( _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { case lcaRef: LateralColumnAliasReference if p.resolved => - failUnresolvedAttribute( - p, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + throw SparkException.internalError("Resolved Project should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $p", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) }) case j: Join if !j.duplicateResolved => @@ -729,8 +731,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB aggList.foreach(_.transformDownWithPruning( _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { case lcaRef: LateralColumnAliasReference => - failUnresolvedAttribute( - agg, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + throw SparkException.internalError("Resolved Aggregate should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $agg", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) }) case _ => // Analysis successful! diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index c86d0a6dff0bb..2ca187b95ffda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE import org.apache.spark.sql.internal.SQLConf @@ -67,6 +68,13 @@ import org.apache.spark.sql.internal.SQLConf object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { case class AliasEntry(alias: Alias, index: Int) + /** + * A tag to store the nameParts from the original unresolved attribute. + * It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back + * to [[LateralColumnAliasReference]]. + */ + val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index ff65eecafc48d..0f5239be6cae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -424,18 +424,8 @@ case class OuterReference(e: NamedExpression) override def qualifier: Seq[String] = e.qualifier override def exprId: ExprId = e.exprId override def toAttribute: Attribute = e.toAttribute - override def newInstance(): NamedExpression = - OuterReference(e.newInstance()).setNameParts(nameParts) + override def newInstance(): NamedExpression = OuterReference(e.newInstance()) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) - - // optional field, the original name parts of UnresolvedAttribute before it is resolved to - // OuterReference. Used in rule ResolveLateralColumnAlias to convert OuterReference back to - // LateralColumnAliasReference. - var nameParts: Option[Seq[String]] = None - def setNameParts(newNameParts: Option[Seq[String]]): OuterReference = { - nameParts = newNameParts - this - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d249a2b5a6bb7..b510893f370e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan} @@ -159,7 +160,11 @@ object SubExprUtils extends PredicateHelper { * Wrap attributes in the expression with [[OuterReference]]s. */ def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = { - e.transform { case a: Attribute => OuterReference(a).setNameParts(nameParts) }.asInstanceOf[E] + e.transform { case a: Attribute => + val o = OuterReference(a) + nameParts.map(o.setTagValue(NAME_PARTS_FROM_UNRESOLVED_ATTR, _)) + o + }.asInstanceOf[E] } /** From ccebc1c46ac2e05074227a2a37b82ae33f1fe783 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 13 Dec 2022 10:36:13 -0800 Subject: [PATCH 25/31] revert some changes --- .../sql/catalyst/analysis/Analyzer.scala | 128 +++++++++--------- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 - .../ResolveLateralColumnAliasReference.scala | 22 +-- .../sql/catalyst/rules/RuleIdCollection.scala | 2 +- 4 files changed, 76 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 64f21a95e0780..a9edd491bed39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -469,6 +469,70 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * Replaces [[UnresolvedAlias]]s with concrete aliases. + */ + object ResolveAliases extends Rule[LogicalPlan] { + private def assignAliases(exprs: Seq[NamedExpression]) = { + def extractOnly(e: Expression): Boolean = e match { + case _: ExtractValue => e.children.forall(extractOnly) + case _: Literal => true + case _: Attribute => true + case _ => false + } + def metaForAutoGeneratedAlias = { + new MetadataBuilder() + .putString("__autoGeneratedAlias", "true") + .build() + } + exprs.map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) { + case u @ UnresolvedAlias(child, optGenAliasFunc) => + child match { + case ne: NamedExpression => ne + case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) + case e if !e.resolved => u + case g: Generator => MultiAlias(g, Nil) + case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() + case e: ExtractValue => + if (extractOnly(e)) { + Alias(e, toPrettySQL(e))() + } else { + Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) + } + case e if optGenAliasFunc.isDefined => + Alias(child, optGenAliasFunc.get.apply(e))() + case l: Literal => Alias(l, toPrettySQL(l))() + case e => + Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) + } + } + ).asInstanceOf[Seq[NamedExpression]] + } + + private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = + exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(UNRESOLVED_ALIAS), ruleId) { + case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => + Aggregate(groups, assignAliases(aggs), child) + + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) + + case up: Unpivot if up.child.resolved && + (up.ids.exists(hasUnresolvedAlias) || up.values.exists(_.exists(hasUnresolvedAlias))) => + up.copy(ids = up.ids.map(assignAliases), values = up.values.map(_.map(assignAliases))) + + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => + Project(assignAliases(projectList), child) + + case c: CollectMetrics if c.child.resolved && hasUnresolvedAlias(c.metrics) => + c.copy(metrics = assignAliases(c.metrics)) + } + } + object ResolveGroupingAnalytics extends Rule[LogicalPlan] { private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.exists (g => g.isInstanceOf[Grouping] || g.isInstanceOf[GroupingID]) @@ -4209,67 +4273,3 @@ object RemoveTempResolvedColumn extends Rule[LogicalPlan] { } } -/** - * Replaces [[UnresolvedAlias]]s with concrete aliases. - */ -object ResolveAliases extends Rule[LogicalPlan] { - def metaForAutoGeneratedAlias: Metadata = { - new MetadataBuilder() - .putString("__autoGeneratedAlias", "true") - .build() - } - - def assignAliases(exprs: Seq[NamedExpression]): Seq[NamedExpression] = { - def extractOnly(e: Expression): Boolean = e match { - case _: ExtractValue => e.children.forall(extractOnly) - case _: Literal => true - case _: Attribute => true - case _ => false - } - exprs.map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) { - case u @ UnresolvedAlias(child, optGenAliasFunc) => - child match { - case ne: NamedExpression => ne - case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) - case e if !e.resolved => u - case g: Generator => MultiAlias(g, Nil) - case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() - case e: ExtractValue => - if (extractOnly(e)) { - Alias(e, toPrettySQL(e))() - } else { - Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) - } - case e if optGenAliasFunc.isDefined => - Alias(child, optGenAliasFunc.get.apply(e))() - case l: Literal => Alias(l, toPrettySQL(l))() - case e => - Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) - } - } - ).asInstanceOf[Seq[NamedExpression]] - } - - private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = - exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) - - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsPattern(UNRESOLVED_ALIAS), ruleId) { - case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => - Aggregate(groups, assignAliases(aggs), child) - - case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) - if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => - Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) - - case up: Unpivot if up.child.resolved && - (up.ids.exists(hasUnresolvedAlias) || up.values.exists(_.exists(hasUnresolvedAlias))) => - up.copy(ids = up.ids.map(assignAliases), values = up.values.map(_.map(assignAliases))) - - case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => - Project(assignAliases(projectList), child) - - case c: CollectMetrics if c.child.resolved && hasUnresolvedAlias(c.metrics) => - c.copy(metrics = assignAliases(c.metrics)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b09a1da0b7396..e7e153a319d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -446,7 +446,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB groupingExprs.foreach(checkValidGroupingExprs) aggregateExprs.foreach(checkValidAggregateExpression) - // TODO: if the Aggregate is resolved, it can't contain the LateralColumnAliasReference case CollectMetrics(name, metrics, _) => if (name == null || name.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 68be0317e0d26..26044722487af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -101,6 +102,14 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { */ val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") + private def assignAlias(expr: Expression): NamedExpression = { + expr match { + case ne: NamedExpression => ne + case e => + Alias(e, toPrettySQL(e))() + } + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan @@ -172,14 +181,7 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { lcaRef.nameParts, aggExpr) } } - val ne = expressionMap.getOrElseUpdate( - aggExpr.canonicalized, - ResolveAliases.assignAliases(Seq(UnresolvedAlias(aggExpr))).map { - // TODO temporarily clear the metadata for an issue found in test - case a: Alias => a.copy(a.child, a.name)( - a.exprId, a.qualifier, None, a.nonInheritableMetadataKeys) - case other => other - }.head) + val ne = expressionMap.getOrElseUpdate(aggExpr.canonicalized, assignAlias(aggExpr)) newAggExprs += ne ne.toAttribute case e if groupingExpressions.exists(_.semanticEquals(e)) => @@ -187,9 +189,7 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { // expressions? For example, Agg [age + 10] [1 + age + 10], when transforming down, // is it possible that (1 + age) + 10, so that it won't be able to match (age + 10) // add a test. - val ne = expressionMap.getOrElseUpdate( - e.canonicalized, - ResolveAliases.assignAliases(Seq(UnresolvedAlias(e))).head) + val ne = expressionMap.getOrElseUpdate(e.canonicalized, assignAlias(e)) newAggExprs += ne ne.toAttribute }.asInstanceOf[NamedExpression] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index d131859e25f29..efafd3cfbcde8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -49,6 +49,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$GlobalAggregates" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggAliasInGroupBy" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveBinaryArithmetic" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" :: @@ -82,7 +83,6 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: "org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases" :: "org.apache.spark.sql.catalyst.analysis.EliminateUnions" :: - "org.apache.spark.sql.catalyst.analysis.ResolveAliases" :: "org.apache.spark.sql.catalyst.analysis.ResolveDefaultColumns" :: "org.apache.spark.sql.catalyst.analysis.ResolveExpressionsWithNamePlaceholders" :: "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveCoalesceHints" :: From 5540b70c5216e2f28cdc007530f892c66f650bff Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 13 Dec 2022 11:07:13 -0800 Subject: [PATCH 26/31] fix few todos --- .../ResolveLateralColumnAliasReference.scala | 6 +-- .../spark/sql/LateralColumnAliasSuite.scala | 41 +++++++++---------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 26044722487af..cf582780f35ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -105,8 +105,7 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { private def assignAlias(expr: Expression): NamedExpression = { expr match { case ne: NamedExpression => ne - case e => - Alias(e, toPrettySQL(e))() + case e => Alias(e, toPrettySQL(e))() } } @@ -203,9 +202,6 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { ) } // TODO withOrigin? - // TODO potential risk. named_struct('a', avg(bonus)) AS foo, foo.a AS bar.. - // foo.a is resolved to LCA(get struct field .. ). But later avg(bonus) is pushed down. - // it just resolves to the ne in the struct. is it still valid? } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index bba351d0aa31c..400e5db422577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -497,7 +497,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 |ORDER BY id |""".stripMargin - withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. + withLCAOff { intercept[AnalysisException] { sql(query4) } } withLCAOn { val analyzedPlan = sql(query4).queryExecution.analyzed assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) @@ -517,10 +517,6 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { | SELECT id2 | FROM (SELECT avg(salary * 1.0) AS id, id + 1 AS id2 FROM $testTable GROUP BY dept)) > 5 |""".stripMargin - // TODO: It no longer returns the following failure: - // [UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY] - // Unsupported subquery expression: A GROUP BY clause in a scalar correlated subquery cannot - // contain non-correlated columns val analyzedPlan = sql(query).queryExecution.analyzed assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) } @@ -563,12 +559,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { sql(s"SELECT properties AS foo, foo.joinYear AS bar, bar + 1 " + s"FROM $testTable GROUP BY properties HAVING properties.mostRecentEmployer = 'B'"), Row(Row(2020, "B"), 2020, 2021)) - // TODO fix this case without clearing out the metadata - // After applying rule org.apache.spark.sql.catalyst.optimizer.CollapseProject in batch - // Operator Optimization before Inferring Filters, the structural integrity of the plan - // is broken. - // It is because one output with the same exprId has auto generated alias as metadata, but - // others not. + checkAnswer( sql(s"SELECT named_struct('avg_salary', avg(salary)) AS foo, foo.avg_salary + 1 AS bar " + s"FROM $testTable GROUP BY dept ORDER BY dept"), @@ -576,15 +567,20 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { ) } -// test("Lateral alias reference attribute further be used by upper plan - Project") { -// // this is out of the scope of lateral alias project functionality requirements, but naturally -// // supported by the current design -// checkAnswer( -// sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + -// s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), -// Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil -// ) -// } + test("Lateral alias reference attribute further be used by upper plan") { + // underlying this is not in the scope of lateral alias project but things already supported + checkAnswer( + sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + + s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), + Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil + ) + + checkAnswer( + sql(s"SELECT avg(bonus) AS avg_bonus, avg_bonus * 1.0 AS new_avg_bonus, avg(salary) " + + s"FROM $testTable GROUP BY dept ORDER BY new_avg_bonus"), + Row(1100, 1100, 9500.0) :: Row(1200, 1200, 12000) :: Row(1250, 1250, 11000) :: Nil + ) + } test("Lateral alias chaining") { // Project @@ -665,7 +661,10 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { sql(s"SELECT count(name) AS cnt, cnt + 1, count(unresovled) FROM $testTable GROUP BY dept") }.getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - // TODO: subquery + assert(intercept[AnalysisException] { + sql(s"SELECT * FROM range(1, 7) WHERE (" + + s"SELECT id2 FROM (SELECT 1 AS id, other_id + 1 AS id2)) > 5") + }.getErrorClass == "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") } test("Pushed-down aggregateExpressions should have no duplicates") { From 136a9308e623fee5b1303103e6397f96d8bb6788 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Fri, 16 Dec 2022 14:00:48 -0800 Subject: [PATCH 27/31] fix the failing test --- .../sql/catalyst/analysis/Analyzer.scala | 308 ++++++++++-------- .../analysis/ResolveLateralColumnAlias.scala | 9 +- 2 files changed, 170 insertions(+), 147 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 12efed3b46137..219e0061446d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -182,6 +183,157 @@ object AnalysisContext { } } +object Analyzer extends Logging { + // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id + private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( + (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), + (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), + (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), + ("user", () => CurrentUser(), toPrettySQL), + (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) + ) + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { + if (nameParts.length != 1) return None + val name = nameParts.head + literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { + case (_, getFuncExpr, getAliasName) => + val funcExpr = getFuncExpr() + Alias(funcExpr, getAliasName(funcExpr))() + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. + */ + private def resolveExpression( + expr: Expression, + resolveColumnByName: Seq[String] => Option[Expression], + getAttrCandidates: () => Seq[Attribute], + resolver: Resolver, + throws: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + + case GetColumnByOrdinal(ordinal, _) => + val attrCandidates = getAttrCandidates() + assert(ordinal >= 0 && ordinal < attrCandidates.length) + attrCandidates(ordinal) + + case GetViewColumnByNameAndOrdinal( + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + val attrCandidates = getAttrCandidates() + val matched = attrCandidates.filter(a => resolver(a.name, colName)) + if (matched.length != expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChange( + viewName, colName, expectedNumCandidates, matched, viewDDL) + } + matched(ordinal) + + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + withOrigin(u.origin) { + ExtractValue(newChild, fieldName, resolver) + } + } else { + u.copy(child = newChild) + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + } + + try { + innerResolve(expr, isTopLevel = true) + } catch { + case ae: AnalysisException if !throws => + logDebug(ae.getMessage) + expr + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + resolver: Resolver, + throws: Boolean = false): Expression = { + resolveExpression( + expr, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, resolver) + }, + getAttrCandidates = () => plan.output, + resolver = resolver, + throws = throws) + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan, + resolver: Resolver): Expression = { + resolveExpression( + e, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + resolver = resolver, + throws = true) + } +} + /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -230,6 +382,18 @@ class Analyzer(override val catalogManager: CatalogManager) def resolver: Resolver = conf.resolver + private def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false): Expression = { + Analyzer.resolveExpressionByPlanOutput(expr, plan, resolver, throws) + } + private def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan): Expression = { + Analyzer.resolveExpressionByPlanChildren(e, q, resolver) + } + /** * If the plan cannot be resolved within maxIterations, analyzer will throw exception to inform * user to increase the value of SQLConf.ANALYZER_MAX_ITERATIONS. @@ -1767,150 +1931,6 @@ class Analyzer(override val catalogManager: CatalogManager) exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } - // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id - private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( - (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), - (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), - (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), - ("user", () => CurrentUser(), toPrettySQL), - (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) - ) - - /** - * Literal functions do not require the user to specify braces when calling them - * When an attributes is not resolvable, we try to resolve it as a literal function. - */ - private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { - if (nameParts.length != 1) return None - val name = nameParts.head - literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { - case (_, getFuncExpr, getAliasName) => - val funcExpr = getFuncExpr() - Alias(funcExpr, getAliasName(funcExpr))() - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by - * traversing the input expression in top-down manner. It must be top-down because we need to - * skip over unbound lambda function expression. The lambda expressions are resolved in a - * different place [[ResolveLambdaVariables]]. - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. - */ - private def resolveExpression( - expr: Expression, - resolveColumnByName: Seq[String] => Option[Expression], - getAttrCandidates: () => Seq[Attribute], - throws: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { - if (e.resolved) return e - e match { - case f: LambdaFunction if !f.bound => f - - case GetColumnByOrdinal(ordinal, _) => - val attrCandidates = getAttrCandidates() - assert(ordinal >= 0 && ordinal < attrCandidates.length) - attrCandidates(ordinal) - - case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => - val attrCandidates = getAttrCandidates() - val matched = attrCandidates.filter(a => resolver(a.name, colName)) - if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChange( - viewName, colName, expectedNumCandidates, matched, viewDDL) - } - matched(ordinal) - - case u @ UnresolvedAttribute(nameParts) => - val result = withPosition(u) { - resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { - // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, - // as we should resolve `UnresolvedAttribute` to a named expression. The caller side - // can trim the top-level alias if it's safe to do so. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. - case Alias(child, _) if !isTopLevel => child - case other => other - }.getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - - case u @ UnresolvedExtractValue(child, fieldName) => - val newChild = innerResolve(child, isTopLevel = false) - if (newChild.resolved) { - withOrigin(u.origin) { - ExtractValue(newChild, fieldName, resolver) - } - } else { - u.copy(child = newChild) - } - - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - } - - try { - innerResolve(expr, isTopLevel = true) - } catch { - case ae: AnalysisException if !throws => - logDebug(ae.getMessage) - expr - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's output attributes. In order to resolve the nested fields correctly, this function - * makes use of `throws` parameter to control when to raise an AnalysisException. - * - * Example : - * SELECT * FROM t ORDER BY a.b - * - * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` - * if there is no such nested field named "b". We should not fail and wait for other rules to - * resolve it if possible. - */ - def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - throws: Boolean = false): Expression = { - resolveExpression( - expr, - resolveColumnByName = nameParts => { - plan.resolve(nameParts, resolver) - }, - getAttrCandidates = () => plan.output, - throws = throws) - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's children output attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @return resolved Expression. - */ - def resolveExpressionByPlanChildren( - e: Expression, - q: LogicalPlan): Expression = { - resolveExpression( - e, - resolveColumnByName = nameParts => { - q.resolveChildren(nameParts, resolver) - }, - getAttrCandidates = () => { - assert(q.children.length == 1) - q.children.head.output - }, - throws = true) - } - /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 56f0444685e84..89ea673d1ad49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -95,6 +95,8 @@ import org.apache.spark.sql.internal.SQLConf object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { import ResolveLateralColumnAliasReference.AliasEntry + def resolver: Resolver = conf.resolver + private def insertIntoAliasMap( a: Alias, idx: Int, @@ -112,9 +114,10 @@ object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { */ private def resolveByLateralAlias( nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { - val resolvedAttr = SimpleAnalyzer.resolveExpressionByPlanOutput( + val resolvedAttr = Analyzer.resolveExpressionByPlanOutput( expr = UnresolvedAttribute(nameParts), plan = LocalRelation(Seq(lateralAlias.toAttribute)), + resolver = resolver, throws = false ).asInstanceOf[NamedExpression] if (resolvedAttr.resolved) { @@ -139,8 +142,8 @@ object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - SimpleAnalyzer.resolveExpressionByPlanChildren( - u, currentPlan).isInstanceOf[UnresolvedAttribute] => + Analyzer.resolveExpressionByPlanChildren( + u, currentPlan, resolver).isInstanceOf[UnresolvedAttribute] => val aliases = aliasMap.get(u.nameParts.head).get aliases.size match { case n if n > 1 => From 5076ad2316c27c6bb3982355d5dba5b05451756f Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 19 Dec 2022 15:21:58 -0800 Subject: [PATCH 28/31] fix the missing_aggregate issue, turn on conf to see failed tests --- .../analysis/ResolveLateralColumnAlias.scala | 28 ++++++++- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/LateralColumnAliasSuite.scala | 58 +++++++++++++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 89ea673d1ad49..d0984150b4fc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, NamedExpression, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression, LateralColumnAliasReference, NamedExpression, OuterReference, ScalarSubquery} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -336,6 +336,10 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { if (newAggExprs.isEmpty) { agg } else { + // perform an early check on current Aggregate before any lift-up / push-down to throw + // the same exception such as non-aggregate expressions not in group by, which becomes + // missing input after transformation + earlyCheckAggregate(agg) Project( projectList = projectExprs, child = agg.copy(aggregateExpressions = newAggExprs.toSeq) @@ -345,4 +349,26 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { } } } + + private def earlyCheckAggregate(plan: Aggregate): Unit = { + val Aggregate(groupingExprs, aggregateExprs, _) = plan + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case expr: Expression if AggregateExpression.isAggregate(expr) => + // doesn't perform any check on aggregation functions + case _: Attribute if groupingExprs.isEmpty => + plan.failAnalysis( + errorClass = "MISSING_GROUP_BY", + messageParameters = Map.empty) + case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => + throw QueryCompilationErrors.columnNotInGroupByClauseError(e) + case s: ScalarSubquery + if s.children.nonEmpty && !groupingExprs.exists(_.semanticEquals(s)) => + s.failAnalysis( + errorClass = "_LEGACY_ERROR_TEMP_2423", + messageParameters = Map("sqlExpr" -> s.sql)) + case e if groupingExprs.exists(_.semanticEquals(e)) => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + aggregateExprs.foreach(checkValidAggregateExpression) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5565926afddfa..19c302641b16d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4045,7 +4045,7 @@ object SQLConf { "higher resolution priority than the lateral column alias.") .version("3.4.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) /** * Holds information about keys that have been deprecated. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 400e5db422577..226bb7f287314 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -327,6 +327,30 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { Row(53000, 58900) :: Nil ) + // grouping expression are correctly recognized and pushed down + checkAnswer( + sql( + s""" + |SELECT dept AS a, dept + 10 AS b, avg(salary) + dept, avg(salary) AS c, + | c + dept, avg(salary + dept), count(dept) + |FROM $testTable GROUP BY dept ORDER BY dept + |""".stripMargin), + Row(1, 11, 9501, 9500, 9501, 9501, 2) :: + Row(2, 12, 11002, 11000, 11002, 11002, 2) :: + Row(6, 16, 12006, 12000, 12006, 12006, 1) :: Nil) + + // two grouping expressions + checkAnswer( + sql( + s""" + |SELECT dept + salary, avg(salary) + dept, avg(bonus) AS c, c + salary + dept, + | avg(bonus) + salary + |FROM $testTable GROUP BY dept, salary HAVING dept = 2 ORDER BY dept, salary + |""".stripMargin + ), + Row(10002, 10002, 1300, 11302, 11300) :: Row(12002, 12002, 1200, 13202, 13200) :: Nil + ) + // LCA and conflicted table column mixed checkAnswerWhenOnAndExceptionWhenOff( s""" @@ -689,4 +713,38 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { s"after extracted from Alias. Current aggregateExpressions: $aggregateExpressions") } } + + test("Non-aggregating expressions not in group by still throws the same error") { + // query without lateral alias + assert( + intercept[AnalysisException] { + sql(s"SELECT dept AS a, salary FROM $testTable GROUP BY dept") + }.getErrorClass == "MISSING_AGGREGATION") + + assert( + intercept[AnalysisException] { + sql(s"SELECT avg(salary), avg(avg(salary)) FROM $testTable GROUP BY dept") + }.getErrorClass == "NESTED_AGGREGATE_FUNCTION") + + // query with lateral alias throws the same error + assert( + intercept[AnalysisException] { + sql(s"SELECT dept AS a, a, salary FROM $testTable GROUP BY dept") + }.getErrorClass == "MISSING_AGGREGATION") + // no longer throws NESTED_AGGREGATE_FUNCTION but UNSUPPORTED_FEATURE + assert( + intercept[AnalysisException] { + sql(s"SELECT avg(salary) AS a, avg(a) FROM $testTable GROUP BY dept") + }.getErrorClass == "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC") + + // checkAnalysis doesn't canonicalize expressions when performing check of non-aggregation + // expression in group by. With LCA, it doesn't change and throw same exception + val e1 = intercept[AnalysisException] { + sql(s"SELECT avg(salary) AS a, avg(salary) + dept + 10 FROM $testTable GROUP BY dept + 10") + } + val e2 = intercept[AnalysisException] { + sql(s"SELECT avg(salary) AS a, a + dept + 10 FROM $testTable GROUP BY dept + 10") + } + assert(e1.getErrorClass == e2.getErrorClass) + } } From 2f2dee5560d5cf16bdf20b12edd408fb57aeb0fc Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 19 Dec 2022 15:27:23 -0800 Subject: [PATCH 29/31] remove few todos --- .../sql/catalyst/analysis/ResolveLateralColumnAlias.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index d0984150b4fc6..1514dbedf5d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -324,10 +324,6 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { newAggExprs += ne ne.toAttribute case e if groupingExpressions.exists(_.semanticEquals(e)) => - // TODO one concern here, is condition here be able to match all grouping - // expressions? For example, Agg [age + 10] [1 + age + 10], when transforming down, - // is it possible that (1 + age) + 10, so that it won't be able to match (age + 10) - // add a test. val ne = expressionMap.getOrElseUpdate(e.canonicalized, assignAlias(e)) newAggExprs += ne ne.toAttribute @@ -345,7 +341,6 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { child = agg.copy(aggregateExpressions = newAggExprs.toSeq) ) } - // TODO withOrigin? } } } From 3a5509aa56218be561eb391cf116f1e6c406f560 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 19 Dec 2022 21:50:42 -0800 Subject: [PATCH 30/31] better fix to maintain aggregate error: only lift up in certain cases --- .../analysis/ResolveLateralColumnAlias.scala | 49 +++++------ .../spark/sql/LateralColumnAliasSuite.scala | 82 ++++++++++++------- 2 files changed, 73 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 1514dbedf5d1c..639eb2c634bdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression, LateralColumnAliasReference, NamedExpression, OuterReference, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, LeafExpression, Literal, NamedExpression, OuterReference, ScalarSubquery} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -307,6 +307,27 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) if agg.resolved && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + // Check if current Aggregate is eligible to lift up with Project: the aggregate + // expression only contains: 1) aggregate functions, 2) grouping expressions, 3) lateral + // column alias reference or 4) literals. + // This check is to prevent unnecessary transformation on invalid plan, to guarantee it + // throws the same exception. For example, cases like non-aggregate expressions not + // in group by, once transformed, will throw a different exception: missing input. + def eligibleToLiftUp(exp: Expression): Boolean = { + exp match { + case e: AggregateExpression if AggregateExpression.isAggregate(e) => true + case e if groupingExpressions.exists(_.semanticEquals(e)) => true + case _: Literal | _: LateralColumnAliasReference => true + case s: ScalarSubquery if s.children.nonEmpty + && !groupingExpressions.exists(_.semanticEquals(s)) => false + case _: LeafExpression => false + case e => e.children.forall(eligibleToLiftUp) + } + } + if (!aggregateExpressions.forall(eligibleToLiftUp)) { + return agg + } + val newAggExprs = collection.mutable.Set.empty[NamedExpression] val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression] val projectExprs = aggregateExpressions.map { exp => @@ -332,10 +353,6 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { if (newAggExprs.isEmpty) { agg } else { - // perform an early check on current Aggregate before any lift-up / push-down to throw - // the same exception such as non-aggregate expressions not in group by, which becomes - // missing input after transformation - earlyCheckAggregate(agg) Project( projectList = projectExprs, child = agg.copy(aggregateExpressions = newAggExprs.toSeq) @@ -344,26 +361,4 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { } } } - - private def earlyCheckAggregate(plan: Aggregate): Unit = { - val Aggregate(groupingExprs, aggregateExprs, _) = plan - def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case expr: Expression if AggregateExpression.isAggregate(expr) => - // doesn't perform any check on aggregation functions - case _: Attribute if groupingExprs.isEmpty => - plan.failAnalysis( - errorClass = "MISSING_GROUP_BY", - messageParameters = Map.empty) - case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => - throw QueryCompilationErrors.columnNotInGroupByClauseError(e) - case s: ScalarSubquery - if s.children.nonEmpty && !groupingExprs.exists(_.semanticEquals(s)) => - s.failAnalysis( - errorClass = "_LEGACY_ERROR_TEMP_2423", - messageParameters = Map("sqlExpr" -> s.sql)) - case e if groupingExprs.exists(_.semanticEquals(e)) => // OK - case e => e.children.foreach(checkValidAggregateExpression) - } - aggregateExprs.foreach(checkValidAggregateExpression) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 226bb7f287314..624d5f98642a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -714,37 +714,57 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { } } - test("Non-aggregating expressions not in group by still throws the same error") { - // query without lateral alias - assert( - intercept[AnalysisException] { - sql(s"SELECT dept AS a, salary FROM $testTable GROUP BY dept") - }.getErrorClass == "MISSING_AGGREGATION") - - assert( - intercept[AnalysisException] { - sql(s"SELECT avg(salary), avg(avg(salary)) FROM $testTable GROUP BY dept") - }.getErrorClass == "NESTED_AGGREGATE_FUNCTION") - - // query with lateral alias throws the same error - assert( - intercept[AnalysisException] { - sql(s"SELECT dept AS a, a, salary FROM $testTable GROUP BY dept") - }.getErrorClass == "MISSING_AGGREGATION") - // no longer throws NESTED_AGGREGATE_FUNCTION but UNSUPPORTED_FEATURE - assert( - intercept[AnalysisException] { - sql(s"SELECT avg(salary) AS a, avg(a) FROM $testTable GROUP BY dept") - }.getErrorClass == "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC") - - // checkAnalysis doesn't canonicalize expressions when performing check of non-aggregation - // expression in group by. With LCA, it doesn't change and throw same exception - val e1 = intercept[AnalysisException] { - sql(s"SELECT avg(salary) AS a, avg(salary) + dept + 10 FROM $testTable GROUP BY dept + 10") - } - val e2 = intercept[AnalysisException] { - sql(s"SELECT avg(salary) AS a, a + dept + 10 FROM $testTable GROUP BY dept + 10") + test("Aggregate expressions not eligible to lift up, throws same error as inline") { + def checkSameMissingAggregationError(q1: String, q2: String, expressionParam: String): Unit = { + Seq(q1, q2).foreach { query => + val e = intercept[AnalysisException] { sql(query) } + assert(e.getErrorClass == "MISSING_AGGREGATION") + assert(e.messageParameters.get("expression").exists(_ == expressionParam)) + } } - assert(e1.getErrorClass == e2.getErrorClass) + + val suffix = s"FROM $testTable GROUP BY dept" + checkSameMissingAggregationError( + s"SELECT dept AS a, dept, salary $suffix", + s"SELECT dept AS a, a, salary $suffix", + "\"salary\"") + checkSameMissingAggregationError( + s"SELECT dept AS a, dept + salary $suffix", + s"SELECT dept AS a, a + salary $suffix", + "\"salary\"") + checkSameMissingAggregationError( + s"SELECT avg(salary) AS a, avg(salary) + bonus $suffix", + s"SELECT avg(salary) AS a, a + bonus $suffix", + "\"bonus\"") + checkSameMissingAggregationError( + s"SELECT dept AS a, dept, avg(salary) + bonus + 10 $suffix", + s"SELECT dept AS a, a, avg(salary) + bonus + 10 $suffix", + "\"bonus\"") + checkSameMissingAggregationError( + s"SELECT avg(salary) AS a, avg(salary), dept FROM $testTable GROUP BY dept + 10", + s"SELECT avg(salary) AS a, a, dept FROM $testTable GROUP BY dept + 10", + "\"dept\"") + checkSameMissingAggregationError( + s"SELECT avg(salary) AS a, avg(salary) + dept + 10 FROM $testTable GROUP BY dept + 10", + s"SELECT avg(salary) AS a, a + dept + 10 FROM $testTable GROUP BY dept + 10", + "\"dept\"") + Seq( + s"SELECT dept AS a, dept, " + + s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $suffix", + s"SELECT dept AS a, a, " + + s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $suffix" + ).foreach { query => + val e = intercept[AnalysisException] { sql(query) } + assert(e.getErrorClass == "_LEGACY_ERROR_TEMP_2423") } + + // one exception: no longer throws NESTED_AGGREGATE_FUNCTION but UNSUPPORTED_FEATURE + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT avg(salary) AS a, avg(a) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map("lca" -> "`a`", "aggFunc" -> "\"avg(lateralAliasReference(a))\"") + ) } } From b200da0fd23a9e6e4f2cc6b32fa66b176efb5820 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 20 Dec 2022 10:18:26 -0800 Subject: [PATCH 31/31] typo --- .../catalyst/analysis/ResolveLateralColumnAliasReference.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index f3995675cb67f..ec8bdb97fbc67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -175,7 +175,7 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { // in group by, once transformed, will throw a different exception: missing input. def eligibleToLiftUp(exp: Expression): Boolean = { exp match { - case e: AggregateExpression if AggregateExpression.isAggregate(e) => true + case e if AggregateExpression.isAggregate(e) => true case e if groupingExpressions.exists(_.semanticEquals(e)) => true case _: Literal | _: LateralColumnAliasReference => true case s: ScalarSubquery if s.children.nonEmpty