diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 49d3088f8a2f0..32875c0d96396 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.function.Function; +import java.util.Iterator; import java.util.Map; import java.util.regex.Pattern; @@ -405,6 +406,105 @@ public boolean isValid() { return true; } + /** + * Code point iteration over a UTF8String can be done using one of two modes: + * 1. CODE_POINT_ITERATOR_ASSUME_VALID: The caller ensures that the UTF8String is valid and does + * not contain any invalid UTF-8 byte sequences. In this case, the code point iterator will + * return the code points in the current string one by one, as integers. If an invalid code + * point is found within the string during iteration, an exception will be thrown. This mode + * is more dangerous, but faster - since no scan is needed prior to beginning iteration. + * 2. CODE_POINT_ITERATOR_MAKE_VALID: The caller does not ensure that the UTF8String is valid, + * but instead expects the code point iterator to first check whether the current UTF8String + * is valid, then perform the invalid byte sequence replacement using `makeValid`, and finally + * begin the code point iteration over the resulting valid UTF8String. However, the original + * UTF8String stays unchanged. This mode is safer, but slower - due to initial validation. + * The default mode is CODE_POINT_ITERATOR_ASSUME_VALID. + */ + public enum CodePointIteratorType { + CODE_POINT_ITERATOR_ASSUME_VALID, // USE ONLY WITH VALID STRINGS + CODE_POINT_ITERATOR_MAKE_VALID + } + + /** + * Returns a code point iterator for this UTF8String. + */ + public Iterator codePointIterator() { + return codePointIterator(CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID); + } + + public Iterator codePointIterator(CodePointIteratorType iteratorMode) { + if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) { + return makeValid().codePointIterator(); + } + return new CodePointIterator(); + } + + /** + * Code point iterator implementation for the UTF8String class. The iterator will return code + * points in the current string one by one, as integers. However, the code point iterator is only + * guaranteed to work if the current UTF8String does not contain any invalid UTF-8 byte sequences. + * If the current string contains any invalid UTF-8 byte sequences, exceptions will be thrown. + */ + private class CodePointIterator implements Iterator { + // Byte index used to iterate over the current UTF8String. + private int byteIndex = 0; + + @Override + public boolean hasNext() { + return byteIndex < numBytes; + } + + @Override + public Integer next() { + if (!hasNext()) { + throw new IndexOutOfBoundsException(); + } + int codePoint = codePointFrom(byteIndex); + byteIndex += numBytesForFirstByte(getByte(byteIndex)); + return codePoint; + } + } + + /** + * Reverse version of the code point iterator for this UTF8String, returns code points in the + * current string one by one, as integers, in reverse order. The logic is similar to the above. + */ + + public Iterator reverseCodePointIterator() { + return reverseCodePointIterator(CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID); + } + + public Iterator reverseCodePointIterator(CodePointIteratorType iteratorMode) { + if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) { + return makeValid().reverseCodePointIterator(); + } + return new ReverseCodePointIterator(); + } + + private class ReverseCodePointIterator implements Iterator { + private int byteIndex = numBytes - 1; + + @Override + public boolean hasNext() { + return byteIndex >= 0; + } + + @Override + public Integer next() { + if (!hasNext()) { + throw new IndexOutOfBoundsException(); + } + while (byteIndex > 0 && isContinuationByte(getByte(byteIndex))) { + --byteIndex; + } + return codePointFrom(byteIndex--); + } + + private boolean isContinuationByte(byte b) { + return (b & 0xC0) == 0x80; + } + } + /** * Returns a substring of this. * @param start the position of first code point @@ -477,10 +577,53 @@ public boolean contains(final UTF8String substring) { } /** - * Returns the byte at position `i`. + * Returns the byte at (byte) position `byteIndex`. If byte index is invalid, returns 0. + */ + public byte getByte(int byteIndex) { + return Platform.getByte(base, offset + byteIndex); + } + + /** + * Returns the code point at (char) position `charIndex`. If char index is invalid, throws + * exception. Note that this method is not efficient as it needs to traverse the UTF-8 string. + * If `byteIndex` of the first byte in the code point is known, use `codePointFrom` instead. + */ + public int getChar(int charIndex) { + if (charIndex < 0 || charIndex >= numChars()) { + throw new IndexOutOfBoundsException(); + } + int charCount = 0, byteCount = 0; + while (charCount < charIndex) { + byteCount += numBytesForFirstByte(getByte(byteCount)); + charCount += 1; + } + return codePointFrom(byteCount); + } + + /** + * Returns the code point starting from the byte at position `byteIndex`. + * If byte index is invalid, throws exception. */ - public byte getByte(int i) { - return Platform.getByte(base, offset + i); + public int codePointFrom(int byteIndex) { + if (byteIndex < 0 || byteIndex >= numBytes) { + throw new IndexOutOfBoundsException(); + } + byte b = getByte(byteIndex); + int numBytes = numBytesForFirstByte(b); + return switch (numBytes) { + case 1 -> + b & 0x7F; + case 2 -> + ((b & 0x1F) << 6) | (getByte(byteIndex + 1) & 0x3F); + case 3 -> + ((b & 0x0F) << 12) | ((getByte(byteIndex + 1) & 0x3F) << 6) | + (getByte(byteIndex + 2) & 0x3F); + case 4 -> + ((b & 0x07) << 18) | ((getByte(byteIndex + 1) & 0x3F) << 12) | + ((getByte(byteIndex + 2) & 0x3F) << 6) | (getByte(byteIndex + 3) & 0x3F); + default -> + throw new IllegalStateException("Error in UTF-8 code point"); + }; } public boolean matchAt(final UTF8String s, int pos) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 07793a24e5eed..d690da53c7c66 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -28,6 +28,7 @@ import org.apache.spark.unsafe.Platform; import org.junit.jupiter.api.Test; +import static org.apache.spark.unsafe.types.UTF8String.fromString; import static org.junit.jupiter.api.Assertions.*; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; @@ -1110,6 +1111,235 @@ public void isValid() { testIsValid("0x9C 0x76 0x17", "0xEF 0xBF 0xBD 0x76 0x17"); } + @Test + public void testGetByte() { + // Valid UTF-8 string + String validString = "abcde"; + UTF8String validUTF8String = fromString(validString); + // Valid byte index handling + for (int i = 0; i < validString.length(); ++i) { + assertEquals(validString.charAt(i), validUTF8String.getByte(i)); + } + // Invalid byte index handling + assertEquals(0, validUTF8String.getByte(-1)); + assertEquals(0, validUTF8String.getByte(validString.length())); + assertEquals(0, validUTF8String.getByte(validString.length() + 1)); + + // Invalid UTF-8 string + byte[] invalidString = new byte[] {(byte) 0x41, (byte) 0x42, (byte) 0x80}; + UTF8String invalidUTF8String = fromBytes(invalidString); + // Valid byte index handling + for (int i = 0; i < invalidString.length; ++i) { + assertEquals(invalidString[i], invalidUTF8String.getByte(i)); + } + // Invalid byte index handling + assertEquals(0, invalidUTF8String.getByte(-1)); + assertEquals(0, invalidUTF8String.getByte(invalidString.length)); + assertEquals(0, invalidUTF8String.getByte(invalidString.length + 1)); + } + + @Test + public void testGetChar() { + // Valid UTF-8 string + String str = "abcde"; + UTF8String s = fromString(str); + // Valid character index handling + for (int i = 0; i < str.length(); ++i) { + assertEquals(str.charAt(i), s.getChar(i)); + } + // Invalid character index handling + assertThrows(IndexOutOfBoundsException.class, () -> s.getChar(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s.getChar(str.length())); + assertThrows(IndexOutOfBoundsException.class, () -> s.getChar(str.length() + 1)); + + // Invalid UTF-8 string + byte[] invalidString = new byte[] {(byte) 0x41, (byte) 0x42, (byte) 0x80}; + UTF8String invalidUTF8String = fromBytes(invalidString); + // Valid byte index handling + for (int i = 0; i < invalidString.length; ++i) { + if (Character.isValidCodePoint(invalidString[i])) { + assertEquals(invalidString[i], invalidUTF8String.getChar(i)); + } else { + assertEquals(0, invalidUTF8String.getChar(i)); + } + } + // Invalid byte index handling + assertThrows(IndexOutOfBoundsException.class, () -> s.getChar(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s.getChar(str.length())); + assertThrows(IndexOutOfBoundsException.class, () -> s.getChar(str.length() + 1)); + } + + @Test + public void testCodePointFrom() { + // Valid UTF-8 string + String str = "abcde"; + UTF8String s = fromString(str); + // Valid character index handling + for (int i = 0; i < str.length(); ++i) { + assertEquals(str.charAt(i), s.codePointFrom(i)); + } + // Invalid character index handling + assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length())); + assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length() + 1)); + + // Invalid UTF-8 string + byte[] invalidString = new byte[] {(byte) 0x41, (byte) 0x42, (byte) 0x80}; + UTF8String invalidUTF8String = fromBytes(invalidString); + // Valid byte index handling + for (int i = 0; i < invalidString.length; ++i) { + if (Character.isValidCodePoint(invalidString[i])) { + assertEquals(invalidString[i], invalidUTF8String.codePointFrom(i)); + } else { + assertEquals(0, invalidUTF8String.codePointFrom(i)); + } + } + // Invalid byte index handling + assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length())); + assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length() + 1)); + } + + @Test + public void utf8StringCodePoints() { + String s = "aéह 日å!"; + UTF8String s0 = fromString(s); + for (int i = 0; i < s.length(); ++i) { + assertEquals(s.codePointAt(i), s0.getChar(i)); + } + + UTF8String s1 = fromBytes(new byte[] {0x41, (byte) 0xC3, (byte) 0xB1, (byte) 0xE2, + (byte) 0x82, (byte) 0xAC, (byte) 0xF0, (byte) 0x90, (byte) 0x8D, (byte) 0x88}); + // numBytesForFirstByte + assertEquals(1, UTF8String.numBytesForFirstByte(s1.getByte(0))); + assertEquals(2, UTF8String.numBytesForFirstByte(s1.getByte(1))); + assertEquals(3, UTF8String.numBytesForFirstByte(s1.getByte(3))); + assertEquals(4, UTF8String.numBytesForFirstByte(s1.getByte(6))); + // getByte + assertEquals((byte) 0x41, s1.getByte(0)); + assertEquals((byte) 0xC3, s1.getByte(1)); + assertEquals((byte) 0xE2, s1.getByte(3)); + assertEquals((byte) 0xF0, s1.getByte(6)); + // codePointFrom + assertEquals(0x41, s1.codePointFrom(0)); + assertEquals(0xF1, s1.codePointFrom(1)); + assertEquals(0x20AC, s1.codePointFrom(3)); + assertEquals(0x10348, s1.codePointFrom(6)); + assertThrows(IndexOutOfBoundsException.class, () -> s1.codePointFrom(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s1.codePointFrom(99)); + // getChar + assertEquals(0x41, s1.getChar(0)); + assertEquals(0xF1, s1.getChar(1)); + assertEquals(0x20AC, s1.getChar(2)); + assertEquals(0x10348, s1.getChar(3)); + assertThrows(IndexOutOfBoundsException.class, () -> s1.getChar(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s1.getChar(99)); + + UTF8String s2 = fromString("Añ€𐍈"); + // numBytesForFirstByte + assertEquals(1, UTF8String.numBytesForFirstByte(s2.getByte(0))); + assertEquals(2, UTF8String.numBytesForFirstByte(s2.getByte(1))); + assertEquals(3, UTF8String.numBytesForFirstByte(s2.getByte(3))); + assertEquals(4, UTF8String.numBytesForFirstByte(s2.getByte(6))); + // getByte + assertEquals((byte) 0x41, s2.getByte(0)); + assertEquals((byte) 0xC3, s2.getByte(1)); + assertEquals((byte) 0xE2, s2.getByte(3)); + assertEquals((byte) 0xF0, s2.getByte(6)); + // codePointFrom + assertEquals(0x41, s2.codePointFrom(0)); + assertEquals(0xF1, s2.codePointFrom(1)); + assertEquals(0x20AC, s2.codePointFrom(3)); + assertEquals(0x10348, s2.codePointFrom(6)); + assertThrows(IndexOutOfBoundsException.class, () -> s2.codePointFrom(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s2.codePointFrom(99)); + // getChar + assertEquals(0x41, s2.getChar(0)); + assertEquals(0xF1, s2.getChar(1)); + assertEquals(0x20AC, s2.getChar(2)); + assertEquals(0x10348, s2.getChar(3)); + assertThrows(IndexOutOfBoundsException.class, () -> s2.getChar(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s2.getChar(99)); + + UTF8String s3 = EMPTY_UTF8; + // codePointFrom + assertThrows(IndexOutOfBoundsException.class, () -> s3.codePointFrom(0)); + assertThrows(IndexOutOfBoundsException.class, () -> s3.codePointFrom(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s3.codePointFrom(99)); + // getChar + assertThrows(IndexOutOfBoundsException.class, () -> s3.getChar(0)); + assertThrows(IndexOutOfBoundsException.class, () -> s3.getChar(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> s3.getChar(99)); + } + + private void testCodePointIterator(UTF8String utf8String) { + CodePointIteratorType iteratorMode = utf8String.isValid() ? + CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID : + CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID; + Iterator iterator = utf8String.codePointIterator(iteratorMode); + for (int i = 0; i < utf8String.numChars(); ++i) { + assertTrue(iterator.hasNext()); + int codePoint = (utf8String.isValid() ? utf8String : utf8String.makeValid()).getChar(i); + assertEquals(codePoint, (int) iterator.next()); + } + assertFalse(iterator.hasNext()); + } + @Test + public void codePointIterator() { + // Valid UTF8 strings. + testCodePointIterator(fromString("")); + testCodePointIterator(fromString("abc")); + testCodePointIterator(fromString("a!2&^R")); + testCodePointIterator(fromString("aéह 日å!")); + testCodePointIterator(fromBytes(new byte[] {(byte) 0x41})); + testCodePointIterator(fromBytes(new byte[] {(byte) 0xC2, (byte) 0xA3})); + testCodePointIterator(fromBytes(new byte[] {(byte) 0xE2, (byte) 0x82, (byte) 0xAC})); + // Invalid UTF8 strings. + testCodePointIterator(fromBytes(new byte[] {(byte) 0xFF})); + testCodePointIterator(fromBytes(new byte[] {(byte) 0x80})); + testCodePointIterator(fromBytes(new byte[] {(byte) 0xC2, (byte) 0x80})); + testCodePointIterator(fromBytes(new byte[] {(byte) 0xE2, (byte) 0x82, (byte) 0x80})); + testCodePointIterator(fromBytes(new byte[] {(byte) 0x41, (byte) 0x80, (byte) 0x42})); + testCodePointIterator(fromBytes(new byte[] { + (byte) 0x41, (byte) 0xC2, (byte) 0x80, (byte) 0x42})); + testCodePointIterator(fromBytes(new byte[] { + (byte) 0x41, (byte) 0xE2, (byte) 0x82, (byte) 0x80, (byte) 0x42})); + } + + private void testReverseCodePointIterator(UTF8String utf8String) { + CodePointIteratorType iteratorMode = utf8String.isValid() ? + CodePointIteratorType.CODE_POINT_ITERATOR_ASSUME_VALID : + CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID; + Iterator iterator = utf8String.codePointIterator(iteratorMode); + for (int i = 0; i < utf8String.numChars(); ++i) { + assertTrue(iterator.hasNext()); + int codePoint = (utf8String.isValid() ? utf8String : utf8String.makeValid()).getChar(i); + assertEquals(codePoint, (int) iterator.next()); + } + assertFalse(iterator.hasNext()); + } + @Test + public void reverseCodePointIterator() { + // Valid UTF8 strings + testReverseCodePointIterator(fromString("")); + testReverseCodePointIterator(fromString("abc")); + testReverseCodePointIterator(fromString("a!2&^R")); + testReverseCodePointIterator(fromString("aéह 日å!")); + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0x41})); + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0xC2, (byte) 0xA3})); + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0xE2, (byte) 0x82, (byte) 0xAC})); + // Invalid UTF8 strings + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0xFF})); + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0x80})); + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0xC2, (byte) 0x80})); + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0xE2, (byte) 0x82, (byte) 0x80})); + testReverseCodePointIterator(fromBytes(new byte[] {(byte) 0x41, (byte) 0x80, (byte) 0x42})); + testReverseCodePointIterator(fromBytes(new byte[] { + (byte) 0x41, (byte) 0xC2, (byte) 0x80, (byte) 0x42})); + testReverseCodePointIterator(fromBytes(new byte[] { + (byte) 0x41, (byte) 0xE2, (byte) 0x82, (byte) 0x80, (byte) 0x42})); + } + @Test public void toBinaryString() { assertEquals(ZERO_UTF8, UTF8String.toBinaryString(0));