From 835be0f8ece234e4343987eed0dffc1576637bca Mon Sep 17 00:00:00 2001 From: stefankandic <154237371+stefankandic@users.noreply.github.com> Date: Fri, 2 Feb 2024 15:47:51 +0100 Subject: [PATCH] New Collate Grammar (#6) * initial change of grammar to support string collation * initial change of grammar to support string collation --- .../sql/catalyst/parser/SqlBaseParser.g4 | 11 ++++----- .../catalyst/parser/DataTypeAstBuilder.scala | 16 +++++++++---- .../sql/catalyst/parser/AstBuilder.scala | 23 ++----------------- .../org/apache/spark/sql/CollationSuite.scala | 18 ++++++++++++++- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f36131e0ea5aa..dce3a7dd3026f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -345,10 +345,6 @@ commentSpec : COMMENT stringLit ; -collationSpec - : COLLATE stringLit - ; - query : ctes? queryTerm queryOrganization ; @@ -1098,6 +1094,10 @@ colPosition : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier ; +collation + : COLLATE collationName=stringLit + ; + type : BOOLEAN | TINYINT | BYTE @@ -1108,7 +1108,7 @@ type | DOUBLE | DATE | TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ - | STRING + | STRING collation? | CHARACTER | CHAR | VARCHAR | BINARY @@ -1175,7 +1175,6 @@ colDefinitionOption | defaultExpression | generationExpression | commentSpec - | collationSpec ; generationExpression diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 29b24cec6121c..b5c66223581c4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -24,6 +24,7 @@ import org.antlr.v4.runtime.Token import org.antlr.v4.runtime.tree.ParseTree import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.util.CollatorFactory import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf @@ -58,8 +59,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { * Resolve/create a primitive type. */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - val typeName = ctx.`type`.start.getType - (typeName, ctx.INTEGER_VALUE().asScala.toList) match { + val typeCtx = ctx.`type` + (typeCtx.start.getType, ctx.INTEGER_VALUE().asScala.toList) match { case (BOOLEAN, Nil) => BooleanType case (TINYINT | BYTE, Nil) => ByteType case (SMALLINT | SHORT, Nil) => ShortType @@ -71,7 +72,14 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType case (TIMESTAMP_NTZ, Nil) => TimestampNTZType case (TIMESTAMP_LTZ, Nil) => TimestampType - case (STRING, Nil) => StringType + case (STRING, Nil) => + typeCtx.children.asScala.toSeq match { + case Seq(_) => StringType + case Seq(_, collationCtx: CollationContext) => + val collationStr = visitCollation(collationCtx) + val collationId = CollatorFactory.getInstance().collationNameToId(collationStr) + StringType(collationId) + } case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt) case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt) case (BINARY, Nil) => BinaryType @@ -209,7 +217,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { /** * Create a collation string. */ - override def visitCollationSpec(ctx: CollationSpecContext): String = withOrigin(ctx) { + override def visitCollation(ctx: CollationContext): String = withOrigin(ctx) { string(visitStringLit(ctx.stringLit)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ba5e0190175af..cf7fc0866664b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, CollatorFactory, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns} +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition @@ -3146,7 +3146,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { var defaultExpression: Option[DefaultExpressionContext] = None var generationExpression: Option[GenerationExpressionContext] = None var commentSpec: Option[CommentSpecContext] = None - var collationSpec: Option[CollationSpecContext] = None ctx.colDefinitionOption().asScala.foreach { option => if (option.NULL != null) { @@ -3177,13 +3176,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } commentSpec = Some(spec) } - Option(option.collationSpec()).foreach { spec => - if (collationSpec.isDefined) { - throw QueryParsingErrors.duplicateTableColumnDescriptor( - option, colName.getText, "COLLATE") - } - collationSpec = Some(spec) - } } val builder = new MetadataBuilder @@ -3191,7 +3183,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { commentSpec.map(visitCommentSpec).foreach { builder.putString("comment", _) } - // Add the 'DEFAULT expression' clause in the column definition, if any, to the column metadata. defaultExpression.map(visitDefaultExpression).foreach { field => if (conf.getConf(SQLConf.ENABLE_DEFAULT_COLUMNS)) { @@ -3208,21 +3199,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { builder.putString(GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY, field) } - val collation = collationSpec.map(visitCollationSpec) val name: String = colName.getText - val dataType = (collation, typedVisit[DataType](ctx.dataType)) match { - case (None, _) => typedVisit[DataType](ctx.dataType) - case (Some(collation), StringType) => - val collationId = CollatorFactory.getInstance().collationNameToId(collation) - StringType(collationId) - case (Some(collation), dataType) => - throw QueryParsingErrors.invalidCollationSpecified(ctx, dataType.catalogString, collation) - } - StructField( name = name, - dataType = dataType, + dataType = typedVisit[DataType](ctx.dataType), nullable = nullable, metadata = builder.build()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 31c1dd8c1c7f8..1437f934aa2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -298,7 +298,6 @@ class CollationSuite extends QueryTest } test("create table support") { - // TODO: Filter pushdown and partitioning are todos. val tableName = "parquet_dummy_t" withTable(tableName) { sql(s"CREATE TABLE IF NOT EXISTS $tableName (c1 STRING COLLATE 'SR_CI_AI') USING PARQUET") @@ -309,6 +308,23 @@ class CollationSuite extends QueryTest } } + test("create table with nested collations in struct") { + val tableName = "nested_collation_tbl" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName + |(c1 STRUCT) + |USING PARQUET + |""".stripMargin) + sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'aaa', 'id', 1))") + sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'AAA', 'id', 2))") + + checkAnswer(sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"), Seq(Row("SR_CI_AI"))) + checkAnswer(sql(s"SELECT COUNT(DISTINCT c1.name) FROM $tableName"), Seq(Row(1))) + } + } + test("disable partition on collated string column") { def createTable(partitionColumns: String*): Unit = { val tableName = "test_partition"