From 43a738f98c9abde1bb7f585ed8ce197d3845e0b7 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 11 Apr 2024 14:26:09 +0200 Subject: [PATCH 1/3] Add support for Upper, Lower, InitCap --- .../expressions/stringExpressions.scala | 4 +- .../sql/CollationStringExpressionsSuite.scala | 41 ++++++++++++++++++- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index cf6c9d4f1d942..4fe02fb4da5f0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1831,8 +1831,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def inputTypes: Seq[DataType] = Seq(StringType) - override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { // scalastyle:off caselocale diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index c26f3ae02255f..c341fdbfa6619 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, Literal, StringRepeat} +import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, InitCap, Literal, Lower, StringRepeat, Upper} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -89,6 +89,45 @@ class CollationStringExpressionsSuite testRepeat("UNICODE_CI", 3, "abc", 2) } + test("UPPER check output type on collated string") { + def testUpper(expected: String, collationId: Int, input: String): Unit = { + val s = Literal.create(input, StringType(collationId)) + + checkEvaluation(Collation(Upper(s)).replacement, expected) + } + + testUpper("UTF8_BINARY", 0, "abc") + testUpper("UTF8_BINARY_LCASE", 1, "abc") + testUpper("UNICODE", 2, "abc") + testUpper("UNICODE_CI", 3, "abc") + } + + test("LOWER check output type on collated string") { + def testLower(expected: String, collationId: Int, input: String): Unit = { + val s = Literal.create(input, StringType(collationId)) + + checkEvaluation(Collation(Lower(s)).replacement, expected) + } + + testLower("UTF8_BINARY", 0, "abc") + testLower("UTF8_BINARY_LCASE", 1, "abc") + testLower("UNICODE", 2, "abc") + testLower("UNICODE_CI", 3, "abc") + } + + test("INITCAP check output type on collated string") { + def testInitCap(expected: String, collationId: Int, input: String): Unit = { + val s = Literal.create(input, StringType(collationId)) + + checkEvaluation(Collation(InitCap(s)).replacement, expected) + } + + testInitCap("UTF8_BINARY", 0, "abc") + testInitCap("UTF8_BINARY_LCASE", 1, "abc") + testInitCap("UNICODE", 2, "abc") + testInitCap("UNICODE_CI", 3, "abc") + } + // TODO: Add more tests for other string expressions } From 7069f53198549990b662e2427a793714dbb1e2ac Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 12 Apr 2024 14:57:48 +0200 Subject: [PATCH 2/3] Fix imports --- .../apache/spark/sql/CollationStringExpressionsSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index f9fe9e73e8529..828fa807e94b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, InitCap, Literal, Lower, StringRepeat, Upper} -import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, InitCap, Literal, Lower, Upper} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, StringType} From 9f1f686ce66b07d80453920699d780c3dc5971a9 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 15 Apr 2024 08:39:25 +0200 Subject: [PATCH 3/3] Refactor tests --- .../sql/CollationStringExpressionsSuite.scala | 82 +++++++++++-------- 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 828fa807e94b6..0dbd4c0ba713f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, InitCap, Literal, Lower, Upper} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, StringType} class CollationStringExpressionsSuite extends QueryTest - with SharedSparkSession - with ExpressionEvalHelper { + with SharedSparkSession { test("Support ConcatWs string expression with collation") { // Supported collations @@ -163,43 +161,55 @@ class CollationStringExpressionsSuite }) } - test("UPPER check output type on collated string") { - def testUpper(expected: String, collationId: Int, input: String): Unit = { - val s = Literal.create(input, StringType(collationId)) - - checkEvaluation(Collation(Upper(s)).replacement, expected) - } - - testUpper("UTF8_BINARY", 0, "abc") - testUpper("UTF8_BINARY_LCASE", 1, "abc") - testUpper("UNICODE", 2, "abc") - testUpper("UNICODE_CI", 3, "abc") + test("SPARK-47357: Support Upper string expression with collation") { + // Supported collations + case class UpperTestCase[R](s: String, c: String, result: R) + val testCases = Seq( + UpperTestCase("aBc", "UTF8_BINARY", "ABC"), + UpperTestCase("aBc", "UTF8_BINARY_LCASE", "ABC"), + UpperTestCase("aBc", "UNICODE", "ABC"), + UpperTestCase("aBc", "UNICODE_CI", "ABC") + ) + testCases.foreach(t => { + val query = s"SELECT upper(collate('${t.s}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) } - test("LOWER check output type on collated string") { - def testLower(expected: String, collationId: Int, input: String): Unit = { - val s = Literal.create(input, StringType(collationId)) - - checkEvaluation(Collation(Lower(s)).replacement, expected) - } - - testLower("UTF8_BINARY", 0, "abc") - testLower("UTF8_BINARY_LCASE", 1, "abc") - testLower("UNICODE", 2, "abc") - testLower("UNICODE_CI", 3, "abc") + test("SPARK-47357: Support Lower string expression with collation") { + // Supported collations + case class LowerTestCase[R](s: String, c: String, result: R) + val testCases = Seq( + LowerTestCase("aBc", "UTF8_BINARY", "abc"), + LowerTestCase("aBc", "UTF8_BINARY_LCASE", "abc"), + LowerTestCase("aBc", "UNICODE", "abc"), + LowerTestCase("aBc", "UNICODE_CI", "abc") + ) + testCases.foreach(t => { + val query = s"SELECT lower(collate('${t.s}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) } - test("INITCAP check output type on collated string") { - def testInitCap(expected: String, collationId: Int, input: String): Unit = { - val s = Literal.create(input, StringType(collationId)) - - checkEvaluation(Collation(InitCap(s)).replacement, expected) - } - - testInitCap("UTF8_BINARY", 0, "abc") - testInitCap("UTF8_BINARY_LCASE", 1, "abc") - testInitCap("UNICODE", 2, "abc") - testInitCap("UNICODE_CI", 3, "abc") + test("SPARK-47357: Support InitCap string expression with collation") { + // Supported collations + case class InitCapTestCase[R](s: String, c: String, result: R) + val testCases = Seq( + InitCapTestCase("aBc ABc", "UTF8_BINARY", "Abc Abc"), + InitCapTestCase("aBc ABc", "UTF8_BINARY_LCASE", "Abc Abc"), + InitCapTestCase("aBc ABc", "UNICODE", "Abc Abc"), + InitCapTestCase("aBc ABc", "UNICODE_CI", "Abc Abc") + ) + testCases.foreach(t => { + val query = s"SELECT initcap(collate('${t.s}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) } // TODO: Add more tests for other string expressions