From adfea273ee9a196b4e7ca10a580082c734ad41f4 Mon Sep 17 00:00:00 2001 From: Naoki Takezoe Date: Tue, 11 Jul 2023 20:54:11 +0900 Subject: [PATCH 1/3] airframe-sql: Resolve columns from CTE inside AliasedRelation --- .../airframe/sql/analyzer/TypeResolver.scala | 3 +-- .../airframe/sql/analyzer/TypeResolverTest.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala index 2ea71b5184..06abe0d74d 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala @@ -380,8 +380,7 @@ object TypeResolver extends LogSupport { a.copy(expr = resolved) } case SingleColumn(a: Attribute, qualifier, _) if a.resolved => - // Optimizes the nested attributes, but preserves qualifier in the parent - a.setQualifierIfEmpty(qualifier) + a.withQualifier(qualifier) case m: MultiSourceColumn => var changed = false val resolvedInputs = m.inputs.map { diff --git a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala index 203915d8e9..471c6afba0 100644 --- a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala +++ b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala @@ -1070,4 +1070,19 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { analyze("select \"prénom\" from (select name as \"prénom\" from A)") // No error } + + test("resolve CTE in AliasedRelation") { + val p1 = analyze("with t1 as (select id from A) select id from (select id from t1) t2") + p1.outputAttributes shouldMatch { case List(col: ResolvedAttribute) => + col.fullName shouldBe "id" + col.sourceColumn.head.fullName shouldBe "A.id" + } + + val p2 = analyze("with t1 as (select id from A) select count(id) from (select id from t1) t2") + p2.outputAttributes shouldMatch { + case List(SingleColumn(FunctionCall("count", Seq(col: ResolvedAttribute), _, _, _, _), _, _)) => + col.fullName shouldBe "t2.id" + col.sourceColumn.head.fullName shouldBe "A.id" + } + } } From 3f1638f38d6f4eb323d4febddaef70f512f641de Mon Sep 17 00:00:00 2001 From: Naoki Takezoe Date: Wed, 12 Jul 2023 09:46:31 +0900 Subject: [PATCH 2/3] Fixup --- .../parquet/ParquetQueryPlanner.scala | 2 +- .../airframe/sql/analyzer/SQLAnonymizer.scala | 2 +- .../airframe/sql/analyzer/TypeResolver.scala | 51 +++--- .../wvlet/airframe/sql/model/Expression.scala | 55 ++++-- .../airframe/sql/model/LogicalPlan.scala | 18 +- .../airframe/sql/model/ResolvedPlan.scala | 6 + .../airframe/sql/parser/SQLGenerator.scala | 2 +- .../airframe/sql/parser/SQLInterpreter.scala | 8 +- .../sql/analyzer/ResolveAggregationTest.scala | 2 +- .../sql/analyzer/SQLAnalyzerTest.scala | 26 ++- .../sql/analyzer/TypeResolverTest.scala | 171 +++++++++--------- .../airframe/sql/model/ExpressionTest.scala | 10 +- 12 files changed, 200 insertions(+), 153 deletions(-) diff --git a/airframe-parquet/src/main/scala/wvlet/airframe/parquet/ParquetQueryPlanner.scala b/airframe-parquet/src/main/scala/wvlet/airframe/parquet/ParquetQueryPlanner.scala index 1129372307..b7e3548752 100644 --- a/airframe-parquet/src/main/scala/wvlet/airframe/parquet/ParquetQueryPlanner.scala +++ b/airframe-parquet/src/main/scala/wvlet/airframe/parquet/ParquetQueryPlanner.scala @@ -49,7 +49,7 @@ object ParquetQueryPlanner extends LogSupport { val queryPlan = ParquetQueryPlan(sql) logicalPlan match { - case Project(input, Seq(AllColumns(None, _, _)), _) => + case Project(input, Seq(AllColumns(None, _, _, _)), _) => parseRelation(input, queryPlan).selectAllColumns case Project(input, selectItems, _) => val columns = selectItems.map { diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/SQLAnonymizer.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/SQLAnonymizer.scala index e3a43c35ac..3d0f101da6 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/SQLAnonymizer.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/SQLAnonymizer.scala @@ -93,7 +93,7 @@ object SQLAnonymizer extends LogSupport { } else { None } - val v = UnresolvedAttribute(qualifier, parts.last, u.nodeLocation) + val v = UnresolvedAttribute(qualifier, parts.last, None, u.nodeLocation) m += u -> v } this diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala index 06abe0d74d..5c839c953e 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala @@ -125,9 +125,9 @@ object TypeResolver extends LogSupport { inputs(index) match { case a: AllColumns => resolveIndex(index, a.inputColumns) - case SingleColumn(expr, _, _) => + case SingleColumn(expr, _, _, _) => expr - case Alias(_, _, expr, _) => + case Alias(_, _, expr, _, _) => expr case other => other @@ -222,7 +222,7 @@ object TypeResolver extends LogSupport { case Some(cte) => CTERelationRef( qname.fullName, - cte.outputAttributes.map(_.withQualifier(qname.fullName)), + cte.outputAttributes.map(_.withTableAlias(qname.fullName)), plan.nodeLocation ) case None => @@ -260,14 +260,14 @@ object TypeResolver extends LogSupport { val mergedJoinKeys = resolvedJoinKeys .groupBy(_.attributeName).map { case (name, keys) => val resolvedKeys = keys.flatMap { - case SingleColumn(r: ResolvedAttribute, qual, _) => + case SingleColumn(r: ResolvedAttribute, qual, _, _) => Seq(r.withQualifier(qual)) case m: MultiSourceColumn => m.inputs case other => Seq(other) } - MultiSourceColumn(resolvedKeys, None, None) + MultiSourceColumn(resolvedKeys, None, None, None) } .toSeq // Preserve the original USING(k1, k2, ...) order @@ -348,14 +348,14 @@ object TypeResolver extends LogSupport { ): Seq[Attribute] = { val resolvedColumns = Seq.newBuilder[Attribute] outputColumns.map { - case a @ Alias(qualifier, name, expr, _) => + case a @ Alias(qualifier, name, expr, _, _) => val resolved = resolveExpression(context, expr, inputAttributes) if (expr eq resolved) { resolvedColumns += a } else { resolvedColumns += a.copy(expr = resolved) } - case s @ SingleColumn(expr, qualifier, nodeLocation) => + case s @ SingleColumn(expr, qualifier, _, nodeLocation) => resolveExpression(context, expr, inputAttributes) match { case a: Attribute => resolvedColumns += a.withQualifier(qualifier) @@ -372,15 +372,15 @@ object TypeResolver extends LogSupport { def resolveAttribute(attribute: Attribute): Attribute = { attribute match { - case a @ Alias(qualifier, name, attr: Attribute, _) => + case a @ Alias(qualifier, name, attr: Attribute, _, _) => val resolved = resolveAttribute(attr) if (attr eq resolved) { a } else { a.copy(expr = resolved) } - case SingleColumn(a: Attribute, qualifier, _) if a.resolved => - a.withQualifier(qualifier) + case SingleColumn(a: Attribute, qualifier, _, _) if a.resolved => + a case m: MultiSourceColumn => var changed = false val resolvedInputs = m.inputs.map { @@ -402,28 +402,25 @@ object TypeResolver extends LogSupport { } private def toResolvedAttribute(name: String, expr: Expression): Attribute = { - def findSourceColumn(e: Expression): Option[SourceColumn] = { e match { - case r: ResolvedAttribute => - r.sourceColumn - case a: Alias => - findSourceColumn(a.expr) - case _ => None + case r: ResolvedAttribute => r.sourceColumn + case a: Alias => findSourceColumn(a.expr) + case _ => None } } expr match { case a: Alias => - ResolvedAttribute(a.name, a.expr.dataType, a.qualifier, findSourceColumn(a.expr), a.nodeLocation) + ResolvedAttribute(a.name, a.expr.dataType, a.qualifier, findSourceColumn(a.expr), None, a.nodeLocation) case s: SingleColumn => - ResolvedAttribute(name, s.dataType, s.qualifier, findSourceColumn(s.expr), s.nodeLocation) + ResolvedAttribute(name, s.dataType, s.qualifier, findSourceColumn(s.expr), None, s.nodeLocation) case a: Attribute => // No need to resolve Attribute expressions - a + a.withTableAlias(None) case other => // Resolve expr as ResolvedAttribute so as not to pull-up too much details - ResolvedAttribute(name, other.dataType, None, findSourceColumn(expr), other.nodeLocation) + ResolvedAttribute(name, other.dataType, None, findSourceColumn(expr), None, other.nodeLocation) } } @@ -458,20 +455,18 @@ object TypeResolver extends LogSupport { val results = expr match { case i: Identifier => lookup(i.value, context).map(toResolvedAttribute(i.value, _)) - case u @ UnresolvedAttribute(qualifier, name, _) => + case u @ UnresolvedAttribute(qualifier, name, _, _) => lookup(u.fullName, context).map(toResolvedAttribute(name, _).withQualifier(qualifier)) - case a @ AllColumns(qualifier, None, _) => + case a @ AllColumns(qualifier, None, _, _) => // Resolve the inputs of AllColumn as ResolvedAttribute // so as not to pull up too much details val allColumns = resolvedAttributes.map { - case a: Attribute => - // Attribute can be used as is - a - case other => - toResolvedAttribute(other.name, other) + // Attribute can be used as is + case a: Attribute => a + case other => toResolvedAttribute(other.name, other) } List(a.copy(columns = Some((qualifier match { - case Some(q) => allColumns.filter(_.qualifier.contains(q)) + case Some(q) => allColumns.filter(c => c.qualifier.contains(q) || c.tableAlias.contains(q)) case None => allColumns }).map(_.withQualifier(None))))) case _ => diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala index a70efc8e9a..d0f97668d4 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala @@ -23,6 +23,7 @@ import wvlet.airframe.sql.Assertion._ import wvlet.log.LogSupport import java.util.Locale +import scala.annotation.tailrec /** */ @@ -294,12 +295,6 @@ trait Attribute extends LeafExpression with LogSupport { def qualifier: Option[String] def withQualifier(newQualifier: String): Attribute = withQualifier(Some(newQualifier)) def withQualifier(newQualifier: Option[String]): Attribute - def setQualifierIfEmpty(newQualifier: Option[String]): Attribute = { - qualifier match { - case Some(q) => this - case None => this.withQualifier(newQualifier) - } - } import Expression.Alias def alias: Option[String] = { @@ -320,11 +315,15 @@ trait Attribute extends LeafExpression with LogSupport { // No need to have alias other case other => - Alias(qualifier, alias, other, None) + Alias(qualifier, alias, other, other.tableAlias, None) } } } + def tableAlias: Option[String] + def withTableAlias(tableAlias: String): Attribute = withTableAlias(Some(tableAlias)) + def withTableAlias(tableAlias: Option[String]): Attribute + /** * Return columns used for generating this attribute */ @@ -355,7 +354,7 @@ trait Attribute extends LeafExpression with LogSupport { columnPath.table match { // TODO handle (catalog).(database).(table) names in the qualifier case Some(tableName) => - qualifier.exists(_ == tableName) && matchesWith(columnPath.columnName) + (qualifier.contains(tableName) || tableAlias.contains(tableName)) && matchesWith(columnPath.columnName) case None => matchesWith(columnPath.columnName) } @@ -366,12 +365,15 @@ trait Attribute extends LeafExpression with LogSupport { * via Join, Union), return MultiSourceAttribute. */ def matched(columnPath: ColumnPath, context: AnalyzerContext): Option[Attribute] = { + @tailrec def findMatched(tableName: Option[String], columnName: String): Seq[Attribute] = { tableName match { case Some(tableName) => this match { case r: ResolvedAttribute - if r.qualifier.orElse(r.sourceColumn.map(_.table.name)).exists(_.equalsIgnoreCase(tableName)) => + if r.qualifier + .orElse(tableAlias) + .orElse(r.sourceColumn.map(_.table.name)).exists(_.equalsIgnoreCase(tableName)) => findMatched(None, columnName) case _ => Nil @@ -394,6 +396,8 @@ trait Attribute extends LeafExpression with LogSupport { if (databaseName == context.database) { if (qualifier.contains(tableName)) { findMatched(None, columnName).map(_.withQualifier(qualifier)) + } else if (tableAlias.contains(tableName)) { + findMatched(None, columnName) } else { findMatched(Some(tableName), columnName) } @@ -405,8 +409,10 @@ trait Attribute extends LeafExpression with LogSupport { } } case ColumnPath(None, Some(tableName), columnName) => - if (qualifier.exists(_ == tableName)) { + if (qualifier.contains(tableName)) { findMatched(None, columnName).map(_.withQualifier(qualifier)) + } else if (tableAlias.contains(tableName)) { + findMatched(None, columnName) } else { findMatched(Some(tableName), columnName) } @@ -421,7 +427,7 @@ trait Attribute extends LeafExpression with LogSupport { } else { qualifier } - Some(MultiSourceColumn(result, qualifier = q, None)) + Some(MultiSourceColumn(result, qualifier = q, None, None)) } else { result.headOption } @@ -490,6 +496,7 @@ object Expression { case class UnresolvedAttribute( override val qualifier: Option[String], name: String, + tableAlias: Option[String], nodeLocation: Option[NodeLocation] ) extends Attribute { override def toString: String = s"UnresolvedAttribute(${fullName})" @@ -498,6 +505,9 @@ object Expression { override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = { this.copy(qualifier = newQualifier) } + override def withTableAlias(tableAlias: Option[String]): Attribute = { + this.copy(tableAlias = tableAlias) + } override def inputColumns: Seq[Attribute] = Seq.empty override def outputColumns: Seq[Attribute] = Seq.empty override def sourceColumns: Seq[SourceColumn] = Seq.empty @@ -566,6 +576,7 @@ object Expression { case class AllColumns( override val qualifier: Option[String], columns: Option[Seq[Attribute]], + tableAlias: Option[String], nodeLocation: Option[NodeLocation] ) extends Attribute with LogSupport { @@ -587,7 +598,7 @@ object Expression { } } override def outputColumns: Seq[Attribute] = { - inputColumns.map(_.withQualifier(qualifier)) + inputColumns.map(_.withTableAlias(tableAlias).withQualifier(qualifier)) } override def dataType: DataType = { @@ -599,6 +610,9 @@ object Expression { override def withQualifier(newQualifier: Option[String]): Attribute = { this.copy(qualifier = newQualifier) } + override def withTableAlias(tableAlias: Option[String]): Attribute = { + this.copy(tableAlias = tableAlias) + } override def toString = { columns match { @@ -622,6 +636,7 @@ object Expression { qualifier: Option[String], name: String, expr: Expression, + tableAlias: Option[String], nodeLocation: Option[NodeLocation] ) extends Attribute { override def inputColumns: Seq[Attribute] = Seq(this) @@ -633,6 +648,10 @@ object Expression { this.copy(qualifier = newQualifier) } + override def withTableAlias(tableAlias: Option[String]): Attribute = { + this.copy(tableAlias = tableAlias) + } + override def toString: String = { s"<${fullName}> := ${expr}" } @@ -659,7 +678,8 @@ object Expression { */ case class SingleColumn( expr: Expression, - qualifier: Option[String] = None, + qualifier: Option[String], + tableAlias: Option[String], nodeLocation: Option[NodeLocation] ) extends Attribute { override def name: String = expr.attributeName @@ -675,6 +695,9 @@ object Expression { override def withQualifier(newQualifier: Option[String]): Attribute = { this.copy(qualifier = newQualifier) } + override def withTableAlias(tableAlias: Option[String]): Attribute = { + this.copy(tableAlias = tableAlias) + } override def sourceColumns: Seq[SourceColumn] = { expr match { @@ -693,6 +716,7 @@ object Expression { case class MultiSourceColumn( inputs: Seq[Expression], qualifier: Option[String], + tableAlias: Option[String], nodeLocation: Option[NodeLocation] ) extends Attribute { require(inputs.nonEmpty, s"The inputs of MultiSourceColumn should not be empty: ${this}", nodeLocation) @@ -703,7 +727,7 @@ object Expression { inputs.map { case a: Attribute => a case e: Expression => - SingleColumn(e, qualifier, e.nodeLocation) + SingleColumn(e, qualifier, None, e.nodeLocation) } } override def outputColumns: Seq[Attribute] = Seq(this) @@ -725,6 +749,9 @@ object Expression { override def withQualifier(newQualifier: Option[String]): Attribute = { this.copy(qualifier = newQualifier) } + override def withTableAlias(tableAlias: Option[String]): Attribute = { + this.copy(tableAlias = tableAlias) + } override def sourceColumns: Seq[SourceColumn] = { inputs.flatMap { diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/LogicalPlan.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/LogicalPlan.scala index a332c543f1..523fc53dee 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/LogicalPlan.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/LogicalPlan.scala @@ -477,8 +477,8 @@ object LogicalPlan { private def isSelectAll(selectItems: Seq[Attribute]): Boolean = { selectItems.exists { - case AllColumns(x, _, _) => true - case _ => false + case AllColumns(x, _, _, _) => true + case _ => false } } @@ -514,9 +514,7 @@ object LogicalPlan { override def inputAttributes: Seq[Attribute] = child.inputAttributes override def outputAttributes: Seq[Attribute] = { - val attrs = child.outputAttributes.map { a => - a.withQualifier(alias.value) - } + val attrs = child.outputAttributes.map(_.withTableAlias(alias.value)) val result = columnNames match { case Some(columnNames) => attrs.zip(columnNames).map { case (a, columnName) => @@ -545,7 +543,7 @@ object LogicalPlan { } } val columns = (0 until values.head.size).map { i => - MultiSourceColumn(values.map(_(i)), None, None) + MultiSourceColumn(values.map(_(i)), None, None, None) } columns } @@ -706,6 +704,7 @@ object LogicalPlan { SingleColumn( in, None, + None, alias.nodeLocation ).withAlias(alias.value) } @@ -794,6 +793,7 @@ object LogicalPlan { None } }, + None, None ) // In set operations, if different column names are merged into one column, the first column name will be used @@ -854,9 +854,9 @@ object LogicalPlan { override def outputAttributes: Seq[Attribute] = { columns.map { case arr: ArrayConstructor => - ResolvedAttribute(UUID.randomUUID().toString, arr.elementType, None, None, None) + ResolvedAttribute(UUID.randomUUID().toString, arr.elementType, None, None, None, None) case other => - SingleColumn(other, None, other.nodeLocation) + SingleColumn(other, None, None, other.nodeLocation) } } override def sig(config: QuerySignatureConfig): String = @@ -880,7 +880,7 @@ object LogicalPlan { nodeLocation: Option[NodeLocation] ) extends UnaryRelation { override def outputAttributes: Seq[Attribute] = - columnAliases.map(x => UnresolvedAttribute(Some(tableAlias.value), x.value, None)) + columnAliases.map(x => UnresolvedAttribute(Some(tableAlias.value), x.value, None, None)) override def sig(config: QuerySignatureConfig): String = s"LV(${child.sig(config)})" } diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/ResolvedPlan.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/ResolvedPlan.scala index e058dcb05b..8b406bdb04 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/ResolvedPlan.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/ResolvedPlan.scala @@ -44,6 +44,7 @@ case class TableScan( col.dataType, None, // This must be None first Some(SourceColumn(table, col)), + None, None // ResolvedAttribute always has no NodeLocation ) } @@ -73,6 +74,7 @@ case class ResolvedAttribute( qualifier: Option[String], // If this attribute directly refers to a table column, its source column will be set. sourceColumn: Option[SourceColumn], + tableAlias: Option[String], nodeLocation: Option[NodeLocation] ) extends Attribute with LogSupport { @@ -83,6 +85,10 @@ case class ResolvedAttribute( override def withQualifier(newQualifier: Option[String]): Attribute = { this.copy(qualifier = newQualifier) } + override def withTableAlias(tableAlias: Option[String]): Attribute = { + this.copy(tableAlias = tableAlias) + } + override def inputColumns: Seq[Attribute] = Seq(this) override def outputColumns: Seq[Attribute] = inputColumns diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLGenerator.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLGenerator.scala index 8e497994f0..d1cfc0191c 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLGenerator.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLGenerator.scala @@ -380,7 +380,7 @@ object SQLGenerator extends LogSupport { case a: Alias => val e = printExpression(a.expr) s"${e} AS ${printNameWithQuotationsIfNeeded(a.name)}" - case SingleColumn(ex, _, _) => + case SingleColumn(ex, _, _, _) => printExpression(ex) case m: MultiSourceColumn => m.sqlExpr diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala index 24fe7d34eb..341d55a0b9 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala @@ -398,13 +398,13 @@ class SQLInterpreter(withNodeLocation: Boolean = true) extends SqlBaseBaseVisito val name = QName.unquote(ctx.fieldName.getText) - UnresolvedAttribute(qualifier, name, getLocation(ctx)) + UnresolvedAttribute(qualifier, name, None, getLocation(ctx)) } override def visitSelectAll(ctx: SelectAllContext): Attribute = { // TODO parse qName val qualifier = Option(ctx.qualifiedName()).map(_.getText) - AllColumns(qualifier, None, getLocation(ctx)) + AllColumns(qualifier, None, None, getLocation(ctx)) } override def visitSelectSingle(ctx: SelectSingleContext): Attribute = { @@ -416,7 +416,7 @@ class SQLInterpreter(withNodeLocation: Boolean = true) extends SqlBaseBaseVisito case a: Attribute => a.qualifier case _ => None } - SingleColumn(child, qualifier, getLocation(ctx)) + SingleColumn(child, qualifier, None, getLocation(ctx)) .withAlias(alias.map(_.value)) } @@ -764,7 +764,7 @@ class SQLInterpreter(withNodeLocation: Boolean = true) extends SqlBaseBaseVisito if (ctx.ASTERISK() != null) { FunctionCall( name, - Seq(AllColumns(None, None, getLocation(ctx))), + Seq(AllColumns(None, None, None, getLocation(ctx))), isDistinct, filter, over, diff --git a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/ResolveAggregationTest.scala b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/ResolveAggregationTest.scala index e5701518a4..4f303492ec 100644 --- a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/ResolveAggregationTest.scala +++ b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/ResolveAggregationTest.scala @@ -55,7 +55,7 @@ class ResolveAggregationTest extends AirSpec with ResolverTestHelper { plan shouldMatch { case a: Aggregate => a.selectItems shouldMatch { case Seq( - ResolvedAttribute("max_id", _, Some("B"), _, _), + ResolvedAttribute("max_id", _, Some("B"), _, _, _), _ ) => } diff --git a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/SQLAnalyzerTest.scala b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/SQLAnalyzerTest.scala index 65c4bfd44a..8c6354baab 100644 --- a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/SQLAnalyzerTest.scala +++ b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/SQLAnalyzerTest.scala @@ -52,8 +52,8 @@ class SQLAnalyzerTest extends AirSpec { val plan = SQLAnalyzer.analyze("select id, name from a", "public", catalog) plan.resolved shouldBe true plan.outputAttributes.toList shouldBe List( - ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tbl1, tbl1.column("id"))), None), - ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tbl1, tbl1.column("name"))), None) + ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tbl1, tbl1.column("id"))), None, None), + ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tbl1, tbl1.column("name"))), None, None) ) } @@ -65,12 +65,13 @@ class SQLAnalyzerTest extends AirSpec { None, Some( Seq( - ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tbl1, tbl1.column("id"))), None), + ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tbl1, tbl1.column("id"))), None, None), ResolvedAttribute( "name", DataType.StringType, None, Some(SourceColumn(tbl1, tbl1.column("name"))), + None, None ), ResolvedAttribute( @@ -78,10 +79,12 @@ class SQLAnalyzerTest extends AirSpec { DataType.StringType, None, Some(SourceColumn(tbl1, tbl1.column("address"))), + None, None ) ) ), + None, Some(NodeLocation(1, 8)) ) ) @@ -92,7 +95,7 @@ class SQLAnalyzerTest extends AirSpec { plan.resolved shouldBe true plan.outputAttributes.toList shouldMatch { // Attribute should not have a qualifier - case List(Alias(_, "person_id", r, _)) => { + case List(Alias(_, "person_id", r, _, _)) => { r.attributeName shouldBe "id" r.dataType shouldBe DataType.LongType } @@ -112,10 +115,18 @@ class SQLAnalyzerTest extends AirSpec { DataType.LongType, Some("a"), Some(SourceColumn(tbl1, tbl1.column("id"))), + None, None ) attr(1) shouldBe - ResolvedAttribute("name", DataType.StringType, Some("a"), Some(SourceColumn(tbl1, tbl1.column("name"))), None) + ResolvedAttribute( + "name", + DataType.StringType, + Some("a"), + Some(SourceColumn(tbl1, tbl1.column("name"))), + None, + None + ) attr(2) shouldBe ResolvedAttribute( @@ -123,10 +134,11 @@ class SQLAnalyzerTest extends AirSpec { DataType.StringType, Some("a"), Some(SourceColumn(tbl1, tbl1.column("address"))), + None, None ) - attr(3) shouldMatch { case Alias(_, "phone_num", a, _) => - a shouldMatch { case ResolvedAttribute("phone", DataType.StringType, _, _, _) => + attr(3) shouldMatch { case Alias(_, "phone_num", a, _, _) => + a shouldMatch { case ResolvedAttribute("phone", DataType.StringType, _, _, _, _) => // c shouldBe SourceColumn(tbl2, tbl2.column("phone")) } } diff --git a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala index 471c6afba0..4c9d982e87 100644 --- a/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala +++ b/airframe-sql/src/test/scala/wvlet/airframe/sql/analyzer/TypeResolverTest.scala @@ -84,12 +84,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { catalog } - private val ra1 = ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tableA, a1)), None) - private val ra2 = ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tableA, a2)), None) - private val rb1 = ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tableB, b1)), None) - private val rb2 = ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tableB, b2)), None) - private val rc1 = ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tableC, c1)), None) - private val rc2 = ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tableC, c2)), None) + private val ra1 = ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tableA, a1)), None, None) + private val ra2 = ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tableA, a2)), None, None) + private val rb1 = ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tableB, b1)), None, None) + private val rb2 = ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tableB, b2)), None, None) + private val rc1 = ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(tableC, c1)), None, None) + private val rc2 = ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(tableC, c2)), None, None) test("resolveTableRef") { test("resolve all columns") { @@ -167,7 +167,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select id from A union all select id from B") p.inputAttributes shouldBe List(ra1, ra2, rb1, rb2) p.outputAttributes shouldBe List( - MultiSourceColumn(List(ra1, rb1), None, None) + MultiSourceColumn(List(ra1, rb1), None, None, None) ) } @@ -175,7 +175,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select id from A union all select id from A") p.inputAttributes shouldBe List(ra1, ra2, ra1, ra2) p.outputAttributes shouldBe List( - MultiSourceColumn(List(ra1, ra1), None, None) + MultiSourceColumn(List(ra1, ra1), None, None, None) ) } @@ -184,8 +184,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { p.inputAttributes shouldBe List(ra1, ra2, rb1, rb2) p.outputAttributes shouldMatch { case Seq( - MultiSourceColumn(Seq(`ra1`, `rb1`), None, _), - MultiSourceColumn(Seq(`ra2`, `rb2`), None, _) + MultiSourceColumn(Seq(`ra1`, `rb1`), None, _, _), + MultiSourceColumn(Seq(`ra2`, `rb2`), None, _, _) ) => } } @@ -194,8 +194,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select * from (select * from A union all select * from B)") p.inputAttributes shouldMatch { case Seq( - MultiSourceColumn(Seq(`ra1`, `rb1`), None, _), - MultiSourceColumn(Seq(`ra2`, `rb2`), None, _) + MultiSourceColumn(Seq(`ra1`, `rb1`), None, _, _), + MultiSourceColumn(Seq(`ra2`, `rb2`), None, _, _) ) => } p.outputAttributes shouldMatch { @@ -204,10 +204,11 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { None, Some( Seq( - MultiSourceColumn(Seq(`ra1`, `rb1`), None, _), - MultiSourceColumn(Seq(`ra2`, `rb2`), None, _) + MultiSourceColumn(Seq(`ra1`, `rb1`), None, _, _), + MultiSourceColumn(Seq(`ra2`, `rb2`), None, _, _) ) ), + _, _ ) ) => @@ -216,12 +217,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("ru2: resolve union with column alias") { val p = analyze("select p1 from (select id as p1 from A union all select id as p1 from B)") - p.inputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), None, _)) => + p.inputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), None, _, _)) => m.name shouldBe "p1" c1 shouldBe ra1.withAlias("p1") c2 shouldBe rb1.withAlias("p1") } - p.outputAttributes shouldMatch { case Seq(MultiSourceColumn(Seq(c1, c2), None, _)) => + p.outputAttributes shouldMatch { case Seq(MultiSourceColumn(Seq(c1, c2), None, _, _)) => c1 shouldBe ra1.withAlias("p1") c2 shouldBe rb1.withAlias("p1") } @@ -229,11 +230,11 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("ru2a: resolve union with different column aliases") { val p = analyze("select p1 from (select id as p1 from A union all select id from B)") - p.inputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), None, _)) => + p.inputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), None, _, _)) => c1 shouldBe ra1.withAlias("p1") c2 shouldBe rb1 } - p.outputAttributes shouldMatch { case Seq(MultiSourceColumn(Seq(c1, c2), None, _)) => + p.outputAttributes shouldMatch { case Seq(MultiSourceColumn(Seq(c1, c2), None, _, _)) => c1 shouldBe ra1.withAlias("p1") c2 shouldBe rb1 } @@ -246,12 +247,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("ru3: resolve union with column alias and qualifier") { val p = analyze("select q1.p1 from (select id as p1 from A union all select id as p1 from B) q1") - p.inputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), _, _)) => + p.inputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), _, _, _)) => m.name shouldBe "p1" c1 shouldBe ra1.withAlias("p1") c2 shouldBe rb1.withAlias("p1") } - p.outputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), Some("q1"), _)) => + p.outputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), Some("q1"), _, _)) => c1 shouldBe ra1.withAlias("p1") c2 shouldBe rb1.withAlias("p1") } @@ -260,7 +261,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("ru4: resolve aggregation key with union") { val p = analyze("select count(*), id from (select * from A union all select * from B) group by id") val agg = p shouldMatch { case a: Aggregate => a } - agg.groupingKeys(0).child shouldMatch { case m @ MultiSourceColumn(Seq(`ra1`, `rb1`), _, _) => + agg.groupingKeys(0).child shouldMatch { case m @ MultiSourceColumn(Seq(`ra1`, `rb1`), _, _, _) => m.name shouldBe "id" } } @@ -271,10 +272,11 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { p.outputAttributes.head shouldMatch { case MultiSourceColumn( Seq( - SingleColumn(ArithmeticBinaryExpr(Add, `ra1`, LongLiteral(1, _), _), None, _), - SingleColumn(ArithmeticBinaryExpr(Add, `rb1`, LongLiteral(1, _), _), None, _) + SingleColumn(ArithmeticBinaryExpr(Add, `ra1`, LongLiteral(1, _), _), None, _, _), + SingleColumn(ArithmeticBinaryExpr(Add, `rb1`, LongLiteral(1, _), _), None, _, _) ), _, + _, _ ) => } @@ -284,8 +286,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select id, name from (select id, name from A union all select id, name from B)") p.outputAttributes shouldMatch { case Seq( - MultiSourceColumn(Seq(`ra1`, `rb1`), None, _), - MultiSourceColumn(Seq(`ra2`, `rb2`), None, _) + MultiSourceColumn(Seq(`ra1`, `rb1`), None, _, _), + MultiSourceColumn(Seq(`ra2`, `rb2`), None, _, _) ) => } } @@ -294,7 +296,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select id from A intersect select id from B") // => Distinct(Intersect(...)) p shouldMatch { case Distinct(i @ Intersect(_, _), _) => i.inputAttributes shouldBe List(ra1, ra2, rb1, rb2) - i.outputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(`ra1`, `rb1`), _, _)) => + i.outputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(`ra1`, `rb1`), _, _, _)) => m.name shouldBe "id" } } @@ -333,7 +335,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("group by index of column with alias") { val p = analyze("select id as i, count(*) from A group by 1") - p shouldMatch { case Aggregate(_, _, List(ResolvedGroupingKey(Some(1), SingleColumn(`ra1`, _, _), _)), _, _) => + p shouldMatch { case Aggregate(_, _, List(ResolvedGroupingKey(Some(1), SingleColumn(`ra1`, _, _, _), _)), _, _) => } } @@ -359,7 +361,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("a1: resolve qualified column used in GROUP BY clause") { val p = analyze("SELECT a.cnt, a.name FROM (SELECT count(id) cnt, name FROM A GROUP BY name) a") p.outputAttributes shouldMatch { case Seq(c1, c2) => - c1 shouldMatch { case ResolvedAttribute("cnt", DataType.LongType, Some("a"), None, _) => } + c1 shouldMatch { case ResolvedAttribute("cnt", DataType.LongType, Some("a"), None, _, _) => } c2 shouldBe ra2.withQualifier("a") } } @@ -369,7 +371,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val agg = p shouldMatch { case a: Aggregate => a } agg.groupingKeys shouldMatch { case List(ResolvedGroupingKey(Some(1), r: Attribute, _)) => - r shouldMatch { case ResolvedAttribute("xxx", DataType.LongType, _, c, _) => + r shouldMatch { case ResolvedAttribute("xxx", DataType.LongType, _, c, _, _) => c shouldBe Some(SourceColumn(tableA, a1)) } } @@ -380,12 +382,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { p shouldMatch { case Aggregate( _, - List(c1, Alias(_, "cnt", SingleColumn(f: FunctionCall, _, _), _)), + List(c1, Alias(_, "cnt", SingleColumn(f: FunctionCall, _, _, _), _, _)), List(ResolvedGroupingKey(None, `ra1`, _)), Some(GreaterThan(col, LongLiteral(10, _), _)), _ ) if c1.name == "id" && f.functionName == "count" => - f.args shouldMatch { case List(AllColumns(_, Some(cols), _)) => + f.args shouldMatch { case List(AllColumns(_, Some(cols), _, _)) => cols.toSet shouldBe Set(ra1, ra2) } col shouldMatch { case FunctionCall("count", Seq(ac: AllColumns), false, _, _, _) => @@ -400,21 +402,21 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("w1: parse WITH statement") { val p = analyze("with q1 as (select id from A) select id from q1") p.outputAttributes.toList shouldMatch { - case List(ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(`tableA`, `a1`)), _)) => + case List(ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(`tableA`, `a1`)), _, _)) => } } test("w2: resolve CTE redundant column alias") { val p = analyze("with q1 as (select id as id from A) select id from q1") p.outputAttributes.toList shouldMatch { - case List(ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(`tableA`, `a1`)), _)) => + case List(ResolvedAttribute("id", DataType.LongType, None, Some(SourceColumn(`tableA`, `a1`)), _, _)) => } } test("parse multiple WITH sub queries") { val p = analyze("with q1 as (select id, name from A), q2 as (select name from q1) select * from q2") - p.outputAttributes.toList shouldMatch { case List(AllColumns(None, Some(Seq(c)), _)) => - c shouldBe ra2 + p.outputAttributes.toList shouldMatch { case List(AllColumns(None, Some(Seq(c)), _, _)) => + c shouldBe ra2.withTableAlias("q2") } } @@ -422,9 +424,9 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("with q1(p1, p2) as (select id, name from A) select * from q1") p.outputAttributes.toList shouldMatch { // The output should use aliases from the source columns - case List(AllColumns(None, Some(Seq(c1, c2)), _)) => - c1 shouldMatch { case Alias(None, "p1", `ra1`, _) => } - c2 shouldMatch { case Alias(None, "p2", `ra2`, _) => } + case List(AllColumns(None, Some(Seq(c1, c2)), _, _)) => + c1 shouldMatch { case Alias(None, "p1", `ra1`, _, _) => } + c2 shouldMatch { case Alias(None, "p2", `ra2`, _, _) => } } } @@ -486,9 +488,9 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("rename table and select *") { val p = analyze("select * from A a") - p.outputAttributes shouldMatch { case List(AllColumns(None, Some(Seq(c1, c2)), _)) => - c1 shouldBe ra1 - c2 shouldBe ra2 + p.outputAttributes shouldMatch { case List(AllColumns(None, Some(Seq(c1, c2)), _, _)) => + c1 shouldBe ra1.withTableAlias("a") + c2 shouldBe ra2.withTableAlias("a") } } } @@ -504,7 +506,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("join: resolve join attributes") { test("j1: join with USING") { val p = analyze("select id, A.name from A join B using(id)") - p.outputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), _, _), c3) => + p.outputAttributes shouldMatch { case Seq(m @ MultiSourceColumn(Seq(c1, c2), _, _, _), c3) => m.name shouldBe "id" c1 shouldBe ra1 c2 shouldBe rb1 @@ -526,15 +528,15 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { } joinKeys shouldMatch { case List( - ResolvedJoinUsing(Seq(MultiSourceColumn(Seq(c1, c2), _, _)), _), - ResolvedJoinUsing(Seq(MultiSourceColumn(Seq(c3, c4, c5), _, _)), _) + ResolvedJoinUsing(Seq(MultiSourceColumn(Seq(c1, c2), _, _, _)), _), + ResolvedJoinUsing(Seq(MultiSourceColumn(Seq(c3, c4, c5), _, _, _)), _) ) => - c1 shouldBe ra1.withQualifier("a") - c2 shouldBe rb1.withQualifier("b") + c1 shouldBe ra1 + c2 shouldBe rb1 - c3 shouldBe ra1.withQualifier("a") - c4 shouldBe rb1.withQualifier("b") - c5 shouldBe rc1.withQualifier("c") + c3 shouldBe ra1 + c4 shouldBe rb1 + c5 shouldBe rc1 } } @@ -570,8 +572,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select pid, name from A join (select id as pid from B) on A.id = pid") p.outputAttributes shouldMatch { case List( - ResolvedAttribute("pid", DataType.LongType, None, Some(SourceColumn(`tableB`, `b1`)), _), - ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(`tableA`, `a2`)), _) + ResolvedAttribute("pid", DataType.LongType, None, Some(SourceColumn(`tableB`, `b1`)), _, _), + ResolvedAttribute("name", DataType.StringType, None, Some(SourceColumn(`tableA`, `a2`)), _, _) ) => () } @@ -580,7 +582,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("j7: refer to duplicated key of equi join") { val p = analyze("select B.id from A inner join B on A.id = B.id") p.outputAttributes shouldMatch { - case List(ResolvedAttribute("id", DataType.LongType, Some("B"), Some(SourceColumn(`tableB`, `b1`)), _)) => + case List(ResolvedAttribute("id", DataType.LongType, Some("B"), Some(SourceColumn(`tableB`, `b1`)), _, _)) => } } @@ -605,8 +607,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("self-join with USING") { val p = analyze("select * from A join A using(id)") - p.outputAttributes shouldMatch { case Seq(a @ AllColumns(None, Some(columns), _)) => - columns shouldBe Seq(MultiSourceColumn(Seq(ra1, ra1), None, None), ra2, ra2) + p.outputAttributes shouldMatch { case Seq(a @ AllColumns(None, Some(columns), _, _)) => + columns shouldBe Seq(MultiSourceColumn(Seq(ra1, ra1), None, None, None), ra2, ra2) } } } @@ -683,18 +685,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { |(select id from (select id from A)) x |inner join |(select id from (select id from B)) y on x.id = y.id""".stripMargin) - p.outputAttributes shouldMatch { case Seq(AllColumns(_, Some(c), _)) => + p.outputAttributes shouldMatch { case Seq(AllColumns(_, Some(c), _, _)) => c shouldMatch { case List(c1, c2) => - c1 shouldBe ra1 - c2 shouldBe rb1 + c1 shouldBe ra1.withTableAlias("x") + c2 shouldBe rb1.withTableAlias("y") } } - p shouldMatch { case Project(Join(_, _, _, join: JoinOnEq, _), _, _) => - join.keys shouldBe List( - ra1.withQualifier("x"), - rb1.withQualifier("y") - ) - } } test("resolve column in nested SELECT *") { @@ -747,8 +743,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { p.outputAttributes.toList shouldMatch { case List( - ResolvedAttribute("id", DataType.LongType, None, _, _), - ResolvedAttribute("name", DataType.StringType, None, _, _) + ResolvedAttribute("id", DataType.LongType, None, _, _, _), + ResolvedAttribute("name", DataType.StringType, None, _, _, _) ) => } @@ -764,8 +760,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("WITH q1 AS (SELECT id + 1 as id, name FROM A) SELECT id, name FROM q1 WHERE q1.id = 99") p.outputAttributes.toList shouldMatch { case List( - ResolvedAttribute("id", DataType.LongType, _, _, _), - ResolvedAttribute("name", DataType.StringType, None, _, _) + ResolvedAttribute("id", DataType.LongType, _, _, _, _), + ResolvedAttribute("name", DataType.StringType, None, _, _, _) ) => } @@ -783,7 +779,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("resolve simple count(*)") { val p = analyze("select count(*) from A") p.outputAttributes shouldMatch { - case List(SingleColumn(FunctionCall("count", Seq(c @ AllColumns(_, _, _)), _, _, _, _), _, _)) => + case List(SingleColumn(FunctionCall("count", Seq(c @ AllColumns(_, _, _, _)), _, _, _, _), _, _, _)) => c.columns shouldBe Some(Seq(ra1, ra2)) } } @@ -795,11 +791,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { SingleColumn( ArithmeticBinaryExpr( _, - FunctionCall("count", Seq(c @ AllColumns(_, _, _)), _, _, _, _), + FunctionCall("count", Seq(c @ AllColumns(_, _, _, _)), _, _, _, _), LongLiteral(1, _), _ ), _, + _, _ ) ) => @@ -809,12 +806,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("resolve count(*) in sub query") { val p = analyze("select cnt from (select count(*) as cnt from A)") - p.outputAttributes shouldMatch { case List(ResolvedAttribute("cnt", DataType.LongType, _, _, _)) => } + p.outputAttributes shouldMatch { case List(ResolvedAttribute("cnt", DataType.LongType, _, _, _, _)) => } } test("resolve count(*) in CTE") { val p = analyze("WITH q AS (select count(*) as cnt from A) select cnt from q") - p.outputAttributes shouldMatch { case List(ResolvedAttribute("cnt", DataType.LongType, None, _, _)) => } + p.outputAttributes shouldMatch { case List(ResolvedAttribute("cnt", DataType.LongType, None, _, _, _)) => } } test("resolve count(*) in Union") { @@ -822,12 +819,12 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { p.outputAttributes shouldMatch { case Seq(m: MultiSourceColumn) => m.inputs.size shouldBe 2 m.inputs(0) shouldMatch { - case Alias(_, "cnt", SingleColumn(f: FunctionCall, _, _), _) if f.functionName == "count" => + case Alias(_, "cnt", SingleColumn(f: FunctionCall, _, _, _), _, _) if f.functionName == "count" => f.args.size shouldBe 1 f.args(0).asInstanceOf[AllColumns].columns shouldBe Some(Seq(ra1, ra2)) } m.inputs(1) shouldMatch { - case Alias(_, "cnt", SingleColumn(f: FunctionCall, _, _), _) if f.functionName == "count" => + case Alias(_, "cnt", SingleColumn(f: FunctionCall, _, _, _), _, _) if f.functionName == "count" => f.args.size shouldBe 1 f.args(0).asInstanceOf[AllColumns].columns shouldBe Some(Seq(rb1, rb2)) } @@ -848,10 +845,11 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { _ ), None, + _, _ ) ) => - ac.columns shouldMatch { case Some(Seq(m @ MultiSourceColumn(Seq(`ra1`, `rb1`), _, _))) => + ac.columns shouldMatch { case Some(Seq(m @ MultiSourceColumn(Seq(`ra1`, `rb1`), _, _, _))) => m.name shouldBe "id" } } @@ -917,8 +915,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { p.outputAttributes shouldBe List( ra1, - ResolvedAttribute("key", DataType.StringType, Some("t"), None, None), - ResolvedAttribute("value", DataType.LongType, Some("t"), None, None) + ResolvedAttribute("key", DataType.StringType, Some("t"), None, None, None), + ResolvedAttribute("value", DataType.LongType, Some("t"), None, None, None) ) } } @@ -935,6 +933,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { m2 ) ), + _, _ ) ) => @@ -947,14 +946,14 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { test("resolve select 1 from subquery") { val p = analyze("select cnt from (select cnt from (select 1 as cnt))") - p.outputAttributes shouldMatch { case List(ResolvedAttribute(cnt, DataType.LongType, _, _, _)) => + p.outputAttributes shouldMatch { case List(ResolvedAttribute(cnt, DataType.LongType, _, _, _, _)) => () } } test("resolve select * from (select 1)") { val p = analyze("select * from (select 1)") - p.outputAttributes shouldMatch { case List(AllColumns(None, Some(List(r: Attribute)), _)) => + p.outputAttributes shouldMatch { case List(AllColumns(None, Some(List(r: Attribute)), _, _)) => r.dataType shouldBe DataType.LongType } } @@ -1007,7 +1006,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { Project(t, Seq(ra1), None) } val resolved = TypeResolver.resolve(defaultAnalyzerContext, rewritten) - resolved shouldMatch { case Project(_, List(AllColumns(None, Some(List(c)), _)), _) => + resolved shouldMatch { case Project(_, List(AllColumns(None, Some(List(c)), _, _)), _) => c shouldBe ra1 } } @@ -1021,7 +1020,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select t1.* from A t1 inner join B t2 on t1.id = t2.id") p.outputAttributes shouldMatch { case List(a: AllColumns) if a.qualifier == Some("t1") => - a.columns shouldBe Some(Seq(ra1, ra2)) + a.columns shouldBe Some(Seq(ra1.withTableAlias("t1"), ra2.withTableAlias("t1"))) } } @@ -1029,7 +1028,7 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p = analyze("select t1.* from A t1 inner join (select * from B) t2 using (id)") p.outputAttributes shouldMatch { case List(a: AllColumns) if a.qualifier == Some("t1") => - a.columns shouldBe Some(Seq(ra2)) // "id" is not contained + a.columns shouldBe Some(Seq(ra2.withTableAlias("t1"))) // "id" is not contained } } @@ -1080,9 +1079,15 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper { val p2 = analyze("with t1 as (select id from A) select count(id) from (select id from t1) t2") p2.outputAttributes shouldMatch { - case List(SingleColumn(FunctionCall("count", Seq(col: ResolvedAttribute), _, _, _, _), _, _)) => - col.fullName shouldBe "t2.id" + case List(SingleColumn(FunctionCall("count", Seq(col: ResolvedAttribute), _, _, _, _), _, _, _)) => + col.fullName shouldBe "id" col.sourceColumn.head.fullName shouldBe "A.id" } + + val p3 = analyze("with t1 as (select id from A) select t2.id from (select id from t1) t2") + p3.outputAttributes shouldMatch { case List(col: ResolvedAttribute) => + col.fullName shouldBe "t2.id" + col.sourceColumn.head.fullName shouldBe "A.id" + } } } diff --git a/airframe-sql/src/test/scala/wvlet/airframe/sql/model/ExpressionTest.scala b/airframe-sql/src/test/scala/wvlet/airframe/sql/model/ExpressionTest.scala index f9812ac6dc..a9712692bd 100644 --- a/airframe-sql/src/test/scala/wvlet/airframe/sql/model/ExpressionTest.scala +++ b/airframe-sql/src/test/scala/wvlet/airframe/sql/model/ExpressionTest.scala @@ -25,15 +25,16 @@ class ExpressionTest extends AirSpec { val expr = SingleColumn( f, None, + None, None ) val newExpr = expr.transformExpression { - case s @ SingleColumn(f: FunctionCall, _, _) if f.functionName == "count" => + case s @ SingleColumn(f: FunctionCall, _, _, _) if f.functionName == "count" => s.withQualifier("xxx") } - newExpr shouldBe SingleColumn(f, Some("xxx"), None) + newExpr shouldBe SingleColumn(f, Some("xxx"), None, None) } test("transform up in breadth-first order") { @@ -63,15 +64,16 @@ class ExpressionTest extends AirSpec { val expr = SingleColumn( f, None, + None, None ) val newExpr = expr.transformUpExpression { - case s @ SingleColumn(f: FunctionCall, _, _) if f.functionName == "count" => + case s @ SingleColumn(f: FunctionCall, _, _, _) if f.functionName == "count" => s.withQualifier("xxx") } - newExpr shouldBe SingleColumn(f, Some("xxx"), None) + newExpr shouldBe SingleColumn(f, Some("xxx"), None, None) } test("transform up in depth-first order") { From f38744052bb58f5c9ea011557be5d22ce792e9b2 Mon Sep 17 00:00:00 2001 From: Naoki Takezoe Date: Wed, 12 Jul 2023 12:38:22 +0900 Subject: [PATCH 3/3] Fix some warnings in TypeResolver --- .../wvlet/airframe/sql/analyzer/TypeResolver.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala index 5c839c953e..304425cf90 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/analyzer/TypeResolver.scala @@ -19,6 +19,8 @@ import wvlet.airframe.sql.model.LogicalPlan._ import wvlet.airframe.sql.model._ import wvlet.log.LogSupport +import scala.annotation.tailrec + /** * Resolve untyped [[LogicalPlan]]s and [[Expression]]s into typed ones. */ @@ -95,9 +97,6 @@ object TypeResolver extends LogSupport { /** * Translate select i1, i2, ... group by 1, 2, ... query into select i1, i2, ... group by i1, i2 - * - * @param context - * @return */ object resolveAggregationIndexes extends RewriteRule { def apply(context: AnalyzerContext): PlanRewriter = { @@ -137,8 +136,6 @@ object TypeResolver extends LogSupport { /** * Resolve group by keys - * @param context - * @return */ object resolveAggregationKeys extends RewriteRule { def apply(context: AnalyzerContext): PlanRewriter = { @@ -171,6 +168,7 @@ object TypeResolver extends LogSupport { s.copy(orderBy = resolvedSortItems) } + @tailrec private def resolveIndex(index: Int, inputs: Seq[Attribute]): Expression = { inputs(index) match { case a: AllColumns => @@ -197,8 +195,6 @@ object TypeResolver extends LogSupport { /** * Resolve TableRefs in a query inside WITH statement with CTERelationRef - * @param context - * @return */ object resolveCTETableRef extends RewriteRule { def apply(context: AnalyzerContext): PlanRewriter = { case q @ Query(withQuery, body, _) => @@ -402,6 +398,7 @@ object TypeResolver extends LogSupport { } private def toResolvedAttribute(name: String, expr: Expression): Attribute = { + @tailrec def findSourceColumn(e: Expression): Option[SourceColumn] = { e match { case r: ResolvedAttribute => r.sourceColumn