diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9c862581bfe47..34e8f3f408599 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1810,8 +1810,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def inputTypes: Seq[DataType] = Seq(StringType) - override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { // scalastyle:off caselocale diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 97dea66975410..0dbd4c0ba713f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, StringType} class CollationStringExpressionsSuite extends QueryTest - with SharedSparkSession - with ExpressionEvalHelper { + with SharedSparkSession { test("Support ConcatWs string expression with collation") { // Supported collations @@ -163,6 +161,57 @@ class CollationStringExpressionsSuite }) } + test("SPARK-47357: Support Upper string expression with collation") { + // Supported collations + case class UpperTestCase[R](s: String, c: String, result: R) + val testCases = Seq( + UpperTestCase("aBc", "UTF8_BINARY", "ABC"), + UpperTestCase("aBc", "UTF8_BINARY_LCASE", "ABC"), + UpperTestCase("aBc", "UNICODE", "ABC"), + UpperTestCase("aBc", "UNICODE_CI", "ABC") + ) + testCases.foreach(t => { + val query = s"SELECT upper(collate('${t.s}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) + } + + test("SPARK-47357: Support Lower string expression with collation") { + // Supported collations + case class LowerTestCase[R](s: String, c: String, result: R) + val testCases = Seq( + LowerTestCase("aBc", "UTF8_BINARY", "abc"), + LowerTestCase("aBc", "UTF8_BINARY_LCASE", "abc"), + LowerTestCase("aBc", "UNICODE", "abc"), + LowerTestCase("aBc", "UNICODE_CI", "abc") + ) + testCases.foreach(t => { + val query = s"SELECT lower(collate('${t.s}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) + } + + test("SPARK-47357: Support InitCap string expression with collation") { + // Supported collations + case class InitCapTestCase[R](s: String, c: String, result: R) + val testCases = Seq( + InitCapTestCase("aBc ABc", "UTF8_BINARY", "Abc Abc"), + InitCapTestCase("aBc ABc", "UTF8_BINARY_LCASE", "Abc Abc"), + InitCapTestCase("aBc ABc", "UNICODE", "Abc Abc"), + InitCapTestCase("aBc ABc", "UNICODE_CI", "Abc Abc") + ) + testCases.foreach(t => { + val query = s"SELECT initcap(collate('${t.s}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) + } + // TODO: Add more tests for other string expressions }