From 334253f1847f29ee0b8661b2f44d77d134516e1b Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 28 Jun 2024 09:04:30 +0800 Subject: [PATCH] [SPARK-48682][SQL] Use ICU in InitCap expression for UTF8_BINARY strings ### What changes were proposed in this pull request? Update `InitCap` Spark expressions to use ICU case mappings for UTF8_BINARY collation, instead of the currently used JVM case mappings. This behaviour is put under the `ICU_CASE_MAPPINGS_ENABLED` flag in SQLConf, which is true by default. Note: the same flag is used for `Lower` & `Upper` expressions, with changes introduced in: https://github.com/apache/spark/pull/47043. ### Why are the changes needed? To keep the consistency between collations - all collations shouls use ICU-based case mappings, including the UTF8_BINARY collation. ### Does this PR introduce _any_ user-facing change? Yes, the behaviour of `initcap` string function for UTF8_BINARY will now rely on ICU-based case mappings. However, by turning the `ICU_CASE_MAPPINGS_ENABLED` flag off, users can get the old JVM-based case mappings. Note that the difference between the two is really subtle. ### How was this patch tested? Existing tests, with extended `CollationSupport` unit tests for InitCap to verify both ICU and JVM behaviour. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47100 from uros-db/change-initcap. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../sql/catalyst/util/CollationSupport.java | 12 ++-- .../apache/spark/unsafe/types/UTF8String.java | 57 +++++++++++++++---- .../unsafe/types/CollationSupportSuite.java | 6 +- .../expressions/stringExpressions.scala | 7 ++- 4 files changed, 64 insertions(+), 18 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index a5bb1fe715bb9..0f10955c986a5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -280,10 +280,10 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { } public static class InitCap { - public static UTF8String exec(final UTF8String v, final int collationId) { + public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { - return execBinary(v); + return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.supportsLowercaseEquality) { return execLowercase(v); } else { @@ -291,11 +291,12 @@ public static UTF8String exec(final UTF8String v, final int collationId) { } } - public static String genCode(final String v, final int collationId) { + public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.InitCap.exec"; if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s)", v); + String funcName = useICU ? "BinaryICU" : "Binary"; + return String.format(expr + "%s(%s)", funcName, v); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s)", v); } else { @@ -305,6 +306,9 @@ public static String genCode(final String v, final int collationId) { public static UTF8String execBinary(final UTF8String v) { return v.toLowerCase().toTitleCase(); } + public static UTF8String execBinaryICU(final UTF8String v) { + return CollationAwareUTF8String.toLowerCase(v).toTitleCaseICU(); + } public static UTF8String execLowercase(final UTF8String v) { return CollationAwareUTF8String.toTitleCase(v); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 49d3088f8a2f0..38b9b803acbe4 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -31,6 +31,7 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.ibm.icu.lang.UCharacter; import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.unsafe.Platform; @@ -558,24 +559,35 @@ public UTF8String toLowerCaseAscii() { } /** - * Returns the title case of this string, that could be used as title. + * Returns the title case of this string, that could be used as title. There are essentially two + * different version of this method - one using the JVM case mapping rules, and the other using + * the ICU case mapping rules. ASCII implementation is the same for both, but please refer to the + * respective methods for the slow (non-ASCII) implementation for more details on the differences. */ public UTF8String toTitleCase() { if (numBytes == 0) { return EMPTY_UTF8; } - // Optimization - in case of ASCII chars we can skip copying the data to and from StringBuilder - byte prev = ' ', curr; - for (int i = 0; i < numBytes; i++) { - curr = getByte(i); - if (prev == ' ' && curr < 0) { - // non-ASCII - return toTitleCaseSlow(); - } - prev = curr; + + return isFullAscii() ? toTitleCaseAscii() : toTitleCaseSlow(); + } + + public UTF8String toTitleCaseICU() { + if (numBytes == 0) { + return EMPTY_UTF8; } + + return isFullAscii() ? toTitleCaseAscii() : toTitleCaseSlowICU(); + } + + /* + * Fast path to return the title case of this string, given that all characters are ASCII. + * This implementation essentially works for all collations currently supported in Spark. + * This method is more efficient, because it skips copying the data to and from StringBuilder. + */ + private UTF8String toTitleCaseAscii() { byte[] bytes = new byte[numBytes]; - prev = ' '; + byte prev = ' ', curr; for (int i = 0; i < numBytes; i++) { curr = getByte(i); if (prev == ' ') { @@ -588,6 +600,11 @@ public UTF8String toTitleCase() { return fromBytes(bytes); } + /* + * Slow path to return the title case of this string, according to JVM case mapping rules. + * This is considered the "old" behaviour for UTF8_BINARY collation, and is not recommended. + * To use this, set the spark.sql.ICU_CASE_MAPPINGS_ENABLED configuration to `false`. + */ private UTF8String toTitleCaseSlow() { StringBuilder sb = new StringBuilder(); String s = toString(); @@ -601,6 +618,24 @@ private UTF8String toTitleCaseSlow() { return fromString(sb.toString()); } + /* + * Slow path to return the title case of this string, according to ICU case mapping rules. + * This is considered the "new" behaviour for UTF8_BINARY collation, and is recommended. + * This is used by default, since spark.sql.ICU_CASE_MAPPINGS_ENABLED is set to `true`. + */ + private UTF8String toTitleCaseSlowICU() { + StringBuilder sb = new StringBuilder(); + String s = toString(); + sb.append(s); + sb.setCharAt(0, (char) UCharacter.toTitleCase(sb.charAt(0))); + for (int i = 1; i < s.length(); i++) { + if (sb.charAt(i - 1) == ' ') { + sb.setCharAt(i, (char) UCharacter.toTitleCase(sb.charAt(i))); + } + } + return fromString(sb.toString()); + } + /* * Returns the index of the string `match` in this String. This string has to be a comma separated * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String, diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 9602c83c6c801..d027f67c08ffc 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -731,7 +731,11 @@ private void assertInitCap(String target, String collationName, String expected) UTF8String target_utf8 = UTF8String.fromString(target); UTF8String expected_utf8 = UTF8String.fromString(expected); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected_utf8, CollationSupport.InitCap.exec(target_utf8, collationId)); + // Testing the new ICU-based implementation of the Lower function. + assertEquals(expected_utf8, CollationSupport.InitCap.exec(target_utf8, collationId, true)); + // Testing the old JVM-based implementation of the Lower function. + assertEquals(expected_utf8, CollationSupport.InitCap.exec(target_utf8, collationId, false)); + // Note: results should be the same in these tests for both ICU and JVM-based implementations. } @Test diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e951d40d4d463..a0c796274f761 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2048,14 +2048,17 @@ case class InitCap(child: Expression) final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId + // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. + private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { - CollationSupport.InitCap.exec(string.asInstanceOf[UTF8String], collationId) + CollationSupport.InitCap.exec(string.asInstanceOf[UTF8String], collationId, useICU) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, str => CollationSupport.InitCap.genCode(str, collationId)) + defineCodeGen(ctx, ev, str => CollationSupport.InitCap.genCode(str, collationId, useICU)) } override protected def withNewChildInternal(newChild: Expression): InitCap =