Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47411][SQL] Support StringInstr & FindInSet functions to work with collated strings #45643

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a4d3592
Add support for instr and unit test in CollationStringExpressionsSuit…
miland-db Mar 21, 2024
eb2d7c5
Correct code style
miland-db Mar 21, 2024
9340831
Remove blank line from CollationStringExpressionsSuite.scala
miland-db Mar 21, 2024
465e814
Correct comment indentation
miland-db Mar 21, 2024
f3f30d8
Add unit tests for INSTR operation
miland-db Mar 22, 2024
9cb92d3
Add doGenCode for FindInSet
miland-db Mar 22, 2024
834be70
Rewrite unit tests for INSTR and FIND_IN_SET
miland-db Mar 22, 2024
db2453a
Correct return value when substr is not found in INSTR method
miland-db Mar 22, 2024
91b648a
Update unit tests for StringInStr and FindInSet
miland-db Mar 25, 2024
1062521
Remove tests on non-explicit default collation
miland-db Mar 25, 2024
42700c7
Merge branch 'apache:master' into miland-db/substr-functions
miland-db Mar 25, 2024
427ea25
Improve signature of testInStr
miland-db Mar 26, 2024
546b3b0
Merge branch 'master' into miland-db/substr-functions
miland-db Mar 26, 2024
108d707
Remove E2E test for collation mismatch. This will be added in Implici…
miland-db Mar 26, 2024
822ecd2
Resolve merge problems with master
miland-db Mar 26, 2024
f730d05
Improve scala style
miland-db Mar 26, 2024
de7b591
Solve whitespace scala style problem
miland-db Mar 27, 2024
f0ee8fd
Add lazy val collationId
miland-db Apr 1, 2024
4ac6885
Remove repeated code
miland-db Apr 1, 2024
b931333
Improve test format
miland-db Apr 1, 2024
0fd51d5
Improve indexOf method
miland-db Apr 1, 2024
28fa7f0
Remove checks in return statement of collatedIndexOf method
miland-db Apr 1, 2024
bab96ac
Merge branch 'master' into substr-functions
miland-db Apr 2, 2024
4666aff
Add branch for collated findInSet
miland-db Apr 2, 2024
ca8a37c
Add branch for collation check in StringInstr
miland-db Apr 2, 2024
4ffab78
Improve naming of collation aware methods
miland-db Apr 2, 2024
037d6be
Improve java style
miland-db Apr 3, 2024
0a22909
Improve collationAwareIndexOf performance
miland-db Apr 3, 2024
b3be85d
Fix indentation
miland-db Apr 3, 2024
c4c0fe7
Add more tests for instr
miland-db Apr 3, 2024
5b29f76
Add more tests
miland-db Apr 3, 2024
877828e
Remove collation match type checks
miland-db Apr 4, 2024
038a071
Merge branch 'master' into substr-functions
miland-db Apr 4, 2024
b35a8ac
Merge with the latest master
miland-db Apr 4, 2024
8b06014
Remove checkInputDataTypes
miland-db Apr 4, 2024
2c454af
Merge branch 'master' into substr-functions
miland-db Apr 11, 2024
4dbc26e
Merge branch 'master' into substr-functions
miland-db Apr 11, 2024
fbd1c00
Refactor code and move it to CollationSupport
miland-db Apr 12, 2024
960af54
Improve codegen and run tests
miland-db Apr 12, 2024
05cd6c4
Unify collationAwareIndexOf for return value to have same semantics a…
miland-db Apr 12, 2024
053efa0
Break line at 100 chars
miland-db Apr 13, 2024
c65d68e
Add new version of getStringSearch
miland-db Apr 15, 2024
ae33a38
Rename StringInstr params and class in CollationSupport
miland-db Apr 15, 2024
be1b52c
Go from nullSafeCodeGen to defineCodeGen
miland-db Apr 15, 2024
5894d2f
Refactor testing
miland-db Apr 15, 2024
75dc0bd
Remove empty lines
miland-db Apr 15, 2024
c712b4b
Improve CollationAware indexOf to have the same semantics as UTF8Stri…
miland-db Apr 15, 2024
3c37f35
Add new e2e test
miland-db Apr 15, 2024
cd860b9
Revert unused import deletion
miland-db Apr 15, 2024
3fa2502
Fix codegen
miland-db Apr 16, 2024
b12f176
Merge branch 'master' into substr-functions
miland-db Apr 17, 2024
b35d718
Remove unused import
miland-db Apr 17, 2024
1ee5ad6
Add new tests
miland-db Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,21 @@ public static StringSearch getStringSearch(
final UTF8String targetUTF8String,
final UTF8String patternUTF8String,
final int collationId) {
String pattern = patternUTF8String.toString();
CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString());
return getStringSearch(targetUTF8String.toString(), patternUTF8String.toString(), collationId);
}

/**
* Returns a StringSearch object for the given pattern and target strings, under collation
* rules corresponding to the given collationId. The external ICU library StringSearch object can
* be used to find occurrences of the pattern in the target string, while respecting collation.
*/
public static StringSearch getStringSearch(
final String targetString,
final String patternString,
final int collationId) {
CharacterIterator target = new StringCharacterIterator(targetString);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
return new StringSearch(pattern, target, (RuleBasedCollator) collator);
return new StringSearch(patternString, target, (RuleBasedCollator) collator);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,76 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
}
}

public static class FindInSet {
public static int exec(final UTF8String word, final UTF8String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(word, set);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(word, set);
} else {
return execICU(word, set, collationId);
}
}
public static String genCode(final String word, final String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.FindInSet.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", word, set);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", word, set);
} else {
return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId);
}
}
public static int execBinary(final UTF8String word, final UTF8String set) {
return set.findInSet(word);
}
public static int execLowercase(final UTF8String word, final UTF8String set) {
return set.toLowerCase().findInSet(word.toLowerCase());
}
public static int execICU(final UTF8String word, final UTF8String set,
final int collationId) {
return CollationAwareUTF8String.findInSet(word, set, collationId);
}
}

public static class StringInstr {
public static int exec(final UTF8String string, final UTF8String substring,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(string, substring);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(string, substring);
} else {
return execICU(string, substring, collationId);
}
}
public static String genCode(final String string, final String substring, final int start,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.StringInstr.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s, %d)", string, substring, start);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start);
} else {
return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId);
}
}
public static int execBinary(final UTF8String string, final UTF8String substring) {
return string.indexOf(substring, 0);
}
public static int execLowercase(final UTF8String string, final UTF8String substring) {
return string.toLowerCase().indexOf(substring.toLowerCase(), 0);
}
public static int execICU(final UTF8String string, final UTF8String substring,
final int collationId) {
return Math.max(CollationAwareUTF8String.indexOf(string, substring, 0, collationId), 0);
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}
}

// TODO: Add more collation-aware string expressions.

/**
Expand Down Expand Up @@ -169,6 +239,48 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern
pos, pos + pattern.numChars()), pattern, collationId).last() == 0;
}

private static int findInSet(final UTF8String match, final UTF8String set, int collationId) {
if (match.contains(UTF8String.fromString(","))) {
return 0;
}

String setString = set.toString();
miland-db marked this conversation as resolved.
Show resolved Hide resolved
StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(),
collationId);

int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if (isValidStart && isValidEnd) {
int pos = 0;
for (int i = 0; i < setString.length() && i < wordStart; i++) {
if (setString.charAt(i) == ',') {
pos++;
}
}

return pos + 1;
}
}

return 0;
}

private static int indexOf(final UTF8String target, final UTF8String pattern,
final int start, final int collationId) {
if (pattern.numBytes() == 0) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
stringSearch.setIndex(start);

return stringSearch.next();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,86 @@ public void testEndsWith() throws SparkException {
assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false);
}

private void assertStringInstr(String string, String substring, String collationName,
miland-db marked this conversation as resolved.
Show resolved Hide resolved
Integer value) throws SparkException {
UTF8String str = UTF8String.fromString(string);
UTF8String substr = UTF8String.fromString(substring);
int collationId = CollationFactory.collationNameToId(collationName);

assertEquals(CollationSupport.StringInstr.exec(str, substr, collationId) + 1, value);
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
public void testStringInstr() throws SparkException {
assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0);
assertStringInstr("aaaDs", "de", "UTF8_BINARY", 0);
assertStringInstr("aaads", "ds", "UTF8_BINARY", 4);
assertStringInstr("xxxx", "", "UTF8_BINARY", 1);
assertStringInstr("", "xxxx", "UTF8_BINARY", 0);
assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY", 5);
assertStringInstr("test大千世界X大千世界", "界X", "UTF8_BINARY", 8);
assertStringInstr("aaads", "Aa", "UTF8_BINARY_LCASE", 1);
assertStringInstr("aaaDs", "de", "UTF8_BINARY_LCASE", 0);
assertStringInstr("aaaDs", "ds", "UTF8_BINARY_LCASE", 4);
assertStringInstr("xxxx", "", "UTF8_BINARY_LCASE", 1);
assertStringInstr("", "xxxx", "UTF8_BINARY_LCASE", 0);
assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY_LCASE", 5);
assertStringInstr("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8);
assertStringInstr("aaads", "Aa", "UNICODE", 0);
assertStringInstr("aaads", "aa", "UNICODE", 1);
assertStringInstr("aaads", "de", "UNICODE", 0);
assertStringInstr("xxxx", "", "UNICODE", 1);
assertStringInstr("", "xxxx", "UNICODE", 0);
assertStringInstr("test大千世界X大千世界", "界x", "UNICODE", 0);
assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8);
assertStringInstr("aaads", "AD", "UNICODE_CI", 3);
assertStringInstr("aaads", "dS", "UNICODE_CI", 4);
assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8);
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

//word: String, set: String, collationId: Integer, expected: Integer
private void assertFindInSet(String word, String set, String collationName,
Integer value) throws SparkException {
UTF8String w = UTF8String.fromString(word);
UTF8String s = UTF8String.fromString(set);
int collationId = CollationFactory.collationNameToId(collationName);

assertEquals(CollationSupport.FindInSet.exec(w, s, collationId), value);
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
public void testFindInSet() throws SparkException {
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1);
assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("c", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4);
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3);
assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 1);
assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("XX", "xx", "UTF8_BINARY_LCASE", 1);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_BINARY_LCASE", 4);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3);
assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("xx", "xx", "UNICODE", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4);
assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5);
assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("XX", "xx", "UNICODE_CI", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4);
assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5);
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

// TODO: Test more collation-aware string expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -974,15 +974,19 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override protected def nullSafeEval(word: Any, set: Any): Any =
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
override protected def nullSafeEval(word: Any, set: Any): Any = {
CollationSupport.FindInSet.
exec(word.asInstanceOf[UTF8String], set.asInstanceOf[UTF8String], collationId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (word, set) =>
s"${ev.value} = $set.findInSet($word);"
)
defineCodeGen(ctx, ev, (word, set) => CollationSupport.FindInSet.
genCode(word, set, collationId))
}

override def dataType: DataType = IntegerType
Expand Down Expand Up @@ -1346,20 +1350,24 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non
case class StringInstr(str: Expression, substr: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

override def left: Expression = str
override def right: Expression = substr
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override def nullSafeEval(string: Any, sub: Any): Any = {
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1
CollationSupport.StringInstr.
exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], collationId) + 1
}

override def prettyName: String = "instr"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (l, r) =>
s"($l).indexOf($r, 0) + 1")
defineCodeGen(ctx, ev, (string, substring) =>
CollationSupport.StringInstr.genCode(string, substring, 0, collationId) + " + 1")
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

package org.apache.spark.sql

import scala.collection.immutable.Seq

miland-db marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{BooleanType, StringType}
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}

class CollationStringExpressionsSuite
extends QueryTest
Expand Down Expand Up @@ -96,6 +94,63 @@ class CollationStringExpressionsSuite
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support StringInStr string expression with collation") {
case class StringInStrTestCase[R](string: String, substring: String, c: String, result: R)
val testCases = Seq(
// scalastyle:off
StringInStrTestCase("test大千世界X大千世界", "大千", "UTF8_BINARY", 5),
StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8),
StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0),
StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8)
// scalastyle:on
)
testCases.foreach(t => {
val query = s"SELECT instr(collate('${t.string}','${t.c}')," +
s"collate('${t.substring}','${t.c}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
// Implicit casting
checkAnswer(sql(s"SELECT instr(collate('${t.string}','${t.c}')," +
s"'${t.substring}')"), Row(t.result))
checkAnswer(sql(s"SELECT instr('${t.string}'," +
s"collate('${t.substring}','${t.c}'))"), Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql(s"SELECT instr(collate('aaads','UTF8_BINARY'), collate('Aa','UTF8_BINARY_LCASE'))")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support FindInSet string expression with collation") {
case class FindInSetTestCase[R](word: String, set: String, c: String, result: R)
val testCases = Seq(
FindInSetTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0),
FindInSetTestCase("C", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4),
FindInSetTestCase("d,ef", "abc,b,ab,c,def", "UNICODE", 0),
FindInSetTestCase("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5)
)
testCases.foreach(t => {
val query = s"SELECT find_in_set(collate('${t.word}', '${t.c}')," +
s"collate('${t.set}', '${t.c}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
// Implicit casting
checkAnswer(sql(s"SELECT find_in_set(collate('${t.word}', '${t.c}')," +
s"'${t.set}')"), Row(t.result))
checkAnswer(sql(s"SELECT find_in_set('${t.word}'," +
s"collate('${t.set}', '${t.c}'))"), Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql(s"SELECT find_in_set(collate('AB','UTF8_BINARY')," +
s"collate('ab,xyz,fgh','UTF8_BINARY_LCASE'))")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support StartsWith string expression with collation") {
// Supported collations
case class StartsWithTestCase[R](l: String, r: String, c: String, result: R)
Expand Down