From b34544a04f59d02dd41cab27995406c7485aa02d Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 5 Mar 2024 08:54:23 +0100 Subject: [PATCH 01/87] Implicit casting on collated expressions --- .../main/resources/error/error-classes.json | 24 +++ .../apache/spark/sql/types/StringType.scala | 8 + .../catalyst/analysis/AnsiTypeCoercion.scala | 25 +-- .../sql/catalyst/analysis/TypeCoercion.scala | 148 +++++++++++++++--- .../spark/sql/catalyst/expressions/Cast.scala | 58 +++---- .../expressions/collectionOperations.scala | 4 +- .../sql/errors/QueryCompilationErrors.scala | 24 +++ .../org/apache/spark/sql/CollationSuite.scala | 118 ++++++++++++++ 8 files changed, 346 insertions(+), 63 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 7cf3e9c533ca8..e53c3b254bfbd 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -475,6 +475,24 @@ ], "sqlState" : "42704" }, + "COLLATION_MISMATCH" : { + "message" : [ + "Could not determine which collation to use for string comparison." + ], + "subClass" : { + "EXPLICIT" : { + "message" : [ + "Error occurred due to the mismatch between explicit collations \"\" and \"\"" + ] + }, + "IMPLICIT" : { + "message" : [ + "Error occurred due to the mismatch between multiple implicit collations. Use COLLATE function to set the collation explicitly." + ] + } + }, + "sqlState" : "42P21" + }, "COLLECTION_SIZE_LIMIT_EXCEEDED" : { "message" : [ "Can't create array with elements which exceeding the array size limit ," @@ -1574,6 +1592,12 @@ ], "sqlState" : "22003" }, + "INDETERMINATE_COLLATION" : { + "message" : [ + "Function called requires knowledge of the collation it should apply, but indeterminate collation was found. Use COLLATE function to set the collation explicitly." + ], + "sqlState" : "42P22" + }, "INDEX_ALREADY_EXISTS" : { "message" : [ "Cannot create the index on table because it already exists." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 313f525742ae9..79966adf5e43b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -41,6 +41,12 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa */ def isBinaryCollation: Boolean = CollationFactory.fetchCollation(collationId).isBinaryCollation + /** + * Returns whether the collation is indeterminate. An indeterminate collation is + * a result of combination of conflicting non-default implicit collations. + */ + def isIndeterminateCollation: Boolean = collationId == StringType.INDETERMINATE_COLLATION_ID + /** * Type name that is shown to the customer. * If this is an UCS_BASIC collation output is `string` due to backwards compatibility. @@ -69,5 +75,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa */ @Stable case object StringType extends StringType(0) { + val DEFAULT_COLLATION_ID = 0 + val INDETERMINATE_COLLATION_ID = -1 def apply(collationId: Int): StringType = new StringType(collationId) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 8857f0b5a25ec..522dc2d5c829f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -90,6 +90,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { Division :: IntegralDivision :: ImplicitTypeCasts :: + CollationTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: GetDateFieldOperations:: Nil) :: Nil @@ -138,15 +139,16 @@ object AnsiTypeCoercion extends TypeCoercionBase { @scala.annotation.tailrec private def findWiderTypeForString(dt1: DataType, dt2: DataType): Option[DataType] = { (dt1, dt2) match { - case (StringType, _: IntegralType) => Some(LongType) - case (StringType, _: FractionalType) => Some(DoubleType) - case (StringType, NullType) => Some(StringType) + case (_: StringType, _: IntegralType) => Some(LongType) + case (_: StringType, _: FractionalType) => Some(DoubleType) + case (st: StringType, NullType) => Some(st) // If a binary operation contains interval type and string, we can't decide which // interval type the string should be promoted as. There are many possible interval // types, such as year interval, month interval, day interval, hour interval, etc. - case (StringType, _: AnsiIntervalType) => None - case (StringType, a: AtomicType) => Some(a) - case (other, StringType) if other != StringType => findWiderTypeForString(StringType, other) + case (_: StringType, _: AnsiIntervalType) => None + case (_: StringType, a: AtomicType) => Some(a) + case (other, st: StringType) if !other.isInstanceOf[StringType] => + findWiderTypeForString(st, other) case _ => None } } @@ -186,23 +188,26 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (NullType, target) if !target.isInstanceOf[TypeCollection] => Some(target.defaultConcreteType) + case (_: StringType, st: StringType) => + Some(st) + // This type coercion system will allow implicit converting String type as other // primitive types, in case of breaking too many existing Spark SQL queries. - case (StringType, a: AtomicType) => + case (_: StringType, a: AtomicType) => Some(a) // If the target type is any Numeric type, convert the String type as Double type. - case (StringType, NumericType) => + case (_: StringType, NumericType) => Some(DoubleType) // If the target type is any Decimal type, convert the String type as the default // Decimal type. - case (StringType, DecimalType) => + case (_: StringType, DecimalType) => Some(DecimalType.SYSTEM_DEFAULT) // If the target type is any timestamp type, convert the String type as the default // Timestamp type. - case (StringType, AnyTimestampType) => + case (_: StringType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) case (DateType, AnyTimestampType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 56e8843fda537..a14af29ccaee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -389,6 +389,11 @@ abstract class TypeCoercionBase { i } + case i @ In(a, b) if CollationTypeCasts.shouldCast(a.dataType +: b.map(_.dataType)) => + // resolve collations between the children of IN expression + val newChildren = CollationTypeCasts.collateToSingleType(a +: b) + i.copy(newChildren.head, newChildren.tail) + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) @@ -611,7 +616,18 @@ abstract class TypeCoercionBase { val newChildren = c.children.map { e => implicitCast(e, StringType).getOrElse(e) } - c.copy(children = newChildren) + + val newNode = c.copy(children = newChildren) + if (CollationTypeCasts.shouldCast(children.map(_.dataType))) { + // if original children had different collations we need to + // cast the output to the expected collation + val collationId = CollationTypeCasts.getOutputCollation( + children, failOnIndeterminate = false) + Cast(newNode, StringType(collationId)) + } + else { + newNode + } } } @@ -764,6 +780,91 @@ abstract class TypeCoercionBase { } } + object CollationTypeCasts extends TypeCoercionRule { + override def transform: PartialFunction[Expression, Expression] = { + case e if !e.childrenResolved => e + + case b @ BinaryComparison(left, right) if shouldCast(Seq(left.dataType, right.dataType)) => + val newChildren = collateToSingleType(Seq(left, right)) + b.withNewChildren(newChildren) + } + + def shouldCast(types: Seq[DataType]): Boolean = { + types.forall(_.isInstanceOf[StringType]) && types.distinct.length > 1 + } + + /** + * Collates the input expression to a single collation. + */ + def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { + val collationId = getOutputCollation(exprs) + + exprs.map { expression => + expression.dataType match { + case st: StringType if st.collationId == collationId => + expression + case _: StringType => + Cast(expression, StringType(collationId)) + } + } + } + + /** + * Based on the data types of the input expressions this method determines + * a collation type which the output will be. + */ + def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { + val explicitTypes = exprs.filter(hasExplicitCollation).map(_.dataType).distinct + + explicitTypes.size match { + case 1 => explicitTypes.head.asInstanceOf[StringType].collationId + case size if size > 1 => throw QueryCompilationErrors.explicitCollationMismatchError( + explicitTypes.head.simpleString, explicitTypes.tail.head.simpleString) + case _ => + val dataTypes = exprs.map(_.dataType.asInstanceOf[StringType]) + + if (isIndeterminate(dataTypes)) { + if (failOnIndeterminate) { + throw QueryCompilationErrors.indeterminateCollationError() + } else { + StringType.INDETERMINATE_COLLATION_ID + } + } + else if (hasMultipleImplicits(dataTypes)) { + if (failOnIndeterminate) { + throw QueryCompilationErrors.implicitCollationMismatchError() + } else { + StringType.INDETERMINATE_COLLATION_ID + } + } + else { + dataTypes.find(!_.isDefaultCollation) + .getOrElse(StringType) + .collationId + } + } + } + + private def isIndeterminate(dataTypes: Seq[StringType]): Boolean = + dataTypes.exists(_.isIndeterminateCollation) + + + private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = + dataTypes.filter(!_.isDefaultCollation).distinct.size > 1 + + private def hasExplicitCollation(expression: Expression): Boolean = { + if (!expression.dataType.isInstanceOf[StringType]) { + false + } + else { + expression match { + case _: Collate => true + case _ => expression.children.exists(hasExplicitCollation) + } + } + } + } + /** * Cast WindowFrame boundaries to the type they operate upon. */ @@ -850,6 +951,7 @@ object TypeCoercion extends TypeCoercionBase { StackCoercion :: Division :: IntegralDivision :: + CollationTypeCasts :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -885,8 +987,8 @@ object TypeCoercion extends TypeCoercionBase { /** Promotes all the way to StringType. */ private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { - case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) - case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case (st: StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(st) + case (t1: AtomicType, st: StringType) if t1 != BinaryType && t1 != BooleanType => Some(st) case _ => None } @@ -909,16 +1011,16 @@ object TypeCoercion extends TypeCoercionBase { */ def findCommonTypeForBinaryComparison( dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { - case (StringType, DateType) - => if (conf.castDatetimeToString) Some(StringType) else Some(DateType) - case (DateType, StringType) - => if (conf.castDatetimeToString) Some(StringType) else Some(DateType) - case (StringType, TimestampType) - => if (conf.castDatetimeToString) Some(StringType) else Some(TimestampType) - case (TimestampType, StringType) - => if (conf.castDatetimeToString) Some(StringType) else Some(TimestampType) - case (StringType, NullType) => Some(StringType) - case (NullType, StringType) => Some(StringType) + case (st: StringType, DateType) + => if (conf.castDatetimeToString) Some(st) else Some(DateType) + case (DateType, st: StringType) + => if (conf.castDatetimeToString) Some(st) else Some(DateType) + case (st: StringType, TimestampType) + => if (conf.castDatetimeToString) Some(st) else Some(TimestampType) + case (TimestampType, st: StringType) + => if (conf.castDatetimeToString) Some(st) else Some(TimestampType) + case (st: StringType, NullType) => Some(st) + case (NullType, st: StringType) => Some(st) // Cast to TimestampType when we compare DateType with TimestampType // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true @@ -958,7 +1060,8 @@ object TypeCoercion extends TypeCoercionBase { override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { implicitCast(e.dataType, expectedType).map { dt => - if (dt == e.dataType) e else Cast(e, dt) + if (dt == e.dataType) { e } + else { Cast(e, dt) } } } @@ -966,6 +1069,7 @@ object TypeCoercion extends TypeCoercionBase { // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. // We wrap immediately an Option after this. @Nullable val ret: DataType = (inType, expectedType) match { + case (_: StringType, st2: StringType) => st2 // If the expected type is already a parent of the input type, no need to cast. case _ if expectedType.acceptsType(inType) => inType @@ -974,7 +1078,7 @@ object TypeCoercion extends TypeCoercionBase { // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double - case (StringType, NumericType) => NumericType.defaultConcreteType + case (_: StringType, NumericType) => NumericType.defaultConcreteType // Implicit cast among numeric types. When we reach here, input type is not acceptable. @@ -989,13 +1093,13 @@ object TypeCoercion extends TypeCoercionBase { case (_: DatetimeType, AnyTimestampType) => AnyTimestampType.defaultConcreteType // Implicit cast from/to string - case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT - case (StringType, target: NumericType) => target - case (StringType, datetime: DatetimeType) => datetime - case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType - case (StringType, BinaryType) => BinaryType + case (_: StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT + case (_: StringType, target: NumericType) => target + case (_: StringType, datetime: DatetimeType) => datetime + case (_: StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType + case (_: StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => StringType + case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] => st // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. @@ -1073,7 +1177,7 @@ object TypeCoercion extends TypeCoercionBase { */ @tailrec def hasStringType(dt: DataType): Boolean = dt match { - case StringType => true + case _: StringType => true case ArrayType(et, _) => hasStringType(et) // Add StructType if we support string promotion for struct fields in the future. case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 3f14f1458433d..0bfdf7386efa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -95,22 +95,22 @@ object Cast extends QueryErrorsBase { case (_, _: StringType) => true - case (StringType, _: BinaryType) => true + case (_: StringType, _: BinaryType) => true - case (StringType, BooleanType) => true + case (_: StringType, BooleanType) => true case (_: NumericType, BooleanType) => true - case (StringType, TimestampType) => true + case (_: StringType, TimestampType) => true case (DateType, TimestampType) => true case (TimestampNTZType, TimestampType) => true case (_: NumericType, TimestampType) => true - case (StringType, TimestampNTZType) => true + case (_: StringType, TimestampNTZType) => true case (DateType, TimestampNTZType) => true case (TimestampType, TimestampNTZType) => true - case (StringType, _: CalendarIntervalType) => true - case (StringType, _: AnsiIntervalType) => true + case (_: StringType, _: CalendarIntervalType) => true + case (_: StringType, _: AnsiIntervalType) => true case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true @@ -118,12 +118,12 @@ object Cast extends QueryErrorsBase { case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true - case (StringType, DateType) => true + case (_: StringType, DateType) => true case (TimestampType, DateType) => true case (TimestampNTZType, DateType) => true case (_: NumericType, _: NumericType) => true - case (StringType, _: NumericType) => true + case (_: StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true case (TimestampType, _: NumericType) => true @@ -192,33 +192,33 @@ object Cast extends QueryErrorsBase { case (NullType, _) => true - case (_, StringType) => true + case (_, _: StringType) => true - case (StringType, BinaryType) => true + case (_: StringType, BinaryType) => true case (_: IntegralType, BinaryType) => true - case (StringType, BooleanType) => true + case (_: StringType, BooleanType) => true case (DateType, BooleanType) => true case (TimestampType, BooleanType) => true case (_: NumericType, BooleanType) => true - case (StringType, TimestampType) => true + case (_: StringType, TimestampType) => true case (BooleanType, TimestampType) => true case (DateType, TimestampType) => true case (_: NumericType, TimestampType) => true case (TimestampNTZType, TimestampType) => true - case (StringType, TimestampNTZType) => true + case (_: StringType, TimestampNTZType) => true case (DateType, TimestampNTZType) => true case (TimestampType, TimestampNTZType) => true - case (StringType, DateType) => true + case (_: StringType, DateType) => true case (TimestampType, DateType) => true case (TimestampNTZType, DateType) => true - case (StringType, CalendarIntervalType) => true - case (StringType, _: DayTimeIntervalType) => true - case (StringType, _: YearMonthIntervalType) => true + case (_: StringType, CalendarIntervalType) => true + case (_: StringType, _: DayTimeIntervalType) => true + case (_: StringType, _: YearMonthIntervalType) => true case (_: IntegralType, DayTimeIntervalType(s, e)) if s == e => true case (_: IntegralType, YearMonthIntervalType(s, e)) if s == e => true @@ -227,7 +227,7 @@ object Cast extends QueryErrorsBase { case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true - case (StringType, _: NumericType) => true + case (_: StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true case (DateType, _: NumericType) => true case (TimestampType, _: NumericType) => true @@ -341,9 +341,9 @@ object Cast extends QueryErrorsBase { case (NullType, _) => false // empty array or map case case (_, _) if from == to => false - case (StringType, BinaryType) => false - case (StringType, _) => true - case (_, StringType) => false + case (_: StringType, BinaryType) => false + case (_: StringType, _) => true + case (_, _: StringType) => false case (TimestampType, ByteType | ShortType | IntegerType) => true case (FloatType | DoubleType, TimestampType) => true @@ -522,7 +522,7 @@ case class Cast( child.nullable || Cast.forceNullable(child.dataType, dataType) } else { (child.dataType, dataType) match { - case (StringType, BinaryType) => child.nullable + case (_: StringType, BinaryType) => child.nullable // TODO: Implement a more accurate method for checking whether a decimal value can be cast // as integral types without overflow. Currently, the cast can overflow even if // "Cast.canUpCast" method returns true. @@ -870,9 +870,9 @@ case class Cast( // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, getContextOrNull())) - case StringType => + case _: StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toByte(result)) { result.value.toByte @@ -1158,8 +1158,8 @@ case class Cast( false } else { (child.dataType, dataType) match { - case (StringType, _: FractionalType) => true - case (StringType, _: DatetimeType) => true + case (_: StringType, _: FractionalType) => true + case (_: StringType, _: DatetimeType) => true case _ => false } } @@ -1251,7 +1251,7 @@ case class Cast( } private[this] def castToBinaryCode(from: DataType): CastFunction = from match { - case StringType => + case _: StringType => (c, evPrim, evNull) => code"$evPrim = $c.getBytes();" case _: IntegralType => @@ -1775,11 +1775,11 @@ case class Cast( } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toByteExact($c, $errorContext);" - case StringType => + case _: StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => code""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a090bdf2bebf6..4cf45b1b0cb00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2643,7 +2643,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) ByteArray.concat(inputs: _*) } - case StringType => + case _: StringType => input => { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) UTF8String.concat(inputs: _*) @@ -2714,7 +2714,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val (concat, initCode) = dataType match { case BinaryType => (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];") - case StringType => + case _: StringType => ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, containsNull) => val concat = genCodeForArrays(ctx, elementType, containsNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 38f2228f33892..c9abcb2997bc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3605,6 +3605,30 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def implicitCollationMismatchError(): Throwable = { + new AnalysisException( + errorClass = "COLLATION_MISMATCH.IMPLICIT", + messageParameters = Map.empty + ) + } + + def explicitCollationMismatchError(left: String, right: String): Throwable = { + new AnalysisException( + errorClass = "COLLATION_MISMATCH.EXPLICIT", + messageParameters = Map( + "left" -> left, + "right" -> right + ) + ) + } + + def indeterminateCollationError(): Throwable = { + new AnalysisException( + errorClass = "INDETERMINATE_COLLATION", + messageParameters = Map.empty + ) + } + def cannotConvertProtobufTypeToSqlTypeError( protobufColumn: String, sqlColumn: Seq[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 2ff50722c3da9..2fc13b7494bf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -518,6 +518,124 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("implicit cast of default collated strings") { + val tableName = "parquet_dummy_implicit_cast_t22" + withTable(tableName) { + spark.sql( + s""" + | CREATE TABLE $tableName(c1 STRING COLLATE 'UCS_BASIC_LCASE', + | c2 STRING COLLATE 'UNICODE', c3 STRING COLLATE 'UNICODE_CI', c4 STRING) + | USING PARQUET + |""".stripMargin) + sql(s"INSERT INTO $tableName VALUES ('a', 'a', 'a', 'a')") + sql(s"INSERT INTO $tableName VALUES ('A', 'A', 'A', 'A')") + + // collate literal to c1's collation + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = 'a'"), + Seq(Row("a"), Row("A"))) + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE 'a' = c1"), + Seq(Row("a"), Row("A"))) + + // collate c1 to UCS_BASIC because it is explicitly set + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = COLLATE('a', 'UCS_BASIC')"), + Seq(Row("a"))) + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = SUBSTR(COLLATE('a', 'UCS_BASIC'), 0)"), + Seq(Row("a"))) + + // in operator + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('a')"), + Seq(Row("a"), Row("A"))) + // explicitly set collation inside IN operator + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('b', COLLATE('a', 'UCS_BASIC'))"), + Seq(Row("a"))) + + // concat should not change collation + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || 'a' || 'a' = 'aaa'"), + Seq(Row("a"), Row("A"))) + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE(c2, 'UCS_BASIC') = 'aa'"), + Seq(Row("a"))) + + // concat of columns of different collations is allowed + // as long as we don't use binary comparison on the result + sql(s"SELECT c1 || c3 from $tableName") + + // concat + in + checkAnswer(sql(s"SELECT c1 FROM $tableName where c1 || 'a' " + + s"IN (COLLATE('aa', 'UCS_BASIC_LCASE'))"), Seq(Row("a"), Row("A"))) + checkAnswer(sql(s"SELECT c1 FROM $tableName where (c1 || 'a') " + + s"IN (COLLATE('aa', 'UCS_BASIC'))"), Seq(Row("a"))) + + // columns have different collation + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 FROM $tableName WHERE c1 = c3") + }, + errorClass = "COLLATION_MISMATCH.IMPLICIT" + ) + + // different explicit collations are set + checkError( + exception = intercept[AnalysisException] { + sql( + s""" + |SELECT c1 FROM $tableName + |WHERE COLLATE('a', 'UCS_BASIC') = COLLATE('a', 'UNICODE')""" + .stripMargin) + }, + errorClass = "COLLATION_MISMATCH.EXPLICIT", + parameters = Map( + "left" -> "string", + "right" -> "string COLLATE 'UNICODE'" + ) + ) + + // in operator has different collations + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 FROM $tableName WHERE c1 IN " + + "(COLLATE('a', 'UCS_BASIC'), COLLATE('b', 'UNICODE'))") + }, + errorClass = "COLLATION_MISMATCH.EXPLICIT", + parameters = Map( + "left" -> "string", + "right" -> "string COLLATE 'UNICODE'" + ) + ) + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 FROM $tableName WHERE COLLATE(c1, 'UNICODE') IN " + + "(COLLATE('a', 'UCS_BASIC'))") + }, + errorClass = "COLLATION_MISMATCH.EXPLICIT", + parameters = Map( + "left" -> "string COLLATE 'UNICODE'", + "right" -> "string" + ) + ) + + // concat on different implicit collations should fail + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 FROM $tableName WHERE c1 || c3 = 'aa'") + }, + errorClass = "INDETERMINATE_COLLATION" + ) + + // concat + in + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a', 'UCS_BASIC') IN " + + s"(COLLATE('a', 'UNICODE'))") + }, + errorClass = "COLLATION_MISMATCH.EXPLICIT", + parameters = Map( + "left" -> "string", + "right" -> "string COLLATE 'UNICODE'" + ) + ) + } + } + test("create v2 table with collation column") { val tableName = "testcat.table_name" val collationName = "UCS_BASIC_LCASE" From fdbfa44150e2aa4068bf48f45349a47f03a67d26 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 5 Mar 2024 10:58:26 +0100 Subject: [PATCH 02/87] Fix doc files --- ...nditions-collation-mismatch-error-class.md | 41 +++++++++++++++++++ docs/sql-error-conditions.md | 14 +++++++ 2 files changed, 55 insertions(+) create mode 100644 docs/sql-error-conditions-collation-mismatch-error-class.md diff --git a/docs/sql-error-conditions-collation-mismatch-error-class.md b/docs/sql-error-conditions-collation-mismatch-error-class.md new file mode 100644 index 0000000000000..91f4f5dbbc800 --- /dev/null +++ b/docs/sql-error-conditions-collation-mismatch-error-class.md @@ -0,0 +1,41 @@ +--- +layout: global +title: COLLATION_MISMATCH error class +displayTitle: COLLATION_MISMATCH error class +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + + + +[SQLSTATE: 42P21](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Could not determine which collation to use for string comparison. + +This error class has the following derived error classes: + +## EXPLICIT + +Error occurred due to the mismatch between explicit collations "``" and "``" + +## IMPLICIT + +Error occurred due to the mismatch between multiple implicit collations. Use COLLATE function to set the collation explicitly. + + diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 7be01f8cb513d..b28ebf27ab267 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -398,6 +398,14 @@ Cannot find a short name for the codec ``. The value `` does not represent a correct collation name. Suggested valid collation name: [``]. +### [COLLATION_MISMATCH](sql-error-conditions-collation-mismatch-error-class.html) + +[SQLSTATE: 42P21](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Could not determine which collation to use for string comparison. + +For more details see [COLLATION_MISMATCH](sql-error-conditions-collation-mismatch-error-class.html) + ### [COLLECTION_SIZE_LIMIT_EXCEEDED](sql-error-conditions-collection-size-limit-exceeded-error-class.html) [SQLSTATE: 54000](sql-error-conditions-sqlstates.html#class-54-program-limit-exceeded) @@ -945,6 +953,12 @@ For more details see [INCONSISTENT_BEHAVIOR_CROSS_VERSION](sql-error-conditions- Max offset with `` rowsPerSecond is ``, but 'rampUpTimeSeconds' is ``. +### INDETERMINATE_COLLATION + +[SQLSTATE: 42P22](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Function called requires knowledge of the collation it should apply, but indeterminate collation was found. Use COLLATE function to set the collation explicitly. + ### INDEX_ALREADY_EXISTS [SQLSTATE: 42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) From ce9b027b691ea1489901bb75c8f0708bdd1af90c Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 5 Mar 2024 14:51:31 +0100 Subject: [PATCH 03/87] Fix contains, startWith, endWith tests --- .../sql/catalyst/util/CollationFactory.java | 1 + .../main/resources/error/error-classes.json | 5 -- .../apache/spark/sql/types/StringType.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 11 +++-- .../expressions/stringExpressions.scala | 18 ++----- .../org/apache/spark/sql/CollationSuite.scala | 48 +++++++------------ 6 files changed, 31 insertions(+), 56 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index c0c011926be9c..96655265b1126 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -105,6 +105,7 @@ public Collation( private static final Collation[] collationTable = new Collation[4]; private static final HashMap collationNameToIdMap = new HashMap<>(); + public static final int INDETERMINATE_COLLATION_ID = -1; public static final int DEFAULT_COLLATION_ID = 0; public static final int LOWERCASE_COLLATION_ID = 1; diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index e53c3b254bfbd..66414812bb272 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -714,11 +714,6 @@ "To convert values from to , you can use the functions instead." ] }, - "COLLATION_MISMATCH" : { - "message" : [ - "Collations and are not compatible. Please use the same collation for both strings." - ] - }, "CREATE_MAP_KEY_DIFF_TYPES" : { "message" : [ "The given keys of function should all be the same type, but they are ." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 79966adf5e43b..443d040e7289d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -45,7 +45,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa * Returns whether the collation is indeterminate. An indeterminate collation is * a result of combination of conflicting non-default implicit collations. */ - def isIndeterminateCollation: Boolean = collationId == StringType.INDETERMINATE_COLLATION_ID + def isIndeterminateCollation: Boolean = collationId == CollationFactory.INDETERMINATE_COLLATION_ID /** * Type name that is shown to the customer. @@ -75,7 +75,5 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa */ @Stable case object StringType extends StringType(0) { - val DEFAULT_COLLATION_ID = 0 - val INDETERMINATE_COLLATION_ID = -1 def apply(collationId: Int): StringType = new StringType(collationId) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a14af29ccaee7..a330340f71763 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec import scala.collection.mutable @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -718,9 +718,10 @@ abstract class TypeCoercionBase { }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => - val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { + case (expr: Expression, st2: StringType) if expr.dataType.isInstanceOf[StringType] => expr // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + case (in, expected) => implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -827,14 +828,14 @@ abstract class TypeCoercionBase { if (failOnIndeterminate) { throw QueryCompilationErrors.indeterminateCollationError() } else { - StringType.INDETERMINATE_COLLATION_ID + CollationFactory.INDETERMINATE_COLLATION_ID } } else if (hasMultipleImplicits(dataTypes)) { if (failOnIndeterminate) { throw QueryCompilationErrors.implicitCollationMismatchError() } else { - StringType.INDETERMINATE_COLLATION_ID + CollationFactory.INDETERMINATE_COLLATION_ID } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e6114ca277cad..a02057eedcc9d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{QueryContext, SparkException} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -509,18 +509,10 @@ abstract class StringPredicate extends BinaryExpression return checkResult } // Additional check needed for collation compatibility - val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId - if (collationId != rightCollationId) { - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName, - "collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName - ) - ) - } else { - TypeCheckResult.TypeCheckSuccess - } + val outputCollationId: Int = TypeCoercion + .CollationTypeCasts + .getOutputCollation(Seq(left, right)) + TypeCheckResult.TypeCheckSuccess } protected override def nullSafeEval(input1: Any, input2: Any): Any = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 2fc13b7494bf4..2aadf9a305de9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -193,56 +193,44 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // contains right = left.substring(1, 2); checkError( - exception = intercept[ExtendedAnalysisException] { + exception = intercept[AnalysisException] { spark.sql(s"SELECT contains(collate('$left', '$leftCollationName')," + s"collate('$right', '$rightCollationName'))") }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", + errorClass = "COLLATION_MISMATCH.EXPLICIT", + sqlState = "42P21", parameters = Map( - "collationNameLeft" -> s"$leftCollationName", - "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"contains(collate(abc), collate(b))\"" - ), - context = ExpectedContext(fragment = - s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))", - start = 7, stop = 68) + "left" -> s"string COLLATE '$leftCollationName'", + "right" -> s"string COLLATE '$rightCollationName'" + ) ) // startsWith right = left.substring(0, 1); checkError( - exception = intercept[ExtendedAnalysisException] { + exception = intercept[AnalysisException] { spark.sql(s"SELECT startsWith(collate('$left', '$leftCollationName')," + s"collate('$right', '$rightCollationName'))") }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", + errorClass = "COLLATION_MISMATCH.EXPLICIT", + sqlState = "42P21", parameters = Map( - "collationNameLeft" -> s"$leftCollationName", - "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"startswith(collate(abc), collate(a))\"" - ), - context = ExpectedContext(fragment = - s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))", - start = 7, stop = 70) + "left" -> s"string COLLATE '$leftCollationName'", + "right" -> s"string COLLATE '$rightCollationName'" + ) ) // endsWith right = left.substring(2, 3); checkError( - exception = intercept[ExtendedAnalysisException] { + exception = intercept[AnalysisException] { spark.sql(s"SELECT endsWith(collate('$left', '$leftCollationName')," + s"collate('$right', '$rightCollationName'))") }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", + errorClass = "COLLATION_MISMATCH.EXPLICIT", + sqlState = "42P21", parameters = Map( - "collationNameLeft" -> s"$leftCollationName", - "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"endswith(collate(abc), collate(c))\"" - ), - context = ExpectedContext(fragment = - s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))", - start = 7, stop = 68) + "left" -> s"string COLLATE '$leftCollationName'", + "right" -> s"string COLLATE '$rightCollationName'" + ) ) } From e537190c856eaedeaac115592a7760df0dbe7602 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 5 Mar 2024 15:06:02 +0100 Subject: [PATCH 04/87] Fix imports --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a330340f71763..6955ae9184f85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import scala.collection.mutable From b5a79c140b9fc5bca729b4d85610b9048e7362eb Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 6 Mar 2024 09:23:38 +0100 Subject: [PATCH 05/87] Fix docs and incorporate changes --- .../src/main/resources/error/error-classes.json | 4 ++-- ...r-conditions-collation-mismatch-error-class.md | 4 ++-- ...or-conditions-datatype-mismatch-error-class.md | 4 ---- .../sql/catalyst/analysis/TypeCoercion.scala | 15 +++++++-------- .../spark/sql/errors/QueryCompilationErrors.scala | 5 ++--- 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 66414812bb272..b5e3997237962 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -482,12 +482,12 @@ "subClass" : { "EXPLICIT" : { "message" : [ - "Error occurred due to the mismatch between explicit collations \"\" and \"\"" + "Error occurred due to the mismatch between explicit collations: " ] }, "IMPLICIT" : { "message" : [ - "Error occurred due to the mismatch between multiple implicit collations. Use COLLATE function to set the collation explicitly." + "Error occurred due to the mismatch between multiple implicit non-default collations. Use COLLATE function to set the collation explicitly." ] } }, diff --git a/docs/sql-error-conditions-collation-mismatch-error-class.md b/docs/sql-error-conditions-collation-mismatch-error-class.md index 91f4f5dbbc800..122196e1bb2cc 100644 --- a/docs/sql-error-conditions-collation-mismatch-error-class.md +++ b/docs/sql-error-conditions-collation-mismatch-error-class.md @@ -32,10 +32,10 @@ This error class has the following derived error classes: ## EXPLICIT -Error occurred due to the mismatch between explicit collations "``" and "``" +Error occurred due to the mismatch between explicit collations: `` ## IMPLICIT -Error occurred due to the mismatch between multiple implicit collations. Use COLLATE function to set the collation explicitly. +Error occurred due to the mismatch between multiple implicit non-default collations. Use COLLATE function to set the collation explicitly. diff --git a/docs/sql-error-conditions-datatype-mismatch-error-class.md b/docs/sql-error-conditions-datatype-mismatch-error-class.md index cd7feb9262f3a..1d18836ac9e77 100644 --- a/docs/sql-error-conditions-datatype-mismatch-error-class.md +++ b/docs/sql-error-conditions-datatype-mismatch-error-class.md @@ -76,10 +76,6 @@ If you have to cast `` to ``, you can set `` as `` to ``. To convert values from `` to ``, you can use the functions `` instead. -## COLLATION_MISMATCH - -Collations `` and `` are not compatible. Please use the same collation for both strings. - ## CREATE_MAP_KEY_DIFF_TYPES The given keys of function `` should all be the same type, but they are ``. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6955ae9184f85..c7e667c67ae09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -796,7 +796,7 @@ abstract class TypeCoercionBase { } /** - * Collates the input expression to a single collation. + * Collates the input expressions to a single collation. */ def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { val collationId = getOutputCollation(exprs) @@ -813,19 +813,19 @@ abstract class TypeCoercionBase { /** * Based on the data types of the input expressions this method determines - * a collation type which the output will be. + * a collation type which the output will have. */ def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { val explicitTypes = exprs.filter(hasExplicitCollation).map(_.dataType).distinct explicitTypes.size match { case 1 => explicitTypes.head.asInstanceOf[StringType].collationId - case size if size > 1 => throw QueryCompilationErrors.explicitCollationMismatchError( - explicitTypes.head.simpleString, explicitTypes.tail.head.simpleString) + case size if size > 1 => + throw QueryCompilationErrors.explicitCollationMismatchError(explicitTypes.map(t => t.toString)) case _ => val dataTypes = exprs.map(_.dataType.asInstanceOf[StringType]) - if (isIndeterminate(dataTypes)) { + if (hasIndeterminate(dataTypes)) { if (failOnIndeterminate) { throw QueryCompilationErrors.indeterminateCollationError() } else { @@ -847,7 +847,7 @@ abstract class TypeCoercionBase { } } - private def isIndeterminate(dataTypes: Seq[StringType]): Boolean = + private def hasIndeterminate(dataTypes: Seq[StringType]): Boolean = dataTypes.exists(_.isIndeterminateCollation) @@ -1062,8 +1062,7 @@ object TypeCoercion extends TypeCoercionBase { override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { implicitCast(e.dataType, expectedType).map { dt => - if (dt == e.dataType) { e } - else { Cast(e, dt) } + if (dt == e.dataType) e else Cast(e, dt) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index c9abcb2997bc7..991329cc7f5d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3612,12 +3612,11 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } - def explicitCollationMismatchError(left: String, right: String): Throwable = { + def explicitCollationMismatchError(explicitTypes: Seq[String]): Throwable = { new AnalysisException( errorClass = "COLLATION_MISMATCH.EXPLICIT", messageParameters = Map( - "left" -> left, - "right" -> right + "explicitTypes" -> toSQLId(explicitTypes) ) ) } From 8321d0c5fe67987e5f91c0f5934f067eb3f97583 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 6 Mar 2024 09:36:41 +0100 Subject: [PATCH 06/87] Fix tests in CollationSuite --- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +++- .../org/apache/spark/sql/CollationSuite.scala | 24 ++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c7e667c67ae09..617e6818d8323 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -821,7 +821,10 @@ abstract class TypeCoercionBase { explicitTypes.size match { case 1 => explicitTypes.head.asInstanceOf[StringType].collationId case size if size > 1 => - throw QueryCompilationErrors.explicitCollationMismatchError(explicitTypes.map(t => t.toString)) + throw QueryCompilationErrors + .explicitCollationMismatchError( + explicitTypes.map(t => t.asInstanceOf[StringType].typeName) + ) case _ => val dataTypes = exprs.map(_.dataType.asInstanceOf[StringType]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 2aadf9a305de9..2cd7922f1505e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -200,8 +200,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map( - "left" -> s"string COLLATE '$leftCollationName'", - "right" -> s"string COLLATE '$rightCollationName'" + "explicitTypes" -> + s"`string COLLATE '$leftCollationName'`.`string COLLATE '$rightCollationName'`" ) ) // startsWith @@ -214,8 +214,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map( - "left" -> s"string COLLATE '$leftCollationName'", - "right" -> s"string COLLATE '$rightCollationName'" + "explicitTypes" -> + s"`string COLLATE '$leftCollationName'`.`string COLLATE '$rightCollationName'`" ) ) // endsWith @@ -228,8 +228,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map( - "left" -> s"string COLLATE '$leftCollationName'", - "right" -> s"string COLLATE '$rightCollationName'" + "explicitTypes" -> + s"`string COLLATE '$leftCollationName'`.`string COLLATE '$rightCollationName'`" ) ) } @@ -572,8 +572,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "left" -> "string", - "right" -> "string COLLATE 'UNICODE'" + "explicitTypes" -> "`string`.`string COLLATE 'UNICODE'`" ) ) @@ -585,8 +584,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "left" -> "string", - "right" -> "string COLLATE 'UNICODE'" + "explicitTypes" -> "`string`.`string COLLATE 'UNICODE'`" ) ) checkError( @@ -596,8 +594,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "left" -> "string COLLATE 'UNICODE'", - "right" -> "string" + "explicitTypes" -> "`string COLLATE 'UNICODE'`.`string`" ) ) @@ -617,8 +614,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "left" -> "string", - "right" -> "string COLLATE 'UNICODE'" + "explicitTypes" -> "`string`.`string COLLATE 'UNICODE'`" ) ) } From d178233cb7228703560073fd7ae551bd52522491 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 7 Mar 2024 10:10:02 +0100 Subject: [PATCH 07/87] Add test and incorporate changes --- .../main/resources/error/error-classes.json | 2 +- .../org/apache/spark/sql/CollationSuite.scala | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index b5e3997237962..f0ced7478f986 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -482,7 +482,7 @@ "subClass" : { "EXPLICIT" : { "message" : [ - "Error occurred due to the mismatch between explicit collations: " + "Error occurred due to the mismatch between explicit collations: . Decide on a single explicit collation and remove others." ] }, "IMPLICIT" : { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 2cd7922f1505e..421e3fa5666da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -620,6 +620,30 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("cast of default collated string in IN expression") { + val tableName = "t1" + withTable(tableName) { + spark.sql( + s""" + | CREATE TABLE $tableName(ucs_basic STRING COLLATE 'UCS_BASIC', + | ucs_basic_lcase STRING COLLATE 'UCS_BASIC_LCASE') + | USING PARQUET + |""".stripMargin) + sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')") + sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')") + sql(s"INSERT INTO $tableName VALUES ('bbb', 'bbb')") + sql(s"INSERT INTO $tableName VALUES ('BBB', 'BBB')") + + checkAnswer(sql(s"SELECT * FROM $tableName " + + s"WHERE ucs_basic_lcase IN " + + s"('aaa' COLLATE 'UCS_BASIC_LCASE', 'bbb' collate 'UCS_BASIC_LCASE')"), + Seq(Row("aaa", "aaa"), Row("AAA", "AAA"), Row("bbb", "bbb"), Row("BBB", "BBB"))) + checkAnswer(sql(s"SELECT * FROM $tableName " + + s"WHERE ucs_basic_lcase IN ('aaa' COLLATE 'UCS_BASIC_LCASE', 'bbb')"), + Seq(Row("aaa", "aaa"), Row("AAA", "AAA"), Row("bbb", "bbb"), Row("BBB", "BBB"))) + } + } + test("create v2 table with collation column") { val tableName = "testcat.table_name" val collationName = "UCS_BASIC_LCASE" From a4b9be70f2472b41a9f71b98123be489d3aa9991 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 7 Mar 2024 11:34:13 +0100 Subject: [PATCH 08/87] Fix godlen files --- docs/sql-error-conditions-collation-mismatch-error-class.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-error-conditions-collation-mismatch-error-class.md b/docs/sql-error-conditions-collation-mismatch-error-class.md index 122196e1bb2cc..616aed2029759 100644 --- a/docs/sql-error-conditions-collation-mismatch-error-class.md +++ b/docs/sql-error-conditions-collation-mismatch-error-class.md @@ -32,7 +32,7 @@ This error class has the following derived error classes: ## EXPLICIT -Error occurred due to the mismatch between explicit collations: `` +Error occurred due to the mismatch between explicit collations: ``. Decide on a single explicit collation and remove others. ## IMPLICIT From a6e7662e2235b9d2123e204a5b7d8e6aa4b4bdf3 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 8 Mar 2024 14:34:53 +0100 Subject: [PATCH 09/87] Incorporate StringType in findWiderCommonType --- .../apache/spark/sql/types/StringType.scala | 1 + .../catalyst/analysis/AnsiTypeCoercion.scala | 19 +++-- .../sql/catalyst/analysis/TypeCoercion.scala | 73 +++++++++---------- 3 files changed, 49 insertions(+), 44 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 443d040e7289d..134827f7fb3d9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -53,6 +53,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa */ override def typeName: String = if (isDefaultCollation) "string" + else if (isIndeterminateCollation) s"string COLLATE 'INDETERMINATE_COLLATION'" else s"string COLLATE '${CollationFactory.fetchCollation(collationId).collationName}'" override def equals(obj: Any): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 522dc2d5c829f..7648b40b38ced 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -153,8 +154,17 @@ object AnsiTypeCoercion extends TypeCoercionBase { } } - override def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => + override def findWiderCommonType(exprs: Seq[Expression], + failOnIndeterminate: Boolean = false): Option[DataType] = { + val (stringTypes, nonStringTypes) = exprs.map(_.dataType).partition(hasStringType) + (if (stringTypes.distinct.size > 1) { + val collationId = CollationTypeCasts.getOutputCollation(exprs, failOnIndeterminate) + exprs.map(e => + if (e.exists(e => e.dataType.isInstanceOf[StringType])) { + Cast(e, StringType(collationId)) + } + else e) + } else exprs).map(_.dataType).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case Some(d) => findWiderTypeForTwo(d, c) case _ => None @@ -175,6 +185,8 @@ object AnsiTypeCoercion extends TypeCoercionBase { inType: DataType, expectedType: AbstractDataType): Option[DataType] = { (inType, expectedType) match { + case (_: StringType, st: StringType) => + Some(st) // If the expected type equals the input type, no need to cast. case _ if expectedType.acceptsType(inType) => Some(inType) @@ -188,9 +200,6 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (NullType, target) if !target.isInstanceOf[TypeCollection] => Some(target.defaultConcreteType) - case (_: StringType, st: StringType) => - Some(st) - // This type coercion system will allow implicit converting String type as other // primitive types, in case of breaking too many existing Spark SQL queries. case (_: StringType, a: AtomicType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 617e6818d8323..bb44afed4b1ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -65,7 +65,8 @@ abstract class TypeCoercionBase { * is larger than decimal, and yet decimal is more precise than double, but in * union we would cast the decimal into double. */ - def findWiderCommonType(types: Seq[DataType]): Option[DataType] + def findWiderCommonType(children: Seq[Expression], + failOnIndeterminate: Boolean = false): Option[DataType] /** * Given an expected data type, try to cast the expression and return the cast expression. @@ -320,7 +321,7 @@ abstract class TypeCoercionBase { if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - val widenTypeOpt = findWiderCommonType(children.map(_.output(attrIndex).dataType)) + val widenTypeOpt = findWiderCommonType(children.map(_.output(attrIndex))) castedTypes.enqueue(widenTypeOpt) getWidestTypes(children, attrIndex + 1, castedTypes) } @@ -390,13 +391,8 @@ abstract class TypeCoercionBase { i } - case i @ In(a, b) if CollationTypeCasts.shouldCast(a.dataType +: b.map(_.dataType)) => - // resolve collations between the children of IN expression - val newChildren = CollationTypeCasts.collateToSingleType(a +: b) - i.copy(newChildren.head, newChildren.tail) - case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { + findWiderCommonType(i.children, failOnIndeterminate = true) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -413,16 +409,14 @@ abstract class TypeCoercionBase { case e if !e.childrenResolved => e case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) => - val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(children) match { case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType))) case None => a } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(c.inputTypesForMerging) => - val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(children) match { case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) case None => c } @@ -437,30 +431,26 @@ abstract class TypeCoercionBase { case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren.map(_.dataType)) => - val types = s.coercibleChildren.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(s.coercibleChildren) match { case Some(widerDataType) => s.castChildrenTo(widerDataType) case None => s } case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && !haveSameType(m.inputTypesForMerging) => - val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(children) match { case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) case None => m } case m @ CreateMap(children, _) if m.keys.length == m.values.length && (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => - val keyTypes = m.keys.map(_.dataType) - val newKeys = findWiderCommonType(keyTypes) match { + val newKeys = findWiderCommonType(m.keys) match { case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) case None => m.keys } - val valueTypes = m.values.map(_.dataType) - val newValues = findWiderCommonType(valueTypes) match { + val newValues = findWiderCommonType(m.values) match { case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) case None => m.values } @@ -475,8 +465,7 @@ abstract class TypeCoercionBase { // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => - val types = es.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(es) match { case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) case None => @@ -554,7 +543,8 @@ abstract class TypeCoercionBase { object CaseWhenCoercion extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => - val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) + val maybeCommonType = findWiderCommonType( + c.branches.map(_._2) ++ c.elseValue) maybeCommonType.map { commonType => val newBranches = c.branches.map { case (condition, value) => (condition, castIfNotSameType(value, commonType)) @@ -573,7 +563,7 @@ abstract class TypeCoercionBase { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => - findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => + findWiderCommonType(Seq(left, right)).map { widestType => val newLeft = castIfNotSameType(left, widestType) val newRight = castIfNotSameType(right, widestType) If(pred, newLeft, newRight) @@ -615,20 +605,20 @@ abstract class TypeCoercionBase { case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => val newChildren = c.children.map { e => - implicitCast(e, StringType).getOrElse(e) + if (e.dataType.isInstanceOf[StringType]) e + else implicitCast(e, StringType).getOrElse(e) } - val newNode = c.copy(children = newChildren) - if (CollationTypeCasts.shouldCast(children.map(_.dataType))) { + val collationId = if (CollationTypeCasts.shouldCast(children.map(_.dataType))) { // if original children had different collations we need to // cast the output to the expected collation - val collationId = CollationTypeCasts.getOutputCollation( + CollationTypeCasts.getOutputCollation( children, failOnIndeterminate = false) - Cast(newNode, StringType(collationId)) - } - else { - newNode } + else 0 + + c.copy( + children = newChildren.map(e => implicitCast(e, StringType(collationId)).getOrElse(e))) } } @@ -642,7 +632,7 @@ abstract class TypeCoercionBase { case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && MapType.acceptsType(a.dataType)) && !DataTypeUtils.sameType(m.leftKeyType, m.rightKeyType) => - findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + findWiderCommonType(Seq(m.left, m.right)) match { case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && !Cast.forceNullable(m.rightKeyType, finalKeyType) => val newLeft = castIfNotSameType( @@ -674,7 +664,8 @@ abstract class TypeCoercionBase { val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { children.tail.map { e => - implicitCast(e, StringType).getOrElse(e) + if (e.dataType.isInstanceOf[StringType]) e + else implicitCast(e, StringType).getOrElse(e) } } else { children.tail @@ -904,7 +895,7 @@ abstract class TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case DateAdd(l, r) if r.dataType == StringType && r.foldable => + case DateAdd(l, r) if r.dataType.isInstanceOf[StringType] && r.foldable => val days = try { Cast(r, IntegerType, ansiEnabled = true).eval().asInstanceOf[Int] } catch { @@ -912,7 +903,7 @@ abstract class TypeCoercionBase { throw QueryCompilationErrors.secondArgumentOfFunctionIsNotIntegerError("date_add", e) } DateAdd(l, Literal(days)) - case DateSub(l, r) if r.dataType == StringType && r.foldable => + case DateSub(l, r) if r.dataType.isInstanceOf[StringType] && r.foldable => val days = try { Cast(r, IntegerType, ansiEnabled = true).eval().asInstanceOf[Int] } catch { @@ -1050,13 +1041,17 @@ object TypeCoercion extends TypeCoercionBase { .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } - override def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { + override def findWiderCommonType(exprs: Seq[Expression], + failOnIndeterminate: Boolean = false): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. - val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) - (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => + val (stringTypes, nonStringTypes) = exprs.map(_.dataType).partition(hasStringType) + (if (stringTypes.distinct.size > 1) Option(StringType(CollationTypeCasts + .getOutputCollation(exprs, failOnIndeterminate))) ++ nonStringTypes + else stringTypes.distinct ++ nonStringTypes) + .foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case Some(d) => findWiderTypeForTwo(d, c) case _ => None From b3b1356b7c63e5b3434f690a91e05cc393dcdc64 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 11 Mar 2024 08:30:24 +0100 Subject: [PATCH 10/87] Fix ArrayType(StringType, _) casting in findWiderCommonType --- .../sql/catalyst/analysis/AnsiTypeCoercion.scala | 10 +++++----- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 7648b40b38ced..3cb129394c7aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{castStringType, hasStringType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -156,12 +156,12 @@ object AnsiTypeCoercion extends TypeCoercionBase { override def findWiderCommonType(exprs: Seq[Expression], failOnIndeterminate: Boolean = false): Option[DataType] = { - val (stringTypes, nonStringTypes) = exprs.map(_.dataType).partition(hasStringType) - (if (stringTypes.distinct.size > 1) { + (if (exprs.map(_.dataType).partition(hasStringType)._1.distinct.size > 1) { val collationId = CollationTypeCasts.getOutputCollation(exprs, failOnIndeterminate) exprs.map(e => - if (e.exists(e => e.dataType.isInstanceOf[StringType])) { - Cast(e, StringType(collationId)) + if (hasStringType(e.dataType)) { + castStringType(e.dataType, collationId) + e } else e) } else exprs).map(_.dataType).foldLeft[Option[DataType]](Some(NullType))((r, c) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index bb44afed4b1ab..2d1530245c741 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1048,8 +1048,10 @@ object TypeCoercion extends TypeCoercionBase { // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. val (stringTypes, nonStringTypes) = exprs.map(_.dataType).partition(hasStringType) - (if (stringTypes.distinct.size > 1) Option(StringType(CollationTypeCasts - .getOutputCollation(exprs, failOnIndeterminate))) ++ nonStringTypes + (if (stringTypes.distinct.size > 1) { + val collationId = CollationTypeCasts.getOutputCollation(exprs, failOnIndeterminate) + stringTypes.distinct.map(castStringType(_, collationId)) ++ nonStringTypes + } else stringTypes.distinct ++ nonStringTypes) .foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { @@ -1182,6 +1184,14 @@ object TypeCoercion extends TypeCoercionBase { case _ => false } + /** + * Method to cast nested StringTypes that hasStringType filters. + */ + def castStringType(fromType: DataType, collationId: Int): Option[DataType] = fromType match { + case ArrayType(_, n) => implicitCast(fromType, ArrayType(StringType(collationId), n)) + case _ => implicitCast(fromType, StringType(collationId)) + } + /** * Promotes strings that appear in arithmetic expressions. */ From 7773d133dca2144b6eb53846cdaf845f926ab578 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 11 Mar 2024 08:55:17 +0100 Subject: [PATCH 11/87] Fix type mismatch error --- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2d1530245c741..2403d5eecdeff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1187,9 +1187,9 @@ object TypeCoercion extends TypeCoercionBase { /** * Method to cast nested StringTypes that hasStringType filters. */ - def castStringType(fromType: DataType, collationId: Int): Option[DataType] = fromType match { - case ArrayType(_, n) => implicitCast(fromType, ArrayType(StringType(collationId), n)) - case _ => implicitCast(fromType, StringType(collationId)) + def castStringType(fromType: DataType, collationId: Int): DataType = fromType match { + case ArrayType(_, n) => implicitCast(fromType, ArrayType(StringType(collationId), n)).get + case _ => implicitCast(fromType, StringType(collationId)).get } /** From 255b1ab7063d5a94986fe64c06385ae092689412 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 11 Mar 2024 15:32:18 +0100 Subject: [PATCH 12/87] Incorporate changes and fix errors --- python/pyspark/sql/tests/test_types.py | 2 ++ python/pyspark/sql/types.py | 6 ++-- .../catalyst/analysis/AnsiTypeCoercion.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 21 ++++++++---- .../org/apache/spark/sql/CollationSuite.scala | 34 ++++++++++++------- 5 files changed, 42 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index a0dfdce1a96e8..d6c7a816f7964 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -863,6 +863,7 @@ def test_parse_datatype_string(self): self.assertEqual(t(), _parse_datatype_string(k)) self.assertEqual(IntegerType(), _parse_datatype_string("int")) self.assertEqual(StringType(), _parse_datatype_string("string COLLATE UCS_BASIC")) + self.assertEqual(StringType(-1), _parse_datatype_string("string COLLATE INDETERMINATE_COLLATION")) self.assertEqual(StringType(0), _parse_datatype_string("string")) self.assertEqual(StringType(0), _parse_datatype_string("string COLLATE UCS_BASIC")) self.assertEqual(StringType(0), _parse_datatype_string("string COLLATE UCS_BASIC")) @@ -1215,6 +1216,7 @@ def test_repr(self): instances = [ NullType(), StringType(), + StringType(-1), StringType(0), StringType(1), StringType(2), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index a30f41ae40239..7aa276a0e34dd 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -262,8 +262,10 @@ def __init__(self, collationId: int = 0): def collationIdToName(self) -> str: return ( " COLLATE %s" % StringType.collationNames[self.collationId] - if self.collationId != 0 - else "" + if self.collationId != 0 && self.collationId != -1 + else ("UNDETERMINATE_COLLATION" + if self.collationId == -1 + else "") ) @classmethod diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 3cb129394c7aa..6ba8e7aa818f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -78,6 +78,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: new AnsiCombinedTypeCoercionRule( + CollationTypeCasts :: InConversion :: PromoteStrings :: DecimalPrecision :: @@ -91,7 +92,6 @@ object AnsiTypeCoercion extends TypeCoercionBase { Division :: IntegralDivision :: ImplicitTypeCasts :: - CollationTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: GetDateFieldOperations:: Nil) :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2403d5eecdeff..a43efebd4c514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -774,16 +774,22 @@ abstract class TypeCoercionBase { } object CollationTypeCasts extends TypeCoercionRule { - override def transform: PartialFunction[Expression, Expression] = { + override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e + case s : SortOrder if s.dataType.isInstanceOf[StringType] && + hasIndeterminate(s.children.map(_.dataType.asInstanceOf[StringType])) => + val newChildren = collateToSingleType(s.children) + s.withNewChildren(newChildren) + case b @ BinaryComparison(left, right) if shouldCast(Seq(left.dataType, right.dataType)) => val newChildren = collateToSingleType(Seq(left, right)) b.withNewChildren(newChildren) } def shouldCast(types: Seq[DataType]): Boolean = { - types.forall(_.isInstanceOf[StringType]) && types.distinct.length > 1 + types.forall(_.isInstanceOf[StringType]) && + types.map(t => t.asInstanceOf[StringType].collationId).distinct.length > 1 } /** @@ -807,16 +813,17 @@ abstract class TypeCoercionBase { * a collation type which the output will have. */ def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { - val explicitTypes = exprs.filter(hasExplicitCollation).map(_.dataType).distinct + val explicitTypes = exprs.filter(hasExplicitCollation) + .map(_.dataType.asInstanceOf[StringType].collationId).distinct explicitTypes.size match { - case 1 => explicitTypes.head.asInstanceOf[StringType].collationId + case 1 => explicitTypes.head case size if size > 1 => throw QueryCompilationErrors .explicitCollationMismatchError( - explicitTypes.map(t => t.asInstanceOf[StringType].typeName) + explicitTypes.map(t => StringType(t).typeName) ) - case _ => + case 0 => val dataTypes = exprs.map(_.dataType.asInstanceOf[StringType]) if (hasIndeterminate(dataTypes)) { @@ -934,6 +941,7 @@ object TypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: new CombinedTypeCoercionRule( + CollationTypeCasts :: InConversion :: PromoteStrings :: DecimalPrecision :: @@ -947,7 +955,6 @@ object TypeCoercion extends TypeCoercionBase { StackCoercion :: Division :: IntegralDivision :: - CollationTypeCasts :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9d25b0339e92e..7064bb71ad4c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -202,7 +202,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string COLLATE '$leftCollationName'`.`string COLLATE '$rightCollationName'`" + s"`string COLLATE $leftCollationName`.`string COLLATE $rightCollationName`" ) ) // startsWith @@ -216,7 +216,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string COLLATE '$leftCollationName'`.`string COLLATE '$rightCollationName'`" + s"`string COLLATE $leftCollationName`.`string COLLATE $rightCollationName`" ) ) // endsWith @@ -230,7 +230,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string COLLATE '$leftCollationName'`.`string COLLATE '$rightCollationName'`" + s"`string COLLATE $leftCollationName`.`string COLLATE $rightCollationName`" ) ) } @@ -517,8 +517,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(tableName) { spark.sql( s""" - | CREATE TABLE $tableName(c1 STRING COLLATE 'UCS_BASIC_LCASE', - | c2 STRING COLLATE 'UNICODE', c3 STRING COLLATE 'UNICODE_CI', c4 STRING) + | CREATE TABLE $tableName(c1 STRING COLLATE UCS_BASIC_LCASE, + | c2 STRING COLLATE UNICODE, c3 STRING COLLATE UNICODE_CI, c4 STRING) | USING PARQUET |""".stripMargin) sql(s"INSERT INTO $tableName VALUES ('a', 'a', 'a', 'a')") @@ -578,7 +578,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`.`string COLLATE 'UNICODE'`" + "explicitTypes" -> "`string`.`string COLLATE UNICODE`" ) ) @@ -590,7 +590,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`.`string COLLATE 'UNICODE'`" + "explicitTypes" -> "`string`.`string COLLATE UNICODE`" ) ) checkError( @@ -600,7 +600,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string COLLATE 'UNICODE'`.`string`" + "explicitTypes" -> "`string COLLATE UNICODE`.`string`" ) ) @@ -612,6 +612,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "INDETERMINATE_COLLATION" ) + // concat should fail on indeterminate collation + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 FROM $tableName ORDER BY c1 || c3") + }, + errorClass = "INDETERMINATE_COLLATION" + ) + // concat + in checkError( exception = intercept[AnalysisException] { @@ -620,7 +628,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`.`string COLLATE 'UNICODE'`" + "explicitTypes" -> "`string`.`string COLLATE UNICODE`" ) ) } @@ -631,8 +639,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(tableName) { spark.sql( s""" - | CREATE TABLE $tableName(ucs_basic STRING COLLATE 'UCS_BASIC', - | ucs_basic_lcase STRING COLLATE 'UCS_BASIC_LCASE') + | CREATE TABLE $tableName(ucs_basic STRING COLLATE UCS_BASIC, + | ucs_basic_lcase STRING COLLATE UCS_BASIC_LCASE) | USING PARQUET |""".stripMargin) sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')") @@ -642,10 +650,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(sql(s"SELECT * FROM $tableName " + s"WHERE ucs_basic_lcase IN " + - s"('aaa' COLLATE 'UCS_BASIC_LCASE', 'bbb' collate 'UCS_BASIC_LCASE')"), + s"('aaa' COLLATE UCS_BASIC_LCASE, 'bbb' COLLATE UCS_BASIC_LCASE)"), Seq(Row("aaa", "aaa"), Row("AAA", "AAA"), Row("bbb", "bbb"), Row("BBB", "BBB"))) checkAnswer(sql(s"SELECT * FROM $tableName " + - s"WHERE ucs_basic_lcase IN ('aaa' COLLATE 'UCS_BASIC_LCASE', 'bbb')"), + s"WHERE ucs_basic_lcase IN ('aaa' COLLATE UCS_BASIC_LCASE, 'bbb')"), Seq(Row("aaa", "aaa"), Row("AAA", "AAA"), Row("bbb", "bbb"), Row("BBB", "BBB"))) } } From 50f3aa277f7c55e38eeccba5e9bded9f6d46c5c2 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 12 Mar 2024 09:57:14 +0100 Subject: [PATCH 13/87] Fix errors --- python/pyspark/sql/types.py | 4 +-- .../catalyst/analysis/AnsiTypeCoercion.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 7 ++-- .../org/apache/spark/sql/CollationSuite.scala | 36 ++++++++++--------- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index bff7357f49fc5..1fe311ab244e2 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -262,8 +262,8 @@ def __init__(self, collationId: int = 0): def collationIdToName(self) -> str: return ( " COLLATE %s" % StringType.collationNames[self.collationId] - if self.collationId != 0 && self.collationId != -1 - else ("UNDETERMINATE_COLLATION" + if self.collationId != 0 and self.collationId != -1 + else ("INDETERMINATE_COLLATION" if self.collationId == -1 else "") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 6ba8e7aa818f4..fe1c989a92b1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -156,7 +156,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { override def findWiderCommonType(exprs: Seq[Expression], failOnIndeterminate: Boolean = false): Option[DataType] = { - (if (exprs.map(_.dataType).partition(hasStringType)._1.distinct.size > 1) { + (if (exprs.map(_.dataType).filter(hasStringType).distinct.size > 1) { val collationId = CollationTypeCasts.getOutputCollation(exprs, failOnIndeterminate) exprs.map(e => if (hasStringType(e.dataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a43efebd4c514..253261de00bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1192,11 +1192,12 @@ object TypeCoercion extends TypeCoercionBase { } /** - * Method to cast nested StringTypes that hasStringType filters. + * Method to cast StringTypes that hasStringType filters. */ + @tailrec def castStringType(fromType: DataType, collationId: Int): DataType = fromType match { - case ArrayType(_, n) => implicitCast(fromType, ArrayType(StringType(collationId), n)).get - case _ => implicitCast(fromType, StringType(collationId)).get + case _: StringType => implicitCast(fromType, StringType(collationId)).get + case ArrayType(et, _) => castStringType(et, collationId) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9991258c37a57..6554cf118a98f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -531,7 +531,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(tableName) { spark.sql( s""" - | CREATE TABLE $tableName(c1 STRING COLLATE UCS_BASIC_LCASE, + | CREATE TABLE $tableName(c1 STRING COLLATE UTF8_BINARY_LCASE, | c2 STRING COLLATE UNICODE, c3 STRING COLLATE UNICODE_CI, c4 STRING) | USING PARQUET |""".stripMargin) @@ -544,23 +544,25 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE 'a' = c1"), Seq(Row("a"), Row("A"))) - // collate c1 to UCS_BASIC because it is explicitly set - checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = COLLATE('a', 'UCS_BASIC')"), + // collate c1 to UTF8_BINARY because it is explicitly set + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = COLLATE('a', 'UTF8_BINARY')"), Seq(Row("a"))) - checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = SUBSTR(COLLATE('a', 'UCS_BASIC'), 0)"), + checkAnswer( + sql(s"SELECT c1 FROM $tableName " + + s"WHERE c1 = SUBSTR(COLLATE('a', 'UTF8_BINARY'), 0)"), Seq(Row("a"))) // in operator checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('a')"), Seq(Row("a"), Row("A"))) // explicitly set collation inside IN operator - checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('b', COLLATE('a', 'UCS_BASIC'))"), + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('b', COLLATE('a', 'UTF8_BINARY'))"), Seq(Row("a"))) // concat should not change collation checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || 'a' || 'a' = 'aaa'"), Seq(Row("a"), Row("A"))) - checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE(c2, 'UCS_BASIC') = 'aa'"), + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE(c2, 'UTF8_BINARY') = 'aa'"), Seq(Row("a"))) // concat of columns of different collations is allowed @@ -569,9 +571,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // concat + in checkAnswer(sql(s"SELECT c1 FROM $tableName where c1 || 'a' " + - s"IN (COLLATE('aa', 'UCS_BASIC_LCASE'))"), Seq(Row("a"), Row("A"))) + s"IN (COLLATE('aa', 'UTF8_BINARY_LCASE'))"), Seq(Row("a"), Row("A"))) checkAnswer(sql(s"SELECT c1 FROM $tableName where (c1 || 'a') " + - s"IN (COLLATE('aa', 'UCS_BASIC'))"), Seq(Row("a"))) + s"IN (COLLATE('aa', 'UTF8_BINARY'))"), Seq(Row("a"))) // columns have different collation checkError( @@ -587,7 +589,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql( s""" |SELECT c1 FROM $tableName - |WHERE COLLATE('a', 'UCS_BASIC') = COLLATE('a', 'UNICODE')""" + |WHERE COLLATE('a', 'UTF8_BINARY') = COLLATE('a', 'UNICODE')""" .stripMargin) }, errorClass = "COLLATION_MISMATCH.EXPLICIT", @@ -600,7 +602,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 IN " + - "(COLLATE('a', 'UCS_BASIC'), COLLATE('b', 'UNICODE'))") + "(COLLATE('a', 'UTF8_BINARY'), COLLATE('b', 'UNICODE'))") }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( @@ -610,7 +612,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE COLLATE(c1, 'UNICODE') IN " + - "(COLLATE('a', 'UCS_BASIC'))") + "(COLLATE('a', 'UTF8_BINARY'))") }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( @@ -637,7 +639,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // concat + in checkError( exception = intercept[AnalysisException] { - sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a', 'UCS_BASIC') IN " + + sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a', 'UTF8_BINARY') IN " + s"(COLLATE('a', 'UNICODE'))") }, errorClass = "COLLATION_MISMATCH.EXPLICIT", @@ -653,8 +655,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(tableName) { spark.sql( s""" - | CREATE TABLE $tableName(ucs_basic STRING COLLATE UCS_BASIC, - | ucs_basic_lcase STRING COLLATE UCS_BASIC_LCASE) + | CREATE TABLE $tableName(utf8_binary STRING COLLATE UTF8_BINARY, + | utf8_binary_lcase STRING COLLATE UTF8_BINARY_LCASE) | USING PARQUET |""".stripMargin) sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')") @@ -663,11 +665,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('BBB', 'BBB')") checkAnswer(sql(s"SELECT * FROM $tableName " + - s"WHERE ucs_basic_lcase IN " + - s"('aaa' COLLATE UCS_BASIC_LCASE, 'bbb' COLLATE UCS_BASIC_LCASE)"), + s"WHERE utf8_binary_lcase IN " + + s"('aaa' COLLATE UTF8_BINARY_LCASE, 'bbb' COLLATE UTF8_BINARY_LCASE)"), Seq(Row("aaa", "aaa"), Row("AAA", "AAA"), Row("bbb", "bbb"), Row("BBB", "BBB"))) checkAnswer(sql(s"SELECT * FROM $tableName " + - s"WHERE ucs_basic_lcase IN ('aaa' COLLATE UCS_BASIC_LCASE, 'bbb')"), + s"WHERE utf8_binary_lcase IN ('aaa' COLLATE UTF8_BINARY_LCASE, 'bbb')"), Seq(Row("aaa", "aaa"), Row("AAA", "AAA"), Row("bbb", "bbb"), Row("BBB", "BBB"))) } } From ca0c84dcedc257dd9519af3497e7a570a27ca2a6 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 13 Mar 2024 14:29:08 +0100 Subject: [PATCH 14/87] Rework casting --- python/pyspark/sql/tests/test_types.py | 1 - python/pyspark/sql/types.py | 4 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 14 +- .../sql/catalyst/analysis/TypeCoercion.scala | 131 ++++++++++-------- .../org/apache/spark/sql/CollationSuite.scala | 10 +- 5 files changed, 76 insertions(+), 84 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 6f09a35ce1a0f..c5a8a6399b3a8 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -863,7 +863,6 @@ def test_parse_datatype_string(self): self.assertEqual(t(), _parse_datatype_string(k)) self.assertEqual(IntegerType(), _parse_datatype_string("int")) self.assertEqual(StringType(), _parse_datatype_string("string COLLATE UTF8_BINARY")) - self.assertEqual(StringType(-1), _parse_datatype_string("string COLLATE INDETERMINATE_COLLATION")) self.assertEqual(StringType(0), _parse_datatype_string("string")) self.assertEqual(StringType(0), _parse_datatype_string("string COLLATE UTF8_BINARY")) self.assertEqual(StringType(0), _parse_datatype_string("string COLLATE UTF8_BINARY")) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1fe311ab244e2..b3c5913c07b37 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -263,9 +263,7 @@ def collationIdToName(self) -> str: return ( " COLLATE %s" % StringType.collationNames[self.collationId] if self.collationId != 0 and self.collationId != -1 - else ("INDETERMINATE_COLLATION" - if self.collationId == -1 - else "") + else ("INDETERMINATE_COLLATION" if self.collationId == -1 else "") ) @classmethod diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index fe1c989a92b1f..faa8bae8c7308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{castStringType, hasStringType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -154,17 +153,8 @@ object AnsiTypeCoercion extends TypeCoercionBase { } } - override def findWiderCommonType(exprs: Seq[Expression], - failOnIndeterminate: Boolean = false): Option[DataType] = { - (if (exprs.map(_.dataType).filter(hasStringType).distinct.size > 1) { - val collationId = CollationTypeCasts.getOutputCollation(exprs, failOnIndeterminate) - exprs.map(e => - if (hasStringType(e.dataType)) { - castStringType(e.dataType, collationId) - e - } - else e) - } else exprs).map(_.dataType).foldLeft[Option[DataType]](Some(NullType))((r, c) => + override def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case Some(d) => findWiderTypeForTwo(d, c) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 253261de00bc1..40ae508c8f7c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -65,8 +65,7 @@ abstract class TypeCoercionBase { * is larger than decimal, and yet decimal is more precise than double, but in * union we would cast the decimal into double. */ - def findWiderCommonType(children: Seq[Expression], - failOnIndeterminate: Boolean = false): Option[DataType] + def findWiderCommonType(children: Seq[DataType]): Option[DataType] /** * Given an expected data type, try to cast the expression and return the cast expression. @@ -321,7 +320,7 @@ abstract class TypeCoercionBase { if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - val widenTypeOpt = findWiderCommonType(children.map(_.output(attrIndex))) + val widenTypeOpt = findWiderCommonType(children.map(_.output(attrIndex).dataType)) castedTypes.enqueue(widenTypeOpt) getWidestTypes(children, attrIndex + 1, castedTypes) } @@ -392,7 +391,7 @@ abstract class TypeCoercionBase { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children, failOnIndeterminate = true) match { + findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -409,14 +408,16 @@ abstract class TypeCoercionBase { case e if !e.childrenResolved => e case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) => - findWiderCommonType(children) match { + val types = children.map(_.dataType) + findWiderCommonType(types) match { case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType))) case None => a } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(c.inputTypesForMerging) => - findWiderCommonType(children) match { + val types = children.map(_.dataType) + findWiderCommonType(types) match { case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) case None => c } @@ -431,26 +432,30 @@ abstract class TypeCoercionBase { case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren.map(_.dataType)) => - findWiderCommonType(s.coercibleChildren) match { + val types = s.coercibleChildren.map(_.dataType) + findWiderCommonType(types) match { case Some(widerDataType) => s.castChildrenTo(widerDataType) case None => s } case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && !haveSameType(m.inputTypesForMerging) => - findWiderCommonType(children) match { + val types = children.map(_.dataType) + findWiderCommonType(types) match { case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) case None => m } case m @ CreateMap(children, _) if m.keys.length == m.values.length && (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => - val newKeys = findWiderCommonType(m.keys) match { + val keyTypes = m.keys.map(_.dataType) + val newKeys = findWiderCommonType(keyTypes) match { case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) case None => m.keys } - val newValues = findWiderCommonType(m.values) match { + val valueTypes = m.values.map(_.dataType) + val newValues = findWiderCommonType(valueTypes) match { case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) case None => m.values } @@ -465,7 +470,8 @@ abstract class TypeCoercionBase { // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => - findWiderCommonType(es) match { + val types = es.map(_.dataType) + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) case None => @@ -543,8 +549,7 @@ abstract class TypeCoercionBase { object CaseWhenCoercion extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => - val maybeCommonType = findWiderCommonType( - c.branches.map(_._2) ++ c.elseValue) + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => val newBranches = c.branches.map { case (condition, value) => (condition, castIfNotSameType(value, commonType)) @@ -563,7 +568,7 @@ abstract class TypeCoercionBase { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => - findWiderCommonType(Seq(left, right)).map { widestType => + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => val newLeft = castIfNotSameType(left, widestType) val newRight = castIfNotSameType(right, widestType) If(pred, newLeft, newRight) @@ -609,16 +614,7 @@ abstract class TypeCoercionBase { else implicitCast(e, StringType).getOrElse(e) } - val collationId = if (CollationTypeCasts.shouldCast(children.map(_.dataType))) { - // if original children had different collations we need to - // cast the output to the expected collation - CollationTypeCasts.getOutputCollation( - children, failOnIndeterminate = false) - } - else 0 - - c.copy( - children = newChildren.map(e => implicitCast(e, StringType(collationId)).getOrElse(e))) + c.copy(children = newChildren) } } @@ -632,7 +628,7 @@ abstract class TypeCoercionBase { case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && MapType.acceptsType(a.dataType)) && !DataTypeUtils.sameType(m.leftKeyType, m.rightKeyType) => - findWiderCommonType(Seq(m.left, m.right)) match { + findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && !Cast.forceNullable(m.rightKeyType, finalKeyType) => val newLeft = castIfNotSameType( @@ -777,26 +773,44 @@ abstract class TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - case s : SortOrder if s.dataType.isInstanceOf[StringType] && - hasIndeterminate(s.children.map(_.dataType.asInstanceOf[StringType])) => - val newChildren = collateToSingleType(s.children) - s.withNewChildren(newChildren) + case c @ (_: Concat) if shouldCast(c.children.map(_.dataType)) => + val newChildren = collateToSingleType(c.children, failOnIndeterminate = false) + c.withNewChildren(newChildren) - case b @ BinaryComparison(left, right) if shouldCast(Seq(left.dataType, right.dataType)) => - val newChildren = collateToSingleType(Seq(left, right)) - b.withNewChildren(newChildren) + case e if shouldCast(e.children.map(_.dataType)) => + val newChildren = collateToSingleType(e.children) + e.withNewChildren(newChildren) } def shouldCast(types: Seq[DataType]): Boolean = { - types.forall(_.isInstanceOf[StringType]) && - types.map(t => t.asInstanceOf[StringType].collationId).distinct.length > 1 + types.filter(hasStringType).map(dt => extractStringType(dt).collationId).distinct.size > 1 + } + + /** + * Whether the data type contains StringType. + */ + @tailrec + def hasStringType(dt: DataType): Boolean = dt match { + case _: StringType => true + case ArrayType(et, _) => hasStringType(et) + // Add StructType if we support string promotion for struct fields in the future. + case _ => false + } + + /** + * Extracts StringTypes from flitered hasStringType + */ + private def extractStringType(dt: DataType): StringType = dt match { + case st: StringType => st + case ArrayType(et, _) => et.asInstanceOf[StringType] } /** * Collates the input expressions to a single collation. */ - def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { - val collationId = getOutputCollation(exprs) + def collateToSingleType(exprs: Seq[Expression], + failOnIndeterminate: Boolean = true): Seq[Expression] = { + val collationId = getOutputCollation(exprs, failOnIndeterminate) exprs.map { expression => expression.dataType match { @@ -804,6 +818,11 @@ abstract class TypeCoercionBase { expression case _: StringType => Cast(expression, StringType(collationId)) + case at: ArrayType if at.elementType.isInstanceOf[StringType] + && at.elementType.asInstanceOf[StringType].collationId != collationId => + Cast(expression, ArrayType(StringType(collationId), at.containsNull)) + case at: ArrayType if at.elementType.isInstanceOf[StringType] => + expression } } } @@ -814,7 +833,7 @@ abstract class TypeCoercionBase { */ def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { val explicitTypes = exprs.filter(hasExplicitCollation) - .map(_.dataType.asInstanceOf[StringType].collationId).distinct + .map(e => extractStringType(e.dataType).collationId).distinct explicitTypes.size match { case 1 => explicitTypes.head @@ -824,7 +843,7 @@ abstract class TypeCoercionBase { explicitTypes.map(t => StringType(t).typeName) ) case 0 => - val dataTypes = exprs.map(_.dataType.asInstanceOf[StringType]) + val dataTypes = exprs.map(e => extractStringType(e.dataType)) if (hasIndeterminate(dataTypes)) { if (failOnIndeterminate) { @@ -848,22 +867,18 @@ abstract class TypeCoercionBase { } } - private def hasIndeterminate(dataTypes: Seq[StringType]): Boolean = - dataTypes.exists(_.isIndeterminateCollation) + private def hasIndeterminate(dataTypes: Seq[DataType]): Boolean = + dataTypes.exists(dt => dt.isInstanceOf[StringType] + && dt.asInstanceOf[StringType].isIndeterminateCollation) private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = - dataTypes.filter(!_.isDefaultCollation).distinct.size > 1 + dataTypes.filter(!_.isDefaultCollation).map(_.collationId).distinct.size > 1 private def hasExplicitCollation(expression: Expression): Boolean = { - if (!expression.dataType.isInstanceOf[StringType]) { - false - } - else { - expression match { - case _: Collate => true - case _ => expression.children.exists(hasExplicitCollation) - } + expression match { + case _: Collate => true + case _ => expression.children.exists(hasExplicitCollation) } } } @@ -1048,19 +1063,13 @@ object TypeCoercion extends TypeCoercionBase { .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } - override def findWiderCommonType(exprs: Seq[Expression], - failOnIndeterminate: Boolean = false): Option[DataType] = { + override def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. - val (stringTypes, nonStringTypes) = exprs.map(_.dataType).partition(hasStringType) - (if (stringTypes.distinct.size > 1) { - val collationId = CollationTypeCasts.getOutputCollation(exprs, failOnIndeterminate) - stringTypes.distinct.map(castStringType(_, collationId)) ++ nonStringTypes - } - else stringTypes.distinct ++ nonStringTypes) - .foldLeft[Option[DataType]](Some(NullType))((r, c) => + val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) + (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case Some(d) => findWiderTypeForTwo(d, c) case _ => None @@ -1077,7 +1086,7 @@ object TypeCoercion extends TypeCoercionBase { // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. // We wrap immediately an Option after this. @Nullable val ret: DataType = (inType, expectedType) match { - case (_: StringType, st2: StringType) => st2 + case (st1: StringType, st2: StringType) if st1 != st2 => st2 // If the expected type is already a parent of the input type, no need to cast. case _ if expectedType.acceptsType(inType) => inType @@ -1196,8 +1205,10 @@ object TypeCoercion extends TypeCoercionBase { */ @tailrec def castStringType(fromType: DataType, collationId: Int): DataType = fromType match { - case _: StringType => implicitCast(fromType, StringType(collationId)).get + case _: StringType if fromType != StringType(collationId) + => implicitCast(fromType, StringType(collationId)).get case ArrayType(et, _) => castStringType(et, collationId) + case _ => fromType } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 6554cf118a98f..b3fde853c4920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -628,14 +628,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "INDETERMINATE_COLLATION" ) - // concat should fail on indeterminate collation - checkError( - exception = intercept[AnalysisException] { - sql(s"SELECT c1 FROM $tableName ORDER BY c1 || c3") - }, - errorClass = "INDETERMINATE_COLLATION" - ) - // concat + in checkError( exception = intercept[AnalysisException] { @@ -648,6 +640,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) ) } + + } test("cast of default collated string in IN expression") { From 56d6c7c67d73deb9c96280130478e50c4cc60a1e Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 14 Mar 2024 09:10:40 +0100 Subject: [PATCH 15/87] Fix failing tests --- .../sql/catalyst/analysis/TypeCoercion.scala | 15 ++++++++------- .../org/apache/spark/sql/CollationSuite.scala | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 40ae508c8f7c8..71e4e722a0b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -773,13 +773,14 @@ abstract class TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - case c @ (_: Concat) if shouldCast(c.children.map(_.dataType)) => - val newChildren = collateToSingleType(c.children, failOnIndeterminate = false) - c.withNewChildren(newChildren) - - case e if shouldCast(e.children.map(_.dataType)) => - val newChildren = collateToSingleType(e.children) - e.withNewChildren(newChildren) + case snf @ (_: Concat) if shouldCast(snf.children.map(_.dataType)) => + val newChildren = collateToSingleType(snf.children, failOnIndeterminate = false) + snf.withNewChildren(newChildren) + + case sf @ (_: BinaryExpression | _: In | _: SortOrder) + if shouldCast(sf.children.map(_.dataType)) => + val newChildren = collateToSingleType(sf.children) + sf.withNewChildren(newChildren) } def shouldCast(types: Seq[DataType]): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index f9d30d1a463d1..943c76df8902b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -233,7 +233,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string COLLATE $leftCollationName`.`string COLLATE $rightCollationName`" + s"`string collate $leftCollationName`.`string collate $rightCollationName`" ) ) // startsWith @@ -247,7 +247,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string COLLATE $leftCollationName`.`string COLLATE $rightCollationName`" + s"`string collate $leftCollationName`.`string collate $rightCollationName`" ) ) // endsWith @@ -261,7 +261,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string COLLATE $leftCollationName`.`string COLLATE $rightCollationName`" + s"`string collate $leftCollationName`.`string collate $rightCollationName`" ) ) } @@ -568,7 +568,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`.`string COLLATE UNICODE`" + "explicitTypes" -> "`string`.`string collate UNICODE`" ) ) @@ -580,7 +580,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`.`string COLLATE UNICODE`" + "explicitTypes" -> "`string`.`string collate UNICODE`" ) ) checkError( @@ -590,7 +590,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string COLLATE UNICODE`.`string`" + "explicitTypes" -> "`string collate UNICODE`.`string`" ) ) @@ -610,7 +610,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`.`string COLLATE UNICODE`" + "explicitTypes" -> "`string`.`string collate UNICODE`" ) ) } From 94e5259f08a2149dbe2274b7bb3067e86f0e454b Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 14 Mar 2024 10:55:55 +0100 Subject: [PATCH 16/87] Fix array cast errors --- .../sql/catalyst/analysis/TypeCoercion.scala | 44 ++++++++++--------- .../org/apache/spark/sql/CollationSuite.scala | 8 ++++ 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 71e4e722a0b80..a22a61edafeb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -773,18 +773,27 @@ abstract class TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - case snf @ (_: Concat) if shouldCast(snf.children.map(_.dataType)) => - val newChildren = collateToSingleType(snf.children, failOnIndeterminate = false) - snf.withNewChildren(newChildren) - - case sf @ (_: BinaryExpression | _: In | _: SortOrder) - if shouldCast(sf.children.map(_.dataType)) => - val newChildren = collateToSingleType(sf.children) - sf.withNewChildren(newChildren) + case checkCastWithIndeterminate @ (_: Concat) + if shouldCast(checkCastWithIndeterminate.children) => + val newChildren = + collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) + checkCastWithIndeterminate.withNewChildren(newChildren) + + case checkCastWithoutIndeterminate @ (_: BinaryExpression | _: In | _: SortOrder) + if shouldCast(checkCastWithoutIndeterminate.children) => + val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) + checkCastWithoutIndeterminate.withNewChildren(newChildren) + + case checkIndeterminate @ (_: BinaryExpression | _: In | _: SortOrder) + if hasIndeterminate(checkIndeterminate.children + .filter(e => hasStringType(e.dataType)) + .map(extractStringType)) => + throw QueryCompilationErrors.indeterminateCollationError() } - def shouldCast(types: Seq[DataType]): Boolean = { - types.filter(hasStringType).map(dt => extractStringType(dt).collationId).distinct.size > 1 + def shouldCast(types: Seq[Expression]): Boolean = { + types.filter(e => hasStringType(e.dataType)) + .map(e => extractStringType(e).collationId).distinct.size > 1 } /** @@ -801,7 +810,7 @@ abstract class TypeCoercionBase { /** * Extracts StringTypes from flitered hasStringType */ - private def extractStringType(dt: DataType): StringType = dt match { + private def extractStringType(expr: Expression): StringType = expr.dataType match { case st: StringType => st case ArrayType(et, _) => et.asInstanceOf[StringType] } @@ -834,7 +843,7 @@ abstract class TypeCoercionBase { */ def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { val explicitTypes = exprs.filter(hasExplicitCollation) - .map(e => extractStringType(e.dataType).collationId).distinct + .map(e => extractStringType(e).collationId).distinct explicitTypes.size match { case 1 => explicitTypes.head @@ -844,16 +853,9 @@ abstract class TypeCoercionBase { explicitTypes.map(t => StringType(t).typeName) ) case 0 => - val dataTypes = exprs.map(e => extractStringType(e.dataType)) + val dataTypes = exprs.filter(e => hasStringType(e.dataType)).map(extractStringType) - if (hasIndeterminate(dataTypes)) { - if (failOnIndeterminate) { - throw QueryCompilationErrors.indeterminateCollationError() - } else { - CollationFactory.INDETERMINATE_COLLATION_ID - } - } - else if (hasMultipleImplicits(dataTypes)) { + if (hasMultipleImplicits(dataTypes)) { if (failOnIndeterminate) { throw QueryCompilationErrors.implicitCollationMismatchError() } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 943c76df8902b..8556a79e73f7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -602,6 +602,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "INDETERMINATE_COLLATION" ) + // concat on different implicit collations should fail + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT * FROM $tableName ORDER BY c1 || c3") + }, + errorClass = "INDETERMINATE_COLLATION" + ) + // concat + in checkError( exception = intercept[AnalysisException] { From ccb52ba4a3bde25cb14aaa2596b8dab73af9d198 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 14 Mar 2024 11:52:46 +0100 Subject: [PATCH 17/87] Fix additional errors --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a22a61edafeb0..5f39bb071b2f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -787,13 +787,13 @@ abstract class TypeCoercionBase { case checkIndeterminate @ (_: BinaryExpression | _: In | _: SortOrder) if hasIndeterminate(checkIndeterminate.children .filter(e => hasStringType(e.dataType)) - .map(extractStringType)) => + .map(e => extractStringType(e.dataType))) => throw QueryCompilationErrors.indeterminateCollationError() } def shouldCast(types: Seq[Expression]): Boolean = { types.filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e).collationId).distinct.size > 1 + .map(e => extractStringType(e.dataType).collationId).distinct.size > 1 } /** @@ -810,9 +810,10 @@ abstract class TypeCoercionBase { /** * Extracts StringTypes from flitered hasStringType */ - private def extractStringType(expr: Expression): StringType = expr.dataType match { + @tailrec + private def extractStringType(dt: DataType): StringType = dt match { case st: StringType => st - case ArrayType(et, _) => et.asInstanceOf[StringType] + case ArrayType(et, _) => extractStringType(et) } /** @@ -843,7 +844,7 @@ abstract class TypeCoercionBase { */ def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { val explicitTypes = exprs.filter(hasExplicitCollation) - .map(e => extractStringType(e).collationId).distinct + .map(e => extractStringType(e.dataType).collationId).distinct explicitTypes.size match { case 1 => explicitTypes.head @@ -853,7 +854,8 @@ abstract class TypeCoercionBase { explicitTypes.map(t => StringType(t).typeName) ) case 0 => - val dataTypes = exprs.filter(e => hasStringType(e.dataType)).map(extractStringType) + val dataTypes = exprs.filter(e => hasStringType(e.dataType)) + .map(e => extractStringType(e.dataType)) if (hasMultipleImplicits(dataTypes)) { if (failOnIndeterminate) { From 9b1387b9091395bf87d94ee6c871be1484ebd27a Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Sun, 17 Mar 2024 21:18:05 +0100 Subject: [PATCH 18/87] Fix explicit collation search --- .../sql/catalyst/analysis/TypeCoercion.scala | 78 +++++++++++-------- .../spark/sql/catalyst/expressions/Cast.scala | 70 ++++++++--------- .../expressions/stringExpressions.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 33 ++++---- 4 files changed, 100 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5f39bb071b2f0..5de737d8312cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ImplicitCastInputTypes, _} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -773,7 +772,7 @@ abstract class TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - case checkCastWithIndeterminate @ (_: Concat) + case checkCastWithIndeterminate: Concat if shouldCast(checkCastWithIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) @@ -789,6 +788,22 @@ abstract class TypeCoercionBase { .filter(e => hasStringType(e.dataType)) .map(e => extractStringType(e.dataType))) => throw QueryCompilationErrors.indeterminateCollationError() + + case checkExpectsInputType: ExpectsInputTypes + if checkExpectsInputType.inputTypes.exists(hasStringType) => + val collationId: Int = getOutputCollation( + checkExpectsInputType.children.zip(checkExpectsInputType.inputTypes) + .filter {case (e, t) => hasStringType(t)} + .map {case (e, _) => e}) + val children: Seq[Expression] = checkExpectsInputType + .children.zip(checkExpectsInputType.inputTypes).map { + case (in, expected) + => if (hasStringType(expected)) { + castStringType(in, collationId) + } + else in + } + checkExpectsInputType.withNewChildren(children) } def shouldCast(types: Seq[Expression]): Boolean = { @@ -800,22 +815,42 @@ abstract class TypeCoercionBase { * Whether the data type contains StringType. */ @tailrec - def hasStringType(dt: DataType): Boolean = dt match { + def hasStringType(dt: AbstractDataType): Boolean = dt match { case _: StringType => true case ArrayType(et, _) => hasStringType(et) - // Add StructType if we support string promotion for struct fields in the future. + case tc: TypeCollection + => hasStringType(tc.defaultConcreteType) case _ => false } /** - * Extracts StringTypes from flitered hasStringType + * Extracts StringTypes from filtered hasStringType */ @tailrec - private def extractStringType(dt: DataType): StringType = dt match { + private def extractStringType(dt: AbstractDataType): StringType = dt match { case st: StringType => st case ArrayType(et, _) => extractStringType(et) + case tc: TypeCollection + => extractStringType(tc.defaultConcreteType) } + /** + * Casts to StringType expressions filtered from hasStringType + */ + private def castStringType(expr: Expression, collationId: Int): Expression = + expr.dataType match { + case st: StringType if st.collationId == collationId => + expr + case _: StringType => + Cast(expr, StringType(collationId)) + case at: ArrayType if at.elementType.isInstanceOf[StringType] + && at.elementType.asInstanceOf[StringType].collationId != collationId => + Cast(expr, ArrayType(StringType(collationId), at.containsNull)) + case at: ArrayType if at.elementType.isInstanceOf[StringType] => + expr + case _ => expr + } + /** * Collates the input expressions to a single collation. */ @@ -823,19 +858,7 @@ abstract class TypeCoercionBase { failOnIndeterminate: Boolean = true): Seq[Expression] = { val collationId = getOutputCollation(exprs, failOnIndeterminate) - exprs.map { expression => - expression.dataType match { - case st: StringType if st.collationId == collationId => - expression - case _: StringType => - Cast(expression, StringType(collationId)) - case at: ArrayType if at.elementType.isInstanceOf[StringType] - && at.elementType.asInstanceOf[StringType].collationId != collationId => - Cast(expression, ArrayType(StringType(collationId), at.containsNull)) - case at: ArrayType if at.elementType.isInstanceOf[StringType] => - expression - } - } + exprs.map(castStringType(_, collationId)) } /** @@ -883,7 +906,9 @@ abstract class TypeCoercionBase { private def hasExplicitCollation(expression: Expression): Boolean = { expression match { case _: Collate => true - case _ => expression.children.exists(hasExplicitCollation) + case e if e.dataType.isInstanceOf[ArrayType] + => expression.children.exists(hasExplicitCollation) + case _ => false } } } @@ -1205,17 +1230,6 @@ object TypeCoercion extends TypeCoercionBase { case _ => false } - /** - * Method to cast StringTypes that hasStringType filters. - */ - @tailrec - def castStringType(fromType: DataType, collationId: Int): DataType = fromType match { - case _: StringType if fromType != StringType(collationId) - => implicitCast(fromType, StringType(collationId)).get - case ArrayType(et, _) => castStringType(et, collationId) - case _ => fromType - } - /** * Promotes strings that appear in arithmetic expressions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0bfdf7386efa8..6d011d24e042e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -267,7 +267,7 @@ object Cast extends QueryErrorsBase { * * Cast.castToTimestamp */ def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match { - case (StringType, TimestampType) => true + case (_: StringType, TimestampType) => true case (TimestampType, StringType) => true case (DateType, TimestampType) => true case (TimestampType, DateType) => true @@ -583,7 +583,7 @@ case class Cast( // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { - case StringType => + case _: StringType => buildCast[UTF8String](_, s => { if (StringUtils.isTrueString(s)) { true @@ -620,7 +620,7 @@ case class Cast( // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { - case StringType => + case _: StringType => buildCast[UTF8String](_, utfs => { if (ansiEnabled) { DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, getContextOrNull()) @@ -662,7 +662,7 @@ case class Cast( } private[this] def castToTimestampNTZ(from: DataType): Any => Any = from match { - case StringType => + case _: StringType => buildCast[UTF8String](_, utfs => { if (ansiEnabled) { DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, getContextOrNull()) @@ -696,7 +696,7 @@ case class Cast( // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { - case StringType => + case _: StringType => if (ansiEnabled) { buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, getContextOrNull())) } else { @@ -712,14 +712,14 @@ case class Cast( // IntervalConverter private[this] def castToInterval(from: DataType): Any => Any = from match { - case StringType => + case _: StringType => buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s)) } private[this] def castToDayTimeInterval( from: DataType, it: DayTimeIntervalType): Any => Any = from match { - case StringType => buildCast[UTF8String](_, s => + case _: StringType => buildCast[UTF8String](_, s => IntervalUtils.castStringToDTInterval(s, it.startField, it.endField)) case _: DayTimeIntervalType => buildCast[Long](_, s => IntervalUtils.durationToMicros(IntervalUtils.microsToDuration(s), it.endField)) @@ -739,7 +739,7 @@ case class Cast( private[this] def castToYearMonthInterval( from: DataType, it: YearMonthIntervalType): Any => Any = from match { - case StringType => buildCast[UTF8String](_, s => + case _: StringType => buildCast[UTF8String](_, s => IntervalUtils.castStringToYMInterval(s, it.startField, it.endField)) case _: YearMonthIntervalType => buildCast[Int](_, s => IntervalUtils.periodToMonths(IntervalUtils.monthsToPeriod(s), it.endField)) @@ -758,9 +758,9 @@ case class Cast( // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, getContextOrNull())) - case StringType => + case _: StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) case BooleanType => @@ -791,9 +791,9 @@ case class Cast( // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, getContextOrNull())) - case StringType => + case _: StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) case BooleanType => @@ -823,9 +823,9 @@ case class Cast( // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, getContextOrNull())) - case StringType => + case _: StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toShort(result)) { result.value.toShort @@ -957,12 +957,12 @@ case class Cast( private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { - case StringType if !ansiEnabled => + case _: StringType if !ansiEnabled => buildCast[UTF8String](_, s => { val d = Decimal.fromString(s) if (d == null) null else changePrecision(d, target) }) - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => buildCast[UTF8String](_, s => changePrecision(Decimal.fromStringANSI(s, target, getContextOrNull()), target)) case BooleanType => @@ -1000,7 +1000,7 @@ case class Cast( // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { - case StringType => + case _: StringType => buildCast[UTF8String](_, s => { val doubleStr = s.toString try doubleStr.toDouble catch { @@ -1027,7 +1027,7 @@ case class Cast( // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { - case StringType => + case _: StringType => buildCast[UTF8String](_, s => { val floatStr = s.toString try floatStr.toFloat catch { @@ -1263,7 +1263,7 @@ case class Cast( from: DataType, ctx: CodegenContext): CastFunction = { from match { - case StringType => + case _: StringType => val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]]) (c, evPrim, evNull) => if (ansiEnabled) { @@ -1347,7 +1347,7 @@ case class Cast( val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) from match { - case StringType if !ansiEnabled => + case _: StringType if !ansiEnabled => (c, evPrim, evNull) => code""" Decimal $tmp = Decimal.fromString($c); @@ -1357,7 +1357,7 @@ case class Cast( ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)} } """ - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => val errorContext = getContextOrNullCode(ctx) val toType = ctx.addReferenceObj("toType", target) (c, evPrim, evNull) => @@ -1427,7 +1427,7 @@ case class Cast( private[this] def castToTimestampCode( from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType => + case _: StringType => val zoneIdClass = classOf[ZoneId] val zid = JavaCode.global( ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName), @@ -1504,7 +1504,7 @@ case class Cast( private[this] def castToTimestampNTZCode( from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType => + case _: StringType => val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => if (ansiEnabled) { @@ -1535,7 +1535,7 @@ case class Cast( } private[this] def castToIntervalCode(from: DataType): CastFunction = from match { - case StringType => + case _: StringType => val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") (c, evPrim, evNull) => code"""$evPrim = $util.safeStringToInterval($c); @@ -1549,7 +1549,7 @@ case class Cast( private[this] def castToDayTimeIntervalCode( from: DataType, it: DayTimeIntervalType): CastFunction = from match { - case StringType => + case _: StringType => val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") (c, evPrim, _) => code""" @@ -1586,7 +1586,7 @@ case class Cast( private[this] def castToYearMonthIntervalCode( from: DataType, it: YearMonthIntervalType): CastFunction = from match { - case StringType => + case _: StringType => val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") (c, evPrim, _) => code""" @@ -1634,7 +1634,7 @@ case class Cast( private[this] def castToBooleanCode( from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType => + case _: StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => val castFailureCode = if (ansiEnabled) { @@ -1812,11 +1812,11 @@ case class Cast( private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toShortExact($c, $errorContext);" - case StringType => + case _: StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => code""" @@ -1847,11 +1847,11 @@ case class Cast( } private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toIntExact($c, $errorContext);" - case StringType => + case _: StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => code""" @@ -1882,11 +1882,11 @@ case class Cast( } private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType if ansiEnabled => + case _: StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") val errorContext = getContextOrNullCode(ctx) (c, evPrim, evNull) => code"$evPrim = $stringUtils.toLongExact($c, $errorContext);" - case StringType => + case _: StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => code""" @@ -1917,7 +1917,7 @@ case class Cast( private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { - case StringType => + case _: StringType => val floatStr = ctx.freshVariable("floatStr", StringType) (c, evPrim, evNull) => val handleNull = if (ansiEnabled) { @@ -1955,7 +1955,7 @@ case class Cast( private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { - case StringType => + case _: StringType => val doubleStr = ctx.freshVariable("doubleStr", StringType) (c, evPrim, evNull) => val handleNull = if (ansiEnabled) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3481bc5baa617..aaa26e5d70e3c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1976,7 +1976,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { str.dataType match { - case StringType => string.asInstanceOf[UTF8String] + case _: StringType => string.asInstanceOf[UTF8String] .substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) case BinaryType => ByteArray.subStringSQL(string.asInstanceOf[Array[Byte]], pos.asInstanceOf[Int], len.asInstanceOf[Int]) @@ -1987,7 +1987,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { - case StringType => s"$string.substringSQL($pos, $len)" + case _: StringType => s"$string.substringSQL($pos, $len)" case BinaryType => s"${classOf[ByteArray].getName}.subStringSQL($string, $pos, $len)" } }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 8556a79e73f7b..583b554173290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -521,10 +521,16 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // collate c1 to UTF8_BINARY because it is explicitly set checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = COLLATE('a', 'UTF8_BINARY')"), Seq(Row("a"))) - checkAnswer( - sql(s"SELECT c1 FROM $tableName " + - s"WHERE c1 = SUBSTR(COLLATE('a', 'UTF8_BINARY'), 0)"), - Seq(Row("a"))) + + // fail with implicit mismatch, as function return should be considered implicit + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 FROM $tableName " + + s"WHERE c1 = SUBSTR(COLLATE('a', 'UNICODE'), 0)") + }, + errorClass = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map.empty + ) // in operator checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('a')"), @@ -611,19 +617,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // concat + in - checkError( - exception = intercept[AnalysisException] { - sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a', 'UTF8_BINARY') IN " + - s"(COLLATE('a', 'UNICODE'))") - }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", - parameters = Map( - "explicitTypes" -> "`string`.`string collate UNICODE`" - ) - ) - } - + checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a', 'UTF8_BINARY') IN " + + s"(COLLATE('aa', 'UNICODE'))"), + Seq(Row("a"))) + // ImplicitInputTypeCast test + checkAnswer(sql(s"SELECT REPEAT('aa' collate unicode_ci, MONTH(\"2024-02-02\" COLLATE UNICODE));"), + Seq(Row("aaaa"))) + } } test("cast of default collated string in IN expression") { From c9974e17ef2ba22bbfebb53a47ac97cb589f0a68 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 18 Mar 2024 07:50:56 +0100 Subject: [PATCH 19/87] Fix scala style errors --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5de737d8312cf..687d351d04a70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import scala.collection.mutable diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 583b554173290..df2ff1b0e6111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -622,7 +622,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("a"))) // ImplicitInputTypeCast test - checkAnswer(sql(s"SELECT REPEAT('aa' collate unicode_ci, MONTH(\"2024-02-02\" COLLATE UNICODE));"), + checkAnswer( + sql("SELECT REPEAT('aa' collate unicode_ci, MONTH(\"2024-02-02\" COLLATE UNICODE));"), Seq(Row("aaaa"))) } } From fca9a655814dd5ccbfbaaca83b5790a0c305a449 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 18 Mar 2024 15:47:03 +0100 Subject: [PATCH 20/87] Add support for ImplicitCastInputTypes --- .../sql/catalyst/analysis/TypeCoercion.scala | 92 ++++++++++++------- .../expressions/collationExpressions.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 2 + 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 687d351d04a70..f6d7055521bcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -6,6 +6,7 @@ * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * + * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software @@ -23,7 +24,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{ImplicitCastInputTypes, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -707,7 +708,7 @@ abstract class TypeCoercionBase { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { - case (expr: Expression, st2: StringType) if expr.dataType.isInstanceOf[StringType] => expr + case (expr: Expression, t) if CollationTypeCasts.hasStringType(t) => expr // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } @@ -779,30 +780,53 @@ abstract class TypeCoercionBase { collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) checkCastWithIndeterminate.withNewChildren(newChildren) - case checkCastWithoutIndeterminate @ (_: BinaryExpression | _: In | _: SortOrder) + case checkCastWithoutIndeterminate@(_: BinaryExpression | _: In | _: SortOrder) if shouldCast(checkCastWithoutIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) checkCastWithoutIndeterminate.withNewChildren(newChildren) - case checkIndeterminate @ (_: BinaryExpression | _: In | _: SortOrder) + case checkIndeterminate@(_: BinaryExpression | _: In | _: SortOrder) if hasIndeterminate(checkIndeterminate.children .filter(e => hasStringType(e.dataType)) .map(e => extractStringType(e.dataType))) => throw QueryCompilationErrors.indeterminateCollationError() + case checkImplicitCastInputTypes: ImplicitCastInputTypes + if checkImplicitCastInputTypes.children.exists(e => hasStringType(e.dataType)) + && checkImplicitCastInputTypes.inputTypes.nonEmpty => + val collationId: Int = + getOutputCollation(checkImplicitCastInputTypes + .children.filter { e => hasStringType(e.dataType) }) + val children: Seq[Expression] = checkImplicitCastInputTypes + .children.zip(checkImplicitCastInputTypes.inputTypes).map { + case (e, st) if hasStringType(st) => + castStringType(e, collationId).getOrElse(e) + case (nt, t) if hasStringType(t.defaultConcreteType) && nt.dataType == NullType => + castStringType(nt, collationId, t.defaultConcreteType).getOrElse(nt) + case (e, TypeCollection(types)) if types.exists(hasStringType) => + types.flatMap{ dt => + if (hasStringType(dt)) { + castStringType(e, collationId, dt) + } else { + TypeCoercion.implicitCast(e, dt) + } + }.headOption.getOrElse(e) + case (in, _) => in + } + checkImplicitCastInputTypes.withNewChildren(children) + case checkExpectsInputType: ExpectsInputTypes - if checkExpectsInputType.inputTypes.exists(hasStringType) => + if checkExpectsInputType.children.exists(e => hasStringType(e.dataType)) + && checkExpectsInputType.inputTypes.nonEmpty => val collationId: Int = getOutputCollation( - checkExpectsInputType.children.zip(checkExpectsInputType.inputTypes) - .filter {case (e, t) => hasStringType(t)} - .map {case (e, _) => e}) + checkExpectsInputType.children.filter {e => hasStringType(e.dataType)}) val children: Seq[Expression] = checkExpectsInputType .children.zip(checkExpectsInputType.inputTypes).map { - case (in, expected) - => if (hasStringType(expected)) { - castStringType(in, collationId) - } - else in + case (st, _) if hasStringType(st.dataType) => + castStringType(st, collationId).getOrElse(st) + case (nt, e) if hasStringType(e.defaultConcreteType) && nt.dataType == NullType => + castStringType(nt, collationId, e.defaultConcreteType).getOrElse(nt) + case (in, _) => in } checkExpectsInputType.withNewChildren(children) } @@ -819,8 +843,6 @@ abstract class TypeCoercionBase { def hasStringType(dt: AbstractDataType): Boolean = dt match { case _: StringType => true case ArrayType(et, _) => hasStringType(et) - case tc: TypeCollection - => hasStringType(tc.defaultConcreteType) case _ => false } @@ -831,26 +853,31 @@ abstract class TypeCoercionBase { private def extractStringType(dt: AbstractDataType): StringType = dt match { case st: StringType => st case ArrayType(et, _) => extractStringType(et) - case tc: TypeCollection - => extractStringType(tc.defaultConcreteType) } - /** - * Casts to StringType expressions filtered from hasStringType - */ - private def castStringType(expr: Expression, collationId: Int): Expression = - expr.dataType match { + private def castStringType(expr: Expression, + collationId: Int, + expected: AbstractDataType = NullType): Option[Expression] = + castStringType(expr.dataType, collationId, expected).map { dt => + if (dt == expr.dataType) expr else Cast(expr, dt) + } + + private def castStringType(inType: DataType, + collationId: Int, + expected: AbstractDataType): Option[DataType] = { + @Nullable val ret: DataType = inType match { case st: StringType if st.collationId == collationId => - expr - case _: StringType => - Cast(expr, StringType(collationId)) - case at: ArrayType if at.elementType.isInstanceOf[StringType] - && at.elementType.asInstanceOf[StringType].collationId != collationId => - Cast(expr, ArrayType(StringType(collationId), at.containsNull)) - case at: ArrayType if at.elementType.isInstanceOf[StringType] => - expr - case _ => expr + st + case _: AtomicType => + StringType(collationId) + case ArrayType(arrType, nullable) => + castStringType(arrType, collationId, expected).map(ArrayType(_, nullable)).orNull + case _: NullType => + castStringType(expected.defaultConcreteType, collationId, expected).orNull + case _ => null } + Option(ret) + } /** * Collates the input expressions to a single collation. @@ -859,7 +886,7 @@ abstract class TypeCoercionBase { failOnIndeterminate: Boolean = true): Seq[Expression] = { val collationId = getOutputCollation(exprs, failOnIndeterminate) - exprs.map(castStringType(_, collationId)) + exprs.map(e => castStringType(e, collationId).getOrElse(e)) } /** @@ -1117,7 +1144,6 @@ object TypeCoercion extends TypeCoercionBase { // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. // We wrap immediately an Option after this. @Nullable val ret: DataType = (inType, expectedType) match { - case (st1: StringType, st2: StringType) if st1 != st2 => st2 // If the expected type is already a parent of the input type, no need to cast. case _ if expectedType.acceptsType(inType) => inType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index b0f77bad44831..18e1312ef535b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -112,7 +112,8 @@ case class Collate(child: Expression, collationName: String) since = "4.0.0", group = "string_funcs") // scalastyle:on line.contains.tab -case class Collation(child: Expression) extends UnaryExpression with RuntimeReplaceable { +case class Collation(child: Expression) + extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes { override def dataType: DataType = StringType override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild) override def replacement: Expression = { @@ -120,4 +121,5 @@ case class Collation(child: Expression) extends UnaryExpression with RuntimeRepl val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, StringType) } + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index df2ff1b0e6111..9a454f1217e9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -626,6 +626,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql("SELECT REPEAT('aa' collate unicode_ci, MONTH(\"2024-02-02\" COLLATE UNICODE));"), Seq(Row("aaaa"))) } + + sql("select elt(4, 'aaa', 'bbb', 'ccc' COLLATE UNICODE, 'ddd')") } test("cast of default collated string in IN expression") { From 660d664f301bbb1d459dc1169c8d115c1f1bf4ab Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 18 Mar 2024 19:53:49 +0100 Subject: [PATCH 21/87] Fix accidental change in license header --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f6d7055521bcb..099f0007a85c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -6,7 +6,6 @@ * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * - * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software From c8edd9380f8fb53af95cae8a38ca5274531482cd Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 19 Mar 2024 08:24:52 +0100 Subject: [PATCH 22/87] Fix null casting --- .../sql/catalyst/analysis/TypeCoercion.scala | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 099f0007a85c5..f2b9bfd60708d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -799,13 +799,11 @@ abstract class TypeCoercionBase { val children: Seq[Expression] = checkImplicitCastInputTypes .children.zip(checkImplicitCastInputTypes.inputTypes).map { case (e, st) if hasStringType(st) => - castStringType(e, collationId).getOrElse(e) - case (nt, t) if hasStringType(t.defaultConcreteType) && nt.dataType == NullType => - castStringType(nt, collationId, t.defaultConcreteType).getOrElse(nt) + castStringType(e, collationId, Some(st)).getOrElse(e) case (e, TypeCollection(types)) if types.exists(hasStringType) => types.flatMap{ dt => if (hasStringType(dt)) { - castStringType(e, collationId, dt) + castStringType(e, collationId, Some(dt)) } else { TypeCoercion.implicitCast(e, dt) } @@ -824,7 +822,7 @@ abstract class TypeCoercionBase { case (st, _) if hasStringType(st.dataType) => castStringType(st, collationId).getOrElse(st) case (nt, e) if hasStringType(e.defaultConcreteType) && nt.dataType == NullType => - castStringType(nt, collationId, e.defaultConcreteType).getOrElse(nt) + castStringType(nt, collationId, Some(e.defaultConcreteType)).getOrElse(nt) case (in, _) => in } checkExpectsInputType.withNewChildren(children) @@ -856,23 +854,28 @@ abstract class TypeCoercionBase { private def castStringType(expr: Expression, collationId: Int, - expected: AbstractDataType = NullType): Option[Expression] = + expected: Option[AbstractDataType] = None): Option[Expression] = castStringType(expr.dataType, collationId, expected).map { dt => if (dt == expr.dataType) expr else Cast(expr, dt) } - private def castStringType(inType: DataType, + private def castStringType(inType: AbstractDataType, collationId: Int, - expected: AbstractDataType): Option[DataType] = { + expected: Option[AbstractDataType]): Option[DataType] = { @Nullable val ret: DataType = inType match { case st: StringType if st.collationId == collationId => st - case _: AtomicType => + case _: AtomicType + if expected.isEmpty || expected.get.defaultConcreteType.isInstanceOf[StringType] => StringType(collationId) - case ArrayType(arrType, nullable) => + case ArrayType(arrType, nullable) + if expected.isEmpty || expected.get.defaultConcreteType.isInstanceOf[ArrayType] => castStringType(arrType, collationId, expected).map(ArrayType(_, nullable)).orNull - case _: NullType => - castStringType(expected.defaultConcreteType, collationId, expected).orNull + case _: NullType if expected.nonEmpty => + castStringType( + expected.get.defaultConcreteType, + collationId, + expected).orNull case _ => null } Option(ret) From a91490bfb4f6b5fd2ae0a1c7ea76ec7f55acd590 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 19 Mar 2024 09:02:17 +0100 Subject: [PATCH 23/87] Fix failing tests --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f2b9bfd60708d..e7542b768d129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -707,7 +707,10 @@ abstract class TypeCoercionBase { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { - case (expr: Expression, t) if CollationTypeCasts.hasStringType(t) => expr + case (expr: Expression, t) + if CollationTypeCasts.hasStringType(t) + && CollationTypeCasts.hasStringType(expr.dataType) + => expr // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } @@ -821,7 +824,11 @@ abstract class TypeCoercionBase { .children.zip(checkExpectsInputType.inputTypes).map { case (st, _) if hasStringType(st.dataType) => castStringType(st, collationId).getOrElse(st) - case (nt, e) if hasStringType(e.defaultConcreteType) && nt.dataType == NullType => + case (nt, e) + if hasStringType(e) && nt.dataType == NullType => + castStringType(nt, collationId, Some(e)).getOrElse(nt) + case (nt, e: TypeCollection) + if hasStringType(e.defaultConcreteType) && nt.dataType == NullType => castStringType(nt, collationId, Some(e.defaultConcreteType)).getOrElse(nt) case (in, _) => in } From 49a8d613fb279862052c8c7137fabafa93918343 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 19 Mar 2024 10:52:30 +0100 Subject: [PATCH 24/87] Move implicit casting when strings present --- .../sql/catalyst/analysis/TypeCoercion.scala | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e7542b768d129..b28308be0cbc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -706,15 +706,16 @@ abstract class TypeCoercionBase { }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => - val children: Seq[Expression] = e.children.zip(e.inputTypes).map { - case (expr: Expression, t) - if CollationTypeCasts.hasStringType(t) - && CollationTypeCasts.hasStringType(expr.dataType) - => expr - // If we cannot do the implicit cast, just use the original input. - case (in, expected) => implicitCast(in, expected).getOrElse(in) + if (!e.children.exists(e => CollationTypeCasts.hasStringType(e.dataType))) { + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { + // If we cannot do the implicit cast, just use the original input. + case (in, expected) => implicitCast(in, expected).getOrElse(in) + } + e.withNewChildren(children) + } + else { + e } - e.withNewChildren(children) case e: ExpectsInputTypes if e.inputTypes.nonEmpty => // Convert NullType into some specific target type for ExpectsInputTypes that don't do @@ -808,10 +809,10 @@ abstract class TypeCoercionBase { if (hasStringType(dt)) { castStringType(e, collationId, Some(dt)) } else { - TypeCoercion.implicitCast(e, dt) + implicitCast(e, dt) } }.headOption.getOrElse(e) - case (in, _) => in + case (in, expected) => implicitCast(in, expected).getOrElse(in) } checkImplicitCastInputTypes.withNewChildren(children) From 4c4cd844e3e627391f347434150d874da6943722 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 19 Mar 2024 14:03:07 +0100 Subject: [PATCH 25/87] Fix unintentional changes --- .../src/main/scala/org/apache/spark/sql/types/StringType.scala | 1 - .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index ad0129f28024e..4a0ed781b1e35 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -56,7 +56,6 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa else if (isIndeterminateCollation) s"string collate INDETERMINATE_COLLATION" else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" - override def equals(obj: Any): Boolean = obj.isInstanceOf[StringType] && obj.asInstanceOf[StringType].collationId == collationId diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b28308be0cbc0..f16ee2d682fa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -65,7 +65,7 @@ abstract class TypeCoercionBase { * is larger than decimal, and yet decimal is more precise than double, but in * union we would cast the decimal into double. */ - def findWiderCommonType(children: Seq[DataType]): Option[DataType] + def findWiderCommonType(types: Seq[DataType]): Option[DataType] /** * Given an expected data type, try to cast the expression and return the cast expression. From 66122a618f2926827e49ae45e37ebe3c1c3bcaba Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 20 Mar 2024 12:35:11 +0100 Subject: [PATCH 26/87] improve types.py --- python/pyspark/sql/types.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c79e539d9c96e..5843a781f47e2 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -260,11 +260,12 @@ def __init__(self, collationId: int = 0): self.collationId = collationId def collationIdToName(self) -> str: - return ( - " collate %s" % StringType.collationNames[self.collationId] - if self.collationId != 0 and self.collationId != -1 - else ("INDETERMINATE_COLLATION" if self.collationId == -1 else "") - ) + if self.collationId == 0: + return "" + elif self.collationId == 1: + return "collate INDETERMINATE_COLLATION" + else: + return "collate %s" % StringType.collationNames[self.collationId] @classmethod def collationNameToId(cls, collationName: str) -> int: From 50f46e4b9c5b784a67ec1b3b61423de62ad74905 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 08:51:06 +0100 Subject: [PATCH 27/87] Refactor code --- .../analysis/CollationTypeCasts.scala | 157 ++++++++++++ .../sql/catalyst/analysis/TypeCoercion.scala | 223 +++--------------- .../expressions/stringExpressions.scala | 7 +- 3 files changed, 189 insertions(+), 198 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala new file mode 100644 index 0000000000000..6db3ab423b10e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import javax.annotation.Nullable + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} + +object CollationTypeCasts extends TypeCoercionRule { + override val transform: PartialFunction[Expression, Expression] = { + case e if !e.childrenResolved => e + + case checkCastWithIndeterminate @ (_: Elt | _: ComplexTypeMergingExpression) + if shouldCast(checkCastWithIndeterminate.children) => + val newChildren = + collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) + checkCastWithIndeterminate.withNewChildren(newChildren) + + case checkCastWithoutIndeterminate @ (_: BinaryExpression + | _: Predicate + | _: SortOrder + | _: CreateArray + | _: ExpectsInputTypes) + if shouldCast(checkCastWithoutIndeterminate.children) => + val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) + checkCastWithoutIndeterminate.withNewChildren(newChildren) + + case checkIndeterminate@(_: BinaryExpression + | _: Predicate + | _: SortOrder + | _: CreateArray + | _: ExpectsInputTypes) + if hasIndeterminate(checkIndeterminate.children + .filter(e => hasStringType(e.dataType)) + .map(e => extractStringType(e.dataType))) => + throw QueryCompilationErrors.indeterminateCollationError() + } + + def shouldCast(types: Seq[Expression]): Boolean = { + types.filter(e => hasStringType(e.dataType)) + .map(e => extractStringType(e.dataType).collationId).distinct.size > 1 + } + + /** + * Whether the data type contains StringType. + */ + def hasStringType(dt: DataType): Boolean = dt.existsRecursively { + case _: StringType => true + case _ => false + } + + /** + * Extracts StringTypes from filtered hasStringType + */ + @tailrec + private def extractStringType(dt: DataType): StringType = dt match { + case st: StringType => st + case ArrayType(et, _) => extractStringType(et) + } + + def castStringType(expr: Expression, collationId: Int): Option[Expression] = + castStringType(expr.dataType, collationId).map { dt => + if (dt == expr.dataType) expr else Cast(expr, dt) + } + + private def castStringType(inType: AbstractDataType, collationId: Int): Option[DataType] = { + @Nullable val ret: DataType = inType match { + case st: StringType if st.collationId == collationId => st + case _: StringType => StringType(collationId) + case ArrayType(arrType, nullable) => + castStringType(arrType, collationId).map(ArrayType(_, nullable)).orNull + case _ => null + } + Option(ret) + } + + /** + * Collates the input expressions to a single collation. + */ + def collateToSingleType(exprs: Seq[Expression], + failOnIndeterminate: Boolean = true): Seq[Expression] = { + val collationId = getOutputCollation(exprs, failOnIndeterminate) + + exprs.map(e => castStringType(e, collationId).getOrElse(e)) + } + + /** + * Based on the data types of the input expressions this method determines + * a collation type which the output will have. + */ + def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { + val explicitTypes = exprs.filter(hasExplicitCollation) + .map(e => extractStringType(e.dataType).collationId).distinct + + explicitTypes.size match { + case 1 => explicitTypes.head + case size if size > 1 => + throw QueryCompilationErrors + .explicitCollationMismatchError( + explicitTypes.map(t => StringType(t).typeName) + ) + case 0 => + val dataTypes = exprs.filter(e => hasStringType(e.dataType)) + .map(e => extractStringType(e.dataType)) + + if (hasMultipleImplicits(dataTypes)) { + if (failOnIndeterminate) { + throw QueryCompilationErrors.implicitCollationMismatchError() + } else { + CollationFactory.INDETERMINATE_COLLATION_ID + } + } + else { + dataTypes.find(!_.isDefaultCollation) + .getOrElse(StringType) + .collationId + } + } + } + + private def hasIndeterminate(dataTypes: Seq[DataType]): Boolean = + dataTypes.exists(dt => dt.isInstanceOf[StringType] + && dt.asInstanceOf[StringType].isIndeterminateCollation) + + + private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = + dataTypes.filter(!_.isDefaultCollation).map(_.collationId).distinct.size > 1 + + private def hasExplicitCollation(expression: Expression): Boolean = { + expression match { + case _: Collate => true + case e if e.dataType.isInstanceOf[ArrayType] + => expression.children.exists(hasExplicitCollation) + case _ => false + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f16ee2d682fa6..1faa4183d0613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -23,13 +23,13 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts.{castStringType, getOutputCollation, hasStringType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -609,9 +609,9 @@ abstract class TypeCoercionBase { case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => + val collationId = getOutputCollation(c.children, failOnIndeterminate = false) val newChildren = c.children.map { e => - if (e.dataType.isInstanceOf[StringType]) e - else implicitCast(e, StringType).getOrElse(e) + implicitCast(e, StringType(collationId)).getOrElse(e) } c.copy(children = newChildren) @@ -659,9 +659,9 @@ abstract class TypeCoercionBase { val newIndex = implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { + val collationId = getOutputCollation(children, failOnIndeterminate = false) children.tail.map { e => - if (e.dataType.isInstanceOf[StringType]) e - else implicitCast(e, StringType).getOrElse(e) + implicitCast(e, StringType(collationId)).getOrElse(e) } } else { children.tail @@ -706,31 +706,39 @@ abstract class TypeCoercionBase { }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => - if (!e.children.exists(e => CollationTypeCasts.hasStringType(e.dataType))) { - val children: Seq[Expression] = e.children.zip(e.inputTypes).map { - // If we cannot do the implicit cast, just use the original input. - case (in, expected) => implicitCast(in, expected).getOrElse(in) - } - e.withNewChildren(children) + val childrenBeforeCollations: Seq[Expression] = e.children.zip(e.inputTypes).map { + // If we cannot do the implicit cast, just use the original input. + case (in, expected) => implicitCast(in, expected).getOrElse(in) } - else { - e + val collationId = getOutputCollation(childrenBeforeCollations) + val children: Seq[Expression] = childrenBeforeCollations.map { + case in if hasStringType(in.dataType) => + castStringType(in, collationId).getOrElse(in) + case in => in } + e.withNewChildren(children) case e: ExpectsInputTypes if e.inputTypes.nonEmpty => // Convert NullType into some specific target type for ExpectsInputTypes that don't do // general implicit casting. - val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + val childrenBeforeCollations: Seq[Expression] = + e.children.zip(e.inputTypes).map { case (in, expected) => if (in.dataType == NullType && !expected.acceptsType(NullType)) { Literal.create(null, expected.defaultConcreteType) } else { in } } + val collationId = getOutputCollation(childrenBeforeCollations) + val children: Seq[Expression] = childrenBeforeCollations.map { + case in if hasStringType(in.dataType) => + castStringType(in, collationId).getOrElse(in) + case in => in + } e.withNewChildren(children) case udf: ScalaUDF if udf.inputTypes.nonEmpty => - val children = udf.children.zip(udf.inputTypes).map { case (in, expected) => + val childrenBeforeCollations = udf.children.zip(udf.inputTypes).map { case (in, expected) => // Currently Scala UDF will only expect `AnyDataType` at top level, so this trick works. // In the future we should create types like `AbstractArrayType`, so that Scala UDF can // accept inputs of array type of arbitrary element type. @@ -742,7 +750,12 @@ abstract class TypeCoercionBase { udfInputToCastType(in.dataType, expected.asInstanceOf[DataType]) ).getOrElse(in) } - + } + val collationId = getOutputCollation(childrenBeforeCollations) + val children: Seq[Expression] = childrenBeforeCollations.map { + case in if hasStringType(in.dataType) => + castStringType(in, collationId).getOrElse(in) + case in => in } udf.copy(children = children) } @@ -773,184 +786,6 @@ abstract class TypeCoercionBase { } } - object CollationTypeCasts extends TypeCoercionRule { - override val transform: PartialFunction[Expression, Expression] = { - case e if !e.childrenResolved => e - - case checkCastWithIndeterminate: Concat - if shouldCast(checkCastWithIndeterminate.children) => - val newChildren = - collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) - checkCastWithIndeterminate.withNewChildren(newChildren) - - case checkCastWithoutIndeterminate@(_: BinaryExpression | _: In | _: SortOrder) - if shouldCast(checkCastWithoutIndeterminate.children) => - val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) - checkCastWithoutIndeterminate.withNewChildren(newChildren) - - case checkIndeterminate@(_: BinaryExpression | _: In | _: SortOrder) - if hasIndeterminate(checkIndeterminate.children - .filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e.dataType))) => - throw QueryCompilationErrors.indeterminateCollationError() - - case checkImplicitCastInputTypes: ImplicitCastInputTypes - if checkImplicitCastInputTypes.children.exists(e => hasStringType(e.dataType)) - && checkImplicitCastInputTypes.inputTypes.nonEmpty => - val collationId: Int = - getOutputCollation(checkImplicitCastInputTypes - .children.filter { e => hasStringType(e.dataType) }) - val children: Seq[Expression] = checkImplicitCastInputTypes - .children.zip(checkImplicitCastInputTypes.inputTypes).map { - case (e, st) if hasStringType(st) => - castStringType(e, collationId, Some(st)).getOrElse(e) - case (e, TypeCollection(types)) if types.exists(hasStringType) => - types.flatMap{ dt => - if (hasStringType(dt)) { - castStringType(e, collationId, Some(dt)) - } else { - implicitCast(e, dt) - } - }.headOption.getOrElse(e) - case (in, expected) => implicitCast(in, expected).getOrElse(in) - } - checkImplicitCastInputTypes.withNewChildren(children) - - case checkExpectsInputType: ExpectsInputTypes - if checkExpectsInputType.children.exists(e => hasStringType(e.dataType)) - && checkExpectsInputType.inputTypes.nonEmpty => - val collationId: Int = getOutputCollation( - checkExpectsInputType.children.filter {e => hasStringType(e.dataType)}) - val children: Seq[Expression] = checkExpectsInputType - .children.zip(checkExpectsInputType.inputTypes).map { - case (st, _) if hasStringType(st.dataType) => - castStringType(st, collationId).getOrElse(st) - case (nt, e) - if hasStringType(e) && nt.dataType == NullType => - castStringType(nt, collationId, Some(e)).getOrElse(nt) - case (nt, e: TypeCollection) - if hasStringType(e.defaultConcreteType) && nt.dataType == NullType => - castStringType(nt, collationId, Some(e.defaultConcreteType)).getOrElse(nt) - case (in, _) => in - } - checkExpectsInputType.withNewChildren(children) - } - - def shouldCast(types: Seq[Expression]): Boolean = { - types.filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e.dataType).collationId).distinct.size > 1 - } - - /** - * Whether the data type contains StringType. - */ - @tailrec - def hasStringType(dt: AbstractDataType): Boolean = dt match { - case _: StringType => true - case ArrayType(et, _) => hasStringType(et) - case _ => false - } - - /** - * Extracts StringTypes from filtered hasStringType - */ - @tailrec - private def extractStringType(dt: AbstractDataType): StringType = dt match { - case st: StringType => st - case ArrayType(et, _) => extractStringType(et) - } - - private def castStringType(expr: Expression, - collationId: Int, - expected: Option[AbstractDataType] = None): Option[Expression] = - castStringType(expr.dataType, collationId, expected).map { dt => - if (dt == expr.dataType) expr else Cast(expr, dt) - } - - private def castStringType(inType: AbstractDataType, - collationId: Int, - expected: Option[AbstractDataType]): Option[DataType] = { - @Nullable val ret: DataType = inType match { - case st: StringType if st.collationId == collationId => - st - case _: AtomicType - if expected.isEmpty || expected.get.defaultConcreteType.isInstanceOf[StringType] => - StringType(collationId) - case ArrayType(arrType, nullable) - if expected.isEmpty || expected.get.defaultConcreteType.isInstanceOf[ArrayType] => - castStringType(arrType, collationId, expected).map(ArrayType(_, nullable)).orNull - case _: NullType if expected.nonEmpty => - castStringType( - expected.get.defaultConcreteType, - collationId, - expected).orNull - case _ => null - } - Option(ret) - } - - /** - * Collates the input expressions to a single collation. - */ - def collateToSingleType(exprs: Seq[Expression], - failOnIndeterminate: Boolean = true): Seq[Expression] = { - val collationId = getOutputCollation(exprs, failOnIndeterminate) - - exprs.map(e => castStringType(e, collationId).getOrElse(e)) - } - - /** - * Based on the data types of the input expressions this method determines - * a collation type which the output will have. - */ - def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { - val explicitTypes = exprs.filter(hasExplicitCollation) - .map(e => extractStringType(e.dataType).collationId).distinct - - explicitTypes.size match { - case 1 => explicitTypes.head - case size if size > 1 => - throw QueryCompilationErrors - .explicitCollationMismatchError( - explicitTypes.map(t => StringType(t).typeName) - ) - case 0 => - val dataTypes = exprs.filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e.dataType)) - - if (hasMultipleImplicits(dataTypes)) { - if (failOnIndeterminate) { - throw QueryCompilationErrors.implicitCollationMismatchError() - } else { - CollationFactory.INDETERMINATE_COLLATION_ID - } - } - else { - dataTypes.find(!_.isDefaultCollation) - .getOrElse(StringType) - .collationId - } - } - } - - private def hasIndeterminate(dataTypes: Seq[DataType]): Boolean = - dataTypes.exists(dt => dt.isInstanceOf[StringType] - && dt.asInstanceOf[StringType].isIndeterminateCollation) - - - private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = - dataTypes.filter(!_.isDefaultCollation).map(_.collationId).distinct.size > 1 - - private def hasExplicitCollation(expression: Expression): Boolean = { - expression match { - case _: Collate => true - case e if e.dataType.isInstanceOf[ArrayType] - => expression.children.exists(hasExplicitCollation) - case _ => false - } - } - } - /** * Cast WindowFrame boundaries to the type they operate upon. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index aaa26e5d70e3c..86d0eea74a9de 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{CollationTypeCasts, ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -304,7 +304,7 @@ case class Elt( "inputSql" -> toSQLExpr(indexExpr), "inputType" -> toSQLType(indexType))) } - if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + if (inputTypes.exists(tpe => !tpe.isInstanceOf[StringType] && tpe != BinaryType)) { return DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( @@ -509,8 +509,7 @@ abstract class StringPredicate extends BinaryExpression return checkResult } // Additional check needed for collation compatibility - val outputCollationId: Int = TypeCoercion - .CollationTypeCasts + val outputCollationId: Int = CollationTypeCasts .getOutputCollation(Seq(left, right)) TypeCheckResult.TypeCheckSuccess } From c01e80c386ca5c6285eb9c9fddea0eb76520175f Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 11:05:10 +0100 Subject: [PATCH 28/87] Fix imports and failing tests --- .../spark/sql/catalyst/analysis/AnsiTypeCoercion.scala | 4 +--- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 8 ++++---- .../sql/catalyst/expressions/collationExpressions.scala | 2 +- .../sql/catalyst/expressions/collectionOperations.scala | 3 ++- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 99d289d4e05d6..fb37c81a96a59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -175,8 +175,6 @@ object AnsiTypeCoercion extends TypeCoercionBase { inType: DataType, expectedType: AbstractDataType): Option[DataType] = { (inType, expectedType) match { - case (_: StringType, st: StringType) => - Some(st) // If the expected type equals the input type, no need to cast. case _ if expectedType.acceptsType(inType) => Some(inType) @@ -192,7 +190,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { // If a function expects a StringType, no StringType instance should be implicitly cast to // StringType with a collation that's not accepted (aka. lockdown unsupported collations). - case (_: StringType, StringType) => None + case (_: StringType, _: StringType) => None case (_: StringType, _: StringTypeCollated) => None // This type coercion system will allow implicit converting String type as other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 74591df66733a..2ba9f5747acc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -710,7 +710,7 @@ abstract class TypeCoercionBase { // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } - val collationId = getOutputCollation(childrenBeforeCollations) + val collationId = getOutputCollation(e.children) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, collationId).getOrElse(in) @@ -729,7 +729,7 @@ abstract class TypeCoercionBase { in } } - val collationId = getOutputCollation(childrenBeforeCollations) + val collationId = getOutputCollation(e.children) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, collationId).getOrElse(in) @@ -751,7 +751,7 @@ abstract class TypeCoercionBase { ).getOrElse(in) } } - val collationId = getOutputCollation(childrenBeforeCollations) + val collationId = getOutputCollation(udf.children) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, collationId).getOrElse(in) @@ -1018,7 +1018,7 @@ object TypeCoercion extends TypeCoercionBase { case (_: StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (_: StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, StringType) if !any.isInstanceOf[StringType] => StringType + case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => StringType case (any: AtomicType, st: StringTypeCollated) if !any.isInstanceOf[StringType] => st.defaultConcreteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index 4a94a0ec53e36..8d58a9518ccd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -121,5 +121,5 @@ case class Collation(child: Expression) val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, StringType) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4cf45b1b0cb00..fa40a90c35a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2593,7 +2593,8 @@ case class TryElementAt(left: Expression, right: Expression, replacement: Expres case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression with QueryErrorsBase { - private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9074cb135500f..def54cdf0c830 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{CollationTypeCasts, ExpressionBuilder, FunctionRegistry, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ From cc797a258a5fa76d9f665f79aab8b4e8c6f2daac Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 12:49:49 +0100 Subject: [PATCH 29/87] Disable casting of StructTypes --- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 6db3ab423b10e..74d10c8172b83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,19 +18,18 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType, StructType} object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - case checkCastWithIndeterminate @ (_: Elt | _: ComplexTypeMergingExpression) + case checkCastWithIndeterminate @ (_: Elt | _: ComplexTypeMergingExpression | _: CreateArray) if shouldCast(checkCastWithIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) @@ -64,8 +63,10 @@ object CollationTypeCasts extends TypeCoercionRule { /** * Whether the data type contains StringType. */ - def hasStringType(dt: DataType): Boolean = dt.existsRecursively { + @tailrec + def hasStringType(dt: DataType): Boolean = dt match { case _: StringType => true + case ArrayType(et, _) => hasStringType(et) case _ => false } From 5d001eedfc40c8a7b6a2f42047ec6db1e4a31b9a Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 12:52:59 +0100 Subject: [PATCH 30/87] Fix imports --- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 74d10c8172b83..82c1ac8134e22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType, StructType} +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { From c68fc7dd30ca7ec7f16b3942e121469881897c0d Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 13:41:01 +0100 Subject: [PATCH 31/87] Fix concat tests --- .../spark/sql/catalyst/expressions/StringExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 1fbd1ac9a29fd..cda9676ca58b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -70,7 +70,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> ordinalNumber(0), - "requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")", + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"" ) From 1c926ab54dfc246e8a59ccf8e3ab1937fb660700 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 14:31:59 +0100 Subject: [PATCH 32/87] Fix unnecessary repetition --- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 82c1ac8134e22..8e7d44f9150ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -30,7 +30,7 @@ object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - case checkCastWithIndeterminate @ (_: Elt | _: ComplexTypeMergingExpression | _: CreateArray) + case checkCastWithIndeterminate @ (_: ComplexTypeMergingExpression | _: CreateArray) if shouldCast(checkCastWithIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) @@ -39,7 +39,6 @@ object CollationTypeCasts extends TypeCoercionRule { case checkCastWithoutIndeterminate @ (_: BinaryExpression | _: Predicate | _: SortOrder - | _: CreateArray | _: ExpectsInputTypes) if shouldCast(checkCastWithoutIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) @@ -48,7 +47,6 @@ object CollationTypeCasts extends TypeCoercionRule { case checkIndeterminate@(_: BinaryExpression | _: Predicate | _: SortOrder - | _: CreateArray | _: ExpectsInputTypes) if hasIndeterminate(checkIndeterminate.children .filter(e => hasStringType(e.dataType)) From dec39bf76c4ba7a5430173f9e97ce79e3b560564 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 14:40:01 +0100 Subject: [PATCH 33/87] Remove Elt test --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 68c7b95407fde..887a1c24462d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -590,8 +590,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql("SELECT REPEAT('aa' collate unicode_ci, MONTH(\"2024-02-02\" COLLATE UNICODE));"), Seq(Row("aaaa"))) } - - sql("select elt(4, 'aaa', 'bbb', 'ccc' COLLATE UNICODE, 'ddd')") } test("cast of default collated string in IN expression") { From e8084466af2a828929a8aadb31821f2c86340a23 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 21 Mar 2024 19:45:45 +0100 Subject: [PATCH 34/87] Remove tests for Repeat --- python/pyspark/sql/types.py | 4 ++-- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 2 +- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 5 ----- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5843a781f47e2..73376583184ff 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -263,9 +263,9 @@ def collationIdToName(self) -> str: if self.collationId == 0: return "" elif self.collationId == 1: - return "collate INDETERMINATE_COLLATION" + return " collate INDETERMINATE_COLLATION" else: - return "collate %s" % StringType.collationNames[self.collationId] + return " collate %s" % StringType.collationNames[self.collationId] @classmethod def collationNameToId(cls, collationName: str) -> int: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 8e7d44f9150ec..176114c0d462b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -21,7 +21,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 887a1c24462d1..8b3f488e55099 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -584,11 +584,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a', 'UTF8_BINARY') IN " + s"(COLLATE('aa', 'UNICODE'))"), Seq(Row("a"))) - - // ImplicitInputTypeCast test - checkAnswer( - sql("SELECT REPEAT('aa' collate unicode_ci, MONTH(\"2024-02-02\" COLLATE UNICODE));"), - Seq(Row("aaaa"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e42f397cbfc29..988830eb6633e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1296,7 +1296,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", - "dataType" -> "(\"MAP, INT>\" or \"MAP\")", + "dataType" -> "(\"MAP, INT>\" or \"MAP\")", "functionName" -> "`map_concat`"), context = ExpectedContext( fragment = "map_concat(map1, map2)", @@ -1312,7 +1312,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", - "dataType" -> "(\"MAP, INT>\" or \"MAP\")", + "dataType" -> "(\"MAP, INT>\" or \"MAP\")", "functionName" -> "`map_concat`"), context = ExpectedContext( From af487a21c44da3c33c0d8d2344051070bd854a68 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 22 Mar 2024 09:26:48 +0100 Subject: [PATCH 35/87] Fix failing tests --- python/pyspark/sql/types.py | 2 +- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 73376583184ff..a8d215d56eda1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -262,7 +262,7 @@ def __init__(self, collationId: int = 0): def collationIdToName(self) -> str: if self.collationId == 0: return "" - elif self.collationId == 1: + elif self.collationId == -1: return " collate INDETERMINATE_COLLATION" else: return " collate %s" % StringType.collationNames[self.collationId] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 988830eb6633e..450c8dc2dc394 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1296,7 +1296,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", - "dataType" -> "(\"MAP, INT>\" or \"MAP\")", + "dataType" -> "(\"MAP, INT>\" or \"MAP\")", "functionName" -> "`map_concat`"), context = ExpectedContext( fragment = "map_concat(map1, map2)", @@ -1312,7 +1312,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", - "dataType" -> "(\"MAP, INT>\" or \"MAP\")", + "dataType" -> "(\"MAP, INT>\" or \"MAP\")", "functionName" -> "`map_concat`"), context = ExpectedContext( @@ -2552,7 +2552,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"concat(map(1, 2), map(3, 4))\"", "paramIndex" -> "first", - "requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")", + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")", "inputSql" -> "\"map(1, 2)\"", "inputType" -> "\"MAP\"" ), From 4ba70556b45e5d20199bd5d2ca275eaa8778f6d1 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 22 Mar 2024 12:01:37 +0100 Subject: [PATCH 36/87] Fix nullability for StringType->StringType --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 6d011d24e042e..458c5d500e3a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -341,7 +341,7 @@ object Cast extends QueryErrorsBase { case (NullType, _) => false // empty array or map case case (_, _) if from == to => false - case (_: StringType, BinaryType) => false + case (_: StringType, BinaryType | _: StringType) => false case (_: StringType, _) => true case (_, _: StringType) => false From e490e42e0ec52fa1a983ec40779adf90c5b37141 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Sun, 24 Mar 2024 21:28:57 +0100 Subject: [PATCH 37/87] Improve comments and switch tests from E2E to unit tests --- .../analysis/CollationTypeCasts.scala | 47 +- .../sql/CollationRegexpExpressionsSuite.scala | 589 +++++++++--------- .../sql/CollationStringExpressionsSuite.scala | 72 ++- 3 files changed, 364 insertions(+), 344 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 176114c0d462b..67fbc14e592b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -21,7 +21,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, Concat, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} @@ -29,13 +29,13 @@ import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, String object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - - case checkCastWithIndeterminate @ (_: ComplexTypeMergingExpression | _: CreateArray) + // Case when we do not fail if resulting collation is indeterminate + case checkCastWithIndeterminate @ (_: Concat | _: CreateArray) if shouldCast(checkCastWithIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) checkCastWithIndeterminate.withNewChildren(newChildren) - + // Case when we do fail if resulting collation is indeterminate case checkCastWithoutIndeterminate @ (_: BinaryExpression | _: Predicate | _: SortOrder @@ -43,7 +43,7 @@ object CollationTypeCasts extends TypeCoercionRule { if shouldCast(checkCastWithoutIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) checkCastWithoutIndeterminate.withNewChildren(newChildren) - + // Case if casting is not needed, but we only have indeterminate collations case checkIndeterminate@(_: BinaryExpression | _: Predicate | _: SortOrder @@ -60,7 +60,7 @@ object CollationTypeCasts extends TypeCoercionRule { } /** - * Whether the data type contains StringType. + * Checks whether given data type contains StringType. */ @tailrec def hasStringType(dt: DataType): Boolean = dt match { @@ -78,6 +78,13 @@ object CollationTypeCasts extends TypeCoercionRule { case ArrayType(et, _) => extractStringType(et) } + /** + * Casts given expression to collated StringType with id equal to collationId only + * if expression has StringType in the first place. + * @param expr + * @param collationId + * @return + */ def castStringType(expr: Expression, collationId: Int): Option[Expression] = castStringType(expr.dataType, collationId).map { dt => if (dt == expr.dataType) expr else Cast(expr, dt) @@ -95,7 +102,7 @@ object CollationTypeCasts extends TypeCoercionRule { } /** - * Collates the input expressions to a single collation. + * Collates input expressions to a single collation. */ def collateToSingleType(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Seq[Expression] = { @@ -106,19 +113,24 @@ object CollationTypeCasts extends TypeCoercionRule { /** * Based on the data types of the input expressions this method determines - * a collation type which the output will have. + * a collation type which the output will have. This function accepts Seq of + * any expressions, but will only be affected by collated StringTypes or + * complex DataTypes with collated StringTypes (e.g. ArrayType) */ def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { val explicitTypes = exprs.filter(hasExplicitCollation) .map(e => extractStringType(e.dataType).collationId).distinct explicitTypes.size match { + // We have 1 explicit collation case 1 => explicitTypes.head + // Multiple explicit collations occurred case size if size > 1 => throw QueryCompilationErrors .explicitCollationMismatchError( explicitTypes.map(t => StringType(t).typeName) ) + // Only implicit or default collations present case 0 => val dataTypes = exprs.filter(e => hasStringType(e.dataType)) .map(e => extractStringType(e.dataType)) @@ -138,14 +150,31 @@ object CollationTypeCasts extends TypeCoercionRule { } } + /** + * Checks if there exists an input with input type StringType(-1) + * @param dataTypes + * @return + */ private def hasIndeterminate(dataTypes: Seq[DataType]): Boolean = dataTypes.exists(dt => dt.isInstanceOf[StringType] && dt.asInstanceOf[StringType].isIndeterminateCollation) - + /** + * This check is always preformed when we have no explicit collation. It returns true + * if there are more than one implicit collations. Collations are distinguished by their + * collationId. + * @param dataTypes + * @return + */ private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = dataTypes.filter(!_.isDefaultCollation).map(_.collationId).distinct.size > 1 + /** + * Checks if a given expression has explicitly set collation. For complex DataTypes + * we need to check nested children. + * @param expression + * @return + */ private def hasExplicitCollation(expression: Expression): Boolean = { expression match { case _: Collate => true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 9a8ffb6efa6b1..a99e2f16bf3f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -20,421 +20,406 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType -class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession { +class CollationRegexpExpressionsSuite + extends QueryTest + with SharedSparkSession + with ExpressionEvalHelper { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) case class CollationTestFail[R](s1: String, s2: String, collation: String) test("Support Like string expression with Collation") { + def prepareLike(input: String, + regExp: String, + collation: String): Expression = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + Like(inputExpr, regExpExpr, '\\') + } // Supported collations val checks = Seq( CollationTestCase("ABC", "%B%", "UTF8_BINARY", true) ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + - s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) - }) + checks.foreach(ct => + checkEvaluation(prepareLike(ct.s1, ct.s2, ct.collation), ct.expectedResult)) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", "%B%", "UNICODE", true), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + - s"collate('${ct.s2}', '${ct.collation}')") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"collate(${ct.s1}) LIKE collate(${ct.s2})\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"like collate('${ct.s2}', '${ct.collation}')", - start = 26 + ct.collation.length, - stop = 48 + 2 * ct.collation.length + CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), + CollationTestFail("ABC", "%B%", "UNICODE"), + CollationTestFail("ABC", "%b%", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareLike(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support ILike string expression with Collation") { + def prepareILike(input: String, + regExp: String, + collation: String): Expression = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + ILike(inputExpr, regExpExpr, '\\').replacement + } + // Supported collations val checks = Seq( CollationTestCase("ABC", "%b%", "UTF8_BINARY", true) ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + - s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) - }) + checks.foreach(ct => + checkEvaluation(prepareILike(ct.s1, ct.s2, ct.collation), ct.expectedResult) + ) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", "%b%", "UNICODE", true), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + - s"collate('${ct.s2}', '${ct.collation}')") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"ilike(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"ilike collate('${ct.s2}', '${ct.collation}')", - start = 26 + ct.collation.length, - stop = 49 + 2 * ct.collation.length + CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), + CollationTestFail("ABC", "%b%", "UNICODE"), + CollationTestFail("ABC", "%b%", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareILike(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""lower(${ct.s1})"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support RLike string expression with Collation") { + def prepareRLike(input: String, + regExp: String, + collation: String): Expression = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + RLike(inputExpr, regExpExpr) + } // Supported collations val checks = Seq( CollationTestCase("ABC", ".B.", "UTF8_BINARY", true) ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + - s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) - }) + checks.foreach(ct => + checkEvaluation(prepareRLike(ct.s1, ct.s2, ct.collation), ct.expectedResult) + ) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", ".B.", "UNICODE", true), - CollationTestCase("ABC", ".b.", "UNICODE_CI", false) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + - s"collate('${ct.s2}', '${ct.collation}')") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"RLIKE(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"rlike collate('${ct.s2}', '${ct.collation}')", - start = 26 + ct.collation.length, - stop = 49 + 2 * ct.collation.length + CollationTestFail("ABC", ".b.", "UTF8_BINARY_LCASE"), + CollationTestFail("ABC", ".B.", "UNICODE"), + CollationTestFail("ABC", ".b.", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareRLike(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support StringSplit string expression with Collation") { + def prepareStringSplit(input: String, + splitBy: String, + collation: String): Expression = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val splitByExpr = Literal.create(splitBy, StringType(collationId)) + StringSplit(inputExpr, splitByExpr, Literal(-1)) + } + // Supported collations val checks = Seq( - CollationTestCase("ABC", "[B]", "UTF8_BINARY", 2) + CollationTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")) + ) + checks.foreach(ct => + checkEvaluation(prepareStringSplit(ct.s1, ct.s2, ct.collation), ct.expectedResult) ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')))"), Row(ct.expectedResult)) - }) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABC", "[B]", "UNICODE", 2), - CollationTestCase("ABC", "[b]", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"split(collate(${ct.s1}), collate(${ct.s2}), -1)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"split(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 12, - stop = 55 + 2 * ct.collation.length + CollationTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"), + CollationTestFail("ABC", "[B]", "UNICODE"), + CollationTestFail("ABC", "[b]", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareStringSplit(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support RegExpReplace string expression with Collation") { + def prepareRegExpReplace(input: String, + regExp: String, + collation: String): RegExpReplace = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + RegExpReplace(inputExpr, regExpExpr, Literal.create("FFF", StringType(collationId))) + } + // Supported collations val checks = Seq( CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") ) - checks.foreach(ct => { - checkAnswer( - sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')" + - s",collate('FFF', '${ct.collation}'))"), - Row(ct.expectedResult) - ) - }) + checks.foreach(ct => + checkEvaluation(prepareRegExpReplace(ct.s1, ct.s2, ct.collation), ct.expectedResult) + ) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), - CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE"), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')" + - s",collate('FFF', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_replace(collate(${ct.s1}), collate(${ct.s2}), collate(FFF), 1)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_replace(collate('${ct.s1}', '${ct.collation}'),collate('${ct.s2}'," + - s" '${ct.collation}'),collate('FFF', '${ct.collation}'))", - start = 7, - stop = 80 + 3 * ct.collation.length + CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + CollationTestFail("ABCDE", ".C.", "UNICODE"), + CollationTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareRegExpReplace(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support RegExpExtract string expression with Collation") { + def prepareRegExpExtract(input: String, + regExp: String, + collation: String): RegExpExtract = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + RegExpExtract(inputExpr, regExpExpr, Literal(0)) + } + // Supported collations val checks = Seq( CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") ) - checks.foreach(ct => { - checkAnswer( - sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'),0)"), - Row(ct.expectedResult) - ) - }) + checks.foreach(ct => + checkEvaluation(prepareRegExpExtract(ct.s1, ct.s2, ct.collation), ct.expectedResult) + ) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'),0)") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_extract(collate(${ct.s1}), collate(${ct.s2}), 0)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_extract(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'),0)", - start = 7, - stop = 63 + 2 * ct.collation.length + CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + CollationTestFail("ABCDE", ".C.", "UNICODE"), + CollationTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareRegExpExtract(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support RegExpExtractAll string expression with Collation") { + def prepareRegExpExtractAll(input: String, + regExp: String, + collation: String): RegExpExtractAll = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + RegExpExtractAll(inputExpr, regExpExpr, Literal(0)) + } + // Supported collations val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")) + ) + checks.foreach(ct => + checkEvaluation(prepareRegExpExtractAll(ct.s1, ct.s2, ct.collation), ct.expectedResult) ) - checks.foreach(ct => { - checkAnswer( - sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'),0))"), - Row(ct.expectedResult) - ) - }) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', " + - s"'${ct.collation}'),collate('${ct.s2}', '${ct.collation}'),0))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_extract_all(collate(${ct.s1}), collate(${ct.s2}), 0)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_extract_all(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'),0)", - start = 12, - stop = 72 + 2 * ct.collation.length + CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + CollationTestFail("ABCDE", ".C.", "UNICODE"), + CollationTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareRegExpExtractAll(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support RegExpCount string expression with Collation") { + def prepareRegExpCount(input: String, + regExp: String, + collation: String): Expression = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + RegExpCount(inputExpr, regExpExpr).replacement + } + // Supported collations val checks = Seq( CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) - }) + checks.foreach(ct => + checkEvaluation(prepareRegExpCount(ct.s1, ct.s2, ct.collation), ct.expectedResult) + ) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_count(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_count(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 7, - stop = 59 + 2 * ct.collation.length + CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + CollationTestFail("ABCDE", ".C.", "UNICODE"), + CollationTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareRegExpCount(ct.s1, ct.s2, ct.collation).asInstanceOf[Size].child + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support RegExpSubStr string expression with Collation") { + def prepareRegExpSubStr(input: String, + regExp: String, + collation: String): Expression = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + RegExpSubStr(inputExpr, regExpExpr).replacement.asInstanceOf[NullIf].left + } + // Supported collations val checks = Seq( CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) - }) + checks.foreach(ct => + checkEvaluation(prepareRegExpSubStr(ct.s1, ct.s2, ct.collation), ct.expectedResult) + ) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_substr(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_substr(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 7, - stop = 60 + 2 * ct.collation.length + CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + CollationTestFail("ABCDE", ".C.", "UNICODE"), + CollationTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareRegExpSubStr(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } test("Support RegExpInStr string expression with Collation") { + def prepareRegExpInStr(input: String, + regExp: String, + collation: String): RegExpInStr = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExpr = Literal.create(input, StringType(collationId)) + val regExpExpr = Literal.create(regExp, StringType(collationId)) + RegExpInStr(inputExpr, regExpExpr, Literal(0)) + } + // Supported collations val checks = Seq( CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) - }) + checks.foreach(ct => + checkEvaluation(prepareRegExpInStr(ct.s1, ct.s2, ct.collation), ct.expectedResult) + ) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABCDE", ".C.", "UNICODE", 2), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_instr(collate(${ct.s1}), collate(${ct.s2}), 0)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_instr(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 7, - stop = 59 + 2 * ct.collation.length + CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), + CollationTestFail("ABCDE", ".C.", "UNICODE"), + CollationTestFail("ABCDE", ".c.", "UNICODE_CI") + ) + fails.foreach(ct => + assert(prepareRegExpInStr(ct.s1, ct.s2, ct.collation) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> s""""${ct.s1}"""", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 04f3781a92cf3..ecaaae8acec06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -20,54 +20,60 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.{ConcatWs, ExpressionEvalHelper, Literal} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType -class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { +class CollationStringExpressionsSuite + extends QueryTest + with SharedSparkSession + with ExpressionEvalHelper { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + test("Support ConcatWs string expression with Collation") { - // Supported collations + def prepareConcatWs(sep: String, + collation: String, + inputs: Any*): ConcatWs = { + val collationId = CollationFactory.collationNameToId(collation) + val inputExprs = inputs.map(s => Literal.create(s, StringType(collationId))) + val sepExpr = Literal.create(sep, StringType(collationId)) + ConcatWs(sepExpr +: inputExprs) + } + // Supported Collations val checks = Seq( CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL") ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT concat_ws(collate(' ', '${ct.collation}'), " + - s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))"), - Row(ct.expectedResult)) - }) - // Unsupported collations + checks.foreach(ct => + checkEvaluation(prepareConcatWs(" ", ct.collation, ct.s1, ct.s2), ct.expectedResult) + ) + + // Unsupported Collations val fails = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", "%B%", "UNICODE", true), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), + CollationTestFail("ABC", "%B%", "UNICODE"), + CollationTestFail("ABC", "%b%", "UNICODE_CI") ) - fails.foreach(ct => { - val expr = s"concat_ws(collate(' ', '${ct.collation}'), " + - s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))" - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT $expr") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"concat_ws(collate( ), collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate( )\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"$expr", - start = 7, - stop = 73 + 3 * ct.collation.length + fails.foreach(ct => + assert(prepareConcatWs(" ", ct.collation, ct.s1, ct.s2) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "first", + "requiredType" -> """"STRING"""", + "inputSql" -> """" """", + "inputType" -> s""""STRING COLLATE ${ct.collation}"""" + ) ) ) - }) + ) } // TODO: Add more tests for other string expressions From 00e88e78c51325ea528df4a49c9d2f4ed1da1fec Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 25 Mar 2024 08:42:18 +0100 Subject: [PATCH 38/87] Add new tests and remove compatibility test --- .../analysis/CollationTypeCasts.scala | 4 +-- ...traints.scala => StringTypeCollated.scala} | 23 ---------------- .../expressions/stringExpressions.scala | 8 ------ .../org/apache/spark/sql/CollationSuite.scala | 26 ++++++++++++++----- 4 files changed, 22 insertions(+), 39 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{CollationTypeConstraints.scala => StringTypeCollated.scala} (70%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 67fbc14e592b9..f8ee267e042c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -21,7 +21,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, Concat, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} @@ -30,7 +30,7 @@ object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e // Case when we do not fail if resulting collation is indeterminate - case checkCastWithIndeterminate @ (_: Concat | _: CreateArray) + case checkCastWithIndeterminate @ (_: ComplexTypeMergingExpression | _: CreateArray) if shouldCast(checkCastWithIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala similarity index 70% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala index cd909a45c1ed6..50e843081a3c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/StringTypeCollated.scala @@ -17,31 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} -object CollationTypeConstraints { - - def checkCollationCompatibility(collationId: Int, dataTypes: Seq[DataType]): TypeCheckResult = { - val collationName = CollationFactory.fetchCollation(collationId).collationName - // Additional check needed for collation compatibility - dataTypes.collectFirst { - case stringType: StringType if stringType.collationId != collationId => - val collation = CollationFactory.fetchCollation(stringType.collationId) - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> collationName, - "collationNameRight" -> collation.collationName - ) - ) - } getOrElse TypeCheckResult.TypeCheckSuccess - } - -} - /** * StringTypeCollated is an abstract class for StringType with collation support. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index def54cdf0c830..df933eed27833 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -504,14 +504,6 @@ abstract class StringPredicate extends BinaryExpression override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, StringTypeAnyCollation) - override def checkInputDataTypes(): TypeCheckResult = { - val defaultCheck = super.checkInputDataTypes() - if (defaultCheck.isFailure) { - return defaultCheck - } - CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) - } - protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 5b309d3eb2002..1196f42aff980 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -465,7 +465,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("implicit cast of default collated strings") { + test("implicit casting of collated strings") { val tableName = "parquet_dummy_implicit_cast_t22" withTable(tableName) { spark.sql( @@ -504,14 +504,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('b', COLLATE('a', 'UTF8_BINARY'))"), Seq(Row("a"))) - // concat should not change collation + // concat without type mismatch checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || 'a' || 'a' = 'aaa'"), Seq(Row("a"), Row("A"))) checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE(c2, 'UTF8_BINARY') = 'aa'"), Seq(Row("a"))) // concat of columns of different collations is allowed - // as long as we don't use binary comparison on the result + // as long as we don't use the result in an unsupported function sql(s"SELECT c1 || c3 from $tableName") // concat + in @@ -565,7 +565,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) ) - // concat on different implicit collations should fail + // concat on different implicit collations should succeed, + // but should fail on try of comparison checkError( exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 || c3 = 'aa'") @@ -573,7 +574,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "INDETERMINATE_COLLATION" ) - // concat on different implicit collations should fail + // concat on different implicit collations should succeed, + // but should fail on try of ordering checkError( exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName ORDER BY c1 || c3") @@ -585,10 +587,22 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a', 'UTF8_BINARY') IN " + s"(COLLATE('aa', 'UNICODE'))"), Seq(Row("a"))) + + // array creation supports implicit casting + checkAnswer(sql(s"SELECT typeof(array('a' COLLATE UNICODE, 'b')[1])"), + Seq(Row("string collate UNICODE"))) + + // contains fails with indeterminate collation + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT * FROM $tableName WHERE contains(c1||c3, 'a')") + }, + errorClass = "INDETERMINATE_COLLATION" + ) } } - test("cast of default collated string in IN expression") { + test("cast of default collated strings in IN expression") { val tableName = "t1" withTable(tableName) { spark.sql( From 85b4d168c78356fb6a42eca085d434c92fd3fab1 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 25 Mar 2024 09:56:13 +0100 Subject: [PATCH 39/87] Fix conflict resolution mistake --- .../apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 958c7933f4358..653eefb23f005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -187,9 +187,6 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (_: StringType, _: StringType) => None case (_: StringType, _: StringTypeCollated) => None - case (DateType, AnyTimestampType) => - Some(AnyTimestampType.defaultConcreteType) - // If a function expects integral type, fractional input is not allowed. case (_: FractionalType, IntegralType) => None From e89a3544c1c23b432c410d09f36ec777c335db6f Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 26 Mar 2024 08:42:58 +0100 Subject: [PATCH 40/87] Add indeterminate collation tests --- .../org/apache/spark/sql/CollationSuite.scala | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 0892e91839655..336bf89880173 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -657,6 +657,36 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("create table on indeterminate result should fail") { + val tableName = "t1" + withTable(tableName) { + spark.sql( + s""" + | CREATE TABLE $tableName(c1 STRING COLLATE UTF8_BINARY, + | c2 STRING COLLATE UTF8_BINARY_LCASE) + | USING PARQUET + |""".stripMargin) + sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')") + sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')") + sql(s"INSERT INTO $tableName VALUES ('bbb', 'bbb')") + sql(s"INSERT INTO $tableName VALUES ('BBB', 'BBB')") + + checkError( + exception = intercept[AnalysisException] { + sql(s"CREATE VIEW v AS SELECT c1 || c3 FROM $tableName") + }, + errorClass = "INDETERMINATE_COLLATION" + ) + + checkError( + exception = intercept[AnalysisException] { + sql(s"CREATE TABLE t2 AS SELECT c1 || c3 FROM $tableName") + }, + errorClass = "INDETERMINATE_COLLATION" + ) + } + } + test("create v2 table with collation column") { val tableName = "testcat.table_name" val collationName = "UTF8_BINARY_LCASE" From 788dc0695602cc111ed6a21d801c7625bddbb413 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 26 Mar 2024 10:14:00 +0100 Subject: [PATCH 41/87] Fix test --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 336bf89880173..b986924241937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -673,14 +673,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { - sql(s"CREATE VIEW v AS SELECT c1 || c3 FROM $tableName") + sql(s"CREATE VIEW v AS SELECT c1 || c2 FROM $tableName") }, errorClass = "INDETERMINATE_COLLATION" ) checkError( exception = intercept[AnalysisException] { - sql(s"CREATE TABLE t2 AS SELECT c1 || c3 FROM $tableName") + sql(s"CREATE TABLE t2 AS SELECT c1 || c2 FROM $tableName") }, errorClass = "INDETERMINATE_COLLATION" ) From 75c01408990af3fc88d7783231a2c7d528fa7f1b Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 27 Mar 2024 14:46:58 +0100 Subject: [PATCH 42/87] Block Alias on Indeterminate --- .../analysis/CollationTypeCasts.scala | 18 +++++++++---- .../expressions/collationExpressions.scala | 7 +++-- .../org/apache/spark/sql/CollationSuite.scala | 27 +++++++++---------- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index f8ee267e042c2..dc7f832292e9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} @@ -30,8 +30,8 @@ object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e // Case when we do not fail if resulting collation is indeterminate - case checkCastWithIndeterminate @ (_: ComplexTypeMergingExpression | _: CreateArray) - if shouldCast(checkCastWithIndeterminate.children) => + case checkCastWithIndeterminate @ (_: ComplexTypeMergingExpression + | _: CreateArray) => val newChildren = collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) checkCastWithIndeterminate.withNewChildren(newChildren) @@ -43,11 +43,19 @@ object CollationTypeCasts extends TypeCoercionRule { if shouldCast(checkCastWithoutIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) checkCastWithoutIndeterminate.withNewChildren(newChildren) + // Case if casting is not needed, but we only have indeterminate + // collations and we do not want to fail + case checkIndeterminateWithoutFail@(_: Invoke) + if hasIndeterminate(checkIndeterminateWithoutFail.children + .filter(e => hasStringType(e.dataType)) + .map(e => extractStringType(e.dataType))) => + checkIndeterminateWithoutFail // Case if casting is not needed, but we only have indeterminate collations case checkIndeterminate@(_: BinaryExpression | _: Predicate | _: SortOrder - | _: ExpectsInputTypes) + | _: ExpectsInputTypes + | _: Alias) if hasIndeterminate(checkIndeterminate.children .filter(e => hasStringType(e.dataType)) .map(e => extractStringType(e.dataType))) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index 8d58a9518ccd5..07dbbee26a6d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -114,12 +114,15 @@ case class Collate(child: Expression, collationName: String) // scalastyle:on line.contains.tab case class Collation(child: Expression) extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes { - override def dataType: DataType = StringType + override def dataType: DataType = child.dataType override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild) override def replacement: Expression = { val collationId = child.dataType.asInstanceOf[StringType].collationId + if (collationId == CollationFactory.INDETERMINATE_COLLATION_ID) { + throw QueryCompilationErrors.indeterminateCollationError() + } val collationName = CollationFactory.fetchCollation(collationId).collationName - Literal.create(collationName, StringType) + Literal.create(collationName, StringType(collationId)) } override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b986924241937..e349692f006ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -543,7 +543,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // concat of columns of different collations is allowed // as long as we don't use the result in an unsupported function - sql(s"SELECT c1 || c3 from $tableName") + checkAnswer(sql(s"SELECT c1 || c2 FROM $tableName"), Seq(Row("aa"), Row("AA"))) // concat + in checkAnswer(sql(s"SELECT c1 FROM $tableName where c1 || 'a' " + @@ -657,12 +657,13 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("create table on indeterminate result should fail") { + test("indeterminate collation checks") { val tableName = "t1" + val newTableName = "t2" withTable(tableName) { spark.sql( s""" - | CREATE TABLE $tableName(c1 STRING COLLATE UTF8_BINARY, + | CREATE TABLE $tableName(c1 STRING COLLATE UNICODE, | c2 STRING COLLATE UTF8_BINARY_LCASE) | USING PARQUET |""".stripMargin) @@ -671,19 +672,15 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('bbb', 'bbb')") sql(s"INSERT INTO $tableName VALUES ('BBB', 'BBB')") - checkError( - exception = intercept[AnalysisException] { - sql(s"CREATE VIEW v AS SELECT c1 || c2 FROM $tableName") - }, - errorClass = "INDETERMINATE_COLLATION" - ) + sql(s"SET spark.sql.legacy.createHiveTableByDefault=false") - checkError( - exception = intercept[AnalysisException] { - sql(s"CREATE TABLE t2 AS SELECT c1 || c2 FROM $tableName") - }, - errorClass = "INDETERMINATE_COLLATION" - ) + withTable(newTableName) { + checkError( + exception = intercept[AnalysisException] { + sql(s"CREATE TABLE $newTableName AS SELECT c1 || c2 FROM $tableName") + }, + errorClass = "INDETERMINATE_COLLATION") + } } } From f6ed55a73f7f6d7c7ea2ee38ddc0279977979d47 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 28 Mar 2024 10:17:44 +0100 Subject: [PATCH 43/87] Remove introduction of indeterminate collation --- python/pyspark/sql/types.py | 2 - .../analysis/CollationTypeCasts.scala | 61 ++++--------------- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../expressions/collationExpressions.scala | 3 - .../org/apache/spark/sql/CollationSuite.scala | 18 ++++-- 5 files changed, 28 insertions(+), 60 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ac231c2e3ea98..c60cd61d649b7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -266,8 +266,6 @@ def fromCollationId(self, collationId: int) -> "StringType": def collationIdToName(self) -> str: if self.collationId == 0: return "" - elif self.collationId == -1: - return " collate INDETERMINATE_COLLATION" else: return " collate %s" % StringType.collationNames[self.collationId] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index dc7f832292e9f..a6d2bd645a6fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -20,46 +20,24 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{Alias, BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} -import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - // Case when we do not fail if resulting collation is indeterminate - case checkCastWithIndeterminate @ (_: ComplexTypeMergingExpression - | _: CreateArray) => - val newChildren = - collateToSingleType(checkCastWithIndeterminate.children, failOnIndeterminate = false) - checkCastWithIndeterminate.withNewChildren(newChildren) // Case when we do fail if resulting collation is indeterminate case checkCastWithoutIndeterminate @ (_: BinaryExpression | _: Predicate | _: SortOrder - | _: ExpectsInputTypes) + | _: ExpectsInputTypes + | _: ComplexTypeMergingExpression + | _: CreateArray) if shouldCast(checkCastWithoutIndeterminate.children) => val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) checkCastWithoutIndeterminate.withNewChildren(newChildren) - // Case if casting is not needed, but we only have indeterminate - // collations and we do not want to fail - case checkIndeterminateWithoutFail@(_: Invoke) - if hasIndeterminate(checkIndeterminateWithoutFail.children - .filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e.dataType))) => - checkIndeterminateWithoutFail - // Case if casting is not needed, but we only have indeterminate collations - case checkIndeterminate@(_: BinaryExpression - | _: Predicate - | _: SortOrder - | _: ExpectsInputTypes - | _: Alias) - if hasIndeterminate(checkIndeterminate.children - .filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e.dataType))) => - throw QueryCompilationErrors.indeterminateCollationError() } def shouldCast(types: Seq[Expression]): Boolean = { @@ -112,9 +90,8 @@ object CollationTypeCasts extends TypeCoercionRule { /** * Collates input expressions to a single collation. */ - def collateToSingleType(exprs: Seq[Expression], - failOnIndeterminate: Boolean = true): Seq[Expression] = { - val collationId = getOutputCollation(exprs, failOnIndeterminate) + def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { + val collationId = getOutputCollation(exprs) exprs.map(e => castStringType(e, collationId).getOrElse(e)) } @@ -125,7 +102,7 @@ object CollationTypeCasts extends TypeCoercionRule { * any expressions, but will only be affected by collated StringTypes or * complex DataTypes with collated StringTypes (e.g. ArrayType) */ - def getOutputCollation(exprs: Seq[Expression], failOnIndeterminate: Boolean = true): Int = { + def getOutputCollation(exprs: Seq[Expression]): Int = { val explicitTypes = exprs.filter(hasExplicitCollation) .map(e => extractStringType(e.dataType).collationId).distinct @@ -144,29 +121,16 @@ object CollationTypeCasts extends TypeCoercionRule { .map(e => extractStringType(e.dataType)) if (hasMultipleImplicits(dataTypes)) { - if (failOnIndeterminate) { - throw QueryCompilationErrors.implicitCollationMismatchError() - } else { - CollationFactory.INDETERMINATE_COLLATION_ID - } + throw QueryCompilationErrors.implicitCollationMismatchError() } else { - dataTypes.find(!_.isDefaultCollation) - .getOrElse(StringType) + dataTypes.find(dt => !(dt == SQLConf.get.defaultStringType)) + .getOrElse(SQLConf.get.defaultStringType) .collationId } } } - /** - * Checks if there exists an input with input type StringType(-1) - * @param dataTypes - * @return - */ - private def hasIndeterminate(dataTypes: Seq[DataType]): Boolean = - dataTypes.exists(dt => dt.isInstanceOf[StringType] - && dt.asInstanceOf[StringType].isIndeterminateCollation) - /** * This check is always preformed when we have no explicit collation. It returns true * if there are more than one implicit collations. Collations are distinguished by their @@ -175,7 +139,8 @@ object CollationTypeCasts extends TypeCoercionRule { * @return */ private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = - dataTypes.filter(!_.isDefaultCollation).map(_.collationId).distinct.size > 1 + dataTypes.filter(dt => !(dt == SQLConf.get.defaultStringType)) + .map(_.collationId).distinct.size > 1 /** * Checks if a given expression has explicitly set collation. For complex DataTypes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2ba9f5747acc2..7e25afc84b9aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -609,7 +609,7 @@ abstract class TypeCoercionBase { case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val collationId = getOutputCollation(c.children, failOnIndeterminate = false) + val collationId = getOutputCollation(c.children) val newChildren = c.children.map { e => implicitCast(e, StringType(collationId)).getOrElse(e) } @@ -659,7 +659,7 @@ abstract class TypeCoercionBase { val newIndex = implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - val collationId = getOutputCollation(children, failOnIndeterminate = false) + val collationId = getOutputCollation(children) children.tail.map { e => implicitCast(e, StringType(collationId)).getOrElse(e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index 07dbbee26a6d3..f167fa6e87e20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -118,9 +118,6 @@ case class Collation(child: Expression) override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild) override def replacement: Expression = { val collationId = child.dataType.asInstanceOf[StringType].collationId - if (collationId == CollationFactory.INDETERMINATE_COLLATION_ID) { - throw QueryCompilationErrors.indeterminateCollationError() - } val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, StringType(collationId)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 0e6bb8e63756a..a6e6dfc428f04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -544,7 +544,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // concat of columns of different collations is allowed // as long as we don't use the result in an unsupported function - checkAnswer(sql(s"SELECT c1 || c2 FROM $tableName"), Seq(Row("aa"), Row("AA"))) + // TODO: (SPARK-47210) Add indeterminate support + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT c1 || c2 FROM $tableName") + }, + errorClass = "COLLATION_MISMATCH.IMPLICIT" + ) + // concat + in checkAnswer(sql(s"SELECT c1 FROM $tableName where c1 || 'a' " + @@ -603,7 +610,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 || c3 = 'aa'") }, - errorClass = "INDETERMINATE_COLLATION" + errorClass = "COLLATION_MISMATCH.IMPLICIT" ) // concat on different implicit collations should succeed, @@ -612,7 +619,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName ORDER BY c1 || c3") }, - errorClass = "INDETERMINATE_COLLATION" + errorClass = "COLLATION_MISMATCH.IMPLICIT" ) // concat + in @@ -629,7 +636,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName WHERE contains(c1||c3, 'a')") }, - errorClass = "INDETERMINATE_COLLATION" + errorClass = "COLLATION_MISMATCH.IMPLICIT" ) } } @@ -658,6 +665,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + // TODO: (SPARK-47210) Add indeterminate support test("indeterminate collation checks") { val tableName = "t1" val newTableName = "t2" @@ -680,7 +688,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"CREATE TABLE $newTableName AS SELECT c1 || c2 FROM $tableName") }, - errorClass = "INDETERMINATE_COLLATION") + errorClass = "COLLATION_MISMATCH.IMPLICIT") } } } From 98960c04e51441f41903f53c9bea09ee330ce9a8 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 28 Mar 2024 10:32:41 +0100 Subject: [PATCH 44/87] Fix import problem --- .../apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index a6d2bd645a6fa..d58ef8c3e7d45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} From de623c8d45f5d161b8cff3e6c509a9918c1230fc Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 28 Mar 2024 14:10:33 +0100 Subject: [PATCH 45/87] Fix failing tests --- .../org/apache/spark/sql/catalyst/expressions/literals.scala | 2 +- .../src/test/resources/sql-functions/sql-expression-schema.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 1b20da0b5cbcd..3c4e74421250d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -485,7 +485,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def sql: String = (value, dataType) match { case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" case _ if value == null => s"CAST(NULL AS ${dataType.sql})" - case (v: UTF8String, _: StringType) => + case (v: UTF8String, StringType) => // Escapes all backslashes and single quotes. "'" + v.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" case (v: Byte, ByteType) => s"${v}Y" diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index bd1b6f0cb7537..cbc107451dcf0 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -81,7 +81,7 @@ | org.apache.spark.sql.catalyst.expressions.Chr | char | SELECT char(65) | struct | | org.apache.spark.sql.catalyst.expressions.Chr | chr | SELECT chr(65) | struct | | org.apache.spark.sql.catalyst.expressions.Coalesce | coalesce | SELECT coalesce(NULL, 1, NULL) | struct | -| org.apache.spark.sql.catalyst.expressions.CollateExpressionBuilder | collate | SELECT COLLATION('Spark SQL' collate UTF8_BINARY_LCASE) | struct | +| org.apache.spark.sql.catalyst.expressions.CollateExpressionBuilder | collate | SELECT COLLATION('Spark SQL' collate UTF8_BINARY_LCASE) | struct | | org.apache.spark.sql.catalyst.expressions.Collation | collation | SELECT collation('Spark SQL') | struct | | org.apache.spark.sql.catalyst.expressions.Concat | concat | SELECT concat('Spark', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.ConcatWs | concat_ws | SELECT concat_ws(' ', 'Spark', 'SQL') | struct | From a92b4e17f026bf7a4d74a98eafa627634c964729 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 28 Mar 2024 19:15:30 +0100 Subject: [PATCH 46/87] Fix pyspark error --- .../src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala | 4 ++-- .../org/apache/spark/sql/catalyst/expressions/literals.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index f01014e1edbb3..92a4c687362da 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -47,9 +47,9 @@ private[sql] object ArrowUtils { case LongType => new ArrowType.Int(8 * 8, true) case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE - case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) case DateType => new ArrowType.Date(DateUnit.DAY) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3c4e74421250d..1b20da0b5cbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -485,7 +485,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def sql: String = (value, dataType) match { case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" case _ if value == null => s"CAST(NULL AS ${dataType.sql})" - case (v: UTF8String, StringType) => + case (v: UTF8String, _: StringType) => // Escapes all backslashes and single quotes. "'" + v.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" case (v: Byte, ByteType) => s"${v}Y" From f67808ef2d85ae53f439f735b2252e56489c4de4 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 29 Mar 2024 07:12:10 +0100 Subject: [PATCH 47/87] Fix errors --- .../main/resources/error/error-classes.json | 2 +- ...nditions-collation-mismatch-error-class.md | 2 +- docs/sql-error-conditions.md | 2 +- .../analysis/CollationTypeCasts.scala | 24 +++++++++++-------- .../expressions/collationExpressions.scala | 4 ++-- 5 files changed, 19 insertions(+), 15 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index a94a5b443fe9c..d095f149aeefa 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -469,7 +469,7 @@ }, "COLLATION_MISMATCH" : { "message" : [ - "Could not determine which collation to use for string comparison." + "Could not determine which collation to use for string functions and operators." ], "subClass" : { "EXPLICIT" : { diff --git a/docs/sql-error-conditions-collation-mismatch-error-class.md b/docs/sql-error-conditions-collation-mismatch-error-class.md index 616aed2029759..b6a63d87b36a0 100644 --- a/docs/sql-error-conditions-collation-mismatch-error-class.md +++ b/docs/sql-error-conditions-collation-mismatch-error-class.md @@ -26,7 +26,7 @@ license: | [SQLSTATE: 42P21](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) -Could not determine which collation to use for string comparison. +Could not determine which collation to use for string functions and operators. This error class has the following derived error classes: diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 890615d7492be..d2aaaa65dcd5c 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -394,7 +394,7 @@ The value `` does not represent a correct collation name. Suggest [SQLSTATE: 42P21](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) -Could not determine which collation to use for string comparison. +Could not determine which collation to use for string functions and operators. For more details see [COLLATION_MISMATCH](sql-error-conditions-collation-mismatch-error-class.html) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index d58ef8c3e7d45..e2c764116947b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -29,18 +29,22 @@ import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, String object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - // Case when we do fail if resulting collation is indeterminate - case checkCastWithoutIndeterminate @ (_: BinaryExpression - | _: Predicate - | _: SortOrder - | _: ExpectsInputTypes - | _: ComplexTypeMergingExpression - | _: CreateArray) - if shouldCast(checkCastWithoutIndeterminate.children) => - val newChildren = collateToSingleType(checkCastWithoutIndeterminate.children) - checkCastWithoutIndeterminate.withNewChildren(newChildren) + case sc @ (_: BinaryExpression + | _: Predicate + | _: SortOrder + | _: ExpectsInputTypes + | _: ComplexTypeMergingExpression + | _: CreateArray) + if shouldCast(sc.children) => + val newChildren = collateToSingleType(sc.children) + sc.withNewChildren(newChildren) } + /** + * Checks whether we have differently collated strings in the given DataTypes + * @param types + * @return + */ def shouldCast(types: Seq[Expression]): Boolean = { types.filter(e => hasStringType(e.dataType)) .map(e => extractStringType(e.dataType).collationId).distinct.size > 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index f167fa6e87e20..ca13508c30c33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -114,12 +114,12 @@ case class Collate(child: Expression, collationName: String) // scalastyle:on line.contains.tab case class Collation(child: Expression) extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes { - override def dataType: DataType = child.dataType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild) override def replacement: Expression = { val collationId = child.dataType.asInstanceOf[StringType].collationId val collationName = CollationFactory.fetchCollation(collationId).collationName - Literal.create(collationName, StringType(collationId)) + Literal.create(collationName, SQLConf.get.defaultStringType) } override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) } From 815ce4263202670ef7bbb47c13626055456eab06 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 29 Mar 2024 07:56:46 +0100 Subject: [PATCH 48/87] Fix schema error --- .../src/test/resources/sql-functions/sql-expression-schema.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index cbc107451dcf0..bd1b6f0cb7537 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -81,7 +81,7 @@ | org.apache.spark.sql.catalyst.expressions.Chr | char | SELECT char(65) | struct | | org.apache.spark.sql.catalyst.expressions.Chr | chr | SELECT chr(65) | struct | | org.apache.spark.sql.catalyst.expressions.Coalesce | coalesce | SELECT coalesce(NULL, 1, NULL) | struct | -| org.apache.spark.sql.catalyst.expressions.CollateExpressionBuilder | collate | SELECT COLLATION('Spark SQL' collate UTF8_BINARY_LCASE) | struct | +| org.apache.spark.sql.catalyst.expressions.CollateExpressionBuilder | collate | SELECT COLLATION('Spark SQL' collate UTF8_BINARY_LCASE) | struct | | org.apache.spark.sql.catalyst.expressions.Collation | collation | SELECT collation('Spark SQL') | struct | | org.apache.spark.sql.catalyst.expressions.Concat | concat | SELECT concat('Spark', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.ConcatWs | concat_ws | SELECT concat_ws(' ', 'Spark', 'SQL') | struct | From b19b0ebc68dc48ec435bdf5bdd89f3d7a408a493 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 29 Mar 2024 09:08:13 +0100 Subject: [PATCH 49/87] Fix collated tests --- .../sql/CollationRegexpExpressionsSuite.scala | 20 +++++++++---------- .../sql/CollationStringExpressionsSuite.scala | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index a99e2f16bf3f7..089e3106a0d3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -64,7 +64,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -103,7 +103,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""lower(${ct.s1})"""", + "inputSql" -> s""""lower('${ct.s1}')"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -141,7 +141,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -180,7 +180,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -219,7 +219,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -258,7 +258,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -297,7 +297,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -336,7 +336,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -375,7 +375,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -414,7 +414,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""${ct.s1}"""", + "inputSql" -> s""""'${ct.s1}'"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 1d50eda09a5d1..97647a9ae02ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -68,7 +68,7 @@ class CollationStringExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> """" """", + "inputSql" -> """"' '"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) From a111f031ab6bc44e9e5c4aacef4109e26fa90262 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 29 Mar 2024 19:05:05 +0100 Subject: [PATCH 50/87] Add isExplicit flag --- .../apache/spark/sql/types/StringType.scala | 6 ++ .../analysis/CollationTypeCasts.scala | 72 +++++-------------- .../sql/catalyst/analysis/TypeCoercion.scala | 18 ++--- .../expressions/collationExpressions.scala | 2 +- 4 files changed, 35 insertions(+), 63 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 47d85b2c645c8..1eb9f2ffd074c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory */ @Stable class StringType private(val collationId: Int) extends AtomicType with Serializable { + var isExplicit: Boolean = false /** * Support for Binary Equality implies that strings are considered equal only if * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered @@ -78,6 +79,11 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa @Stable case object StringType extends StringType(0) { private[spark] def apply(collationId: Int): StringType = new StringType(collationId) + private[spark] def apply(collationId: Int, isExplicit: Boolean): StringType = { + val st = new StringType(collationId) + st.isExplicit = isExplicit + st + } def apply(collation: String): StringType = { val collationId = CollationFactory.collationNameToId(collation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index e2c764116947b..a89403f213150 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} @@ -34,32 +34,11 @@ object CollationTypeCasts extends TypeCoercionRule { | _: SortOrder | _: ExpectsInputTypes | _: ComplexTypeMergingExpression - | _: CreateArray) - if shouldCast(sc.children) => + | _: CreateArray) => val newChildren = collateToSingleType(sc.children) sc.withNewChildren(newChildren) } - /** - * Checks whether we have differently collated strings in the given DataTypes - * @param types - * @return - */ - def shouldCast(types: Seq[Expression]): Boolean = { - types.filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e.dataType).collationId).distinct.size > 1 - } - - /** - * Checks whether given data type contains StringType. - */ - @tailrec - def hasStringType(dt: DataType): Boolean = dt match { - case _: StringType => true - case ArrayType(et, _) => hasStringType(et) - case _ => false - } - /** * Extracts StringTypes from filtered hasStringType */ @@ -77,14 +56,12 @@ object CollationTypeCasts extends TypeCoercionRule { * @return */ def castStringType(expr: Expression, collationId: Int): Option[Expression] = - castStringType(expr.dataType, collationId).map { dt => - if (dt == expr.dataType) expr else Cast(expr, dt) - } + castStringType(expr.dataType, collationId).map { dt => Cast(expr, dt)} private def castStringType(inType: AbstractDataType, collationId: Int): Option[DataType] = { @Nullable val ret: DataType = inType match { - case st: StringType if st.collationId == collationId => st - case _: StringType => StringType(collationId) + case st: StringType if st.collationId == collationId && !st.isExplicit => null + case _: StringType => StringType(collationId, isExplicit = false) case ArrayType(arrType, nullable) => castStringType(arrType, collationId).map(ArrayType(_, nullable)).orNull case _ => null @@ -96,7 +73,7 @@ object CollationTypeCasts extends TypeCoercionRule { * Collates input expressions to a single collation. */ def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { - val collationId = getOutputCollation(exprs) + val collationId = getOutputCollation(exprs.map(_.dataType)) exprs.map(e => castStringType(e, collationId).getOrElse(e)) } @@ -107,9 +84,13 @@ object CollationTypeCasts extends TypeCoercionRule { * any expressions, but will only be affected by collated StringTypes or * complex DataTypes with collated StringTypes (e.g. ArrayType) */ - def getOutputCollation(exprs: Seq[Expression]): Int = { - val explicitTypes = exprs.filter(hasExplicitCollation) - .map(e => extractStringType(e.dataType).collationId).distinct + def getOutputCollation(dataTypes: Seq[DataType]): Int = { + val explicitTypes = + dataTypes.filter(hasStringType) + .map(extractStringType) + .filter(_.isExplicit) + .map(_.collationId) + .distinct explicitTypes.size match { // We have 1 explicit collation @@ -122,14 +103,13 @@ object CollationTypeCasts extends TypeCoercionRule { ) // Only implicit or default collations present case 0 => - val dataTypes = exprs.filter(e => hasStringType(e.dataType)) - .map(e => extractStringType(e.dataType)) + val implicitTypes = dataTypes.filter(hasStringType).map(extractStringType) - if (hasMultipleImplicits(dataTypes)) { + if (hasMultipleImplicits(implicitTypes)) { throw QueryCompilationErrors.implicitCollationMismatchError() } else { - dataTypes.find(dt => !(dt == SQLConf.get.defaultStringType)) + implicitTypes.find(dt => !(dt == SQLConf.get.defaultStringType)) .getOrElse(SQLConf.get.defaultStringType) .collationId } @@ -144,21 +124,7 @@ object CollationTypeCasts extends TypeCoercionRule { * @return */ private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = - dataTypes.filter(dt => !(dt == SQLConf.get.defaultStringType)) - .map(_.collationId).distinct.size > 1 + dataTypes.map(_.collationId) + .filter(dt => !(dt == SQLConf.get.defaultStringType.collationId)).distinct.size > 1 - /** - * Checks if a given expression has explicitly set collation. For complex DataTypes - * we need to check nested children. - * @param expression - * @return - */ - private def hasExplicitCollation(expression: Expression): Boolean = { - expression match { - case _: Collate => true - case e if e.dataType.isInstanceOf[ArrayType] - => expression.children.exists(hasExplicitCollation) - case _ => false - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 7e25afc84b9aa..fdb70f5715602 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts.{castStringType, getOutputCollation, hasStringType} +import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts.{castStringType, getOutputCollation} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -609,9 +609,9 @@ abstract class TypeCoercionBase { case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val collationId = getOutputCollation(c.children) + val collationId = getOutputCollation(c.children.map(_.dataType)) val newChildren = c.children.map { e => - implicitCast(e, StringType(collationId)).getOrElse(e) + implicitCast(e, StringType(collationId, isExplicit = false)).getOrElse(e) } c.copy(children = newChildren) @@ -659,9 +659,9 @@ abstract class TypeCoercionBase { val newIndex = implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - val collationId = getOutputCollation(children) + val collationId = getOutputCollation(children.map(_.dataType)) children.tail.map { e => - implicitCast(e, StringType(collationId)).getOrElse(e) + implicitCast(e, StringType(collationId, isExplicit = false)).getOrElse(e) } } else { children.tail @@ -710,7 +710,7 @@ abstract class TypeCoercionBase { // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } - val collationId = getOutputCollation(e.children) + val collationId = getOutputCollation(e.children.map(_.dataType)) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, collationId).getOrElse(in) @@ -729,7 +729,7 @@ abstract class TypeCoercionBase { in } } - val collationId = getOutputCollation(e.children) + val collationId = getOutputCollation(e.children.map(_.dataType)) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, collationId).getOrElse(in) @@ -751,7 +751,7 @@ abstract class TypeCoercionBase { ).getOrElse(in) } } - val collationId = getOutputCollation(udf.children) + val collationId = getOutputCollation(udf.children.map(_.dataType)) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, collationId).getOrElse(in) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index ca13508c30c33..775ba196ba9da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -81,7 +81,7 @@ object CollateExpressionBuilder extends ExpressionBuilder { case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) - override def dataType: DataType = StringType(collationId) + override def dataType: DataType = StringType(collationId, isExplicit = true) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override protected def withNewChildInternal( From 55bdd9b607b00e2ea34f73cf4f8c433c98428c89 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 29 Mar 2024 19:11:53 +0100 Subject: [PATCH 51/87] Fix import error --- .../apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index a89403f213150..fef0cec377db4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType From a7228be0616fc18e4defd290b73e973777b650a1 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Sun, 31 Mar 2024 10:23:00 +0200 Subject: [PATCH 52/87] Fix imports in TypeCoercion --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index fdb70f5715602..bd6b2e99f9e05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import scala.collection.mutable From 18ada04c53d680e2946bc8432ff7d57d81e0c2ea Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Apr 2024 10:18:33 +0200 Subject: [PATCH 53/87] Add support for explicit propagation in arrays --- .../analysis/CollationTypeCasts.scala | 32 +++++++++++-------- .../sql/catalyst/analysis/TypeCoercion.scala | 20 ++++++------ .../expressions/complexTypeCreator.scala | 2 +- .../analyzer-results/collations.sql.out | 20 ++++++------ .../org/apache/spark/sql/CollationSuite.scala | 10 ++++++ 5 files changed, 49 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index fef0cec377db4..0591183e229df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -34,10 +34,12 @@ object CollationTypeCasts extends TypeCoercionRule { | _: Predicate | _: SortOrder | _: ExpectsInputTypes - | _: ComplexTypeMergingExpression - | _: CreateArray) => + | _: ComplexTypeMergingExpression) => val newChildren = collateToSingleType(sc.children) sc.withNewChildren(newChildren) + case pesc @ (_: CreateArray) => + val newChildren = collateToSingleType(pesc.children, true) + pesc.withNewChildren(newChildren) } /** @@ -56,15 +58,16 @@ object CollationTypeCasts extends TypeCoercionRule { * @param collationId * @return */ - def castStringType(expr: Expression, collationId: Int): Option[Expression] = - castStringType(expr.dataType, collationId).map { dt => Cast(expr, dt)} + def castStringType(expr: Expression, st: StringType): Option[Expression] = + castStringType(expr.dataType, st).map { dt => Cast(expr, dt)} - private def castStringType(inType: AbstractDataType, collationId: Int): Option[DataType] = { + private def castStringType(inType: AbstractDataType, st: StringType): Option[DataType] = { @Nullable val ret: DataType = inType match { - case st: StringType if st.collationId == collationId && !st.isExplicit => null - case _: StringType => StringType(collationId, isExplicit = false) + case ost: StringType if ost.collationId == st.collationId + && ost.isExplicit == st.isExplicit => null + case _: StringType => st case ArrayType(arrType, nullable) => - castStringType(arrType, collationId).map(ArrayType(_, nullable)).orNull + castStringType(arrType, st).map(ArrayType(_, nullable)).orNull case _ => null } Option(ret) @@ -73,10 +76,11 @@ object CollationTypeCasts extends TypeCoercionRule { /** * Collates input expressions to a single collation. */ - def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { - val collationId = getOutputCollation(exprs.map(_.dataType)) + def collateToSingleType(exprs: Seq[Expression], + preserveExplicit: Boolean = false): Seq[Expression] = { + val st = getOutputCollation(exprs.map(_.dataType), preserveExplicit) - exprs.map(e => castStringType(e, collationId).getOrElse(e)) + exprs.map(e => castStringType(e, st).getOrElse(e)) } /** @@ -85,7 +89,8 @@ object CollationTypeCasts extends TypeCoercionRule { * any expressions, but will only be affected by collated StringTypes or * complex DataTypes with collated StringTypes (e.g. ArrayType) */ - def getOutputCollation(dataTypes: Seq[DataType]): Int = { + def getOutputCollation(dataTypes: Seq[DataType], + preserveExplicit: Boolean = false): StringType = { val explicitTypes = dataTypes.filter(hasStringType) .map(extractStringType) @@ -95,7 +100,7 @@ object CollationTypeCasts extends TypeCoercionRule { explicitTypes.size match { // We have 1 explicit collation - case 1 => explicitTypes.head + case 1 => StringType(explicitTypes.head, preserveExplicit) // Multiple explicit collations occurred case size if size > 1 => throw QueryCompilationErrors @@ -112,7 +117,6 @@ object CollationTypeCasts extends TypeCoercionRule { else { implicitTypes.find(dt => !(dt == SQLConf.get.defaultStringType)) .getOrElse(SQLConf.get.defaultStringType) - .collationId } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index bd6b2e99f9e05..a3fcdcd93d596 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -610,9 +610,9 @@ abstract class TypeCoercionBase { case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val collationId = getOutputCollation(c.children.map(_.dataType)) + val st = getOutputCollation(c.children.map(_.dataType)) val newChildren = c.children.map { e => - implicitCast(e, StringType(collationId, isExplicit = false)).getOrElse(e) + implicitCast(e, st).getOrElse(e) } c.copy(children = newChildren) @@ -660,9 +660,9 @@ abstract class TypeCoercionBase { val newIndex = implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - val collationId = getOutputCollation(children.map(_.dataType)) + val st = getOutputCollation(children.map(_.dataType)) children.tail.map { e => - implicitCast(e, StringType(collationId, isExplicit = false)).getOrElse(e) + implicitCast(e, st).getOrElse(e) } } else { children.tail @@ -711,10 +711,10 @@ abstract class TypeCoercionBase { // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } - val collationId = getOutputCollation(e.children.map(_.dataType)) + val st = getOutputCollation(e.children.map(_.dataType)) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => - castStringType(in, collationId).getOrElse(in) + castStringType(in, st).getOrElse(in) case in => in } e.withNewChildren(children) @@ -730,10 +730,10 @@ abstract class TypeCoercionBase { in } } - val collationId = getOutputCollation(e.children.map(_.dataType)) + val st = getOutputCollation(e.children.map(_.dataType)) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => - castStringType(in, collationId).getOrElse(in) + castStringType(in, st).getOrElse(in) case in => in } e.withNewChildren(children) @@ -752,10 +752,10 @@ abstract class TypeCoercionBase { ).getOrElse(in) } } - val collationId = getOutputCollation(udf.children.map(_.dataType)) + val st = getOutputCollation(udf.children.map(_.dataType)) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => - castStringType(in, collationId).getOrElse(in) + castStringType(in, st).getOrElse(in) case in => in } udf.copy(children = children) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 993684f2c1ed4..3eb6225b5426e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -78,7 +78,7 @@ case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolea private val defaultElementType: DataType = { if (useStringTypeWhenEmpty) { - StringType + SQLConf.get.defaultStringType } else { NullType } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index d242a60a17c18..74de072ba4ad6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -72,7 +72,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase = 'aaa' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x = collate(aaa, utf8_binary_lcase)) ++- Filter (utf8_binary_lcase#x = cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -90,7 +90,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase < 'bbb' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x < collate(bbb, utf8_binary_lcase)) ++- Filter (utf8_binary_lcase#x < cast(collate(bbb, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -254,14 +254,14 @@ DropTable false, false -- !query select array_contains(ARRAY('aaa' collate utf8_binary_lcase),'AAA' collate utf8_binary_lcase) -- !query analysis -Project [array_contains(array(collate(aaa, utf8_binary_lcase)), collate(AAA, utf8_binary_lcase)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +Project [array_contains(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +- OneRowRelation -- !query select array_position(ARRAY('aaa' collate utf8_binary_lcase, 'bbb' collate utf8_binary_lcase),'BBB' collate utf8_binary_lcase) -- !query analysis -Project [array_position(array(collate(aaa, utf8_binary_lcase), collate(bbb, utf8_binary_lcase)), collate(BBB, utf8_binary_lcase)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +Project [array_position(cast(array(collate(aaa, utf8_binary_lcase), collate(bbb, utf8_binary_lcase)) as array), cast(collate(BBB, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +- OneRowRelation @@ -275,40 +275,40 @@ Project [nullif(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase) -- !query select least('aaa' COLLATE utf8_binary_lcase, 'AAA' collate utf8_binary_lcase, 'a' collate utf8_binary_lcase) -- !query analysis -Project [least(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase), collate(a, utf8_binary_lcase)) AS least(collate(aaa), collate(AAA), collate(a))#x] +Project [least(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(a, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS least(collate(aaa), collate(AAA), collate(a))#x] +- OneRowRelation -- !query select arrays_overlap(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [arrays_overlap(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +Project [arrays_overlap(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_distinct(array('aaa' collate utf8_binary_lcase, 'AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_distinct(array(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase))) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +Project [array_distinct(cast(array(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase)) as array)) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +- OneRowRelation -- !query select array_union(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_union(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_union(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_intersect(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_intersect(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_intersect(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_except(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_except(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index a6e6dfc428f04..639f39bc699e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -638,6 +638,16 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.IMPLICIT" ) + + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT array('A', 'a' COLLATE UNICODE) == array('b' COLLATE UNICODE_CI)") + }, + errorClass = "COLLATION_MISMATCH.EXPLICIT", + parameters = Map( + "explicitTypes" -> "`string collate UNICODE`.`string collate UNICODE_CI`" + ) + ) } } From 38670afe2b5edc62fba4a1997a1085f07632cb5d Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Apr 2024 11:59:24 +0200 Subject: [PATCH 54/87] Fix tests to follow recent changes --- .../sql/CollationRegexpExpressionsSuite.scala | 20 +++++++++---------- .../sql/CollationStringExpressionsSuite.scala | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 089e3106a0d3c..9fdc2acc085d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -64,7 +64,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -103,7 +103,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""lower('${ct.s1}')"""", + "inputSql" -> s""""lower('${ct.s1}' collate ${ct.collation})"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -141,7 +141,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -180,7 +180,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -219,7 +219,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -258,7 +258,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -297,7 +297,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -336,7 +336,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -375,7 +375,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) @@ -414,7 +414,7 @@ class CollationRegexpExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> s""""'${ct.s1}'"""", + "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 97647a9ae02ca..aa0a5367719f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -68,7 +68,7 @@ class CollationStringExpressionsSuite messageParameters = Map( "paramIndex" -> "first", "requiredType" -> """"STRING"""", - "inputSql" -> """"' '"""", + "inputSql" -> s""""' ' collate ${ct.collation}"""", "inputType" -> s""""STRING COLLATE ${ct.collation}"""" ) ) From 01d891ea5baf68ca996c2c85ec9b728a9a8e62e0 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Apr 2024 16:34:04 +0200 Subject: [PATCH 55/87] Incorporate changes --- .../apache/spark/sql/types/StringType.scala | 5 +++-- .../analysis/CollationTypeCasts.scala | 22 ++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1eb9f2ffd074c..61f6b43612eba 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -27,8 +27,9 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * @param collationId The id of collation for this StringType. */ @Stable -class StringType private(val collationId: Int) extends AtomicType with Serializable { - var isExplicit: Boolean = false +class StringType private(val collationId: Int, var isExplicit: Boolean = false) + extends AtomicType + with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 0591183e229df..df9c4582da082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, ComplexTypeMergingExpression, CreateArray, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} @@ -31,10 +30,11 @@ object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e case sc @ (_: BinaryExpression - | _: Predicate - | _: SortOrder - | _: ExpectsInputTypes - | _: ComplexTypeMergingExpression) => + | _: ComplexTypeMergingExpression + | _: Elt + | _: ExpectsInputTypes + | _: Predicate + | _: SortOrder) => val newChildren = collateToSingleType(sc.children) sc.withNewChildren(newChildren) case pesc @ (_: CreateArray) => @@ -76,8 +76,9 @@ object CollationTypeCasts extends TypeCoercionRule { /** * Collates input expressions to a single collation. */ - def collateToSingleType(exprs: Seq[Expression], - preserveExplicit: Boolean = false): Seq[Expression] = { + def collateToSingleType( + exprs: Seq[Expression], + preserveExplicit: Boolean = false): Seq[Expression] = { val st = getOutputCollation(exprs.map(_.dataType), preserveExplicit) exprs.map(e => castStringType(e, st).getOrElse(e)) @@ -89,8 +90,9 @@ object CollationTypeCasts extends TypeCoercionRule { * any expressions, but will only be affected by collated StringTypes or * complex DataTypes with collated StringTypes (e.g. ArrayType) */ - def getOutputCollation(dataTypes: Seq[DataType], - preserveExplicit: Boolean = false): StringType = { + def getOutputCollation( + dataTypes: Seq[DataType], + preserveExplicit: Boolean = false): StringType = { val explicitTypes = dataTypes.filter(hasStringType) .map(extractStringType) From c5daf86ad6253bbd43d77c74f93e6a502c016ce3 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Apr 2024 16:35:10 +0200 Subject: [PATCH 56/87] Fix error --- .../src/main/scala/org/apache/spark/sql/types/StringType.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 61f6b43612eba..a98cda850b540 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -78,7 +78,7 @@ class StringType private(val collationId: Int, var isExplicit: Boolean = false) * @since 1.3.0 */ @Stable -case object StringType extends StringType(0) { +case object StringType extends StringType(0, false) { private[spark] def apply(collationId: Int): StringType = new StringType(collationId) private[spark] def apply(collationId: Int, isExplicit: Boolean): StringType = { val st = new StringType(collationId) From 9ac5678cb8c9862ee8292f27d81624bf4c76bdea Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Apr 2024 16:38:34 +0200 Subject: [PATCH 57/87] Change var to val in StringType --- .../main/scala/org/apache/spark/sql/types/StringType.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index a98cda850b540..16b498917f22f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * @param collationId The id of collation for this StringType. */ @Stable -class StringType private(val collationId: Int, var isExplicit: Boolean = false) +class StringType private(val collationId: Int, val isExplicit: Boolean = false) extends AtomicType with Serializable { /** @@ -81,9 +81,7 @@ class StringType private(val collationId: Int, var isExplicit: Boolean = false) case object StringType extends StringType(0, false) { private[spark] def apply(collationId: Int): StringType = new StringType(collationId) private[spark] def apply(collationId: Int, isExplicit: Boolean): StringType = { - val st = new StringType(collationId) - st.isExplicit = isExplicit - st + new StringType(collationId, isExplicit) } def apply(collation: String): StringType = { From 0f1757d3339c6ad31e3b1e0b4f5fdace87dd3043 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Apr 2024 16:58:18 +0200 Subject: [PATCH 58/87] Fix import style --- .../apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index df9c4582da082..0cf6f599c2eca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType From 506c8c008a7c9586fa57954014f0dbc3adc81f37 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Apr 2024 22:50:26 +0200 Subject: [PATCH 59/87] Revert explicit flag addition --- .../apache/spark/sql/types/StringType.scala | 9 +---- .../analysis/CollationTypeCasts.scala | 39 +++++++------------ .../sql/catalyst/analysis/TypeCoercion.scala | 10 ++--- .../expressions/collationExpressions.scala | 2 +- .../analyzer-results/collations.sql.out | 20 +++++----- .../org/apache/spark/sql/CollationSuite.scala | 5 +-- 6 files changed, 34 insertions(+), 51 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 16b498917f22f..47d85b2c645c8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -27,9 +27,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * @param collationId The id of collation for this StringType. */ @Stable -class StringType private(val collationId: Int, val isExplicit: Boolean = false) - extends AtomicType - with Serializable { +class StringType private(val collationId: Int) extends AtomicType with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered @@ -78,11 +76,8 @@ class StringType private(val collationId: Int, val isExplicit: Boolean = false) * @since 1.3.0 */ @Stable -case object StringType extends StringType(0, false) { +case object StringType extends StringType(0) { private[spark] def apply(collationId: Int): StringType = new StringType(collationId) - private[spark] def apply(collationId: Int, isExplicit: Boolean): StringType = { - new StringType(collationId, isExplicit) - } def apply(collation: String): StringType = { val collationId = CollationFactory.collationNameToId(collation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 0cf6f599c2eca..02ee8d36e646d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} @@ -32,15 +32,13 @@ object CollationTypeCasts extends TypeCoercionRule { case e if !e.childrenResolved => e case sc @ (_: BinaryExpression | _: ComplexTypeMergingExpression + | _: CreateArray | _: Elt | _: ExpectsInputTypes | _: Predicate | _: SortOrder) => val newChildren = collateToSingleType(sc.children) sc.withNewChildren(newChildren) - case pesc @ (_: CreateArray) => - val newChildren = collateToSingleType(pesc.children, true) - pesc.withNewChildren(newChildren) } /** @@ -62,13 +60,11 @@ object CollationTypeCasts extends TypeCoercionRule { def castStringType(expr: Expression, st: StringType): Option[Expression] = castStringType(expr.dataType, st).map { dt => Cast(expr, dt)} - private def castStringType(inType: AbstractDataType, st: StringType): Option[DataType] = { + private def castStringType(inType: AbstractDataType, castType: StringType): Option[DataType] = { @Nullable val ret: DataType = inType match { - case ost: StringType if ost.collationId == st.collationId - && ost.isExplicit == st.isExplicit => null - case _: StringType => st + case st: StringType if st.collationId != castType.collationId => castType case ArrayType(arrType, nullable) => - castStringType(arrType, st).map(ArrayType(_, nullable)).orNull + castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull case _ => null } Option(ret) @@ -77,10 +73,8 @@ object CollationTypeCasts extends TypeCoercionRule { /** * Collates input expressions to a single collation. */ - def collateToSingleType( - exprs: Seq[Expression], - preserveExplicit: Boolean = false): Seq[Expression] = { - val st = getOutputCollation(exprs.map(_.dataType), preserveExplicit) + def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { + val st = getOutputCollation(exprs) exprs.map(e => castStringType(e, st).getOrElse(e)) } @@ -91,19 +85,14 @@ object CollationTypeCasts extends TypeCoercionRule { * any expressions, but will only be affected by collated StringTypes or * complex DataTypes with collated StringTypes (e.g. ArrayType) */ - def getOutputCollation( - dataTypes: Seq[DataType], - preserveExplicit: Boolean = false): StringType = { - val explicitTypes = - dataTypes.filter(hasStringType) - .map(extractStringType) - .filter(_.isExplicit) - .map(_.collationId) - .distinct + def getOutputCollation(expr: Seq[Expression]): StringType = { + val explicitTypes = expr.filter(_.isInstanceOf[Collate]) + .map(_.dataType.asInstanceOf[StringType].collationId) + .distinct explicitTypes.size match { // We have 1 explicit collation - case 1 => StringType(explicitTypes.head, preserveExplicit) + case 1 => StringType(explicitTypes.head) // Multiple explicit collations occurred case size if size > 1 => throw QueryCompilationErrors @@ -112,7 +101,9 @@ object CollationTypeCasts extends TypeCoercionRule { ) // Only implicit or default collations present case 0 => - val implicitTypes = dataTypes.filter(hasStringType).map(extractStringType) + val implicitTypes = expr.map(_.dataType) + .filter(hasStringType) + .map(extractStringType) if (hasMultipleImplicits(implicitTypes)) { throw QueryCompilationErrors.implicitCollationMismatchError() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a3fcdcd93d596..bf0c4ae667d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -610,7 +610,7 @@ abstract class TypeCoercionBase { case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val st = getOutputCollation(c.children.map(_.dataType)) + val st = getOutputCollation(c.children) val newChildren = c.children.map { e => implicitCast(e, st).getOrElse(e) } @@ -660,7 +660,7 @@ abstract class TypeCoercionBase { val newIndex = implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - val st = getOutputCollation(children.map(_.dataType)) + val st = getOutputCollation(children) children.tail.map { e => implicitCast(e, st).getOrElse(e) } @@ -711,7 +711,7 @@ abstract class TypeCoercionBase { // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } - val st = getOutputCollation(e.children.map(_.dataType)) + val st = getOutputCollation(e.children) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, st).getOrElse(in) @@ -730,7 +730,7 @@ abstract class TypeCoercionBase { in } } - val st = getOutputCollation(e.children.map(_.dataType)) + val st = getOutputCollation(e.children) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, st).getOrElse(in) @@ -752,7 +752,7 @@ abstract class TypeCoercionBase { ).getOrElse(in) } } - val st = getOutputCollation(udf.children.map(_.dataType)) + val st = getOutputCollation(udf.children) val children: Seq[Expression] = childrenBeforeCollations.map { case in if hasStringType(in.dataType) => castStringType(in, st).getOrElse(in) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index 775ba196ba9da..ca13508c30c33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -81,7 +81,7 @@ object CollateExpressionBuilder extends ExpressionBuilder { case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) - override def dataType: DataType = StringType(collationId, isExplicit = true) + override def dataType: DataType = StringType(collationId) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override protected def withNewChildInternal( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 74de072ba4ad6..d242a60a17c18 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -72,7 +72,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase = 'aaa' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x = cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) ++- Filter (utf8_binary_lcase#x = collate(aaa, utf8_binary_lcase)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -90,7 +90,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase < 'bbb' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x < cast(collate(bbb, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) ++- Filter (utf8_binary_lcase#x < collate(bbb, utf8_binary_lcase)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -254,14 +254,14 @@ DropTable false, false -- !query select array_contains(ARRAY('aaa' collate utf8_binary_lcase),'AAA' collate utf8_binary_lcase) -- !query analysis -Project [array_contains(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +Project [array_contains(array(collate(aaa, utf8_binary_lcase)), collate(AAA, utf8_binary_lcase)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +- OneRowRelation -- !query select array_position(ARRAY('aaa' collate utf8_binary_lcase, 'bbb' collate utf8_binary_lcase),'BBB' collate utf8_binary_lcase) -- !query analysis -Project [array_position(cast(array(collate(aaa, utf8_binary_lcase), collate(bbb, utf8_binary_lcase)) as array), cast(collate(BBB, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +Project [array_position(array(collate(aaa, utf8_binary_lcase), collate(bbb, utf8_binary_lcase)), collate(BBB, utf8_binary_lcase)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +- OneRowRelation @@ -275,40 +275,40 @@ Project [nullif(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase) -- !query select least('aaa' COLLATE utf8_binary_lcase, 'AAA' collate utf8_binary_lcase, 'a' collate utf8_binary_lcase) -- !query analysis -Project [least(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(a, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS least(collate(aaa), collate(AAA), collate(a))#x] +Project [least(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase), collate(a, utf8_binary_lcase)) AS least(collate(aaa), collate(AAA), collate(a))#x] +- OneRowRelation -- !query select arrays_overlap(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [arrays_overlap(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +Project [arrays_overlap(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_distinct(array('aaa' collate utf8_binary_lcase, 'AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_distinct(cast(array(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase)) as array)) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +Project [array_distinct(array(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase))) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +- OneRowRelation -- !query select array_union(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_union(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_union(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_intersect(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_intersect(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_intersect(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_except(cast(array(collate(aaa, utf8_binary_lcase)) as array), cast(array(collate(AAA, utf8_binary_lcase)) as array)) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_except(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 639f39bc699e3..d1baea9c7da5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -643,10 +643,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT array('A', 'a' COLLATE UNICODE) == array('b' COLLATE UNICODE_CI)") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", - parameters = Map( - "explicitTypes" -> "`string collate UNICODE`.`string collate UNICODE_CI`" - ) + errorClass = "COLLATION_MISMATCH.IMPLICIT" ) } } From f743cf831d799d6be4278f95cfab73025291784a Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 2 Apr 2024 11:50:24 +0200 Subject: [PATCH 60/87] Narrow down expressions casting --- .../catalyst/analysis/AnsiTypeCoercion.scala | 5 +- .../analysis/CollationTypeCasts.scala | 53 +++++++++++++------ .../sql/catalyst/analysis/TypeCoercion.scala | 35 +++--------- 3 files changed, 49 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 653eefb23f005..db427939d220c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -77,7 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: new AnsiCombinedTypeCoercionRule( - CollationTypeCasts :: + PreCollationTypeCasts :: InConversion :: PromoteStrings :: DecimalPrecision :: @@ -93,7 +93,8 @@ object AnsiTypeCoercion extends TypeCoercionBase { ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: - GetDateFieldOperations:: Nil) :: Nil + GetDateFieldOperations :: + PostCollationTypeCasts :: Nil) :: Nil val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 02ee8d36e646d..06882e0b63ad9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,25 +22,12 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, CreateArray, Elt, ExpectsInputTypes, Expression, Predicate, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, ConcatWs, CreateArray, Expression, In, InSubquery, Substring} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} -object CollationTypeCasts extends TypeCoercionRule { - override val transform: PartialFunction[Expression, Expression] = { - case e if !e.childrenResolved => e - case sc @ (_: BinaryExpression - | _: ComplexTypeMergingExpression - | _: CreateArray - | _: Elt - | _: ExpectsInputTypes - | _: Predicate - | _: SortOrder) => - val newChildren = collateToSingleType(sc.children) - sc.withNewChildren(newChildren) - } - +abstract class CollationTypeCasts extends TypeCoercionRule { /** * Extracts StringTypes from filtered hasStringType */ @@ -104,6 +91,7 @@ object CollationTypeCasts extends TypeCoercionRule { val implicitTypes = expr.map(_.dataType) .filter(hasStringType) .map(extractStringType) + .filter(dt => dt.collationId != SQLConf.get.defaultStringType.collationId) if (hasMultipleImplicits(implicitTypes)) { throw QueryCompilationErrors.implicitCollationMismatchError() @@ -127,3 +115,38 @@ object CollationTypeCasts extends TypeCoercionRule { .filter(dt => !(dt == SQLConf.get.defaultStringType.collationId)).distinct.size > 1 } + +/** + * This rule is used to collate all existing expressions related to StringType into a single + * collation. Arrays are handled using their elementType and should be cast for these expressions. + */ +object PreCollationTypeCasts extends CollationTypeCasts { + override val transform: PartialFunction[Expression, Expression] = { + case e if !e.childrenResolved => e + case sc@(_: In + | _: InSubquery + | _: CreateArray + | _: ComplexTypeMergingExpression + | _: ArrayJoin + | _: BinaryExpression + | _: ConcatWs + | _: Substring) => + val newChildren = collateToSingleType(sc.children) + sc.withNewChildren(newChildren) + } +} + +/** + * This rule is used for managing expressions that have possible implicit casts from different + * types in ImplicitTypeCasts rule. + */ +object PostCollationTypeCasts extends CollationTypeCasts { + override val transform: PartialFunction[Expression, Expression] = { + case e if !e.childrenResolved => e + case sc@(_: ArrayJoin + | _: BinaryExpression + | _: Substring) => + val newChildren = collateToSingleType(sc.children) + sc.withNewChildren(newChildren) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index bf0c4ae667d0d..c9c367415b83a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -23,8 +23,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts.{castStringType, getOutputCollation} -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType +import org.apache.spark.sql.catalyst.analysis.PreCollationTypeCasts.getOutputCollation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -660,9 +659,8 @@ abstract class TypeCoercionBase { val newIndex = implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - val st = getOutputCollation(children) children.tail.map { e => - implicitCast(e, st).getOrElse(e) + implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) } } else { children.tail @@ -707,22 +705,16 @@ abstract class TypeCoercionBase { }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => - val childrenBeforeCollations: Seq[Expression] = e.children.zip(e.inputTypes).map { + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } - val st = getOutputCollation(e.children) - val children: Seq[Expression] = childrenBeforeCollations.map { - case in if hasStringType(in.dataType) => - castStringType(in, st).getOrElse(in) - case in => in - } e.withNewChildren(children) case e: ExpectsInputTypes if e.inputTypes.nonEmpty => // Convert NullType into some specific target type for ExpectsInputTypes that don't do // general implicit casting. - val childrenBeforeCollations: Seq[Expression] = + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => if (in.dataType == NullType && !expected.acceptsType(NullType)) { Literal.create(null, expected.defaultConcreteType) @@ -730,16 +722,10 @@ abstract class TypeCoercionBase { in } } - val st = getOutputCollation(e.children) - val children: Seq[Expression] = childrenBeforeCollations.map { - case in if hasStringType(in.dataType) => - castStringType(in, st).getOrElse(in) - case in => in - } e.withNewChildren(children) case udf: ScalaUDF if udf.inputTypes.nonEmpty => - val childrenBeforeCollations = udf.children.zip(udf.inputTypes).map { case (in, expected) => + val children = udf.children.zip(udf.inputTypes).map { case (in, expected) => // Currently Scala UDF will only expect `AnyDataType` at top level, so this trick works. // In the future we should create types like `AbstractArrayType`, so that Scala UDF can // accept inputs of array type of arbitrary element type. @@ -752,12 +738,6 @@ abstract class TypeCoercionBase { ).getOrElse(in) } } - val st = getOutputCollation(udf.children) - val children: Seq[Expression] = childrenBeforeCollations.map { - case in if hasStringType(in.dataType) => - castStringType(in, st).getOrElse(in) - case in => in - } udf.copy(children = children) } @@ -860,7 +840,7 @@ object TypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: new CombinedTypeCoercionRule( - CollationTypeCasts :: + PreCollationTypeCasts :: InConversion :: PromoteStrings :: DecimalPrecision :: @@ -877,7 +857,8 @@ object TypeCoercion extends TypeCoercionBase { ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: - StringLiteralCoercion :: Nil) :: Nil + StringLiteralCoercion :: + PostCollationTypeCasts :: Nil) :: Nil override def canCast(from: DataType, to: DataType): Boolean = Cast.canCast(from, to) From 3f46919c0542f9933f3e4580b8b3cd28b1a86b92 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 2 Apr 2024 14:16:50 +0200 Subject: [PATCH 61/87] Add priority flag --- .../catalyst/parser/DataTypeAstBuilder.scala | 7 +++-- .../apache/spark/sql/types/StringType.scala | 18 +++++++++-- .../analysis/CollationTypeCasts.scala | 31 ++++++++++++------- .../sql/catalyst/analysis/TypeCoercion.scala | 8 +++-- .../expressions/collationExpressions.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++-- 6 files changed, 51 insertions(+), 24 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 38ecd29266db7..96861155edbae 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StringTypePriority, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -74,7 +74,10 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => SqlApiConf.get.defaultStringType + case Seq(_) => + val st = SqlApiConf.get.defaultStringType + st.priority = StringTypePriority.ImplicitST + st case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 47d85b2c645c8..4fd30e156536e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -19,6 +19,13 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types.StringTypePriority.{ImplicitST, StringTypePriority} + +object StringTypePriority extends Enumeration { + type StringTypePriority = Value + + val DefaultST, ImplicitST, ExplicitST = Value +} /** * The data type representing `String` values. Please use the singleton `DataTypes.StringType`. @@ -27,7 +34,10 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * @param collationId The id of collation for this StringType. */ @Stable -class StringType private(val collationId: Int) extends AtomicType with Serializable { +class StringType private(val collationId: Int, + var priority: StringTypePriority = ImplicitST) + extends AtomicType + with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered @@ -76,8 +86,10 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa * @since 1.3.0 */ @Stable -case object StringType extends StringType(0) { - private[spark] def apply(collationId: Int): StringType = new StringType(collationId) +case object StringType extends StringType(0, ImplicitST) { + private[spark] def apply(collationId: Int, + priority: StringTypePriority = ImplicitST): StringType = + new StringType(collationId, priority) def apply(collation: String): StringType = { val collationId = CollationFactory.collationNameToId(collation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 06882e0b63ad9..efe36b55bce76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,10 +22,10 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, ConcatWs, CreateArray, Expression, In, InSubquery, Substring} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, Cast, ComplexTypeMergingExpression, ConcatWs, CreateArray, Expression, In, InSubquery, String2StringExpression, Substring} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType, StringTypePriority} abstract class CollationTypeCasts extends TypeCoercionRule { /** @@ -49,7 +49,8 @@ abstract class CollationTypeCasts extends TypeCoercionRule { private def castStringType(inType: AbstractDataType, castType: StringType): Option[DataType] = { @Nullable val ret: DataType = inType match { - case st: StringType if st.collationId != castType.collationId => castType + case st: StringType if st.collationId != castType.collationId + || st.priority != castType.priority => castType case ArrayType(arrType, nullable) => castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull case _ => null @@ -73,8 +74,11 @@ abstract class CollationTypeCasts extends TypeCoercionRule { * complex DataTypes with collated StringTypes (e.g. ArrayType) */ def getOutputCollation(expr: Seq[Expression]): StringType = { - val explicitTypes = expr.filter(_.isInstanceOf[Collate]) - .map(_.dataType.asInstanceOf[StringType].collationId) + val explicitTypes = expr.map(_.dataType) + .filter(hasStringType) + .map(extractStringType) + .filter(dt => dt.priority == StringTypePriority.ExplicitST) + .map(_.collationId) .distinct explicitTypes.size match { @@ -91,14 +95,17 @@ abstract class CollationTypeCasts extends TypeCoercionRule { val implicitTypes = expr.map(_.dataType) .filter(hasStringType) .map(extractStringType) - .filter(dt => dt.collationId != SQLConf.get.defaultStringType.collationId) + .filter(dt => dt.priority == StringTypePriority.ImplicitST) if (hasMultipleImplicits(implicitTypes)) { throw QueryCompilationErrors.implicitCollationMismatchError() } else { - implicitTypes.find(dt => !(dt == SQLConf.get.defaultStringType)) - .getOrElse(SQLConf.get.defaultStringType) + implicitTypes.headOption.getOrElse{ + val st = SQLConf.get.defaultStringType + st.priority = StringTypePriority.ImplicitST + st + } } } } @@ -110,9 +117,8 @@ abstract class CollationTypeCasts extends TypeCoercionRule { * @param dataTypes * @return */ - private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = - dataTypes.map(_.collationId) - .filter(dt => !(dt == SQLConf.get.defaultStringType.collationId)).distinct.size > 1 + private def hasMultipleImplicits(implicitTypes: Seq[StringType]): Boolean = + implicitTypes.map(_.collationId).distinct.size > 1 } @@ -145,7 +151,8 @@ object PostCollationTypeCasts extends CollationTypeCasts { case e if !e.childrenResolved => e case sc@(_: ArrayJoin | _: BinaryExpression - | _: Substring) => + | _: Substring + | _: String2StringExpression) => val newChildren = collateToSingleType(sc.children) sc.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c9c367415b83a..99eb4e27ef332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1000,9 +1000,11 @@ object TypeCoercion extends TypeCoercionBase { case (_: StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (_: StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => StringType - case (any: AtomicType, st: StringTypeCollated) - if !any.isInstanceOf[StringType] => st.defaultConcreteType + case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => + SQLConf.get.defaultStringType + case (any: AtomicType, _: StringTypeCollated) + if !any.isInstanceOf[StringType] => + SQLConf.get.defaultStringType // When we reach here, input type is not acceptable for any types in this type collection, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index ca13508c30c33..5b66adc03412f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -81,7 +81,7 @@ object CollateExpressionBuilder extends ExpressionBuilder { case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) - override def dataType: DataType = StringType(collationId) + override def dataType: DataType = StringType(collationId, StringTypePriority.ExplicitST) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override protected def withNewChildInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index af4498274620a..ecaf945e8af97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans.logical.HintErrorHandler import org.apache.spark.sql.catalyst.util.{CollationFactory, DateTimeUtils} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{AtomicType, StringType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{AtomicType, StringType, StringTypePriority, TimestampNTZType, TimestampType} import org.apache.spark.storage.{StorageLevel, StorageLevelMapper} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.{Utils, VersionUtils} @@ -5063,9 +5063,12 @@ class SQLConf extends Serializable with Logging with SqlApiConf { override def defaultStringType: StringType = { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { - StringType + val st = StringType + st.priority = StringTypePriority.DefaultST + st } else { - StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION))) + StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION)), + StringTypePriority.DefaultST) } } From 4f8fe1db82fea7345098dc9be5d81456bd6772b0 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 2 Apr 2024 15:09:31 +0200 Subject: [PATCH 62/87] Incorporate minor changes --- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 3 +-- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 06882e0b63ad9..56d718f109d12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -97,8 +97,7 @@ abstract class CollationTypeCasts extends TypeCoercionRule { throw QueryCompilationErrors.implicitCollationMismatchError() } else { - implicitTypes.find(dt => !(dt == SQLConf.get.defaultStringType)) - .getOrElse(SQLConf.get.defaultStringType) + implicitTypes.headOption.getOrElse(SQLConf.get.defaultStringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c9c367415b83a..d5a62214eaecc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -705,17 +705,16 @@ abstract class TypeCoercionBase { }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => - val children: Seq[Expression] = e.children.zip(e.inputTypes).map { + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - case (in, expected) => implicitCast(in, expected).getOrElse(in) + implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) case e: ExpectsInputTypes if e.inputTypes.nonEmpty => // Convert NullType into some specific target type for ExpectsInputTypes that don't do // general implicit casting. - val children: Seq[Expression] = - e.children.zip(e.inputTypes).map { case (in, expected) => + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => if (in.dataType == NullType && !expected.acceptsType(NullType)) { Literal.create(null, expected.defaultConcreteType) } else { @@ -1004,7 +1003,6 @@ object TypeCoercion extends TypeCoercionBase { case (any: AtomicType, st: StringTypeCollated) if !any.isInstanceOf[StringType] => st.defaultConcreteType - // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. case (_, TypeCollection(types)) => From 52bf4dcb91a6548272896544ed10e3b83c55cfd6 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 2 Apr 2024 21:48:03 +0200 Subject: [PATCH 63/87] Incorporate changes --- .../catalyst/analysis/AnsiTypeCoercion.scala | 5 +- .../analysis/CollationTypeCasts.scala | 60 +++++++------------ .../sql/catalyst/analysis/TypeCoercion.scala | 7 +-- 3 files changed, 27 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index db427939d220c..dd904be9a7601 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -77,7 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: new AnsiCombinedTypeCoercionRule( - PreCollationTypeCasts :: + CollationTypeCasts :: InConversion :: PromoteStrings :: DecimalPrecision :: @@ -93,8 +93,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: - GetDateFieldOperations :: - PostCollationTypeCasts :: Nil) :: Nil + GetDateFieldOperations :: Nil) :: Nil val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 56d718f109d12..b64ae7abf3ec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -21,13 +21,31 @@ import javax.annotation.Nullable import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.hasStringType -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, Cast, Collate, ComplexTypeMergingExpression, ConcatWs, CreateArray, Expression, In, InSubquery, Substring} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, Substring} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} -abstract class CollationTypeCasts extends TypeCoercionRule { +object CollationTypeCasts extends TypeCoercionRule { + override val transform: PartialFunction[Expression, Expression] = { + case e if !e.childrenResolved => e + case sc@(_: In + | _: InSubquery + | _: CreateArray + | _: If + | _: ArrayJoin + | _: CaseWhen + | _: Concat + | _: Greatest + | _: Least + | _: Coalesce + | _: BinaryExpression + | _: ConcatWs + | _: Substring) => + val newChildren = collateToSingleType(sc.children) + sc.withNewChildren(newChildren) + } /** * Extracts StringTypes from filtered hasStringType */ @@ -92,6 +110,7 @@ abstract class CollationTypeCasts extends TypeCoercionRule { .filter(hasStringType) .map(extractStringType) .filter(dt => dt.collationId != SQLConf.get.defaultStringType.collationId) + .distinctBy(_.collationId) if (hasMultipleImplicits(implicitTypes)) { throw QueryCompilationErrors.implicitCollationMismatchError() @@ -114,38 +133,3 @@ abstract class CollationTypeCasts extends TypeCoercionRule { .filter(dt => !(dt == SQLConf.get.defaultStringType.collationId)).distinct.size > 1 } - -/** - * This rule is used to collate all existing expressions related to StringType into a single - * collation. Arrays are handled using their elementType and should be cast for these expressions. - */ -object PreCollationTypeCasts extends CollationTypeCasts { - override val transform: PartialFunction[Expression, Expression] = { - case e if !e.childrenResolved => e - case sc@(_: In - | _: InSubquery - | _: CreateArray - | _: ComplexTypeMergingExpression - | _: ArrayJoin - | _: BinaryExpression - | _: ConcatWs - | _: Substring) => - val newChildren = collateToSingleType(sc.children) - sc.withNewChildren(newChildren) - } -} - -/** - * This rule is used for managing expressions that have possible implicit casts from different - * types in ImplicitTypeCasts rule. - */ -object PostCollationTypeCasts extends CollationTypeCasts { - override val transform: PartialFunction[Expression, Expression] = { - case e if !e.childrenResolved => e - case sc@(_: ArrayJoin - | _: BinaryExpression - | _: Substring) => - val newChildren = collateToSingleType(sc.children) - sc.withNewChildren(newChildren) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index d5a62214eaecc..23aaa240b63a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -23,7 +23,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.PreCollationTypeCasts.getOutputCollation +import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts.getOutputCollation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -839,7 +839,7 @@ object TypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: new CombinedTypeCoercionRule( - PreCollationTypeCasts :: + CollationTypeCasts :: InConversion :: PromoteStrings :: DecimalPrecision :: @@ -856,8 +856,7 @@ object TypeCoercion extends TypeCoercionBase { ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: - StringLiteralCoercion :: - PostCollationTypeCasts :: Nil) :: Nil + StringLiteralCoercion :: Nil) :: Nil override def canCast(from: DataType, to: DataType): Boolean = Cast.canCast(from, to) From 7cbeafe5b6fe798f22e5b9cada0522be6d7f1009 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Apr 2024 08:18:19 +0200 Subject: [PATCH 64/87] Special case expressions --- .../analysis/CollationTypeCasts.scala | 57 ++++++++++--------- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index b64ae7abf3ec6..f1b101467df91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -30,21 +30,30 @@ import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, String object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e - case sc@(_: In - | _: InSubquery - | _: CreateArray - | _: If - | _: ArrayJoin - | _: CaseWhen - | _: Concat - | _: Greatest - | _: Least - | _: Coalesce - | _: BinaryExpression - | _: ConcatWs - | _: Substring) => - val newChildren = collateToSingleType(sc.children) - sc.withNewChildren(newChildren) + case ifExpr: If => + ifExpr.withNewChildren( + ifExpr.predicate +: collateToSingleType(Seq(ifExpr.trueValue, ifExpr.falseValue))) + case caseWhenExpr: CaseWhen => + val newValues = collateToSingleType( + caseWhenExpr.branches.map(b => b._2) ++ caseWhenExpr.elseValue) + caseWhenExpr.withNewChildren( + interleave(Seq.empty, caseWhenExpr.branches.map(b => b._1), newValues)) + case substrExpr: Substring => + // This case is necessary for changing Substring input to implicit collation + substrExpr.withNewChildren( + collateToSingleType(Seq(substrExpr.str)) :+ substrExpr.pos :+ substrExpr.len) + case otherExpr @ (_: In + | _: InSubquery + | _: CreateArray + | _: ArrayJoin + | _: Concat + | _: Greatest + | _: Least + | _: Coalesce + | _: BinaryExpression + | _: ConcatWs) => + val newChildren = collateToSingleType(otherExpr.children) + otherExpr.withNewChildren(newChildren) } /** * Extracts StringTypes from filtered hasStringType @@ -112,7 +121,7 @@ object CollationTypeCasts extends TypeCoercionRule { .filter(dt => dt.collationId != SQLConf.get.defaultStringType.collationId) .distinctBy(_.collationId) - if (hasMultipleImplicits(implicitTypes)) { + if (implicitTypes.length > 1) { throw QueryCompilationErrors.implicitCollationMismatchError() } else { @@ -121,15 +130,9 @@ object CollationTypeCasts extends TypeCoercionRule { } } - /** - * This check is always preformed when we have no explicit collation. It returns true - * if there are more than one implicit collations. Collations are distinguished by their - * collationId. - * @param dataTypes - * @return - */ - private def hasMultipleImplicits(dataTypes: Seq[StringType]): Boolean = - dataTypes.map(_.collationId) - .filter(dt => !(dt == SQLConf.get.defaultStringType.collationId)).distinct.size > 1 - + @tailrec + final def interleave[A](base: Seq[A], a: Seq[A], b: Seq[A]): Seq[A] = a match { + case elt :: aTail => interleave(base :+ elt, b, aTail) + case _ => base ++ b + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 23aaa240b63a8..a365c9097e184 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -23,7 +23,6 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts.getOutputCollation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -609,11 +608,9 @@ abstract class TypeCoercionBase { case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val st = getOutputCollation(c.children) val newChildren = c.children.map { e => - implicitCast(e, st).getOrElse(e) + implicitCast(e, SQLConf.get.defaultStringType).getOrElse(e) } - c.copy(children = newChildren) } } From 3e92e926df4bdb3fd42987ed5122302a8ea12435 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Apr 2024 08:22:53 +0200 Subject: [PATCH 65/87] Return new line --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a365c9097e184..a999058efe8a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -733,6 +733,7 @@ abstract class TypeCoercionBase { udfInputToCastType(in.dataType, expected.asInstanceOf[DataType]) ).getOrElse(in) } + } udf.copy(children = children) } From b23e1060ae1b91f56176b7a3c612446d42405e80 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Apr 2024 08:31:01 +0200 Subject: [PATCH 66/87] Remove indentation cosmetic --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a999058efe8a0..615d21f676956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -733,7 +733,7 @@ abstract class TypeCoercionBase { udfInputToCastType(in.dataType, expected.asInstanceOf[DataType]) ).getOrElse(in) } - + } udf.copy(children = children) } From 880ebed5b42a2c03c378080c296102d99c871308 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Apr 2024 08:41:04 +0200 Subject: [PATCH 67/87] Add more cosmetic changes --- .../analysis/CollationTypeCasts.scala | 13 +--- .../sql/CollationRegexpExpressionsSuite.scala | 74 +++++++++++-------- .../sql/CollationStringExpressionsSuite.scala | 12 +-- .../org/apache/spark/sql/CollationSuite.scala | 4 +- 4 files changed, 53 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index f1b101467df91..ccbef4c5145a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -42,16 +42,9 @@ object CollationTypeCasts extends TypeCoercionRule { // This case is necessary for changing Substring input to implicit collation substrExpr.withNewChildren( collateToSingleType(Seq(substrExpr.str)) :+ substrExpr.pos :+ substrExpr.len) - case otherExpr @ (_: In - | _: InSubquery - | _: CreateArray - | _: ArrayJoin - | _: Concat - | _: Greatest - | _: Least - | _: Coalesce - | _: BinaryExpression - | _: ConcatWs) => + case otherExpr @ ( + _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | + _: Coalesce | _: BinaryExpression | _: ConcatWs) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 9fdc2acc085d1..c547068a03c3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -29,16 +29,17 @@ import org.apache.spark.sql.types.StringType class CollationRegexpExpressionsSuite extends QueryTest - with SharedSparkSession - with ExpressionEvalHelper { + with SharedSparkSession + with ExpressionEvalHelper { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) case class CollationTestFail[R](s1: String, s2: String, collation: String) test("Support Like string expression with Collation") { - def prepareLike(input: String, - regExp: String, - collation: String): Expression = { + def prepareLike( + input: String, + regExp: String, + collation: String): Expression = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -73,9 +74,10 @@ class CollationRegexpExpressionsSuite } test("Support ILike string expression with Collation") { - def prepareILike(input: String, - regExp: String, - collation: String): Expression = { + def prepareILike( + input: String, + regExp: String, + collation: String): Expression = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -112,9 +114,10 @@ class CollationRegexpExpressionsSuite } test("Support RLike string expression with Collation") { - def prepareRLike(input: String, - regExp: String, - collation: String): Expression = { + def prepareRLike( + input: String, + regExp: String, + collation: String): Expression = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -150,9 +153,10 @@ class CollationRegexpExpressionsSuite } test("Support StringSplit string expression with Collation") { - def prepareStringSplit(input: String, - splitBy: String, - collation: String): Expression = { + def prepareStringSplit( + input: String, + splitBy: String, + collation: String): Expression = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val splitByExpr = Literal.create(splitBy, StringType(collationId)) @@ -189,9 +193,10 @@ class CollationRegexpExpressionsSuite } test("Support RegExpReplace string expression with Collation") { - def prepareRegExpReplace(input: String, - regExp: String, - collation: String): RegExpReplace = { + def prepareRegExpReplace( + input: String, + regExp: String, + collation: String): RegExpReplace = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -228,9 +233,10 @@ class CollationRegexpExpressionsSuite } test("Support RegExpExtract string expression with Collation") { - def prepareRegExpExtract(input: String, - regExp: String, - collation: String): RegExpExtract = { + def prepareRegExpExtract( + input: String, + regExp: String, + collation: String): RegExpExtract = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -267,9 +273,10 @@ class CollationRegexpExpressionsSuite } test("Support RegExpExtractAll string expression with Collation") { - def prepareRegExpExtractAll(input: String, - regExp: String, - collation: String): RegExpExtractAll = { + def prepareRegExpExtractAll( + input: String, + regExp: String, + collation: String): RegExpExtractAll = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -306,9 +313,10 @@ class CollationRegexpExpressionsSuite } test("Support RegExpCount string expression with Collation") { - def prepareRegExpCount(input: String, - regExp: String, - collation: String): Expression = { + def prepareRegExpCount( + input: String, + regExp: String, + collation: String): Expression = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -345,9 +353,10 @@ class CollationRegexpExpressionsSuite } test("Support RegExpSubStr string expression with Collation") { - def prepareRegExpSubStr(input: String, - regExp: String, - collation: String): Expression = { + def prepareRegExpSubStr( + input: String, + regExp: String, + collation: String): Expression = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) @@ -384,9 +393,10 @@ class CollationRegexpExpressionsSuite } test("Support RegExpInStr string expression with Collation") { - def prepareRegExpInStr(input: String, - regExp: String, - collation: String): RegExpInStr = { + def prepareRegExpInStr( + input: String, + regExp: String, + collation: String): RegExpInStr = { val collationId = CollationFactory.collationNameToId(collation) val inputExpr = Literal.create(input, StringType(collationId)) val regExpExpr = Literal.create(regExp, StringType(collationId)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index aa0a5367719f0..c26f3ae02255f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -29,18 +29,18 @@ import org.apache.spark.sql.types.StringType class CollationStringExpressionsSuite extends QueryTest - with SharedSparkSession - with ExpressionEvalHelper { + with SharedSparkSession + with ExpressionEvalHelper { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) - case class CollationTestFail[R](s1: String, s2: String, collation: String) test("Support ConcatWs string expression with Collation") { - def prepareConcatWs(sep: String, - collation: String, - inputs: Any*): ConcatWs = { + def prepareConcatWs( + sep: String, + collation: String, + inputs: Any*): ConcatWs = { val collationId = CollationFactory.collationNameToId(collation) val inputExprs = inputs.map(s => Literal.create(s, StringType(collationId))) val sepExpr = Literal.create(sep, StringType(collationId)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index d1baea9c7da5f..a1a537fb10825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -544,7 +544,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // concat of columns of different collations is allowed // as long as we don't use the result in an unsupported function - // TODO: (SPARK-47210) Add indeterminate support + // TODO(SPARK-47210): Add indeterminate support checkError( exception = intercept[AnalysisException] { sql(s"SELECT c1 || c2 FROM $tableName") @@ -672,7 +672,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - // TODO: (SPARK-47210) Add indeterminate support + // TODO(SPARK-47210): Add indeterminate support test("indeterminate collation checks") { val tableName = "t1" val newTableName = "t2" From 00bd3613a6b3c328cec30759f830cce9003a7385 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Apr 2024 10:30:24 +0200 Subject: [PATCH 68/87] Propagate default collation --- .../catalyst/encoders/AgnosticEncoder.scala | 5 ++-- .../expressions/CallMethodViaReflection.scala | 3 ++- .../catalyst/expressions/ToPrettyString.scala | 5 ++-- .../expressions/collectionOperations.scala | 2 +- .../expressions/complexTypeCreator.scala | 2 +- .../catalyst/expressions/csvExpressions.scala | 4 +-- .../expressions/datetimeExpressions.scala | 10 +++---- .../spark/sql/catalyst/expressions/hash.scala | 6 ++--- .../catalyst/expressions/inputFileBlock.scala | 5 ++-- .../expressions/jsonExpressions.scala | 6 ++--- .../expressions/maskExpressions.scala | 3 ++- .../expressions/mathExpressions.scala | 6 ++--- .../spark/sql/catalyst/expressions/misc.scala | 10 +++---- .../expressions/numberFormatExpressions.scala | 5 ++-- .../expressions/regexpExpressions.scala | 7 ++--- .../expressions/stringExpressions.scala | 26 +++++++++---------- .../catalyst/expressions/urlExpressions.scala | 2 +- .../sql/catalyst/expressions/xml/xpath.scala | 3 ++- .../catalyst/expressions/xmlExpressions.scala | 4 +-- .../analysis/AnalysisErrorSuite.scala | 4 +-- .../connector/catalog/InMemoryBaseTable.scala | 3 ++- .../analyzer-results/collations.sql.out | 24 ++++++++--------- 22 files changed, 77 insertions(+), 68 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 9133abce88adc..1ecb2807618df 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -21,8 +21,9 @@ import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.util.concurrent.ConcurrentHashMap -import scala.reflect.{classTag, ClassTag} +import scala.reflect.{ClassTag, classTag} +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} @@ -162,7 +163,7 @@ object AgnosticEncoders { // Enums are special leafs because we need to capture the class. protected abstract class EnumEncoder[E] extends AgnosticEncoder[E] { override def isPrimitive: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SqlApiConf.get.defaultStringType } case class ScalaEnumEncoder[T, E]( parent: Class[T], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index c42b54222f171..da88f28c26fcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -134,7 +135,7 @@ case class CallMethodViaReflection( } override def nullable: Boolean = true - override val dataType: DataType = StringType + override val dataType: DataType = SQLConf.get.defaultStringType override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala index 8db08dbbcb813..6f71aff9742d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** @@ -32,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with ToStringBase { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a68071f9cfa35..9ced1bd774230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2149,7 +2149,7 @@ case class ArrayJoin( } } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "array_join" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3eb6225b5426e..27169baf76d76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -349,7 +349,7 @@ case class MapFromArrays(left: Expression, right: Expression) case object NamePlaceholder extends LeafExpression with Unevaluable { override lazy val resolved: Boolean = false override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "NamePlaceholder" override def toString: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 4714fc1ded9cd..ef35a757fd3b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -177,7 +177,7 @@ case class SchemaOfCsv( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false @@ -300,7 +300,7 @@ case class StructsToCsv( (row: Any) => UTF8String.fromString(gen.writeToString(row.asInstanceOf[InternalRow])) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index a9155e8daf101..91d2906d390a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -104,7 +104,7 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { since = "3.1.0") case class CurrentTimeZone() extends LeafExpression with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "current_timezone" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } @@ -904,7 +904,7 @@ case class WeekOfYear(child: Expression) extends GetDateField { case class MonthName(child: Expression) extends GetDateField { override val func = DateTimeUtils.getMonthName override val funcName = "getMonthName" - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): MonthName = copy(child = newChild) } @@ -923,7 +923,7 @@ case class DayName(child: Expression) extends GetDateField { override val funcName = "getDayName" override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): DayName = copy(child = newChild) } @@ -951,7 +951,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti def this(left: Expression, right: Expression) = this(left, right, None) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) @@ -1429,7 +1429,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ this(unix, Literal(TimestampFormatter.defaultPattern())) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 436efa8924165..7a19ede79fabb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -63,7 +63,7 @@ import org.apache.spark.util.ArrayImplicits._ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -103,7 +103,7 @@ case class Md5(child: Expression) case class Sha2(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) @@ -169,7 +169,7 @@ case class Sha2(left: Expression, right: Expression) case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(BinaryType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 6cd88367aa9a0..61a423acd7a32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +40,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "input_file_name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index f35c6da4f8af9..3df290092d112 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -132,7 +132,7 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -820,7 +820,7 @@ case class StructsToJson( } } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def checkInputDataTypes(): TypeCheckResult = inputSchema match { case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) => @@ -869,7 +869,7 @@ case class SchemaOfJson( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index e5157685a9a6d..264bd372d823b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -276,7 +277,7 @@ case class Mask( * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query * the dataType of an unresolved expression (i.e., when `resolved` == false). */ - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType /** * Returns a Seq of the children of this node. Children should not change. Immutability required diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 0c09e9be12e94..3387d8f9ed8bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -451,7 +451,7 @@ case class Conv( override def second: Expression = fromBaseExpr override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { @@ -1002,7 +1002,7 @@ case class Bin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(LongType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) @@ -1110,7 +1110,7 @@ case class Hex(child: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType protected override def nullSafeEval(num: Any): Any = child.dataType match { case LongType => Hex.hex(num.asInstanceOf[Long]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index c7281e4e87378..bdbc5296f0845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -199,7 +199,7 @@ object AssertTrue { since = "1.6.0", group = "misc_funcs") case class CurrentDatabase() extends LeafExpression with Unevaluable { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false override def prettyName: String = "current_schema" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) @@ -218,7 +218,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { since = "3.1.0", group = "misc_funcs") case class CurrentCatalog() extends LeafExpression with Unevaluable { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false override def prettyName: String = "current_catalog" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) @@ -251,7 +251,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def stateful: Boolean = true @@ -311,7 +311,7 @@ case class SparkVersion() extends LeafExpression with RuntimeReplaceable { case class TypeOf(child: Expression) extends UnaryExpression { override def nullable: Boolean = false override def foldable: Boolean = true - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def eval(input: InternalRow): Any = UTF8String.fromString(child.dataType.catalogString) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -334,7 +334,7 @@ case class TypeOf(child: Expression) extends UnaryExpression { // scalastyle:on line.size.limit case class CurrentUser() extends LeafExpression with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_user") final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 6d95d7e620a2e..6ee35baaf3869 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -22,10 +22,11 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -279,7 +280,7 @@ case class ToCharacter(left: Expression, right: Expression) } } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b33de303b5d55..da401b75bbec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import java.util.regex.{Matcher, MatchResult, Pattern, PatternSyntaxException} +import java.util.regex.{MatchResult, Matcher, Pattern, PatternSyntaxException} import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -683,7 +684,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType, IntegerType) override def prettyName: String = "regexp_replace" @@ -848,7 +849,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ad15509dc6fda..d6d45d0ac93b6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -82,7 +82,7 @@ case class ConcatWs(children: Seq[Expression]) StringType +: Seq.fill(children.size - 1)(arrayOrStr) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = children.head.nullable override def foldable: Boolean = children.forall(_.foldable) @@ -742,7 +742,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp }) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) override def first: Expression = srcExpr override def second: Expression = searchExpr @@ -960,7 +960,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac }) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) override def first: Expression = srcExpr override def second: Expression = matchingExpr @@ -1020,7 +1020,7 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def direction: String override def children: Seq[Expression] = srcStr +: trimStr.toSeq - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) override def nullable: Boolean = children.exists(_.nullable) @@ -1412,7 +1412,7 @@ case class StringInstr(str: Expression, substr: Expression) case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) override def first: Expression = strExpr override def second: Expression = delimExpr @@ -1715,7 +1715,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = StringType :: List.fill(children.size - 1)(AnyDataType) @@ -1831,7 +1831,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(StringType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullSafeEval(string: Any): Any = { // scalastyle:off caselocale @@ -1895,7 +1895,7 @@ case class StringRepeat(str: Expression, times: Expression) case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(IntegerType) override def nullSafeEval(s: Any): Any = { @@ -2337,7 +2337,7 @@ case class Levenshtein( case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -2415,7 +2415,7 @@ case class Ascii(child: Expression) case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(LongType) protected override def nullSafeEval(lon: Any): Any = { @@ -2464,7 +2464,7 @@ case class Chr(child: Expression) case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(BinaryType) protected override def nullSafeEval(bytes: Any): Any = { @@ -2689,7 +2689,7 @@ case class StringDecode(bin: Expression, charset: Expression, legacyCharsets: Bo override def left: Expression = bin override def right: Expression = charset - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) private val supportedCharsets = Set( @@ -2946,7 +2946,7 @@ case class FormatNumber(x: Expression, d: Expression) override def left: Expression = x override def right: Expression = d - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, TypeCollection(IntegerType, StringType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 47b37a5edeba8..f13e41041ff81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -155,7 +155,7 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" // If the url is a constant, cache the URL object so that we don't need to convert url diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index c3a285178c110..8ab9cc1ed48c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -221,7 +222,7 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { // scalastyle:on line.size.limit case class XPathString(xml: Expression, path: Expression) extends XPathExtract { override def prettyName: String = "xpath_string" - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullSafeEval(xml: Any, path: Any): Any = { val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 415d55d19ded2..f2ba82c3dccb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -178,7 +178,7 @@ case class SchemaOfXml( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false @@ -320,7 +320,7 @@ case class StructsToXml( getAndReset() } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index f12d224096917..954a00fe4a9ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -61,7 +61,7 @@ case class TestFunction( inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes with Unevaluable { override def nullable: Boolean = true - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) } @@ -83,7 +83,7 @@ case class TestFunctionWithTypeCheckFailure( } override def nullable: Boolean = true - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 505a5a6169204..4ff1ce1efe938 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -62,7 +63,7 @@ abstract class InMemoryBaseTable( protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def comment: String = "Partition key used to store the row" } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index d242a60a17c18..a932f49cabad0 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -63,7 +63,7 @@ Aggregate [utf8_binary_lcase#x], [count(1) AS count(1)#xL] select * from t1 where utf8_binary = 'aaa' -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary#x = aaa) ++- Filter (utf8_binary#x = cast(aaa as string)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -72,7 +72,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase = 'aaa' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x = collate(aaa, utf8_binary_lcase)) ++- Filter (utf8_binary_lcase#x = cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -81,7 +81,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary < 'bbb' -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary#x < bbb) ++- Filter (utf8_binary#x < cast(bbb as string)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -90,7 +90,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase < 'bbb' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x < collate(bbb, utf8_binary_lcase)) ++- Filter (utf8_binary_lcase#x < cast(collate(bbb, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -254,14 +254,14 @@ DropTable false, false -- !query select array_contains(ARRAY('aaa' collate utf8_binary_lcase),'AAA' collate utf8_binary_lcase) -- !query analysis -Project [array_contains(array(collate(aaa, utf8_binary_lcase)), collate(AAA, utf8_binary_lcase)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +Project [array_contains(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +- OneRowRelation -- !query select array_position(ARRAY('aaa' collate utf8_binary_lcase, 'bbb' collate utf8_binary_lcase),'BBB' collate utf8_binary_lcase) -- !query analysis -Project [array_position(array(collate(aaa, utf8_binary_lcase), collate(bbb, utf8_binary_lcase)), collate(BBB, utf8_binary_lcase)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +Project [array_position(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(bbb, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), cast(collate(BBB, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +- OneRowRelation @@ -275,40 +275,40 @@ Project [nullif(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase) -- !query select least('aaa' COLLATE utf8_binary_lcase, 'AAA' collate utf8_binary_lcase, 'a' collate utf8_binary_lcase) -- !query analysis -Project [least(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase), collate(a, utf8_binary_lcase)) AS least(collate(aaa), collate(AAA), collate(a))#x] +Project [least(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(a, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS least(collate(aaa), collate(AAA), collate(a))#x] +- OneRowRelation -- !query select arrays_overlap(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [arrays_overlap(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +Project [arrays_overlap(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_distinct(array('aaa' collate utf8_binary_lcase, 'AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_distinct(array(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase))) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +Project [array_distinct(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +- OneRowRelation -- !query select array_union(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_union(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_union(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_intersect(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_intersect(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_intersect(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_except(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_except(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation From f96ecd9c45e4aad1c37ce08fbd50d7f113b5917c Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Apr 2024 11:36:59 +0200 Subject: [PATCH 69/87] Incorporate changes --- .../analysis/CollationTypeCasts.scala | 35 +++++++++---------- .../expressions/collationExpressions.scala | 1 - 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index ccbef4c5145a2..1a14b4227de8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -21,27 +21,30 @@ import javax.annotation.Nullable import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, Substring} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e + case ifExpr: If => ifExpr.withNewChildren( ifExpr.predicate +: collateToSingleType(Seq(ifExpr.trueValue, ifExpr.falseValue))) - case caseWhenExpr: CaseWhen => - val newValues = collateToSingleType( - caseWhenExpr.branches.map(b => b._2) ++ caseWhenExpr.elseValue) - caseWhenExpr.withNewChildren( - interleave(Seq.empty, caseWhenExpr.branches.map(b => b._1), newValues)) - case substrExpr: Substring => - // This case is necessary for changing Substring input to implicit collation - substrExpr.withNewChildren( - collateToSingleType(Seq(substrExpr.str)) :+ substrExpr.pos :+ substrExpr.len) + + case caseWhenExpr: CaseWhen if !haveSameType(caseWhenExpr.inputTypesForMerging) => + val outputStringType = + getOutputCollation(caseWhenExpr.branches.map(_._2) ++ caseWhenExpr.elseValue) + val newBranches = caseWhenExpr.branches.map { case (condition, value) => + (condition, castStringType(value, outputStringType).getOrElse(value)) + } + val newElseValue = + caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e)) + CaseWhen(newBranches, newElseValue) + case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs) => @@ -67,7 +70,7 @@ object CollationTypeCasts extends TypeCoercionRule { def castStringType(expr: Expression, st: StringType): Option[Expression] = castStringType(expr.dataType, st).map { dt => Cast(expr, dt)} - private def castStringType(inType: AbstractDataType, castType: StringType): Option[DataType] = { + private def castStringType(inType: DataType, castType: StringType): Option[DataType] = { @Nullable val ret: DataType = inType match { case st: StringType if st.collationId != castType.collationId => castType case ArrayType(arrType, nullable) => @@ -122,10 +125,4 @@ object CollationTypeCasts extends TypeCoercionRule { } } } - - @tailrec - final def interleave[A](base: Seq[A], a: Seq[A], b: Seq[A]): Seq[A] = a match { - case elt :: aTail => interleave(base :+ elt, b, aTail) - case _ => base ++ b - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index ca13508c30c33..e8b738921b73b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -114,7 +114,6 @@ case class Collate(child: Expression, collationName: String) // scalastyle:on line.contains.tab case class Collation(child: Expression) extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes { - override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild) override def replacement: Expression = { val collationId = child.dataType.asInstanceOf[StringType].collationId From 5002028dc3454272a5727824cc677aedebfe6718 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Apr 2024 14:58:56 +0200 Subject: [PATCH 70/87] Fix priority casting --- .../catalyst/encoders/AgnosticEncoder.scala | 4 +-- .../analysis/CollationTypeCasts.scala | 34 +++++++++++++------ .../catalyst/expressions/inputFileBlock.scala | 4 +-- .../expressions/numberFormatExpressions.scala | 2 +- .../expressions/regexpExpressions.scala | 2 +- 5 files changed, 30 insertions(+), 16 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 1ecb2807618df..04cdfc02b4ec7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -21,10 +21,10 @@ import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.util.concurrent.ConcurrentHashMap -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} -import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.SparkClassUtils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index af47109ea6f7d..1cf22c4084614 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, Substring} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType, StringTypePriority} @@ -35,16 +35,28 @@ object CollationTypeCasts extends TypeCoercionRule { ifExpr.withNewChildren( ifExpr.predicate +: collateToSingleType(Seq(ifExpr.trueValue, ifExpr.falseValue))) - case caseWhenExpr: CaseWhen if !haveSameType(caseWhenExpr.inputTypesForMerging) => + case caseWhenExpr: CaseWhen + if !haveSameType(caseWhenExpr.inputTypesForMerging) + || caseWhenExpr.inputTypesForMerging.exists( + dt => dt.isInstanceOf[StringType] + && dt.asInstanceOf[StringType].priority != StringTypePriority.ImplicitST) => val outputStringType = getOutputCollation(caseWhenExpr.branches.map(_._2) ++ caseWhenExpr.elseValue) val newBranches = caseWhenExpr.branches.map { case (condition, value) => - (condition, castStringType(value, outputStringType).getOrElse(value)) + (condition, castStringType(value, outputStringType)) } val newElseValue = - caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e)) + caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType)) CaseWhen(newBranches, newElseValue) + case substrExpr: Substring + if substrExpr.str.dataType.isInstanceOf[StringType] + && substrExpr.str.dataType.asInstanceOf[StringType].priority + != StringTypePriority.ImplicitST => + val st = substrExpr.dataType.asInstanceOf[StringType] + st.priority = StringTypePriority.ImplicitST + substrExpr.withNewChildren(Seq(Cast(substrExpr.str, st), substrExpr.pos, substrExpr.len)) + case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs) => @@ -67,16 +79,18 @@ object CollationTypeCasts extends TypeCoercionRule { * @param collationId * @return */ - def castStringType(expr: Expression, st: StringType): Option[Expression] = - castStringType(expr.dataType, st).map { dt => Cast(expr, dt)} + def castStringType(expr: Expression, st: StringType): Expression = + castStringType(expr.dataType, st).map { dt => + if (dt == expr.dataType) expr else Cast(expr, dt) + }.getOrElse {Cast(expr, st)} private def castStringType(inType: DataType, castType: StringType): Option[DataType] = { @Nullable val ret: DataType = inType match { - case st: StringType if st.collationId != castType.collationId - || st.priority != castType.priority => castType + case st: StringType if st.collationId != castType.collationId => castType + case st: StringType if st.priority != castType.priority => null case ArrayType(arrType, nullable) => castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull - case _ => null + case other => other } Option(ret) } @@ -87,7 +101,7 @@ object CollationTypeCasts extends TypeCoercionRule { def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { val st = getOutputCollation(exprs) - exprs.map(e => castStringType(e, st).getOrElse(e)) + exprs.map(e => castStringType(e, st)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 61a423acd7a32..65eb995ff32ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, LongType, StringType} +import org.apache.spark.sql.types.{DataType, LongType} import org.apache.spark.unsafe.types.UTF8String // scalastyle:off whitespace.end.of.line diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 6ee35baaf3869..52f9b26a99785 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index da401b75bbec3..dd14cca88f9b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import java.util.regex.{MatchResult, Matcher, Pattern, PatternSyntaxException} +import java.util.regex.{Matcher, MatchResult, Pattern, PatternSyntaxException} import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ From a96f3aa41c764fad427ca6e71b5820f9cd345c2c Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 4 Apr 2024 08:49:14 +0200 Subject: [PATCH 71/87] Fix cosmetics in StringType --- .../org/apache/spark/sql/types/StringType.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 4fd30e156536e..1e92a3e835bd0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -34,10 +34,11 @@ object StringTypePriority extends Enumeration { * @param collationId The id of collation for this StringType. */ @Stable -class StringType private(val collationId: Int, - var priority: StringTypePriority = ImplicitST) +class StringType private( + val collationId: Int, + var priority: StringTypePriority = ImplicitST) extends AtomicType - with Serializable { + with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered @@ -87,8 +88,9 @@ class StringType private(val collationId: Int, */ @Stable case object StringType extends StringType(0, ImplicitST) { - private[spark] def apply(collationId: Int, - priority: StringTypePriority = ImplicitST): StringType = + private[spark] def apply( + collationId: Int, + priority: StringTypePriority = ImplicitST): StringType = new StringType(collationId, priority) def apply(collation: String): StringType = { From 3f35fd7db7f797616cadf676a7439b032b9c6b61 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 5 Apr 2024 08:40:01 +0200 Subject: [PATCH 72/87] Comment out Substring casting --- .../catalyst/analysis/CollationTypeCasts.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 1cf22c4084614..1bdbbd4367b38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, Substring} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType, StringTypePriority} @@ -48,15 +48,15 @@ object CollationTypeCasts extends TypeCoercionRule { val newElseValue = caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType)) CaseWhen(newBranches, newElseValue) - + /* case substrExpr: Substring - if substrExpr.str.dataType.isInstanceOf[StringType] - && substrExpr.str.dataType.asInstanceOf[StringType].priority - != StringTypePriority.ImplicitST => - val st = substrExpr.dataType.asInstanceOf[StringType] + if substrExpr.str.dataType.isInstanceOf[StringType] && + !substrExpr.str.dataType.asInstanceOf[StringType] + .priority.equals(StringTypePriority.ImplicitST) => + val st = substrExpr.str.dataType.asInstanceOf[StringType] st.priority = StringTypePriority.ImplicitST - substrExpr.withNewChildren(Seq(Cast(substrExpr.str, st), substrExpr.pos, substrExpr.len)) - + Substring(Cast(substrExpr.str, st), substrExpr.pos, substrExpr.len) + */ case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs) => From f601f8f6a78d20e11c74044b54dae168b277141c Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 5 Apr 2024 12:37:04 +0200 Subject: [PATCH 73/87] Fix substring error --- .../catalyst/analysis/CollationTypeCasts.scala | 17 ++++++----------- .../catalyst/expressions/mathExpressions.scala | 6 +++--- .../org/apache/spark/sql/CollationSuite.scala | 8 ++++++++ 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 1bdbbd4367b38..83a9ddff8d789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, Substring} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType, StringTypePriority} @@ -48,15 +47,11 @@ object CollationTypeCasts extends TypeCoercionRule { val newElseValue = caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType)) CaseWhen(newBranches, newElseValue) - /* - case substrExpr: Substring - if substrExpr.str.dataType.isInstanceOf[StringType] && - !substrExpr.str.dataType.asInstanceOf[StringType] - .priority.equals(StringTypePriority.ImplicitST) => - val st = substrExpr.str.dataType.asInstanceOf[StringType] - st.priority = StringTypePriority.ImplicitST - Substring(Cast(substrExpr.str, st), substrExpr.pos, substrExpr.len) - */ + + case substrExpr: Substring => + substrExpr.withNewChildren( + collateToSingleType(Seq(substrExpr.str)) :+ substrExpr.pos :+ substrExpr.len) + case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 3387d8f9ed8bc..e63d02a1164fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1108,21 +1108,21 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringType)) + Seq(TypeCollection(LongType, BinaryType, StringTypeAnyCollation)) override def dataType: DataType = SQLConf.get.defaultStringType protected override def nullSafeEval(num: Any): Any = child.dataType match { case LongType => Hex.hex(num.asInstanceOf[Long]) case BinaryType => Hex.hex(num.asInstanceOf[Array[Byte]]) - case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) + case _: StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s"${ev.value} = " + (child.dataType match { - case StringType => s"""$hex.hex($c.getBytes());""" + case _: StringType => s"""$hex.hex($c.getBytes());""" case _ => s"""$hex.hex($c);""" }) }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index a1a537fb10825..3120e0e37345e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -645,6 +645,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.IMPLICIT" ) + + // check if substring passes through implicit collation + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT substr('a' COLLATE UNICODE, 0, 1) == substr('b' COLLATE UNICODE_CI, 0, 1)") + }, + errorClass = "COLLATION_MISMATCH.IMPLICIT" + ) } } From 736c931ad81cb4ab486a190a51be0adbef702ab3 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 5 Apr 2024 14:00:58 +0200 Subject: [PATCH 74/87] Fix import error --- .../apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 83a9ddff8d789..923ed78d71ad1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} From b439ad1f442c1d7b23185dd293a84919738facc3 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 8 Apr 2024 11:42:11 +0200 Subject: [PATCH 75/87] Add support for parameter markers --- .../sql/catalyst/analysis/parameters.scala | 48 +++++++++++++++++-- .../org/apache/spark/sql/CollationSuite.scala | 21 ++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index f1cc44b270bc5..c2ef8f42b692b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.catalyst.analysis +import scala.annotation.tailrec + import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH} import org.apache.spark.sql.errors.QueryErrorsBase -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, StringType, StringTypePriority} sealed trait Parameter extends LeafExpression with Unevaluable { override lazy val resolved: Boolean = false @@ -130,6 +133,25 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { }) } + @tailrec + private def hasStringRelatedParameterWithFaultyPriority(dt: DataType): Boolean = { + dt match { + case st: StringType if st.priority != StringTypePriority.DefaultST => true + case ArrayType(elementType, _) => hasStringRelatedParameterWithFaultyPriority(elementType) + case _ => false + } + } + + private def getStringParameterWithRightPriority(dt: DataType): Option[DataType] = { + val ret = dt match { + case _: StringType => SQLConf.get.defaultStringType + case ArrayType(et, nullable) => + getStringParameterWithRightPriority(et).map(ArrayType(_, nullable)).orNull + case other => other + } + Option(ret) + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (plan.containsPattern(PARAMETERIZED_QUERY)) { // One unresolved plan can have at most one ParameterizedQuery. @@ -148,7 +170,17 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { } val args = argNames.zip(argValues).toMap checkArgs(args) - bind(child) { case NamedParameter(name) if args.contains(name) => args(name) } + bind(child) { case NamedParameter(name) if args.contains(name) => + val bindValue = args(name) + bindValue.dataType match { + case dt if hasStringRelatedParameterWithFaultyPriority(dt) => + Cast(bindValue, + getStringParameterWithRightPriority( + bindValue.dataType).getOrElse(bindValue.dataType)) + case _ => + bindValue + } + } case PosParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_.resolved) => @@ -161,7 +193,15 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { bind(child) { case PosParameter(pos) if posToIndex.contains(pos) && args.size > posToIndex(pos) => - args(posToIndex(pos)) + val bindValue = args(posToIndex(pos)) + bindValue.dataType match { + case dt if hasStringRelatedParameterWithFaultyPriority(dt) => + Cast(bindValue, + getStringParameterWithRightPriority( + bindValue.dataType).getOrElse(bindValue.dataType)) + case _ => + bindValue + } } case _ => plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 3120e0e37345e..ea3d82ec246db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -653,6 +654,26 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.IMPLICIT" ) + + checkAnswer(spark.sql("SELECT collation(:var1 || :var2)", + Map( + "var1" -> Literal.create("a", StringType(1)), + "var2" -> Literal.create("b", StringType(2)) + ) + ), + Seq(Row("UTF8_BINARY")) + ) + + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + checkAnswer(spark.sql("SELECT collation(:var1 || :var2)", + Map( + "var1" -> Literal.create("a", StringType(1)), + "var2" -> Literal.create("b", StringType(2)) + ) + ), + Seq(Row("UNICODE")) + ) + } } } From 008a7956c4f4ac3131c0838703af330da0cfb9b4 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 11 Apr 2024 06:57:19 +0200 Subject: [PATCH 76/87] Improve test --- .../org/apache/spark/sql/types/DataType.scala | 2 ++ .../org/apache/spark/sql/CollationSuite.scala | 25 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 16cf6224ce27b..8cb0f3d633f30 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -182,6 +182,8 @@ object DataType { /** Given the string representation of a type, return its DataType */ private def nameToType(name: String): DataType = { name match { + case "string collate INDETERMINATE" => + StringType(-1) case COLLATED_STRING_TYPE(collation) => val collationId = CollationFactory.collationNameToId(collation) StringType(collationId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 6b5fe9c8e8606..1d8a487d027dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -22,7 +22,6 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -655,22 +654,22 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "COLLATION_MISMATCH.IMPLICIT" ) - checkAnswer(spark.sql("SELECT collation(:var1 || :var2)", - Map( - "var1" -> Literal.create("a", StringType(1)), - "var2" -> Literal.create("b", StringType(2)) - ) - ), + sql(s"DECLARE stmtStr = 'SELECT collation(:var1 || :var2)';") + + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr USING + | 'a' COLLATE UNICODE AS var1, + | 'b' COLLATE UNICODE_CI AS var2;""".stripMargin), Seq(Row("UTF8_BINARY")) ) withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - checkAnswer(spark.sql("SELECT collation(:var1 || :var2)", - Map( - "var1" -> Literal.create("a", StringType(1)), - "var2" -> Literal.create("b", StringType(2)) - ) - ), + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr USING + | 'a' COLLATE UNICODE AS var1, + | 'b' COLLATE UNICODE_CI AS var2;""".stripMargin), Seq(Row("UNICODE")) ) } From d4b72cf09c3b4154eddd30f6759d1e5723c51b62 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 11 Apr 2024 13:16:21 +0200 Subject: [PATCH 77/87] Resolve conflicts --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +- .../apache/spark/sql/catalyst/expressions/mathExpressions.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 45ac78a7589c0..3f045b4063bbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -999,7 +999,7 @@ object TypeCoercion extends TypeCoercionBase { // Cast any atomic type to string. case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => SQLConf.get.defaultStringType - case (any: AtomicType, _: StringTypeCollated) + case (any: AtomicType, _: AbstractStringType) if !any.isInstanceOf[StringType] => SQLConf.get.defaultStringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e63d02a1164fa..92c5486a80435 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String From de3c66048a011c340637fdd963c40e70a9e5042c Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 12 Apr 2024 09:44:53 +0200 Subject: [PATCH 78/87] Fix test --- .../test/scala/org/apache/spark/sql/CollationSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index cf50f8b493624..8b63dca1d0eb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -578,8 +578,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer( sql( """EXECUTE IMMEDIATE stmtStr USING - | 'a' COLLATE UNICODE AS var1, - | 'b' COLLATE UNICODE_CI AS var2;""".stripMargin), + | 'a' AS var1, + | 'b' AS var2;""".stripMargin), Seq(Row("UTF8_BINARY")) ) @@ -587,8 +587,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer( sql( """EXECUTE IMMEDIATE stmtStr USING - | 'a' COLLATE UNICODE AS var1, - | 'b' COLLATE UNICODE_CI AS var2;""".stripMargin), + | 'a' AS var1, + | 'b' AS var2;""".stripMargin), Seq(Row("UNICODE")) ) } From c4a61a2fb6251a7622d5e039917247430ddbe43e Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 19 Apr 2024 16:01:50 +0200 Subject: [PATCH 79/87] Rework default collation meaning --- .../catalyst/parser/DataTypeAstBuilder.scala | 7 +-- .../internal/types/AbstractStringType.scala | 3 +- .../org/apache/spark/sql/types/DataType.scala | 2 - .../apache/spark/sql/types/StringType.scala | 20 ++------ .../analysis/CollationTypeCasts.scala | 27 +++++++---- .../sql/catalyst/analysis/TypeCoercion.scala | 7 ++- .../sql/catalyst/analysis/parameters.scala | 47 ++----------------- .../expressions/collationExpressions.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++-- .../analyzer-results/collations.sql.out | 24 +++++----- 10 files changed, 47 insertions(+), 101 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 96861155edbae..38ecd29266db7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StringTypePriority, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -74,10 +74,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => typeCtx.children.asScala.toSeq match { - case Seq(_) => - val st = SqlApiConf.get.defaultStringType - st.priority = StringTypePriority.ImplicitST - st + case Seq(_) => SqlApiConf.get.defaultStringType case Seq(_, ctx: CollateClauseContext) => val collationName = visitCollateClause(ctx) val collationId = CollationFactory.collationNameToId(collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 6403295fe20c4..0828c2d6fc104 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.internal.types +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** * StringTypeCollated is an abstract class for StringType with collation support. */ abstract class AbstractStringType extends AbstractDataType { - override private[sql] def defaultConcreteType: DataType = StringType + override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType override private[sql] def simpleString: String = "string" } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 8cb0f3d633f30..16cf6224ce27b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -182,8 +182,6 @@ object DataType { /** Given the string representation of a type, return its DataType */ private def nameToType(name: String): DataType = { name match { - case "string collate INDETERMINATE" => - StringType(-1) case COLLATED_STRING_TYPE(collation) => val collationId = CollationFactory.collationNameToId(collation) StringType(collationId) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1e92a3e835bd0..47d85b2c645c8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -19,13 +19,6 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.types.StringTypePriority.{ImplicitST, StringTypePriority} - -object StringTypePriority extends Enumeration { - type StringTypePriority = Value - - val DefaultST, ImplicitST, ExplicitST = Value -} /** * The data type representing `String` values. Please use the singleton `DataTypes.StringType`. @@ -34,11 +27,7 @@ object StringTypePriority extends Enumeration { * @param collationId The id of collation for this StringType. */ @Stable -class StringType private( - val collationId: Int, - var priority: StringTypePriority = ImplicitST) - extends AtomicType - with Serializable { +class StringType private(val collationId: Int) extends AtomicType with Serializable { /** * Support for Binary Equality implies that strings are considered equal only if * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered @@ -87,11 +76,8 @@ class StringType private( * @since 1.3.0 */ @Stable -case object StringType extends StringType(0, ImplicitST) { - private[spark] def apply( - collationId: Int, - priority: StringTypePriority = ImplicitST): StringType = - new StringType(collationId, priority) +case object StringType extends StringType(0) { + private[spark] def apply(collationId: Int): StringType = new StringType(collationId) def apply(collation: String): StringType = { val collationId = CollationFactory.collationNameToId(collation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index cffdd28722241..d6c0ac764637e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Overlay} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType} @@ -48,9 +48,9 @@ object CollationTypeCasts extends TypeCoercionRule { case eltExpr: Elt => eltExpr.withNewChildren(eltExpr.children.head +: collateToSingleType(eltExpr.children.tail)) - case overlay: Overlay => - overlay.withNewChildren(collateToSingleType(Seq(overlay.input, overlay.replace)) - ++ Seq(overlay.pos, overlay.len)) + case overlayExpr: Overlay => + overlayExpr.withNewChildren(collateToSingleType(Seq(overlayExpr.input, overlayExpr.replace)) + ++ Seq(overlayExpr.pos, overlayExpr.len)) case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | @@ -103,7 +103,10 @@ object CollationTypeCasts extends TypeCoercionRule { * complex DataTypes with collated StringTypes (e.g. ArrayType) */ def getOutputCollation(expr: Seq[Expression]): StringType = { - val explicitTypes = expr.filter(_.isInstanceOf[Collate]) + val explicitTypes = expr.filter { + case _: Collate => true + case _ => false + } .map(_.dataType.asInstanceOf[StringType].collationId) .distinct @@ -118,17 +121,21 @@ object CollationTypeCasts extends TypeCoercionRule { ) // Only implicit or default collations present case 0 => - val implicitTypes = expr.map(_.dataType) + val implicitTypes = expr.filter { + case Literal(_, _: StringType) => false + case Cast(child, _: StringType, _, _) => child.dataType.isInstanceOf[StringType] + case _ => true + } + .map(_.dataType) .filter(hasStringType) - .map(extractStringType) - .filter(dt => dt.collationId != SQLConf.get.defaultStringType.collationId) - .distinctBy(_.collationId) + .map(extractStringType(_).collationId) + .distinct if (implicitTypes.length > 1) { throw QueryCompilationErrors.implicitCollationMismatchError() } else { - implicitTypes.headOption.getOrElse(SQLConf.get.defaultStringType) + implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 3cadc9bf3ac4a..506314effde33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -998,11 +998,10 @@ object TypeCoercion extends TypeCoercionBase { case (_: StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (_: StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => - SQLConf.get.defaultStringType - case (any: AtomicType, _: AbstractStringType) + case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] => st + case (any: AtomicType, st: AbstractStringType) if !any.isInstanceOf[StringType] => - SQLConf.get.defaultStringType + st.defaultConcreteType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index c2ef8f42b692b..d74b713e4cb4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.catalyst.analysis -import scala.annotation.tailrec - import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH} import org.apache.spark.sql.errors.QueryErrorsBase -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, DataType, StringType, StringTypePriority} +import org.apache.spark.sql.types.DataType sealed trait Parameter extends LeafExpression with Unevaluable { override lazy val resolved: Boolean = false @@ -133,25 +130,6 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { }) } - @tailrec - private def hasStringRelatedParameterWithFaultyPriority(dt: DataType): Boolean = { - dt match { - case st: StringType if st.priority != StringTypePriority.DefaultST => true - case ArrayType(elementType, _) => hasStringRelatedParameterWithFaultyPriority(elementType) - case _ => false - } - } - - private def getStringParameterWithRightPriority(dt: DataType): Option[DataType] = { - val ret = dt match { - case _: StringType => SQLConf.get.defaultStringType - case ArrayType(et, nullable) => - getStringParameterWithRightPriority(et).map(ArrayType(_, nullable)).orNull - case other => other - } - Option(ret) - } - override def apply(plan: LogicalPlan): LogicalPlan = { if (plan.containsPattern(PARAMETERIZED_QUERY)) { // One unresolved plan can have at most one ParameterizedQuery. @@ -170,16 +148,7 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { } val args = argNames.zip(argValues).toMap checkArgs(args) - bind(child) { case NamedParameter(name) if args.contains(name) => - val bindValue = args(name) - bindValue.dataType match { - case dt if hasStringRelatedParameterWithFaultyPriority(dt) => - Cast(bindValue, - getStringParameterWithRightPriority( - bindValue.dataType).getOrElse(bindValue.dataType)) - case _ => - bindValue - } + bind(child) { case NamedParameter(name) if args.contains(name) => args(name) } case PosParameterizedQuery(child, args) @@ -193,15 +162,7 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { bind(child) { case PosParameter(pos) if posToIndex.contains(pos) && args.size > posToIndex(pos) => - val bindValue = args(posToIndex(pos)) - bindValue.dataType match { - case dt if hasStringRelatedParameterWithFaultyPriority(dt) => - Cast(bindValue, - getStringParameterWithRightPriority( - bindValue.dataType).getOrElse(bindValue.dataType)) - case _ => - bindValue - } + args(posToIndex(pos)) } case _ => plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index 72314925fc588..6af00e193d94d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -82,7 +82,7 @@ object CollateExpressionBuilder extends ExpressionBuilder { case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) - override def dataType: DataType = StringType(collationId, StringTypePriority.ExplicitST) + override def dataType: DataType = StringType(collationId) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override protected def withNewChildInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04d928dac91a8..1c7ae3d0bfa83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans.logical.HintErrorHandler import org.apache.spark.sql.catalyst.util.{CollationFactory, DateTimeUtils} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{AtomicType, StringType, StringTypePriority, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{AtomicType, StringType, TimestampNTZType, TimestampType} import org.apache.spark.storage.{StorageLevel, StorageLevelMapper} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.{Utils, VersionUtils} @@ -5118,12 +5118,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { override def defaultStringType: StringType = { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { - val st = StringType - st.priority = StringTypePriority.DefaultST - st + StringType } else { - StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION)), - StringTypePriority.DefaultST) + StringType(CollationFactory.collationNameToId(getConf(DEFAULT_COLLATION))) } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index a932f49cabad0..d242a60a17c18 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -63,7 +63,7 @@ Aggregate [utf8_binary_lcase#x], [count(1) AS count(1)#xL] select * from t1 where utf8_binary = 'aaa' -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary#x = cast(aaa as string)) ++- Filter (utf8_binary#x = aaa) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -72,7 +72,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase = 'aaa' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x = cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) ++- Filter (utf8_binary_lcase#x = collate(aaa, utf8_binary_lcase)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -81,7 +81,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary < 'bbb' -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary#x < cast(bbb as string)) ++- Filter (utf8_binary#x < bbb) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -90,7 +90,7 @@ Project [utf8_binary#x, utf8_binary_lcase#x] select * from t1 where utf8_binary_lcase < 'bbb' collate utf8_binary_lcase -- !query analysis Project [utf8_binary#x, utf8_binary_lcase#x] -+- Filter (utf8_binary_lcase#x < cast(collate(bbb, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) ++- Filter (utf8_binary_lcase#x < collate(bbb, utf8_binary_lcase)) +- SubqueryAlias spark_catalog.default.t1 +- Relation spark_catalog.default.t1[utf8_binary#x,utf8_binary_lcase#x] parquet @@ -254,14 +254,14 @@ DropTable false, false -- !query select array_contains(ARRAY('aaa' collate utf8_binary_lcase),'AAA' collate utf8_binary_lcase) -- !query analysis -Project [array_contains(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +Project [array_contains(array(collate(aaa, utf8_binary_lcase)), collate(AAA, utf8_binary_lcase)) AS array_contains(array(collate(aaa)), collate(AAA))#x] +- OneRowRelation -- !query select array_position(ARRAY('aaa' collate utf8_binary_lcase, 'bbb' collate utf8_binary_lcase),'BBB' collate utf8_binary_lcase) -- !query analysis -Project [array_position(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(bbb, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), cast(collate(BBB, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +Project [array_position(array(collate(aaa, utf8_binary_lcase), collate(bbb, utf8_binary_lcase)), collate(BBB, utf8_binary_lcase)) AS array_position(array(collate(aaa), collate(bbb)), collate(BBB))#xL] +- OneRowRelation @@ -275,40 +275,40 @@ Project [nullif(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase) -- !query select least('aaa' COLLATE utf8_binary_lcase, 'AAA' collate utf8_binary_lcase, 'a' collate utf8_binary_lcase) -- !query analysis -Project [least(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(a, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)) AS least(collate(aaa), collate(AAA), collate(a))#x] +Project [least(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase), collate(a, utf8_binary_lcase)) AS least(collate(aaa), collate(AAA), collate(a))#x] +- OneRowRelation -- !query select arrays_overlap(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [arrays_overlap(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +Project [arrays_overlap(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS arrays_overlap(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_distinct(array('aaa' collate utf8_binary_lcase, 'AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_distinct(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE), cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +Project [array_distinct(array(collate(aaa, utf8_binary_lcase), collate(AAA, utf8_binary_lcase))) AS array_distinct(array(collate(aaa), collate(AAA)))#x] +- OneRowRelation -- !query select array_union(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_union(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_union(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_union(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_intersect(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_intersect(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_intersect(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_intersect(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation -- !query select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)) -- !query analysis -Project [array_except(array(cast(collate(aaa, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE)), array(cast(collate(AAA, utf8_binary_lcase) as string collate UTF8_BINARY_LCASE))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +Project [array_except(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation From 5864a9a824ca0664bb5b05cb5766c45a3db92077 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 19 Apr 2024 17:39:04 +0200 Subject: [PATCH 80/87] Revert unnecessary changes --- .../sql/catalyst/encoders/AgnosticEncoder.scala | 3 +-- .../spark/sql/catalyst/analysis/parameters.scala | 3 +-- .../expressions/CallMethodViaReflection.scala | 3 +-- .../sql/catalyst/expressions/ToPrettyString.scala | 5 ++--- .../catalyst/expressions/complexTypeCreator.scala | 2 +- .../sql/catalyst/expressions/csvExpressions.scala | 4 ++-- .../catalyst/expressions/datetimeExpressions.scala | 8 ++++---- .../spark/sql/catalyst/expressions/hash.scala | 6 +++--- .../sql/catalyst/expressions/inputFileBlock.scala | 5 ++--- .../sql/catalyst/expressions/jsonExpressions.scala | 6 +++--- .../sql/catalyst/expressions/maskExpressions.scala | 3 +-- .../sql/catalyst/expressions/mathExpressions.scala | 13 ++++++------- .../spark/sql/catalyst/expressions/misc.scala | 10 +++++----- .../expressions/numberFormatExpressions.scala | 3 +-- .../catalyst/expressions/regexpExpressions.scala | 5 ++--- .../catalyst/expressions/stringExpressions.scala | 10 +++++----- .../sql/catalyst/expressions/urlExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/xml/xpath.scala | 3 +-- .../sql/catalyst/expressions/xmlExpressions.scala | 4 ++-- .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 4 ++-- .../sql/connector/catalog/InMemoryBaseTable.scala | 3 +-- 21 files changed, 47 insertions(+), 58 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 04cdfc02b4ec7..9133abce88adc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -24,7 +24,6 @@ import java.util.concurrent.ConcurrentHashMap import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.{Encoder, Row} -import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.SparkClassUtils @@ -163,7 +162,7 @@ object AgnosticEncoders { // Enums are special leafs because we need to capture the class. protected abstract class EnumEncoder[E] extends AgnosticEncoder[E] { override def isPrimitive: Boolean = false - override def dataType: DataType = SqlApiConf.get.defaultStringType + override def dataType: DataType = StringType } case class ScalaEnumEncoder[T, E]( parent: Class[T], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index d74b713e4cb4d..5356225b209c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -148,8 +148,7 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { } val args = argNames.zip(argValues).toMap checkArgs(args) - bind(child) { case NamedParameter(name) if args.contains(name) => args(name) - } + bind(child) { case NamedParameter(name) if args.contains(name) => args(name)} case PosParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_.resolved) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index da88f28c26fcb..c42b54222f171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -135,7 +134,7 @@ case class CallMethodViaReflection( } override def nullable: Boolean = true - override val dataType: DataType = SQLConf.get.defaultStringType + override val dataType: DataType = StringType override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala index 6f71aff9742d6..8db08dbbcb813 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** @@ -33,7 +32,7 @@ import org.apache.spark.unsafe.types.UTF8String case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with ToStringBase { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 27169baf76d76..3eb6225b5426e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -349,7 +349,7 @@ case class MapFromArrays(left: Expression, right: Expression) case object NamePlaceholder extends LeafExpression with Unevaluable { override lazy val resolved: Boolean = false override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def prettyName: String = "NamePlaceholder" override def toString: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index ef35a757fd3b6..4714fc1ded9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -177,7 +177,7 @@ case class SchemaOfCsv( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = false @@ -300,7 +300,7 @@ case class StructsToCsv( (row: Any) => UTF8String.fromString(gen.writeToString(row.asInstanceOf[InternalRow])) } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 91d2906d390a6..af41149803e33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -904,7 +904,7 @@ case class WeekOfYear(child: Expression) extends GetDateField { case class MonthName(child: Expression) extends GetDateField { override val func = DateTimeUtils.getMonthName override val funcName = "getMonthName" - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override protected def withNewChildInternal(newChild: Expression): MonthName = copy(child = newChild) } @@ -923,7 +923,7 @@ case class DayName(child: Expression) extends GetDateField { override val funcName = "getDayName" override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override protected def withNewChildInternal(newChild: Expression): DayName = copy(child = newChild) } @@ -951,7 +951,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti def this(left: Expression, right: Expression) = this(left, right, None) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) @@ -1429,7 +1429,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ this(unix, Literal(TimestampFormatter.defaultPattern())) } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index fa342f6415097..5089cea136a8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -63,7 +63,7 @@ import org.apache.spark.util.ArrayImplicits._ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -103,7 +103,7 @@ case class Md5(child: Expression) case class Sha2(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) @@ -169,7 +169,7 @@ case class Sha2(left: Expression, right: Expression) case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 65eb995ff32ff..6cd88367aa9a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -21,8 +21,7 @@ import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, LongType} +import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String // scalastyle:off whitespace.end.of.line @@ -40,7 +39,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def prettyName: String = "input_file_name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 12fd521da165c..35e30ceb45cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -133,7 +133,7 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -824,7 +824,7 @@ case class StructsToJson( } } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def checkInputDataTypes(): TypeCheckResult = inputSchema match { case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) => @@ -873,7 +873,7 @@ case class SchemaOfJson( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index 264bd372d823b..e5157685a9a6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -277,7 +276,7 @@ case class Mask( * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query * the dataType of an unresolved expression (i.e., when `resolved` == false). */ - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType /** * Returns a Seq of the children of this node. Children should not change. Immutability required diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 92c5486a80435..2e8f53bde970d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -452,7 +451,7 @@ case class Conv( override def second: Expression = fromBaseExpr override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = true override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { @@ -1003,7 +1002,7 @@ case class Bin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(LongType) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) @@ -1109,21 +1108,21 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringTypeAnyCollation)) + Seq(TypeCollection(LongType, BinaryType, StringType)) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType protected override def nullSafeEval(num: Any): Any = child.dataType match { case LongType => Hex.hex(num.asInstanceOf[Long]) case BinaryType => Hex.hex(num.asInstanceOf[Array[Byte]]) - case _: StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) + case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s"${ev.value} = " + (child.dataType match { - case _: StringType => s"""$hex.hex($c.getBytes());""" + case StringType => s"""$hex.hex($c.getBytes());""" case _ => s"""$hex.hex($c);""" }) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index bdbc5296f0845..c7281e4e87378 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -199,7 +199,7 @@ object AssertTrue { since = "1.6.0", group = "misc_funcs") case class CurrentDatabase() extends LeafExpression with Unevaluable { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = false override def prettyName: String = "current_schema" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) @@ -218,7 +218,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { since = "3.1.0", group = "misc_funcs") case class CurrentCatalog() extends LeafExpression with Unevaluable { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = false override def prettyName: String = "current_catalog" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) @@ -251,7 +251,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def stateful: Boolean = true @@ -311,7 +311,7 @@ case class SparkVersion() extends LeafExpression with RuntimeReplaceable { case class TypeOf(child: Expression) extends UnaryExpression { override def nullable: Boolean = false override def foldable: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def eval(input: InternalRow): Any = UTF8String.fromString(child.dataType.catalogString) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -334,7 +334,7 @@ case class TypeOf(child: Expression) extends UnaryExpression { // scalastyle:on line.size.limit case class CurrentUser() extends LeafExpression with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_user") final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 52f9b26a99785..6d95d7e620a2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -280,7 +279,7 @@ case class ToCharacter(left: Expression, right: Expression) } } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index dd14cca88f9b0..b33de303b5d55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -684,7 +683,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType, IntegerType) override def prettyName: String = "regexp_replace" @@ -849,7 +848,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c6e6a778d3973..3c9888940221a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -725,7 +725,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp }) } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) override def first: Expression = srcExpr override def second: Expression = searchExpr @@ -944,7 +944,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac }) } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) override def first: Expression = srcExpr override def second: Expression = matchingExpr @@ -1004,7 +1004,7 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def direction: String override def children: Seq[Expression] = srcStr +: trimStr.toSeq - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) override def nullable: Boolean = children.exists(_.nullable) @@ -1396,7 +1396,7 @@ case class StringInstr(str: Expression, substr: Expression) case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) override def first: Expression = strExpr override def second: Expression = delimExpr @@ -1879,7 +1879,7 @@ case class StringRepeat(str: Expression, times: Expression) case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) override def nullSafeEval(s: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index f13e41041ff81..47b37a5edeba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -155,7 +155,7 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def prettyName: String = "parse_url" // If the url is a constant, cache the URL object so that we don't need to convert url diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 8ab9cc1ed48c9..c3a285178c110 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -222,7 +221,7 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { // scalastyle:on line.size.limit case class XPathString(xml: Expression, path: Expression) extends XPathExtract { override def prettyName: String = "xpath_string" - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullSafeEval(xml: Any, path: Any): Any = { val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index f2ba82c3dccb3..415d55d19ded2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -178,7 +178,7 @@ case class SchemaOfXml( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def nullable: Boolean = false @@ -320,7 +320,7 @@ case class StructsToXml( getAndReset() } - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 954a00fe4a9ec..f12d224096917 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -61,7 +61,7 @@ case class TestFunction( inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes with Unevaluable { override def nullable: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) } @@ -83,7 +83,7 @@ case class TestFunctionWithTypeCheckFailure( } override def nullable: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 4ff1ce1efe938..505a5a6169204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -63,7 +62,7 @@ abstract class InMemoryBaseTable( protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def comment: String = "Partition key used to store the row" } From 5cd6da39a703ec99809539e6e48484486c855647 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 19 Apr 2024 17:40:52 +0200 Subject: [PATCH 81/87] Remove more unrelated changes --- .../org/apache/spark/sql/catalyst/analysis/parameters.scala | 2 +- .../spark/sql/catalyst/expressions/datetimeExpressions.scala | 2 +- .../apache/spark/sql/catalyst/expressions/mathExpressions.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index 5356225b209c6..f1cc44b270bc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -148,7 +148,7 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { } val args = argNames.zip(argValues).toMap checkArgs(args) - bind(child) { case NamedParameter(name) if args.contains(name) => args(name)} + bind(child) { case NamedParameter(name) if args.contains(name) => args(name) } case PosParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_.resolved) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index af41149803e33..a9155e8daf101 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -104,7 +104,7 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { since = "3.1.0") case class CurrentTimeZone() extends LeafExpression with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType + override def dataType: DataType = StringType override def prettyName: String = "current_timezone" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 2e8f53bde970d..0c09e9be12e94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1110,7 +1110,7 @@ case class Hex(child: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) - override def dataType: DataType = StringType + override def dataType: DataType = StringType protected override def nullSafeEval(num: Any): Any = child.dataType match { case LongType => Hex.hex(num.asInstanceOf[Long]) From 5daff51a87d212af42532fcbec3c2b9c055cb8a9 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 22 Apr 2024 08:06:00 +0200 Subject: [PATCH 82/87] Improve tests --- .../org/apache/spark/sql/CollationSuite.scala | 88 +++++++++++++++---- 1 file changed, 73 insertions(+), 15 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 8b63dca1d0eb4..9994b57b76146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.expressions.{CreateArray, Literal} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -413,7 +414,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("implicit casting of collated strings") { + test("SPARK-47210: Implicit casting of collated strings") { val tableName = "parquet_dummy_implicit_cast_t22" withTable(tableName) { spark.sql( @@ -572,30 +573,87 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, errorClass = "COLLATION_MISMATCH.IMPLICIT" ) + } + } + + test("SPARK-47692: Parameter marker with EXECUTE IMMEDIATE implicit casting") { + sql(s"DECLARE stmtStr1 = 'SELECT collation(:var1 || :var2)';") + sql(s"DECLARE stmtStr2 = 'SELECT collation(:var1 || (\\\'a\\\' COLLATE UNICODE))';") - sql(s"DECLARE stmtStr = 'SELECT collation(:var1 || :var2)';") + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr1 USING + | 'a' AS var1, + | 'b' AS var2;""".stripMargin), + Seq(Row("UTF8_BINARY")) + ) + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { checkAnswer( sql( - """EXECUTE IMMEDIATE stmtStr USING + """EXECUTE IMMEDIATE stmtStr1 USING | 'a' AS var1, | 'b' AS var2;""".stripMargin), - Seq(Row("UTF8_BINARY")) + Seq(Row("UNICODE")) ) + } - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - checkAnswer( - sql( - """EXECUTE IMMEDIATE stmtStr USING - | 'a' AS var1, - | 'b' AS var2;""".stripMargin), - Seq(Row("UNICODE")) - ) - } + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr2 USING + | 'a' AS var1;""".stripMargin), + Seq(Row("UNICODE")) + ) + + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr2 USING + | 'a' AS var1;""".stripMargin), + Seq(Row("UNICODE")) + ) + } + + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr1 USING + | array('a') AS var1, + | array('b') AS var2;""".stripMargin), + Seq(Row("false")) + ) + + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + checkAnswer( + sql( + """EXECUTE IMMEDIATE stmtStr1 USING + | array('a') AS var1, + | array('a') AS var2;""".stripMargin), + Seq(Row("true")) + ) + } + } + + test("SPARK-47692: Parameter markers with variable mapping") { + checkAnswer( + spark.sql( + "SELECT collation(:var1 || :var2)", + Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")), + "var2" -> Literal.create('b', StringType("UNICODE")))), + Seq(Row("UTF8_BINARY")) + ) + + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + checkAnswer( + spark.sql( + "SELECT collation(:var1 || :var2)", + Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")), + "var2" -> Literal.create('b', StringType("UNICODE")))), + Seq(Row("UNICODE")) + ) } } - test("cast of default collated strings in IN expression") { + test("SPARK-47210: Cast of default collated strings in IN expression") { val tableName = "t1" withTable(tableName) { spark.sql( @@ -620,7 +678,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } // TODO(SPARK-47210): Add indeterminate support - test("indeterminate collation checks") { + test("SPARK-47210: Indeterminate collation checks") { val tableName = "t1" val newTableName = "t2" withTable(tableName) { From 5be0a166904cfc44d5c5f2c7033dba3b9acc326e Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 22 Apr 2024 08:07:18 +0200 Subject: [PATCH 83/87] Remove unnecessary test --- .../test/scala/org/apache/spark/sql/CollationSuite.scala | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9994b57b76146..e915fbcef23e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -565,14 +565,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c' collate UNICODE_CI)"), Seq(Row("acb"))) - - // check if substring passes through implicit collation - checkError( - exception = intercept[AnalysisException] { - sql(s"SELECT substr('a' COLLATE UNICODE, 0, 1) == substr('b' COLLATE UNICODE_CI, 0, 1)") - }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" - ) } } From 6a6175dd38b8103d1b56f8ef828cae36aaa06474 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 22 Apr 2024 08:55:38 +0200 Subject: [PATCH 84/87] Fix imports --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index e915fbcef23e0..b1a296ad7c4c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -22,7 +22,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{CreateArray, Literal} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} From ff58fa0d966d9f49c721ed95a84ff10a54df13fb Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 22 Apr 2024 11:39:12 +0200 Subject: [PATCH 85/87] Remove incorrect test --- .../org/apache/spark/sql/CollationSuite.scala | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b1a296ad7c4c8..924f50eda243b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -605,24 +605,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("UNICODE")) ) } - - checkAnswer( - sql( - """EXECUTE IMMEDIATE stmtStr1 USING - | array('a') AS var1, - | array('b') AS var2;""".stripMargin), - Seq(Row("false")) - ) - - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { - checkAnswer( - sql( - """EXECUTE IMMEDIATE stmtStr1 USING - | array('a') AS var1, - | array('a') AS var2;""".stripMargin), - Seq(Row("true")) - ) - } } test("SPARK-47692: Parameter markers with variable mapping") { From 5e378b0f490f52840f40a890aa891992f1a7b110 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 23 Apr 2024 12:02:51 +0200 Subject: [PATCH 86/87] Fix Cast meaning in collations casting --- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index a130b0ad29b6a..046a96fa0ed5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -110,6 +110,7 @@ object CollationTypeCasts extends TypeCoercionRule { def getOutputCollation(expr: Seq[Expression]): StringType = { val explicitTypes = expr.filter { case _: Collate => true + case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined => true case _ => false } .map(_.dataType.asInstanceOf[StringType].collationId) @@ -128,7 +129,8 @@ object CollationTypeCasts extends TypeCoercionRule { case 0 => val implicitTypes = expr.filter { case Literal(_, _: StringType) => false - case Cast(child, _: StringType, _, _) => child.dataType.isInstanceOf[StringType] + case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty + => cast.child.dataType.isInstanceOf[StringType] case _ => true } .map(_.dataType) From a7d24811ed96d0906254f8b5f618b127ad1c1a77 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 23 Apr 2024 13:05:38 +0200 Subject: [PATCH 87/87] Fix casting --- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 046a96fa0ed5f..c6232a870dff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -110,7 +110,8 @@ object CollationTypeCasts extends TypeCoercionRule { def getOutputCollation(expr: Seq[Expression]): StringType = { val explicitTypes = expr.filter { case _: Collate => true - case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined => true + case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined => + cast.dataType.isInstanceOf[StringType] case _ => false } .map(_.dataType.asInstanceOf[StringType].collationId) @@ -129,8 +130,8 @@ object CollationTypeCasts extends TypeCoercionRule { case 0 => val implicitTypes = expr.filter { case Literal(_, _: StringType) => false - case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty - => cast.child.dataType.isInstanceOf[StringType] + case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty => + cast.child.dataType.isInstanceOf[StringType] case _ => true } .map(_.dataType)