From 9b0e196cdd0044dc6b2446f71ee2ddbf42a35d59 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 7 Oct 2024 16:48:48 +0200 Subject: [PATCH 01/14] initial support for hashing and comparison. --- .../sql/catalyst/util/CollationFactory.java | 71 +++++++++++++++---- 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 113c5f866fd88..019a695803a64 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -161,6 +161,7 @@ public Collation( Comparator comparator, String version, ToLongFunction hashFunction, + BiFunction equalsFunction, boolean supportsBinaryEquality, boolean supportsBinaryOrdering, boolean supportsLowercaseEquality) { @@ -173,6 +174,7 @@ public Collation( this.supportsBinaryEquality = supportsBinaryEquality; this.supportsBinaryOrdering = supportsBinaryOrdering; this.supportsLowercaseEquality = supportsLowercaseEquality; + this.equalsFunction = equalsFunction; // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality assert(!supportsBinaryOrdering || supportsBinaryEquality); @@ -180,12 +182,6 @@ public Collation( assert(!supportsBinaryEquality || !supportsLowercaseEquality); assert(SUPPORTED_PROVIDERS.contains(provider)); - - if (supportsBinaryEquality) { - this.equalsFunction = UTF8String::equals; - } else { - this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0; - } } /** @@ -538,24 +534,57 @@ private static boolean isValidCollationId(int collationId) { @Override protected Collation buildCollation() { if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { + Comparator comparator; + ToLongFunction hashFunction; + BiFunction equalsFunction; + + if(spaceTrimming == SpaceTrimming.NONE) { + comparator = UTF8String::binaryCompare; + hashFunction = s -> (long) s.hashCode(); + equalsFunction = UTF8String::equals; + }else { + comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); + equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( + applyTrimmingPolicy(s2, spaceTrimming)); + } + return new Collation( normalizedCollationName(), PROVIDER_SPARK, null, - UTF8String::binaryCompare, + comparator, "1.0", - s -> (long) s.hashCode(), + hashFunction, + equalsFunction, /* supportsBinaryEquality = */ true, /* supportsBinaryOrdering = */ true, /* supportsLowercaseEquality = */ false); } else { + Comparator comparator; + ToLongFunction hashFunction; + + if(spaceTrimming == SpaceTrimming.NONE ) { + comparator = CollationAwareUTF8String::compareLowerCase; + hashFunction = s -> + (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + }else{ + comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( + applyTrimmingPolicy(s1, spaceTrimming), + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) CollationAwareUTF8String. + lowerCaseCodePoints(applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + } + return new Collation( normalizedCollationName(), PROVIDER_SPARK, null, - CollationAwareUTF8String::compareLowerCase, + comparator, "1.0", - s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(), + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, /* supportsLowercaseEquality = */ true); @@ -917,13 +946,31 @@ protected Collation buildCollation() { Collator collator = Collator.getInstance(resultLocale); // Freeze ICU collator to ensure thread safety. collator.freeze(); + + Comparator comparator; + ToLongFunction hashFunction; + + if(spaceTrimming == SpaceTrimming.NONE){ + hashFunction = s -> (long) collator.getCollationKey( + s.toValidString()).hashCode(); + comparator = (s1, s2) -> + collator.compare(s1.toValidString(), s2.toValidString()); + } else { + comparator = (s1, s2) -> collator.compare( + applyTrimmingPolicy(s1, spaceTrimming).toValidString(), + applyTrimmingPolicy(s2, spaceTrimming).toValidString()); + hashFunction = s -> (long) collator.getCollationKey( + applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + } + return new Collation( normalizedCollationName(), PROVIDER_ICU, collator, - (s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()), + comparator, ICU_COLLATOR_VERSION, - s -> (long) collator.getCollationKey(s.toValidString()).hashCode(), + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, /* supportsLowercaseEquality = */ false); From 9f4e1044a9a6c0833a741b1cd2222b6ebcb91629 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 7 Oct 2024 17:02:37 +0200 Subject: [PATCH 02/14] deprecate uses trim collation. --- .../sql/catalyst/util/CollationFactory.java | 34 +++++++++++-------- .../apache/spark/sql/types/StringType.scala | 2 +- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 019a695803a64..85db5b02a6992 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -154,6 +154,12 @@ public static class Collation { */ public final boolean supportsLowercaseEquality; + /** + * Support for Space Trimming implies that that based on specifier (for now only right trim) + * leading, trailing or both spaces are removed from the input string before comparison. + */ + public final boolean supportsSpaceTrimming; + public Collation( String collationName, String provider, @@ -164,7 +170,8 @@ public Collation( BiFunction equalsFunction, boolean supportsBinaryEquality, boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality) { + boolean supportsLowercaseEquality, + boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; this.collator = collator; @@ -175,6 +182,7 @@ public Collation( this.supportsBinaryOrdering = supportsBinaryOrdering; this.supportsLowercaseEquality = supportsLowercaseEquality; this.equalsFunction = equalsFunction; + this.supportsSpaceTrimming = supportsSpaceTrimming; // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality assert(!supportsBinaryOrdering || supportsBinaryEquality); @@ -537,6 +545,7 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; BiFunction equalsFunction; + boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; if(spaceTrimming == SpaceTrimming.NONE) { comparator = UTF8String::binaryCompare; @@ -560,7 +569,8 @@ protected Collation buildCollation() { equalsFunction, /* supportsBinaryEquality = */ true, /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false); + /* supportsLowercaseEquality = */ false, + spaceTrimming != SpaceTrimming.NONE); } else { Comparator comparator; ToLongFunction hashFunction; @@ -587,7 +597,8 @@ protected Collation buildCollation() { (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true); + /* supportsLowercaseEquality = */ true, + spaceTrimming != SpaceTrimming.NONE); } } @@ -973,7 +984,8 @@ protected Collation buildCollation() { (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false); + /* supportsLowercaseEquality = */ false, + spaceTrimming != SpaceTrimming.NONE); } @Override @@ -1150,14 +1162,6 @@ public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { Collation.CollationSpecICU.AccentSensitivity.AI; } - /** - * Returns whether the collation uses trim collation for the given collation id. - */ - public static boolean usesTrimCollation(int collationId) { - return Collation.CollationSpec.getSpaceTrimming(collationId) != - Collation.CollationSpec.SpaceTrimming.NONE; - } - public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( @@ -1184,10 +1188,10 @@ public static String[] getICULocaleNames() { public static UTF8String getCollationKey(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (usesTrimCollation(collationId)) { + if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { return input; } else if (collation.supportsLowercaseEquality) { return CollationAwareUTF8String.lowerCaseCodePoints(input); @@ -1200,7 +1204,7 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (usesTrimCollation(collationId)) { + if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } if (collation.supportsBinaryEquality) { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 29d48e3d1f47f..b920f32200c7f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -48,7 +48,7 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) private[sql] def usesTrimCollation: Boolean = - CollationFactory.usesTrimCollation(collationId) + CollationFactory.fetchCollation(collationId).supportsLowercaseEquality private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID From 20389753aab212fb2595f11b6d3cb2745b2972e9 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Tue, 8 Oct 2024 15:44:00 +0200 Subject: [PATCH 03/14] add tests. --- .../apache/spark/sql/types/StringType.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 6 +-- .../sql/catalyst/util/UnsafeRowUtils.scala | 4 +- .../aggregate/HashMapGenerator.scala | 5 +- .../sql/CollationSQLExpressionsSuite.scala | 50 ++++++++++++++++--- .../org/apache/spark/sql/CollationSuite.scala | 27 ++++++++-- 6 files changed, 77 insertions(+), 17 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index b920f32200c7f..1c93c2ad550e9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -48,7 +48,7 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) private[sql] def usesTrimCollation: Boolean = - CollationFactory.fetchCollation(collationId).supportsLowercaseEquality + CollationFactory.fetchCollation(collationId).supportsSpaceTrimming private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 3a667f370428e..7128190902550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -415,7 +415,7 @@ abstract class HashExpression[E] extends Expression { protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" @@ -566,7 +566,7 @@ abstract class InterpretedHashFunction { hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) case s: UTF8String => val st = dataType.asInstanceOf[StringType] - if (st.supportsBinaryEquality) { + if (st.supportsBinaryEquality && !st.usesTrimCollation) { hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) } else { val stringHash = CollationFactory @@ -817,7 +817,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index e296b5be6134b..a60a3d3854b4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -205,7 +205,9 @@ object UnsafeRowUtils { * can lead to rows being semantically equal even though their binary representations differ). */ def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { - case st: StringType => !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality + case st: StringType => + (!CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality || + CollationFactory.fetchCollation(st.collationId).supportsSpaceTrimming) case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 45a71b4da7287..3b1f349520f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,8 +173,9 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") - case st: StringType if !st.supportsBinaryEquality => + case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation => + hashBytes(s"$input.getBytes()") + case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") case CalendarIntervalType => hashInt(s"$input.hashCode()") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 4c3cd93873bd4..fd83408da7f74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -49,9 +49,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Md5TestCase("Spark", "UTF8_BINARY", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_BINARY_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("Spark", "UTF8_LCASE", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_LCASE_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("SQL", "UNICODE", "9778840a0100cb30c982876741b0b5a2"), - Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2") + Md5TestCase("SQL", "UNICODE_RTRIM", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI_RTRIM", "9778840a0100cb30c982876741b0b5a2") ) // Supported collations @@ -81,11 +85,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha2TestCase("Spark", "UTF8_BINARY", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_BINARY_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("Spark", "UTF8_LCASE", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_LCASE_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("SQL", "UNICODE", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_RTRIM", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), Sha2TestCase("SQL", "UNICODE_CI", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_CI_RTRIM", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35") ) @@ -114,9 +126,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha1TestCase("Spark", "UTF8_BINARY", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_BINARY_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("Spark", "UTF8_LCASE", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_LCASE_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("SQL", "UNICODE", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), - Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") + Sha1TestCase("SQL", "UNICODE_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") ) // Supported collations @@ -144,9 +160,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Crc321TestCase("Spark", "UTF8_BINARY", 1557323817), + Crc321TestCase("Spark", "UTF8_BINARY_RTRIM", 1557323817), Crc321TestCase("Spark", "UTF8_LCASE", 1557323817), + Crc321TestCase("Spark", "UTF8_LCASE_RTRIM", 1557323817), Crc321TestCase("SQL", "UNICODE", 1299261525), - Crc321TestCase("SQL", "UNICODE_CI", 1299261525) + Crc321TestCase("SQL", "UNICODE_RTRIM", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI_RTRIM", 1299261525) ) // Supported collations @@ -172,9 +192,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Murmur3HashTestCase("Spark", "UTF8_BINARY", 228093765), + Murmur3HashTestCase("Spark ", "UTF8_BINARY_RTRIM", 1779328737), Murmur3HashTestCase("Spark", "UTF8_LCASE", -1928694360), + Murmur3HashTestCase("Spark ", "UTF8_LCASE_RTRIM", -1928694360), Murmur3HashTestCase("SQL", "UNICODE", -1923567940), - Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950) + Murmur3HashTestCase("SQL ", "UNICODE_RTRIM", -1923567940), + Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950), + Murmur3HashTestCase("SQL ", "UNICODE_CI_RTRIM", 1029527950) ) // Supported collations @@ -200,9 +224,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( XxHash64TestCase("Spark", "UTF8_BINARY", -4294468057691064905L), + XxHash64TestCase("Spark ", "UTF8_BINARY_RTRIM", 6480371823304753502L), XxHash64TestCase("Spark", "UTF8_LCASE", -3142112654825786434L), + XxHash64TestCase("Spark ", "UTF8_LCASE_RTRIM", -3142112654825786434L), XxHash64TestCase("SQL", "UNICODE", 5964849564945649886L), - XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L) + XxHash64TestCase("SQL ", "UNICODE_RTRIM", 5964849564945649886L), + XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L), + XxHash64TestCase("SQL ", "UNICODE_CI_RTRIM", 3732497619779520590L) ) // Supported collations @@ -469,9 +497,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( BinTestCase("13", "UTF8_BINARY", "1101"), + BinTestCase("13", "UTF8_BINARY_RTRIM", "1101"), BinTestCase("13", "UTF8_LCASE", "1101"), + BinTestCase("13", "UTF8_LCASE_RTRIM", "1101"), BinTestCase("13", "UNICODE", "1101"), - BinTestCase("13", "UNICODE_CI", "1101") + BinTestCase("13", "UNICODE_RTRIM", "1101"), + BinTestCase("13", "UNICODE_CI", "1101"), + BinTestCase("13", "UNICODE_CI_RTRIM", "1101"), ) testCases.foreach(t => { val query = @@ -494,9 +526,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( HexTestCase("13", "UTF8_BINARY", "D"), + HexTestCase("13", "UTF8_BINARY_RTRIM", "D"), HexTestCase("13", "UTF8_LCASE", "D"), + HexTestCase("13", "UTF8_LCASE_RTRIM", "D"), HexTestCase("13", "UNICODE", "D"), - HexTestCase("13", "UNICODE_CI", "D") + HexTestCase("13", "UNICODE_RTRIM", "D"), + HexTestCase("13", "UNICODE_CI", "D"), + HexTestCase("13", "UNICODE_CI_RTRIM", "D") ) testCases.foreach(t => { val query = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index ef01f71c68bf9..e12c2838b88ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -101,8 +101,12 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("collate function syntax") { assert(sql(s"select collate('aaa', 'utf8_binary')").schema(0).dataType == StringType("UTF8_BINARY")) + assert(sql(s"select collate('aaa', 'utf8_binary_rtrim')").schema(0).dataType == + StringType("UTF8_BINARY_RTRIM")) assert(sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == StringType("UTF8_LCASE")) + assert(sql(s"select collate('aaa', 'utf8_lcase_rtrim')").schema(0).dataType == + StringType("UTF8_LCASE_RTRIM")) } test("collate function syntax with default collation set") { @@ -260,14 +264,23 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq( ("utf8_binary", "aaa", "AAA", false), ("utf8_binary", "aaa", "aaa", true), + ("utf8_binary_rtrim", "aaa", "AAA", false), + ("utf8_binary_rtrim", "aaa", "aaa ", true), ("utf8_lcase", "aaa", "aaa", true), ("utf8_lcase", "aaa", "AAA", true), ("utf8_lcase", "aaa", "bbb", false), + ("utf8_lcase_rtrim", "aaa", "AAA ", true), + ("utf8_lcase_rtrim", "aaa", "bbb", false), ("unicode", "aaa", "aaa", true), ("unicode", "aaa", "AAA", false), + ("unicode_rtrim", "aaa ", "aaa ", true), + ("unicode_rtrim", "aaa", "AAA", false), ("unicode_CI", "aaa", "aaa", true), ("unicode_CI", "aaa", "AAA", true), - ("unicode_CI", "aaa", "bbb", false) + ("unicode_CI", "aaa", "bbb", false), + ("unicode_CI_rtrim", "aaa", "aaa", true), + ("unicode_CI_rtrim", "aaa ", "AAA ", true), + ("unicode_CI_rtrim", "aaa", "bbb", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -284,15 +297,19 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("utf8_binary", "AAA", "aaa", true), ("utf8_binary", "aaa", "aaa", false), ("utf8_binary", "aaa", "BBB", false), + ("utf8_binary_rtrim", "aaa ", "aaa ", false), ("utf8_lcase", "aaa", "aaa", false), ("utf8_lcase", "AAA", "aaa", false), ("utf8_lcase", "aaa", "bbb", true), + ("utf8_lcase_rtrim", "AAA ", "aaa", false), ("unicode", "aaa", "aaa", false), ("unicode", "aaa", "AAA", true), ("unicode", "aaa", "BBB", true), + ("unicode_rtrim", "aaa ", "aaa", false), ("unicode_CI", "aaa", "aaa", false), ("unicode_CI", "aaa", "AAA", false), - ("unicode_CI", "aaa", "bbb", true) + ("unicode_CI", "aaa", "bbb", true), + ("unicode_CI_rtrim", "aaa ", "aaa", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -355,18 +372,22 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("aggregates count respects collation") { Seq( + ("utf8_binary_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("utf8_binary", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), ("utf8_lcase", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_lcase", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), ("utf8_lcase", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("utf8_lcase_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))), ("unicode", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("unicode", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) + ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_CI_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))) ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => checkAnswer(sql( From c513451fdb8f09940bf6ff95ef603555ab4e520b Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Tue, 8 Oct 2024 16:05:16 +0200 Subject: [PATCH 04/14] add more tests. --- .../unsafe/types/CollationFactorySuite.scala | 21 +++++++++++++++++-- .../apache/spark/sql/types/StringType.scala | 3 --- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index ff40f16e5a052..491abeab58e01 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -127,6 +127,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), CollationTestCase("UTF8_BINARY", "aaa", "bbb", false), CollationTestCase("UTF8_BINARY", "å", "a\u030A", false), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -134,15 +135,18 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", false), CollationTestCase("UTF8_LCASE", "aaa", "bbb", false), CollationTestCase("UTF8_LCASE", "å", "a\u030A", false), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA", true), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), CollationTestCase("UNICODE", "å", "a\u030A", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), - CollationTestCase("UNICODE_CI", "Å", "a\u030A", true) + CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", true), ) checks.foreach(testCase => { @@ -162,19 +166,32 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", 1), CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "aa", 1), CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), CollationTestCase("UNICODE", "aaa", "BBB", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UNICODE_RTRIM", "aaa", "AAA ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), - CollationTestCase("UNICODE_CI", "aaa", "bbb", -1)) + CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "aaa ", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + ) checks.foreach(testCase => { val collation = fetchCollation(testCase.collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1c93c2ad550e9..e07471a15b6a3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -41,9 +41,6 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def supportsBinaryEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsBinaryEquality - private[sql] def supportsLowercaseEquality: Boolean = - CollationFactory.fetchCollation(collationId).supportsLowercaseEquality - private[sql] def isNonCSAI: Boolean = !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) From 2f6705c81bbd0799f982003acedb01a5f6355212 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Tue, 8 Oct 2024 16:14:06 +0200 Subject: [PATCH 05/14] nit fix. --- .../org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index a60a3d3854b4a..40b8bccafaad2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -206,8 +206,8 @@ object UnsafeRowUtils { */ def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { case st: StringType => - (!CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality || - CollationFactory.fetchCollation(st.collationId).supportsSpaceTrimming) + val collation = CollationFactory.fetchCollation(st.collationId) + (!collation.supportsBinaryEquality || collation.supportsSpaceTrimming) case _ => false } } From 6cdcf5b001b23791504dcdb964474727527e563b Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Wed, 9 Oct 2024 11:06:42 +0200 Subject: [PATCH 06/14] fix scala style. --- .../sql/catalyst/util/CollationFactory.java | 30 +++++++++---------- .../unsafe/types/CollationFactorySuite.scala | 4 +-- .../sql/CollationSQLExpressionsSuite.scala | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 85db5b02a6992..b1d11e96d9bbd 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -551,12 +551,12 @@ protected Collation buildCollation() { comparator = UTF8String::binaryCompare; hashFunction = s -> (long) s.hashCode(); equalsFunction = UTF8String::equals; - }else { + } else { comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( - applyTrimmingPolicy(s2, spaceTrimming)); + applyTrimmingPolicy(s2, spaceTrimming)); hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( - applyTrimmingPolicy(s2, spaceTrimming)); + applyTrimmingPolicy(s2, spaceTrimming)); } return new Collation( @@ -575,16 +575,16 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if(spaceTrimming == SpaceTrimming.NONE ) { + if (spaceTrimming == SpaceTrimming.NONE ) { comparator = CollationAwareUTF8String::compareLowerCase; hashFunction = s -> - (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); - }else{ + (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + } else { comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( - applyTrimmingPolicy(s1, spaceTrimming), - applyTrimmingPolicy(s2, spaceTrimming)); + applyTrimmingPolicy(s1, spaceTrimming), + applyTrimmingPolicy(s2, spaceTrimming)); hashFunction = s -> (long) CollationAwareUTF8String. - lowerCaseCodePoints(applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + lowerCaseCodePoints(applyTrimmingPolicy(s, spaceTrimming)).hashCode(); } return new Collation( @@ -961,17 +961,17 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if(spaceTrimming == SpaceTrimming.NONE){ + if (spaceTrimming == SpaceTrimming.NONE){ hashFunction = s -> (long) collator.getCollationKey( - s.toValidString()).hashCode(); + s.toValidString()).hashCode(); comparator = (s1, s2) -> - collator.compare(s1.toValidString(), s2.toValidString()); + collator.compare(s1.toValidString(), s2.toValidString()); } else { comparator = (s1, s2) -> collator.compare( - applyTrimmingPolicy(s1, spaceTrimming).toValidString(), - applyTrimmingPolicy(s2, spaceTrimming).toValidString()); + applyTrimmingPolicy(s1, spaceTrimming).toValidString(), + applyTrimmingPolicy(s2, spaceTrimming).toValidString()); hashFunction = s -> (long) collator.getCollationKey( - applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); } return new Collation( diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 491abeab58e01..88ef9a3c2d83f 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -146,7 +146,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", true) ) checks.foreach(testCase => { @@ -190,7 +190,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), CollationTestCase("UNICODE_CI_RTRIM", "aaa", "aaa ", 0), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1) ) checks.foreach(testCase => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index fd83408da7f74..ac8ad69dd55d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -503,7 +503,7 @@ class CollationSQLExpressionsSuite BinTestCase("13", "UNICODE", "1101"), BinTestCase("13", "UNICODE_RTRIM", "1101"), BinTestCase("13", "UNICODE_CI", "1101"), - BinTestCase("13", "UNICODE_CI_RTRIM", "1101"), + BinTestCase("13", "UNICODE_CI_RTRIM", "1101") ) testCases.foreach(t => { val query = From fef3a7167aff4ff9b9f561378a152503663ad00f Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Wed, 9 Oct 2024 19:37:26 +0200 Subject: [PATCH 07/14] fix bug. --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index b1d11e96d9bbd..0868fbf6da4b8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -1191,7 +1191,7 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsSpaceTrimming) { + if (collation.supportsBinaryEquality) { return input; } else if (collation.supportsLowercaseEquality) { return CollationAwareUTF8String.lowerCaseCodePoints(input); From c9c33d9f9c65ec731fe26946c3bdac14a53fe965 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Thu, 10 Oct 2024 08:22:55 +0200 Subject: [PATCH 08/14] fix style. --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 0868fbf6da4b8..6c6594c0b94af 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -583,8 +583,8 @@ protected Collation buildCollation() { comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( applyTrimmingPolicy(s1, spaceTrimming), applyTrimmingPolicy(s2, spaceTrimming)); - hashFunction = s -> (long) CollationAwareUTF8String. - lowerCaseCodePoints(applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).hashCode(); } return new Collation( From d77b2932eb4b07e6274f64adcdfccbf8bbb5f565 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 11:58:10 +0200 Subject: [PATCH 09/14] add more tests. --- .../sql/catalyst/util/CollationFactory.java | 4 +- .../unsafe/types/CollationFactorySuite.scala | 38 +++++++++++++++---- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 6c6594c0b94af..636fe09b0f3b3 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -547,7 +547,7 @@ protected Collation buildCollation() { BiFunction equalsFunction; boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; - if(spaceTrimming == SpaceTrimming.NONE) { + if (spaceTrimming == SpaceTrimming.NONE) { comparator = UTF8String::binaryCompare; hashFunction = s -> (long) s.hashCode(); equalsFunction = UTF8String::equals; @@ -575,7 +575,7 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if (spaceTrimming == SpaceTrimming.NONE ) { + if (spaceTrimming == SpaceTrimming.NONE) { comparator = CollationAwareUTF8String::compareLowerCase; hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 88ef9a3c2d83f..f6ac21d951f83 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -127,7 +127,10 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), CollationTestCase("UTF8_BINARY", "aaa", "bbb", false), CollationTestCase("UTF8_BINARY", "å", "a\u030A", false), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa", true), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -135,18 +138,27 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", false), CollationTestCase("UTF8_LCASE", "aaa", "bbb", false), CollationTestCase("UTF8_LCASE", "å", "a\u030A", false), - CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), CollationTestCase("UNICODE", "å", "a\u030A", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa", true), CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", true) + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false) ) checks.foreach(testCase => { @@ -167,8 +179,11 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", 0), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), @@ -176,21 +191,30 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", 1), CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA ", 0), CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), CollationTestCase("UNICODE", "aaa", "BBB", -1), - CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), - CollationTestCase("UNICODE_RTRIM", "aaa", "AAA ", -1), - CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), - CollationTestCase("UNICODE_CI_RTRIM", "aaa", "aaa ", 0), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1) + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1) ) checks.foreach(testCase => { From 7e781775dd284bf3af91c8b621117cead681011e Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 12:08:33 +0200 Subject: [PATCH 10/14] nit fixes. --- .../org/apache/spark/sql/types/StringType.scala | 3 +++ .../spark/sql/CollationSQLExpressionsSuite.scala | 12 ++---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index e07471a15b6a3..1c93c2ad550e9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -41,6 +41,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def supportsBinaryEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsBinaryEquality + private[sql] def supportsLowercaseEquality: Boolean = + CollationFactory.fetchCollation(collationId).supportsLowercaseEquality + private[sql] def isNonCSAI: Boolean = !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index ac8ad69dd55d6..ce6818652d2b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -497,13 +497,9 @@ class CollationSQLExpressionsSuite val testCases = Seq( BinTestCase("13", "UTF8_BINARY", "1101"), - BinTestCase("13", "UTF8_BINARY_RTRIM", "1101"), BinTestCase("13", "UTF8_LCASE", "1101"), - BinTestCase("13", "UTF8_LCASE_RTRIM", "1101"), BinTestCase("13", "UNICODE", "1101"), - BinTestCase("13", "UNICODE_RTRIM", "1101"), - BinTestCase("13", "UNICODE_CI", "1101"), - BinTestCase("13", "UNICODE_CI_RTRIM", "1101") + BinTestCase("13", "UNICODE_CI", "1101") ) testCases.foreach(t => { val query = @@ -526,13 +522,9 @@ class CollationSQLExpressionsSuite val testCases = Seq( HexTestCase("13", "UTF8_BINARY", "D"), - HexTestCase("13", "UTF8_BINARY_RTRIM", "D"), HexTestCase("13", "UTF8_LCASE", "D"), - HexTestCase("13", "UTF8_LCASE_RTRIM", "D"), HexTestCase("13", "UNICODE", "D"), - HexTestCase("13", "UNICODE_RTRIM", "D"), - HexTestCase("13", "UNICODE_CI", "D"), - HexTestCase("13", "UNICODE_CI_RTRIM", "D") + HexTestCase("13", "UNICODE_CI", "D") ) testCases.foreach(t => { val query = From af27d43000f1f40b494c47155127fefda4dc03de Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 15:10:06 +0200 Subject: [PATCH 11/14] nit fixes. --- .../spark/sql/catalyst/util/CollationFactory.java | 2 +- .../spark/unsafe/types/CollationFactorySuite.scala | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 636fe09b0f3b3..01f6c7e0331b0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -961,7 +961,7 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if (spaceTrimming == SpaceTrimming.NONE){ + if (spaceTrimming == SpaceTrimming.NONE) { hashFunction = s -> (long) collator.getCollationKey( s.toValidString()).hashCode(); comparator = (s1, s2) -> diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index f6ac21d951f83..039babcbb01c3 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -131,6 +131,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " ", true), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -142,6 +143,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", true), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), @@ -150,6 +152,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), @@ -158,7 +161,8 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), - CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false) + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true) ) checks.foreach(testCase => { @@ -184,6 +188,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " " , 0), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), @@ -196,6 +201,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), @@ -206,6 +212,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), + CollationTestCase("UNICODE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), @@ -214,7 +221,8 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1) + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0) ) checks.foreach(testCase => { From 1d73ad6bc92a4811b16292d263a6fe9c1ad7b68e Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 18:46:49 +0200 Subject: [PATCH 12/14] init commit. --- .../util/CollationAwareUTF8String.java | 4 +- .../sql/catalyst/util/CollationFactory.java | 54 ++++---- .../sql/catalyst/util/CollationSupport.java | 122 +++++++++--------- .../unsafe/types/CollationFactorySuite.scala | 8 +- .../spark/sql/catalyst/expressions/hash.scala | 6 +- .../sql/catalyst/util/UnsafeRowUtils.scala | 2 +- .../expressions/HashExpressionsSuite.scala | 2 +- .../aggregate/HashMapGenerator.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 12 +- 9 files changed, 111 insertions(+), 103 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index fb610a5d96f17..d67697eaea38b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1363,9 +1363,9 @@ public static UTF8String trimRight( public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, final int limit, final int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { return input.split(delim, limit); - } else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { return lowercaseSplitSQL(input, delim, limit); } else { return icuSplitSQL(input, delim, limit, collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 01f6c7e0331b0..b9dce20cf0345 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -154,12 +154,25 @@ public static class Collation { */ public final boolean supportsLowercaseEquality; + /** * Support for Space Trimming implies that that based on specifier (for now only right trim) * leading, trailing or both spaces are removed from the input string before comparison. */ public final boolean supportsSpaceTrimming; + /** + * Is Utf8 binary type as indicator if collation base type is UTF8 binary. Note currently only + * collations Utf8_Binary and Utf8_Binary_RTRIM are considered as Utf8 binary type. + */ + public final boolean isUtf8BinaryType; + + /** + * Is Utf8 lcase type as indicator if collation base type is UTF8 lcase. Note currently only + * collations Utf8_Lcase and Utf8_Lcase_RTRIM are considered as Utf8 Lcase type. + */ + public final boolean isUtf8LcaseType; + public Collation( String collationName, String provider, @@ -168,9 +181,8 @@ public Collation( String version, ToLongFunction hashFunction, BiFunction equalsFunction, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality, + boolean isUtf8BinaryType, + boolean isUtf8LcaseType, boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; @@ -178,16 +190,15 @@ public Collation( this.comparator = comparator; this.version = version; this.hashFunction = hashFunction; - this.supportsBinaryEquality = supportsBinaryEquality; - this.supportsBinaryOrdering = supportsBinaryOrdering; - this.supportsLowercaseEquality = supportsLowercaseEquality; + this.isUtf8BinaryType = isUtf8BinaryType; + this.isUtf8LcaseType = isUtf8LcaseType; this.equalsFunction = equalsFunction; this.supportsSpaceTrimming = supportsSpaceTrimming; - - // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality - assert(!supportsBinaryOrdering || supportsBinaryEquality); + this.supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType; // No Collation can simultaneously support binary equality and lowercase equality - assert(!supportsBinaryEquality || !supportsLowercaseEquality); + assert(!isUtf8BinaryType || !isUtf8LcaseType); assert(SUPPORTED_PROVIDERS.contains(provider)); } @@ -567,9 +578,8 @@ protected Collation buildCollation() { "1.0", hashFunction, equalsFunction, - /* supportsBinaryEquality = */ true, - /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ true, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } else { Comparator comparator; @@ -595,9 +605,8 @@ protected Collation buildCollation() { "1.0", hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ true, spaceTrimming != SpaceTrimming.NONE); } } @@ -982,9 +991,8 @@ protected Collation buildCollation() { ICU_COLLATOR_VERSION, hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } @@ -1191,9 +1199,9 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input; - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input); } else { CollationKey collationKey = collation.collator.getCollationKey( @@ -1207,9 +1215,9 @@ public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input.getBytes(); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes(); } else { return collation.collator.getCollationKey( diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index f05d9e512568f..978b663cc25c9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -37,9 +37,9 @@ public final class CollationSupport { public static class StringSplitSQL { public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(s, d); } else { return execICU(s, d, collationId); @@ -48,9 +48,9 @@ public static UTF8String[] exec(final UTF8String s, final UTF8String d, final in public static String genCode(final String s, final String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplitSQL.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", s, d); } else { return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); @@ -71,9 +71,9 @@ public static UTF8String[] execICU(final UTF8String string, final UTF8String del public static class Contains { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -82,9 +82,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Contains.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -109,9 +109,9 @@ public static class StartsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -120,9 +120,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StartsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -146,9 +146,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class EndsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -157,9 +157,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.EndsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -184,9 +184,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class Upper { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -195,10 +195,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -221,9 +221,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class Lower { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -232,10 +232,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Lower.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -258,9 +258,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class InitCap { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -270,10 +270,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.InitCap.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -296,7 +296,7 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(word, set); } else { return execCollationAware(word, set, collationId); @@ -305,7 +305,7 @@ public static int exec(final UTF8String word, final UTF8String set, final int co public static String genCode(final String word, final String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", word, set); } else { return String.format(expr + "CollationAware(%s, %s, %d)", word, set, collationId); @@ -324,9 +324,9 @@ public static class StringInstr { public static int exec(final UTF8String string, final UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring); } else { return execICU(string, substring, collationId); @@ -336,9 +336,9 @@ public static String genCode(final String string, final String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); @@ -360,9 +360,9 @@ public static class StringReplace { public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(src, search, replace); } else { return execICU(src, search, replace, collationId); @@ -372,9 +372,9 @@ public static String genCode(final String src, final String search, final String final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringReplace.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); @@ -398,9 +398,9 @@ public static class StringLocate { public static int exec(final UTF8String string, final UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring, start); } else { return execICU(string, substring, start, collationId); @@ -410,9 +410,9 @@ public static String genCode(final String string, final String substring, final final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringLocate.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); @@ -436,9 +436,9 @@ public static class SubstringIndex { public static UTF8String exec(final UTF8String string, final UTF8String delimiter, final int count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, delimiter, count); } else { return execICU(string, delimiter, count, collationId); @@ -448,9 +448,9 @@ public static String genCode(final String string, final String delimiter, final String count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.SubstringIndex.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", string, delimiter, count); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", string, delimiter, count, collationId); @@ -474,9 +474,9 @@ public static class StringTranslate { public static UTF8String exec(final UTF8String source, Map dict, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(source, dict); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(source, dict); } else { return execICU(source, dict, collationId); @@ -503,9 +503,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -520,9 +520,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrim.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -559,9 +559,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -576,9 +576,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimLeft.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -614,9 +614,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -631,9 +631,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimRight.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -669,7 +669,7 @@ public static UTF8String execICU( public static boolean supportsLowercaseRegex(final int collationId) { // for regex, only Unicode case-insensitive matching is possible, // so UTF8_LCASE is treated as UNICODE_CI in this context - return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality; + return CollationFactory.fetchCollation(collationId).isUtf8LcaseType; } static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 039babcbb01c3..4672c39d9be8a 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -38,22 +38,22 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig assert(UTF8_BINARY_COLLATION_ID == 0) val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") - assert(utf8Binary.supportsBinaryEquality) + assert(utf8Binary.isUtf8BinaryType) assert(UTF8_LCASE_COLLATION_ID == 1) val utf8Lcase = fetchCollation(UTF8_LCASE_COLLATION_ID) assert(utf8Lcase.collationName == "UTF8_LCASE") - assert(!utf8Lcase.supportsBinaryEquality) + assert(!utf8Lcase.isUtf8BinaryType) assert(UNICODE_COLLATION_ID == (1 << 29)) val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(!unicode.supportsBinaryEquality) + assert(!unicode.isUtf8BinaryType) assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") - assert(!unicodeCi.supportsBinaryEquality) + assert(!unicodeCi.isUtf8BinaryType) } test("UTF8_BINARY and ICU root locale collation names") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 7128190902550..3a667f370428e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -415,7 +415,7 @@ abstract class HashExpression[E] extends Expression { protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { + if (stringType.supportsBinaryEquality) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" @@ -566,7 +566,7 @@ abstract class InterpretedHashFunction { hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) case s: UTF8String => val st = dataType.asInstanceOf[StringType] - if (st.supportsBinaryEquality && !st.usesTrimCollation) { + if (st.supportsBinaryEquality) { hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) } else { val stringHash = CollationFactory @@ -817,7 +817,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { + if (stringType.supportsBinaryEquality) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 40b8bccafaad2..118dd92c3ed54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -207,7 +207,7 @@ object UnsafeRowUtils { def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { case st: StringType => val collation = CollationFactory.fetchCollation(st.collationId) - (!collation.supportsBinaryEquality || collation.supportsSpaceTrimming) + (!collation.supportsBinaryEquality) case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 6f3890cafd2ac..92ef24bb8ec63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -636,7 +636,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(murmur3Hash1, interpretedHash1) checkEvaluation(murmur3Hash2, interpretedHash2) - if (CollationFactory.fetchCollation(collation).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collation).isUtf8BinaryType) { assert(interpretedHash1 != interpretedHash2) } else { assert(interpretedHash1 == interpretedHash2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 3b1f349520f39..19a36483abe6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,9 +173,9 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation => + case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") - case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation => + case st: StringType if !st.supportsBinaryEquality => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") case CalendarIntervalType => hashInt(s"$input.hashCode()") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index e12c2838b88ab..25e1197bea4a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1127,7 +1127,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).supportsBinaryEquality + CollationFactory.fetchCollation(collation).isUtf8BinaryType test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1352,7 +1352,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: HashJoin => b.leftKeys.head }.head.isInstanceOf[CollationKey]) @@ -1409,7 +1409,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. @@ -1470,7 +1470,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function. @@ -1529,7 +1529,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1588,7 +1588,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) From 5715d5c7aac3170f3a3d2a811b85918ea195b310 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Tue, 15 Oct 2024 10:57:53 +0200 Subject: [PATCH 13/14] nit fixes. --- .../spark/sql/catalyst/util/CollationAwareUTF8String.java | 4 ++-- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index d67697eaea38b..fb610a5d96f17 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1363,9 +1363,9 @@ public static UTF8String trimRight( public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, final int limit, final int collationId) { - if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return input.split(delim, limit); - } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { + } else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { return lowercaseSplitSQL(input, delim, limit); } else { return icuSplitSQL(input, delim, limit, collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 887e554fcd246..bdb971663c380 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -153,7 +153,6 @@ public static class Collation { * expressions, as this particular collation is not supported by the external ICU library. */ public final boolean supportsLowercaseEquality; - /** * Support for Space Trimming implies that that based on specifier (for now only right trim) * leading, trailing or both spaces are removed from the input string before comparison. @@ -197,7 +196,7 @@ public Collation( this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType; this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType; // No Collation can simultaneously support binary equality and lowercase equality - assert(!isUtf8BinaryType || !isUtf8LcaseType); + assert(!supportsBinaryEquality || !supportsLowercaseEquality); assert(SUPPORTED_PROVIDERS.contains(provider)); } From 8ead603ee3cafb322f5cc4c6e6a228c7f93261e8 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Tue, 15 Oct 2024 12:42:29 +0200 Subject: [PATCH 14/14] revert unnecessary changes. --- .../spark/sql/catalyst/util/CollationAwareUTF8String.java | 4 ++-- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index fb610a5d96f17..d67697eaea38b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1363,9 +1363,9 @@ public static UTF8String trimRight( public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, final int limit, final int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { return input.split(delim, limit); - } else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { return lowercaseSplitSQL(input, delim, limit); } else { return icuSplitSQL(input, delim, limit, collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index bdb971663c380..50bb93465921e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -153,6 +153,7 @@ public static class Collation { * expressions, as this particular collation is not supported by the external ICU library. */ public final boolean supportsLowercaseEquality; + /** * Support for Space Trimming implies that that based on specifier (for now only right trim) * leading, trailing or both spaces are removed from the input string before comparison.