Skip to content

Commit

Permalink
[SPARK-48441][SQL] Fix StringTrim behaviour for non-UTF8_BINARY colla…
Browse files Browse the repository at this point in the history
…tions

### What changes were proposed in this pull request?
String searching in UTF8_LCASE now works on character-level, rather than on byte-level. For example: `ltrim("İ", "i")` now returns `"İ"`, because there exist **no characters** in `"İ"`, starting from the left, such that lowercased version of those characters are equal to `"i"`. Note, however, that there is a byte subsequence of `"İ"` such that lowercased version of that UTF-8 byte sequence equals to `"i"` (so the new behaviour is different than the old behaviour).

Also, translation for ICU collations works by repeatedly trimming the longest possible substring that matches a character in the trim string, starting from the left side of the input string, until trimming is done.

### Why are the changes needed?
Fix functions that give unusable results due to one-to-many case mapping when performing string search under UTF8_LCASE (see example above).

### Does this PR introduce _any_ user-facing change?
Yes, behaviour of `trim*` expressions is changed for collated strings for edge cases with one-to-many case mapping.

### How was this patch tested?
New unit tests in `CollationSupportSuite` and new e2e sql tests in `CollationStringExpressionsSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46762 from uros-db/alter-trim.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and attilapiros committed Oct 4, 2024
1 parent 7723b81 commit 24cb7ae
Show file tree
Hide file tree
Showing 5 changed files with 922 additions and 237 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;

Expand Down Expand Up @@ -841,117 +842,268 @@ public static UTF8String translate(final UTF8String input,
return UTF8String.fromString(sb.toString());
}

/**
* Trims the `srcString` string from both ends of the string using the specified `trimString`
* characters, with respect to the UTF8_LCASE collation. String trimming is performed by
* first trimming the left side of the string, and then trimming the right side of the string.
* The method returns the trimmed string. If the `trimString` is null, the method returns null.
*
* @param srcString the input string to be trimmed from both ends of the string
* @param trimString the trim string characters to trim
* @return the trimmed string (for UTF8_LCASE collation)
*/
public static UTF8String lowercaseTrim(
final UTF8String srcString,
final UTF8String trimString) {
// Matching UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}
return lowercaseTrimRight(lowercaseTrimLeft(srcString, trimString), trimString);
}

UTF8String leftTrimmed = lowercaseTrimLeft(srcString, trimString);
return lowercaseTrimRight(leftTrimmed, trimString);
/**
* Trims the `srcString` string from both ends of the string using the specified `trimString`
* characters, with respect to all ICU collations in Spark. String trimming is performed by
* first trimming the left side of the string, and then trimming the right side of the string.
* The method returns the trimmed string. If the `trimString` is null, the method returns null.
*
* @param srcString the input string to be trimmed from both ends of the string
* @param trimString the trim string characters to trim
* @param collationId the collation ID to use for string trimming
* @return the trimmed string (for ICU collations)
*/
public static UTF8String trim(
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
return trimRight(trimLeft(srcString, trimString, collationId), trimString, collationId);
}

/**
* Trims the `srcString` string from the left side using the specified `trimString` characters,
* with respect to the UTF8_LCASE collation. For UTF8_LCASE, the method first creates a hash
* set of lowercased code points in `trimString`, and then iterates over the `srcString` from
* the left side, until reaching a character whose lowercased code point is not in the hash set.
* Finally, the method returns the substring from that position to the end of `srcString`.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the left end of the string
* @param trimString the trim string characters to trim
* @return the trimmed string (for UTF8_LCASE collation)
*/
public static UTF8String lowercaseTrimLeft(
final UTF8String srcString,
final UTF8String trimString) {
// Matching UTF8String behavior for null `trimString`.
// Matching the default UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}

// The searching byte position in the srcString.
int searchIdx = 0;
// The byte position of a first non-matching character in the srcString.
int trimByteIdx = 0;
// Number of bytes in srcString.
int numBytes = srcString.numBytes();
// Convert trimString to lowercase, so it can be searched properly.
UTF8String lowercaseTrimString = trimString.toLowerCase();

while (searchIdx < numBytes) {
UTF8String searchChar = srcString.copyUTF8String(
searchIdx,
searchIdx + UTF8String.numBytesForFirstByte(srcString.getByte(searchIdx)) - 1);
int searchCharBytes = searchChar.numBytes();

// Try to find the matching for the searchChar in the trimString.
if (lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
trimByteIdx += searchCharBytes;
searchIdx += searchCharBytes;
} else {
// No matching, exit the search.
// Create a hash set of lowercased code points for all characters of `trimString`.
HashSet<Integer> trimChars = new HashSet<>();
Iterator<Integer> trimIter = trimString.codePointIterator();
while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next()));

// Iterate over `srcString` from the left to find the first character that is not in the set.
int searchIndex = 0, codePoint;
Iterator<Integer> srcIter = srcString.codePointIterator();
while (srcIter.hasNext()) {
codePoint = getLowercaseCodePoint(srcIter.next());
// Special handling for Turkish dotted uppercase letter I.
if (codePoint == CODE_POINT_LOWERCASE_I && srcIter.hasNext() &&
trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) {
int nextCodePoint = getLowercaseCodePoint(srcIter.next());
if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint))
|| nextCodePoint == CODE_POINT_COMBINING_DOT) {
searchIndex += 2;
}
else {
if (trimChars.contains(codePoint)) ++searchIndex;
break;
}
} else if (trimChars.contains(codePoint)) {
++searchIndex;
}
else {
break;
}
}

if (searchIdx == 0) {
// Nothing trimmed - return original string (not converted to lowercase).
return srcString;
// Return the substring from that position to the end of the string.
return searchIndex == 0 ? srcString : srcString.substring(searchIndex, srcString.numChars());
}

/**
* Trims the `srcString` string from the left side using the specified `trimString` characters,
* with respect to ICU collations. For these collations, the method iterates over `srcString`
* from left to right, and repeatedly skips the longest possible substring that matches any
* character in `trimString`, until reaching a character that is not found in `trimString`.
* Finally, the method returns the substring from that position to the end of `srcString`.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the left end of the string
* @param trimString the trim string characters to trim
* @param collationId the collation ID to use for string trimming
* @return the trimmed string (for ICU collations)
*/
public static UTF8String trimLeft(
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
// Short-circuit for base cases.
if (trimString == null) return null;
if (srcString.numBytes() == 0) return srcString;

// Create an array of Strings for all characters of `trimString`.
Map<Integer, String> trimChars = new HashMap<>();
Iterator<Integer> trimIter = trimString.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
while (trimIter.hasNext()) {
int codePoint = trimIter.next();
trimChars.putIfAbsent(codePoint, String.valueOf((char) codePoint));
}
if (trimByteIdx >= numBytes) {
// Everything trimmed.
return UTF8String.EMPTY_UTF8;

// Iterate over srcString from the left and find the first character that is not in trimChars.
String src = srcString.toValidString();
CharacterIterator target = new StringCharacterIterator(src);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
int charIndex = 0, longestMatchLen;
while (charIndex < src.length()) {
longestMatchLen = 0;
for (String trim : trimChars.values()) {
StringSearch stringSearch = new StringSearch(trim, target, (RuleBasedCollator) collator);
stringSearch.setIndex(charIndex);
int matchIndex = stringSearch.next();
if (matchIndex == charIndex) {
int matchLen = stringSearch.getMatchLength();
if (matchLen > longestMatchLen) {
longestMatchLen = matchLen;
}
}
}
if (longestMatchLen == 0) break;
else charIndex += longestMatchLen;
}
return srcString.copyUTF8String(trimByteIdx, numBytes - 1);

// Return the substring from the calculated position until the end of the string.
return UTF8String.fromString(src.substring(charIndex));
}

/**
* Trims the `srcString` string from the right side using the specified `trimString` characters,
* with respect to the UTF8_LCASE collation. For UTF8_LCASE, the method first creates a hash
* set of lowercased code points in `trimString`, and then iterates over the `srcString` from
* the right side, until reaching a character whose lowercased code point is not in the hash set.
* Finally, the method returns the substring from the start of `srcString` until that position.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the right end of the string
* @param trimString the trim string characters to trim
* @return the trimmed string (for UTF8_LCASE collation)
*/
public static UTF8String lowercaseTrimRight(
final UTF8String srcString,
final UTF8String trimString) {
// Matching UTF8String behavior for null `trimString`.
// Matching the default UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}

// Number of bytes iterated from the srcString.
int byteIdx = 0;
// Number of characters iterated from the srcString.
int numChars = 0;
// Number of bytes in srcString.
int numBytes = srcString.numBytes();
// Array of character length for the srcString.
int[] stringCharLen = new int[numBytes];
// Array of the first byte position for each character in the srcString.
int[] stringCharPos = new int[numBytes];
// Convert trimString to lowercase, so it can be searched properly.
UTF8String lowercaseTrimString = trimString.toLowerCase();

// Build the position and length array.
while (byteIdx < numBytes) {
stringCharPos[numChars] = byteIdx;
stringCharLen[numChars] = UTF8String.numBytesForFirstByte(srcString.getByte(byteIdx));
byteIdx += stringCharLen[numChars];
numChars++;
}

// Index trimEnd points to the first no matching byte position from the right side of
// the source string.
int trimByteIdx = numBytes - 1;

while (numChars > 0) {
UTF8String searchChar = srcString.copyUTF8String(
stringCharPos[numChars - 1],
stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1);

if(lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
trimByteIdx -= stringCharLen[numChars - 1];
numChars--;
} else {
// Create a hash set of lowercased code points for all characters of `trimString`.
HashSet<Integer> trimChars = new HashSet<>();
Iterator<Integer> trimIter = trimString.codePointIterator();
while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next()));

// Iterate over `srcString` from the right to find the first character that is not in the set.
int searchIndex = srcString.numChars(), codePoint;
Iterator<Integer> srcIter = srcString.reverseCodePointIterator();
while (srcIter.hasNext()) {
codePoint = getLowercaseCodePoint(srcIter.next());
// Special handling for Turkish dotted uppercase letter I.
if (codePoint == CODE_POINT_COMBINING_DOT && srcIter.hasNext() &&
trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) {
int nextCodePoint = getLowercaseCodePoint(srcIter.next());
if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint))
|| nextCodePoint == CODE_POINT_LOWERCASE_I) {
searchIndex -= 2;
}
else {
if (trimChars.contains(codePoint)) --searchIndex;
break;
}
} else if (trimChars.contains(codePoint)) {
--searchIndex;
}
else {
break;
}
}

if (trimByteIdx == numBytes - 1) {
// Nothing trimmed.
return srcString;
// Return the substring from the start of the string to the calculated position.
return searchIndex == srcString.numChars() ? srcString : srcString.substring(0, searchIndex);
}

/**
* Trims the `srcString` string from the right side using the specified `trimString` characters,
* with respect to ICU collations. For these collations, the method iterates over `srcString`
* from right to left, and repeatedly skips the longest possible substring that matches any
* character in `trimString`, until reaching a character that is not found in `trimString`.
* Finally, the method returns the substring from the start of `srcString` until that position.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the right end of the string
* @param trimString the trim string characters to trim
* @param collationId the collation ID to use for string trimming
* @return the trimmed string (for ICU collations)
*/
public static UTF8String trimRight(
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
// Short-circuit for base cases.
if (trimString == null) return null;
if (srcString.numBytes() == 0) return srcString;

// Create an array of Strings for all characters of `trimString`.
Map<Integer, String> trimChars = new HashMap<>();
Iterator<Integer> trimIter = trimString.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
while (trimIter.hasNext()) {
int codePoint = trimIter.next();
trimChars.putIfAbsent(codePoint, String.valueOf((char) codePoint));
}
if (trimByteIdx < 0) {
// Everything trimmed.
return UTF8String.EMPTY_UTF8;

// Iterate over srcString from the left and find the first character that is not in trimChars.
String src = srcString.toValidString();
CharacterIterator target = new StringCharacterIterator(src);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
int charIndex = src.length(), longestMatchLen;
while (charIndex >= 0) {
longestMatchLen = 0;
for (String trim : trimChars.values()) {
StringSearch stringSearch = new StringSearch(trim, target, (RuleBasedCollator) collator);
// Note: stringSearch.previous() is NOT consistent with stringSearch.next()!
// Example: StringSearch("İ", "i\\u0307İi\\u0307İi\\u0307İ", "UNICODE_CI")
// stringSearch.next() gives: [0, 2, 3, 5, 6, 8].
// stringSearch.previous() gives: [8, 6, 3, 0].
// Since 1 character can map to at most 3 characters in Unicode, we can begin the search
// from character position: `charIndex` - 3, and use `next()` to find the longest match.
stringSearch.setIndex(Math.max(charIndex - 3, 0));
int matchIndex = stringSearch.next();
int matchLen = stringSearch.getMatchLength();
while (matchIndex != StringSearch.DONE && matchIndex < charIndex - matchLen) {
matchIndex = stringSearch.next();
matchLen = stringSearch.getMatchLength();
}
if (matchIndex == charIndex - matchLen) {
if (matchLen > longestMatchLen) {
longestMatchLen = matchLen;
}
}
}
if (longestMatchLen == 0) break;
else charIndex -= longestMatchLen;
}
return srcString.copyUTF8String(0, trimByteIdx);

// Return the substring from the start of the string until that position.
return UTF8String.fromString(src.substring(0, charIndex));
}

// TODO: Add more collation-aware UTF8String operations here.
Expand Down
Loading

0 comments on commit 24cb7ae

Please sign in to comment.