diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index d046195bcfd1b..2b88f9a01a73c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -40,6 +40,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa * equality and hashing). */ def isBinaryCollation: Boolean = CollationFactory.fetchCollation(collationId).isBinaryCollation + def isLowercaseCollation: Boolean = collationId == CollationFactory.LOWERCASE_COLLATION_ID /** * Type name that is shown to the customer. @@ -54,8 +55,6 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa override def hashCode(): Int = collationId.hashCode() - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] - /** * The default size of a value of the StringType is 20 bytes. */ @@ -65,6 +64,8 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa } /** + * Use StringType for expressions supporting only binary collation. + * * @since 1.3.0 */ @Stable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 8857f0b5a25ec..c70d6696ad06c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -186,6 +186,11 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (NullType, target) if !target.isInstanceOf[TypeCollection] => Some(target.defaultConcreteType) + // If a function expects a StringType, no StringType instance should be implicitly cast to + // StringType with a collation that's not accepted (aka. lockdown unsupported collations). + case (_: StringType, StringType) => None + case (_: StringType, _: StringTypeCollated) => None + // This type coercion system will allow implicit converting String type as other // primitive types, in case of breaking too many existing Spark SQL queries. case (StringType, a: AtomicType) => @@ -215,6 +220,16 @@ object AnsiTypeCoercion extends TypeCoercionBase { None } + // "canANSIStoreAssign" doesn't account for targets extending StringTypeCollated, but + // ANSIStoreAssign is generally expected to work with StringTypes + case (_, st: StringTypeCollated) => + if (Cast.canANSIStoreAssign(inType, st.defaultConcreteType)) { + Some(st.defaultConcreteType) + } + else { + None + } + // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. case (_, TypeCollection(types)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 56e8843fda537..ecc54976f2db4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -995,7 +995,9 @@ object TypeCoercion extends TypeCoercionBase { case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => StringType + case (any: AtomicType, StringType) if !any.isInstanceOf[StringType] => StringType + case (any: AtomicType, st: StringTypeCollated) + if !any.isInstanceOf[StringType] => st.defaultConcreteType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala new file mode 100644 index 0000000000000..cd909a45c1ed6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} + +object CollationTypeConstraints { + + def checkCollationCompatibility(collationId: Int, dataTypes: Seq[DataType]): TypeCheckResult = { + val collationName = CollationFactory.fetchCollation(collationId).collationName + // Additional check needed for collation compatibility + dataTypes.collectFirst { + case stringType: StringType if stringType.collationId != collationId => + val collation = CollationFactory.fetchCollation(stringType.collationId) + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> collationName, + "collationNameRight" -> collation.collationName + ) + ) + } getOrElse TypeCheckResult.TypeCheckSuccess + } + +} + +/** + * StringTypeCollated is an abstract class for StringType with collation support. + */ +abstract class StringTypeCollated extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = StringType +} + +/** + * Use StringTypeBinary for expressions supporting only binary collation. + */ +case object StringTypeBinary extends StringTypeCollated { + override private[sql] def simpleString: String = "string_binary" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isBinaryCollation +} + +/** + * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. + */ +case object StringTypeBinaryLcase extends StringTypeCollated { + override private[sql] def simpleString: String = "string_binary_lcase" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].isBinaryCollation || + other.asInstanceOf[StringType].isLowercaseCollation) +} + +/** + * Use StringTypeAnyCollation for expressions supporting all possible collation types. + */ +case object StringTypeAnyCollation extends StringTypeCollated { + override private[sql] def simpleString: String = "string_any_collation" + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index b0f77bad44831..8ef0280b728e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -82,7 +82,7 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7403c52ece909..742db0ed5a474 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -427,8 +427,8 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -501,26 +501,15 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { - val checkResult = super.checkInputDataTypes() - if (checkResult.isFailure) { - return checkResult - } - // Additional check needed for collation compatibility - val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId - if (collationId != rightCollationId) { - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName, - "collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName - ) - ) - } else { - TypeCheckResult.TypeCheckSuccess + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck } + CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) } protected override def nullSafeEval(input1: Any, input2: Any): Any = @@ -1976,7 +1965,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType) + Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) override def first: Expression = str override def second: Expression = pos diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala new file mode 100644 index 0000000000000..9a8ffb6efa6b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession { + + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + + test("Support Like string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%B%", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"collate(${ct.s1}) LIKE collate(${ct.s2})\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"like collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 48 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support ILike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"ilike(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"ilike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RLike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", ".B.", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", ".B.", "UNICODE", true), + CollationTestCase("ABC", ".b.", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"RLIKE(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"rlike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support StringSplit string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "[B]", "UTF8_BINARY", 2) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABC", "[B]", "UNICODE", 2), + CollationTestCase("ABC", "[b]", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"split(collate(${ct.s1}), collate(${ct.s2}), -1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"split(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 12, + stop = 55 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpReplace string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_replace(collate(${ct.s1}), collate(${ct.s2}), collate(FFF), 1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_replace(collate('${ct.s1}', '${ct.collation}'),collate('${ct.s2}'," + + s" '${ct.collation}'),collate('FFF', '${ct.collation}'))", + start = 7, + stop = 80 + 3 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpExtract string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_extract(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 7, + stop = 63 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpExtractAll string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0))"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', " + + s"'${ct.collation}'),collate('${ct.s2}', '${ct.collation}'),0))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_extract_all(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract_all(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 12, + stop = 72 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpCount string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_count(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_count(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpSubStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_substr(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_substr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 60 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpInStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 2), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_instr(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_instr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length + ) + ) + }) + } +} + +class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala new file mode 100644 index 0000000000000..04f3781a92cf3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { + + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + + test("Support ConcatWs string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL") + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))"), + Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + val expr = s"concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))" + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT $expr") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"concat_ws(collate( ), collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate( )\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"$expr", + start = 7, + stop = 73 + 3 * ct.collation.length + ) + ) + }) + } + + // TODO: Add more tests for other string expressions + +} + +class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 42506950149dc..ee2b34706e0ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -121,7 +121,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "paramIndex" -> "first", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRING\""), + "requiredType" -> "\"STRING_ANY_COLLATION\""), context = ExpectedContext( fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31)) } @@ -611,7 +611,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s""" |CREATE TABLE testcat.test_table( | c1 STRING COLLATE UNICODE, - | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (c1 || 'a' COLLATE UNICODE) + | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (LOWER(c1)) |) |USING $v2Source |""".stripMargin) @@ -619,7 +619,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", - "expressionStr" -> "c1 || 'a' COLLATE UNICODE", + "expressionStr" -> "LOWER(c1)", "reason" -> "generation expression cannot contain non-default collated string type")) checkError( @@ -628,7 +628,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s""" |CREATE TABLE testcat.test_table( | struct1 STRUCT, - | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(struct1.a, 0, 1)) + | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (UCASE(struct1.a)) |) |USING $v2Source |""".stripMargin) @@ -636,7 +636,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", - "expressionStr" -> "SUBSTRING(struct1.a, 0, 1)", + "expressionStr" -> "UCASE(struct1.a)", "reason" -> "generation expression cannot contain non-default collated string type")) } }