diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index d5bcc61bac2af..89d689473e606 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -206,19 +206,20 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class Upper { - public static UTF8String exec(final UTF8String v, final int collationId) { + public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { - return execUTF8(v); + return useICU ? execUTF8ICU(v) : execUTF8(v); } else { return execICU(v, collationId); } } - public static String genCode(final String v, final int collationId) { + public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; + String icuStr = useICU ? "ICU" : ""; if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { - return String.format(expr + "UTF8(%s)", v); + return String.format(expr + "UTF8" + useICU + "(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); } @@ -226,25 +227,29 @@ public static String genCode(final String v, final int collationId) { public static UTF8String execUTF8(final UTF8String v) { return v.toUpperCase(); } + public static UTF8String execUTF8ICU(final UTF8String v) { + return v.toUpperCaseICU(); + } public static UTF8String execICU(final UTF8String v, final int collationId) { return UTF8String.fromString(CollationAwareUTF8String.toUpperCase(v.toString(), collationId)); } } public static class Lower { - public static UTF8String exec(final UTF8String v, final int collationId) { + public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { - return execUTF8(v); + return useICU ? execUTF8ICU(v) : execUTF8(v); } else { return execICU(v, collationId); } } - public static String genCode(final String v, final int collationId) { + public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - String expr = "CollationSupport.Lower.exec"; + String expr = "CollationSupport.Lower.exec"; + String icuStr = useICU ? "ICU" : ""; if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { - return String.format(expr + "UTF8(%s)", v); + return String.format(expr + "UTF8" + useICU + "(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); } @@ -252,6 +257,9 @@ public static String genCode(final String v, final int collationId) { public static UTF8String execUTF8(final UTF8String v) { return v.toLowerCase(); } + public static UTF8String execUTF8ICU(final UTF8String v) { + return v.toLowerCaseICU(); + } public static UTF8String execICU(final UTF8String v, final int collationId) { return UTF8String.fromString(CollationAwareUTF8String.toLowerCase(v.toString(), collationId)); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e28dfa910b59e..0c4f4c461eef7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,6 +29,7 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.ibm.icu.lang.UCharacter; import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.unsafe.Platform; @@ -370,24 +371,34 @@ public UTF8String toUpperCase() { if (numBytes == 0) { return EMPTY_UTF8; } - // Optimization - do char level uppercase conversion in case of chars in ASCII range - for (int i = 0; i < numBytes; i++) { - if (getByte(i) < 0) { - // non-ASCII - return toUpperCaseSlow(); - } + + return isFullAscii() ? toUpperCaseAscii() : toUpperCaseSlowJVM(); + } + + public UTF8String toUpperCaseICU() { + if (numBytes == 0) { + return EMPTY_UTF8; } - byte[] bytes = new byte[numBytes]; - for (int i = 0; i < numBytes; i++) { + + return isFullAscii() ? toUpperCaseAscii() : toUpperCaseSlowICU(); + } + + private UTF8String toUpperCaseAscii() { + final var bytes = new byte[numBytes]; + for (var i = 0; i < numBytes; i++) { bytes[i] = (byte) Character.toUpperCase(getByte(i)); } return fromBytes(bytes); } - private UTF8String toUpperCaseSlow() { + private UTF8String toUpperCaseSlowJVM() { return fromString(toString().toUpperCase()); } + private UTF8String toUpperCaseSlowICU() { + return fromString(UCharacter.toUpperCase(toString())); + } + /** * Optimized lowercase comparison for UTF8_BINARY_LCASE collation * a.compareLowerCase(b) is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase()) @@ -413,7 +424,7 @@ private int compareLowerCaseSuffixSlow(UTF8String other, int pref) { numBytes - pref); UTF8String suffixRight = UTF8String.fromAddress(other.base, other.offset + pref, other.numBytes - pref); - return suffixLeft.toLowerCaseSlow().binaryCompare(suffixRight.toLowerCaseSlow()); + return suffixLeft.toLowerCaseSlowICU().binaryCompare(suffixRight.toLowerCaseSlowICU()); } /** @@ -424,7 +435,15 @@ public UTF8String toLowerCase() { return EMPTY_UTF8; } - return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlow(); + return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlowJVM(); + } + + public UTF8String toLowerCaseICU() { + if (numBytes == 0) { + return EMPTY_UTF8; + } + + return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlowICU(); } private boolean isFullAscii() { @@ -436,10 +455,14 @@ private boolean isFullAscii() { return true; } - private UTF8String toLowerCaseSlow() { + private UTF8String toLowerCaseSlowJVM() { return fromString(toString().toLowerCase()); } + private UTF8String toLowerCaseSlowICU() { + return fromString(UCharacter.toLowerCase(toString())); + } + private UTF8String toLowerCaseAscii() { final var bytes = new byte[numBytes]; for (var i = 0; i < numBytes; i++) { diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 3f02d72611128..badbe6b675698 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -66,14 +66,24 @@ class UTF8StringPropertyCheckSuite extends AnyFunSuite with ScalaCheckDrivenProp // scalastyle:off caselocale test("toUpperCase") { - forAll { (s: String) => - assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase)) + val useICU = SQLConf.conf.getConf(ICU_CASE_MAPPINGS_ENABLED).getKey + forAll { (s: String) => { + if (useICU) { + assert(toUTF8(s).toUpperCase === toUTF8(UCharacter.toUpperCase(s))) + } else { + assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase)) + } } } test("toLowerCase") { - forAll { (s: String) => - assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase)) + val useICU = SQLConf.conf.getConf(ICU_CASE_MAPPINGS_ENABLED).getKey + forAll { (s: String) => { + if (useICU) { + assert(toUTF8(s).toLowerCase === toUTF8(UCharacter.toLowerCase(s))) + } else { + assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase)) + } } } // scalastyle:on caselocale diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 09ec501311ade..9d2ffaf01b934 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -453,14 +453,17 @@ trait String2StringExpression extends ImplicitCastInputTypes { case class Upper(child: Expression) extends UnaryExpression with String2StringExpression with NullIntolerant { + private final lazy val useICU = SQLConf.get.spark.sql.icu.caseMappings.enabled + final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId - override def convert(v: UTF8String): UTF8String = CollationSupport.Upper.exec(v, collationId) + override def convert(v: UTF8String): UTF8String = + CollationSupport.Upper.exec(v, collationId, useICU) final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId)) + defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId, useICU)) } override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild) @@ -481,14 +484,17 @@ case class Upper(child: Expression) case class Lower(child: Expression) extends UnaryExpression with String2StringExpression with NullIntolerant { + private final lazy val useICU = SQLConf.get.spark.sql.icu.caseMappings.enabled + final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId - override def convert(v: UTF8String): UTF8String = CollationSupport.Lower.exec(v, collationId) + override def convert(v: UTF8String): UTF8String = + CollationSupport.Lower.exec(v, collationId, useICU) final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId)) + defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId, useICU)) } override def prettyName: String = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 88c2228e640c4..cfb7e07774e94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -785,6 +785,14 @@ object SQLConf { _ => Map()) .createWithDefault("UTF8_BINARY") + val ICU_CASE_MAPPINGS_ENABLED = + buildConf("spark.sql.icu.caseMappings.enabled") + .doc("When enabled we use the ICU library (instead of the JVM) to implement case mappings" + + " for strings.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val FETCH_SHUFFLE_BLOCKS_IN_BATCH = buildConf("spark.sql.adaptive.fetchShuffleBlocksInBatch") .internal()