From 3d3d3da19da50c5effd2932ae559a3d73664c244 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 20 Mar 2024 22:57:37 +0800 Subject: [PATCH] [SPARK-47296][SQL][COLLATION] Fail unsupported functions for non-binary collations ### What changes were proposed in this pull request? ### Why are the changes needed? Currently, all `StringType` arguments passed to built-in string functions in Spark SQL get treated as binary strings. This behaviour is incorrect for almost all collationIds except the default (0), and we should instead warn the user if they try to use an unsupported collation for the given function. Over time, we should implement the appropriate support for these (function, collation) pairs, but until then - we should have a way to fail unsupported statements in query analysis. ### Does this PR introduce _any_ user-facing change? Yes, users will now get appropriate errors when they try to use an unsupported collation with a given string function. ### How was this patch tested? Tests in CollationSuite to check if these functions work for binary collations and throw exceptions for others. ### Was this patch authored or co-authored using generative AI tooling? Yes. Closes #45422 from uros-db/regexp-functions. Lead-authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Co-authored-by: Mihailo Milosevic Signed-off-by: Wenchen Fan --- .../apache/spark/sql/types/StringType.scala | 5 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 15 + .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../CollationTypeConstraints.scala | 77 +++ .../expressions/collationExpressions.scala | 2 +- .../expressions/stringExpressions.scala | 29 +- .../sql/CollationRegexpExpressionsSuite.scala | 444 ++++++++++++++++++ .../sql/CollationStringExpressionsSuite.scala | 80 ++++ .../org/apache/spark/sql/CollationSuite.scala | 10 +- 9 files changed, 637 insertions(+), 29 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index d046195bcfd1b..2b88f9a01a73c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -40,6 +40,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa * equality and hashing). */ def isBinaryCollation: Boolean = CollationFactory.fetchCollation(collationId).isBinaryCollation + def isLowercaseCollation: Boolean = collationId == CollationFactory.LOWERCASE_COLLATION_ID /** * Type name that is shown to the customer. @@ -54,8 +55,6 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa override def hashCode(): Int = collationId.hashCode() - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] - /** * The default size of a value of the StringType is 20 bytes. */ @@ -65,6 +64,8 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa } /** + * Use StringType for expressions supporting only binary collation. + * * @since 1.3.0 */ @Stable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 8857f0b5a25ec..c70d6696ad06c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -186,6 +186,11 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (NullType, target) if !target.isInstanceOf[TypeCollection] => Some(target.defaultConcreteType) + // If a function expects a StringType, no StringType instance should be implicitly cast to + // StringType with a collation that's not accepted (aka. lockdown unsupported collations). + case (_: StringType, StringType) => None + case (_: StringType, _: StringTypeCollated) => None + // This type coercion system will allow implicit converting String type as other // primitive types, in case of breaking too many existing Spark SQL queries. case (StringType, a: AtomicType) => @@ -215,6 +220,16 @@ object AnsiTypeCoercion extends TypeCoercionBase { None } + // "canANSIStoreAssign" doesn't account for targets extending StringTypeCollated, but + // ANSIStoreAssign is generally expected to work with StringTypes + case (_, st: StringTypeCollated) => + if (Cast.canANSIStoreAssign(inType, st.defaultConcreteType)) { + Some(st.defaultConcreteType) + } + else { + None + } + // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. case (_, TypeCollection(types)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 56e8843fda537..ecc54976f2db4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -995,7 +995,9 @@ object TypeCoercion extends TypeCoercionBase { case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => StringType + case (any: AtomicType, StringType) if !any.isInstanceOf[StringType] => StringType + case (any: AtomicType, st: StringTypeCollated) + if !any.isInstanceOf[StringType] => st.defaultConcreteType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala new file mode 100644 index 0000000000000..cd909a45c1ed6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} + +object CollationTypeConstraints { + + def checkCollationCompatibility(collationId: Int, dataTypes: Seq[DataType]): TypeCheckResult = { + val collationName = CollationFactory.fetchCollation(collationId).collationName + // Additional check needed for collation compatibility + dataTypes.collectFirst { + case stringType: StringType if stringType.collationId != collationId => + val collation = CollationFactory.fetchCollation(stringType.collationId) + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> collationName, + "collationNameRight" -> collation.collationName + ) + ) + } getOrElse TypeCheckResult.TypeCheckSuccess + } + +} + +/** + * StringTypeCollated is an abstract class for StringType with collation support. + */ +abstract class StringTypeCollated extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = StringType +} + +/** + * Use StringTypeBinary for expressions supporting only binary collation. + */ +case object StringTypeBinary extends StringTypeCollated { + override private[sql] def simpleString: String = "string_binary" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isBinaryCollation +} + +/** + * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. + */ +case object StringTypeBinaryLcase extends StringTypeCollated { + override private[sql] def simpleString: String = "string_binary_lcase" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].isBinaryCollation || + other.asInstanceOf[StringType].isLowercaseCollation) +} + +/** + * Use StringTypeAnyCollation for expressions supporting all possible collation types. + */ +case object StringTypeAnyCollation extends StringTypeCollated { + override private[sql] def simpleString: String = "string_any_collation" + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index b0f77bad44831..8ef0280b728e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -82,7 +82,7 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7403c52ece909..742db0ed5a474 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -427,8 +427,8 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -501,26 +501,15 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { - val checkResult = super.checkInputDataTypes() - if (checkResult.isFailure) { - return checkResult - } - // Additional check needed for collation compatibility - val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId - if (collationId != rightCollationId) { - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName, - "collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName - ) - ) - } else { - TypeCheckResult.TypeCheckSuccess + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck } + CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) } protected override def nullSafeEval(input1: Any, input2: Any): Any = @@ -1976,7 +1965,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType) + Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) override def first: Expression = str override def second: Expression = pos diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala new file mode 100644 index 0000000000000..9a8ffb6efa6b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession { + + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + + test("Support Like string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%B%", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"collate(${ct.s1}) LIKE collate(${ct.s2})\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"like collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 48 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support ILike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"ilike(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"ilike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RLike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", ".B.", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", ".B.", "UNICODE", true), + CollationTestCase("ABC", ".b.", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"RLIKE(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"rlike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support StringSplit string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "[B]", "UTF8_BINARY", 2) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABC", "[B]", "UNICODE", 2), + CollationTestCase("ABC", "[b]", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"split(collate(${ct.s1}), collate(${ct.s2}), -1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"split(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 12, + stop = 55 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpReplace string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_replace(collate(${ct.s1}), collate(${ct.s2}), collate(FFF), 1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_replace(collate('${ct.s1}', '${ct.collation}'),collate('${ct.s2}'," + + s" '${ct.collation}'),collate('FFF', '${ct.collation}'))", + start = 7, + stop = 80 + 3 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpExtract string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_extract(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 7, + stop = 63 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpExtractAll string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0))"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', " + + s"'${ct.collation}'),collate('${ct.s2}', '${ct.collation}'),0))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_extract_all(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract_all(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 12, + stop = 72 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpCount string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_count(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_count(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpSubStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_substr(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_substr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 60 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpInStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 2), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_instr(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_instr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length + ) + ) + }) + } +} + +class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala new file mode 100644 index 0000000000000..04f3781a92cf3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { + + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + + test("Support ConcatWs string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL") + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))"), + Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + val expr = s"concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))" + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT $expr") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"concat_ws(collate( ), collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate( )\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"$expr", + start = 7, + stop = 73 + 3 * ct.collation.length + ) + ) + }) + } + + // TODO: Add more tests for other string expressions + +} + +class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 42506950149dc..ee2b34706e0ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -121,7 +121,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "paramIndex" -> "first", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRING\""), + "requiredType" -> "\"STRING_ANY_COLLATION\""), context = ExpectedContext( fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31)) } @@ -611,7 +611,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s""" |CREATE TABLE testcat.test_table( | c1 STRING COLLATE UNICODE, - | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (c1 || 'a' COLLATE UNICODE) + | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (LOWER(c1)) |) |USING $v2Source |""".stripMargin) @@ -619,7 +619,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", - "expressionStr" -> "c1 || 'a' COLLATE UNICODE", + "expressionStr" -> "LOWER(c1)", "reason" -> "generation expression cannot contain non-default collated string type")) checkError( @@ -628,7 +628,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s""" |CREATE TABLE testcat.test_table( | struct1 STRUCT, - | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(struct1.a, 0, 1)) + | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (UCASE(struct1.a)) |) |USING $v2Source |""".stripMargin) @@ -636,7 +636,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", - "expressionStr" -> "SUBSTRING(struct1.a, 0, 1)", + "expressionStr" -> "UCASE(struct1.a)", "reason" -> "generation expression cannot contain non-default collated string type")) } }