From 17d0d7f4712be2a1a8434b978445ce25a5c7f7c9 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 24 Apr 2024 16:12:20 +0800 Subject: [PATCH] [SPARK-47692][SQL] Fix default StringType meaning in implicit casting ### What changes were proposed in this pull request? Addition of priority flag to StringType. ### Why are the changes needed? In order to follow casting rules for collations, we need to know whether StringType is considered default, implicit or explicit. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Implicit tests in CollationSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45819 from mihailom-db/SPARK-47692. Authored-by: Mihailo Milosevic Signed-off-by: Wenchen Fan --- .../internal/types/AbstractStringType.scala | 3 +- .../analysis/CollationTypeCasts.scala | 30 ++++++--- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +- .../org/apache/spark/sql/CollationSuite.scala | 66 ++++++++++++++++++- 4 files changed, 88 insertions(+), 16 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 6403295fe20c4..0828c2d6fc104 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.internal.types +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** * StringTypeCollated is an abstract class for StringType with collation support. */ abstract class AbstractStringType extends AbstractDataType { - override private[sql] def defaultConcreteType: DataType = StringType + override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType override private[sql] def simpleString: String = "string" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 3affd91dd3b82..c6232a870dff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Overlay, StringLPad, StringRPad} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay, StringLPad, StringRPad} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType} @@ -48,9 +48,9 @@ object CollationTypeCasts extends TypeCoercionRule { case eltExpr: Elt => eltExpr.withNewChildren(eltExpr.children.head +: collateToSingleType(eltExpr.children.tail)) - case overlay: Overlay => - overlay.withNewChildren(collateToSingleType(Seq(overlay.input, overlay.replace)) - ++ Seq(overlay.pos, overlay.len)) + case overlayExpr: Overlay => + overlayExpr.withNewChildren(collateToSingleType(Seq(overlayExpr.input, overlayExpr.replace)) + ++ Seq(overlayExpr.pos, overlayExpr.len)) case stringPadExpr @ (_: StringRPad | _: StringLPad) => val Seq(str, len, pad) = stringPadExpr.children @@ -108,7 +108,12 @@ object CollationTypeCasts extends TypeCoercionRule { * complex DataTypes with collated StringTypes (e.g. ArrayType) */ def getOutputCollation(expr: Seq[Expression]): StringType = { - val explicitTypes = expr.filter(_.isInstanceOf[Collate]) + val explicitTypes = expr.filter { + case _: Collate => true + case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined => + cast.dataType.isInstanceOf[StringType] + case _ => false + } .map(_.dataType.asInstanceOf[StringType].collationId) .distinct @@ -123,17 +128,22 @@ object CollationTypeCasts extends TypeCoercionRule { ) // Only implicit or default collations present case 0 => - val implicitTypes = expr.map(_.dataType) + val implicitTypes = expr.filter { + case Literal(_, _: StringType) => false + case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty => + cast.child.dataType.isInstanceOf[StringType] + case _ => true + } + .map(_.dataType) .filter(hasStringType) - .map(extractStringType) - .filter(dt => dt.collationId != SQLConf.get.defaultStringType.collationId) - .distinctBy(_.collationId) + .map(extractStringType(_).collationId) + .distinct if (implicitTypes.length > 1) { throw QueryCompilationErrors.implicitCollationMismatchError() } else { - implicitTypes.headOption.getOrElse(SQLConf.get.defaultStringType) + implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 259e28b62bca7..506314effde33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -998,9 +998,10 @@ object TypeCoercion extends TypeCoercionBase { case (_: StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (_: StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => StringType + case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] => st case (any: AtomicType, st: AbstractStringType) - if !any.isInstanceOf[StringType] => st.defaultConcreteType + if !any.isInstanceOf[StringType] => + st.defaultConcreteType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9aad96c696ead..26f7726c39642 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -412,7 +413,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("implicit casting of collated strings") { + test("SPARK-47210: Implicit casting of collated strings") { val tableName = "parquet_dummy_implicit_cast_t22" withTable(tableName) { spark.sql( @@ -566,7 +567,66 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("cast of default collated strings in IN expression") { + test("SPARK-47692: Parameter marker with EXECUTE IMMEDIATE implicit casting") { + sql(s"DECLARE stmtStr1 = 'SELECT collation(:var1 || :var2)';") + sql(s"DECLARE stmtStr2 = 'SELECT collation(:var1 || (\\\'a\\\' COLLATE UNICODE))';") + + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr1 USING + | 'a' AS var1, + | 'b' AS var2;""".stripMargin), + Seq(Row("UTF8_BINARY")) + ) + + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr1 USING + | 'a' AS var1, + | 'b' AS var2;""".stripMargin), + Seq(Row("UNICODE")) + ) + } + + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr2 USING + | 'a' AS var1;""".stripMargin), + Seq(Row("UNICODE")) + ) + + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr2 USING + | 'a' AS var1;""".stripMargin), + Seq(Row("UNICODE")) + ) + } + } + + test("SPARK-47692: Parameter markers with variable mapping") { + checkAnswer( + spark.sql( + "SELECT collation(:var1 || :var2)", + Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")), + "var2" -> Literal.create('b', StringType("UNICODE")))), + Seq(Row("UTF8_BINARY")) + ) + + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + checkAnswer( + spark.sql( + "SELECT collation(:var1 || :var2)", + Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")), + "var2" -> Literal.create('b', StringType("UNICODE")))), + Seq(Row("UNICODE")) + ) + } + } + + test("SPARK-47210: Cast of default collated strings in IN expression") { val tableName = "t1" withTable(tableName) { spark.sql( @@ -591,7 +651,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } // TODO(SPARK-47210): Add indeterminate support - test("indeterminate collation checks") { + test("SPARK-47210: Indeterminate collation checks") { val tableName = "t1" val newTableName = "t2" withTable(tableName) {