Skip to content

Commit

Permalink
[SPARK-46625] CTE with Identifier clause as reference
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
DECLARE agg = 'max';
DECLARE col = 'c1';
DECLARE tab = 'T';

WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
      T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab);

-- OR

WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
      T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T');

Currently we don't support Identifier clause as part of CTE reference.

### Why are the changes needed?
Adding support for Identifier clause as part of CTE reference for both constant string expressions and session variables.

### Does this PR introduce _any_ user-facing change?
It contains user facing changes in sense that identifier clause as cte reference will now be supported.

### How was this patch tested?
Added tests as part of this PR.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47180 from nebojsa-db/SPARK-46625.

Authored-by: Nebojsa Savic <nebojsa.savic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
nebojsa-db authored and cloud-fan committed Jul 9, 2024
1 parent fdbacdf commit d824e9e
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case s: Sort if !s.resolved || s.missingInput.nonEmpty =>
resolveReferencesInSort(s)

case u: UnresolvedWithCTERelations =>
UnresolvedWithCTERelations(this.apply(u.unresolvedPlan), u.cteRelations)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}")
q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Command, CTEInChildren, CTERelationDef, CTERelationRef, InsertIntoDir, LogicalPlan, ParsedStatement, SubqueryAlias, UnresolvedWith, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.SQLConf.LEGACY_CTE_PRECEDENCE_POLICY
Expand Down Expand Up @@ -272,7 +272,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
alwaysInline: Boolean,
cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = {
plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION)) {
_.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION,
UNRESOLVED_IDENTIFIER)) {
case RelationTimeTravel(UnresolvedRelation(Seq(table), _, _), _, _)
if cteRelations.exists(r => plan.conf.resolver(r._1, table)) =>
throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table))
Expand All @@ -287,6 +288,14 @@ object CTESubstitution extends Rule[LogicalPlan] {
}
}.getOrElse(u)

case p: PlanWithUnresolvedIdentifier =>
// We must look up CTE relations first when resolving `UnresolvedRelation`s,
// but we can't do it here as `PlanWithUnresolvedIdentifier` is a leaf node
// and may produce `UnresolvedRelation` later.
// Here we wrap it with `UnresolvedWithCTERelations` so that we can
// delay the CTE relations lookup after `PlanWithUnresolvedIdentifier` is resolved.
UnresolvedWithCTERelations(p, cteRelations)

case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE}
import org.apache.spark.sql.types.StringType

/**
Expand All @@ -35,9 +35,18 @@ class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch]
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
_.containsAnyPattern(UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE)) {
case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved =>
executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr)))
case u @ UnresolvedWithCTERelations(p, cteRelations) =>
this.apply(p) match {
case u @ UnresolvedRelation(Seq(table), _, _) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) =>
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}.getOrElse(u)
case other => other
}
case other =>
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
Expand Down Expand Up @@ -65,6 +65,17 @@ case class PlanWithUnresolvedIdentifier(
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER)
}

/**
* A logical plan placeholder which delays CTE resolution
* to moment when PlanWithUnresolvedIdentifier gets resolved
*/
case class UnresolvedWithCTERelations(
unresolvedPlan: LogicalPlan,
cteRelations: Seq[(String, CTERelationDef)])
extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER_WITH_CTE)
}

/**
* An expression placeholder that holds the identifier clause string expression. It will be
* replaced by the actual expression with the evaluated identifier string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value

// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_FUNC: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,79 @@ DropTable false, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2


-- !query
DECLARE agg = 'max'
-- !query analysis
CreateVariable defaultvalueexpression(max, 'max'), false
+- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.agg


-- !query
DECLARE col = 'c1'
-- !query analysis
CreateVariable defaultvalueexpression(c1, 'c1'), false
+- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.col


-- !query
DECLARE tab = 'T'
-- !query analysis
CreateVariable defaultvalueexpression(T, 'T'), false
+- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.tab


-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab)
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias S
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
:- CTERelationDef xxxx, false
: +- SubqueryAlias T
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Aggregate [max(c1#x) AS max(c1)#x]
+- SubqueryAlias T
+- CTERelationRef xxxx, true, [c1#x, c2#x], false


-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T')
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias S
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
:- CTERelationDef xxxx, false
: +- SubqueryAlias T
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Aggregate [max(c1#x) AS max(c1)#x]
+- SubqueryAlias T
+- CTERelationRef xxxx, true, [c1#x, c2#x], false


-- !query
WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC')
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias ABC
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Aggregate [max(c1#x) AS max(c1)#x]
+- SubqueryAlias ABC
+- CTERelationRef xxxx, true, [c1#x, c2#x], false


-- !query
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
-- !query analysis
Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,22 @@ drop view v1;
drop table t1;
drop table t2;

-- SPARK-46625: CTE reference with identifier clause and session variables
DECLARE agg = 'max';
DECLARE col = 'c1';
DECLARE tab = 'T';

WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab);

WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T');

WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC');

-- Not supported
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1);
SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,59 @@ struct<>



-- !query
DECLARE agg = 'max'
-- !query schema
struct<>
-- !query output



-- !query
DECLARE col = 'c1'
-- !query schema
struct<>
-- !query output



-- !query
DECLARE tab = 'T'
-- !query schema
struct<>
-- !query output



-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab)
-- !query schema
struct<max(c1):string>
-- !query output
c


-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T')
-- !query schema
struct<max(c1):string>
-- !query output
c


-- !query
WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC')
-- !query schema
struct<max(c1):int>
-- !query output
2


-- !query
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
-- !query schema
Expand Down

0 comments on commit d824e9e

Please sign in to comment.