Skip to content

Commit

Permalink
Use ICU case mappings for upper and lower casing
Browse files Browse the repository at this point in the history
  • Loading branch information
mkaravel committed Jun 4, 2024
1 parent 560c083 commit a20e244
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,52 +206,60 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
}

public static class Upper {
public static UTF8String exec(final UTF8String v, final int collationId) {
public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return execUTF8(v);
return useICU ? execUTF8ICU(v) : execUTF8(v);
} else {
return execICU(v, collationId);
}
}
public static String genCode(final String v, final int collationId) {
public static String genCode(final String v, final int collationId, boolean useICU) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.Upper.exec";
String icuStr = useICU ? "ICU" : "";
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return String.format(expr + "UTF8(%s)", v);
return String.format(expr + "UTF8" + useICU + "(%s)", v);
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
public static UTF8String execUTF8(final UTF8String v) {
return v.toUpperCase();
}
public static UTF8String execUTF8ICU(final UTF8String v) {
return v.toUpperCaseICU();
}
public static UTF8String execICU(final UTF8String v, final int collationId) {
return UTF8String.fromString(CollationAwareUTF8String.toUpperCase(v.toString(), collationId));
}
}

public static class Lower {
public static UTF8String exec(final UTF8String v, final int collationId) {
public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return execUTF8(v);
return useICU ? execUTF8ICU(v) : execUTF8(v);
} else {
return execICU(v, collationId);
}
}
public static String genCode(final String v, final int collationId) {
public static String genCode(final String v, final int collationId, boolean useICU) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.Lower.exec";
String expr = "CollationSupport.Lower.exec";
String icuStr = useICU ? "ICU" : "";
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return String.format(expr + "UTF8(%s)", v);
return String.format(expr + "UTF8" + useICU + "(%s)", v);
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
public static UTF8String execUTF8(final UTF8String v) {
return v.toLowerCase();
}
public static UTF8String execUTF8ICU(final UTF8String v) {
return v.toLowerCaseICU();
}
public static UTF8String execICU(final UTF8String v, final int collationId) {
return UTF8String.fromString(CollationAwareUTF8String.toLowerCase(v.toString(), collationId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import com.ibm.icu.lang.UCharacter;

import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -370,24 +371,34 @@ public UTF8String toUpperCase() {
if (numBytes == 0) {
return EMPTY_UTF8;
}
// Optimization - do char level uppercase conversion in case of chars in ASCII range
for (int i = 0; i < numBytes; i++) {
if (getByte(i) < 0) {
// non-ASCII
return toUpperCaseSlow();
}

return isFullAscii() ? toUpperCaseAscii() : toUpperCaseSlowJVM();
}

public UTF8String toUpperCaseICU() {
if (numBytes == 0) {
return EMPTY_UTF8;
}
byte[] bytes = new byte[numBytes];
for (int i = 0; i < numBytes; i++) {

return isFullAscii() ? toUpperCaseAscii() : toUpperCaseSlowICU();
}

private UTF8String toUpperCaseAscii() {
final var bytes = new byte[numBytes];
for (var i = 0; i < numBytes; i++) {
bytes[i] = (byte) Character.toUpperCase(getByte(i));
}
return fromBytes(bytes);
}

private UTF8String toUpperCaseSlow() {
private UTF8String toUpperCaseSlowJVM() {
return fromString(toString().toUpperCase());
}

private UTF8String toUpperCaseSlowICU() {
return fromString(UCharacter.toUpperCase(toString()));
}

/**
* Optimized lowercase comparison for UTF8_BINARY_LCASE collation
* a.compareLowerCase(b) is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase())
Expand All @@ -413,7 +424,7 @@ private int compareLowerCaseSuffixSlow(UTF8String other, int pref) {
numBytes - pref);
UTF8String suffixRight = UTF8String.fromAddress(other.base, other.offset + pref,
other.numBytes - pref);
return suffixLeft.toLowerCaseSlow().binaryCompare(suffixRight.toLowerCaseSlow());
return suffixLeft.toLowerCaseSlowICU().binaryCompare(suffixRight.toLowerCaseSlowICU());
}

/**
Expand All @@ -424,7 +435,15 @@ public UTF8String toLowerCase() {
return EMPTY_UTF8;
}

return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlow();
return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlowJVM();
}

public UTF8String toLowerCaseICU() {
if (numBytes == 0) {
return EMPTY_UTF8;
}

return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlowICU();
}

private boolean isFullAscii() {
Expand All @@ -436,10 +455,14 @@ private boolean isFullAscii() {
return true;
}

private UTF8String toLowerCaseSlow() {
private UTF8String toLowerCaseSlowJVM() {
return fromString(toString().toLowerCase());
}

private UTF8String toLowerCaseSlowICU() {
return fromString(UCharacter.toLowerCase(toString()));
}

private UTF8String toLowerCaseAscii() {
final var bytes = new byte[numBytes];
for (var i = 0; i < numBytes; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,26 @@ class UTF8StringPropertyCheckSuite extends AnyFunSuite with ScalaCheckDrivenProp

// scalastyle:off caselocale
test("toUpperCase") {
forAll { (s: String) =>
assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase))
val useICU = SQLConf.conf.getConf(ICU_CASE_MAPPINGS_ENABLED).getKey
forAll { (s: String) => {
if (useICU) {
assert(toUTF8(s).toUpperCase === toUTF8(UCharacter.toUpperCase(s)))
} else {
assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase))
}
}
}
}

test("toLowerCase") {
forAll { (s: String) =>
assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase))
val useICU = SQLConf.conf.getConf(ICU_CASE_MAPPINGS_ENABLED).getKey
forAll { (s: String) => {
if (useICU) {
assert(toUTF8(s).toLowerCase === toUTF8(UCharacter.toLowerCase(s)))
} else {
assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase))
}
}
}
}
// scalastyle:on caselocale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,14 +453,17 @@ trait String2StringExpression extends ImplicitCastInputTypes {
case class Upper(child: Expression)
extends UnaryExpression with String2StringExpression with NullIntolerant {

private final lazy val useICU = SQLConf.get.spark.sql.icu.caseMappings.enabled

final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId

override def convert(v: UTF8String): UTF8String = CollationSupport.Upper.exec(v, collationId)
override def convert(v: UTF8String): UTF8String =
CollationSupport.Upper.exec(v, collationId, useICU)

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId))
defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId, useICU))
}

override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild)
Expand All @@ -481,14 +484,17 @@ case class Upper(child: Expression)
case class Lower(child: Expression)
extends UnaryExpression with String2StringExpression with NullIntolerant {

private final lazy val useICU = SQLConf.get.spark.sql.icu.caseMappings.enabled

final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId

override def convert(v: UTF8String): UTF8String = CollationSupport.Lower.exec(v, collationId)
override def convert(v: UTF8String): UTF8String =
CollationSupport.Lower.exec(v, collationId, useICU)

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId))
defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId, useICU))
}

override def prettyName: String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,14 @@ object SQLConf {
_ => Map())
.createWithDefault("UTF8_BINARY")

val ICU_CASE_MAPPINGS_ENABLED =
buildConf("spark.sql.icu.caseMappings.enabled")
.doc("When enabled we use the ICU library (instead of the JVM) to implement case mappings" +
" for strings.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val FETCH_SHUFFLE_BLOCKS_IN_BATCH =
buildConf("spark.sql.adaptive.fetchShuffleBlocksInBatch")
.internal()
Expand Down

0 comments on commit a20e244

Please sign in to comment.