Skip to content

Commit

Permalink
[SPARK-47411][SQL] Support StringInstr & FindInSet functions to work …
Browse files Browse the repository at this point in the history
…with collated strings

### What changes were proposed in this pull request?
Extend built-in string functions to support non-binary, non-lowercase collation for: instr & find_in_set.

### Why are the changes needed?
Update collation support for built-in string functions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use COLLATE within arguments for built-in string functions INSTR and FIND_IN_SET in Spark SQL queries, using non-binary collations such as UNICODE_CI.

### How was this patch tested?
Unit tests for queries using "collate" (CollationSuite).

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#45643 from miland-db/miland-db/substr-functions.

Authored-by: Milan Dankovic <milan.dankovic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
miland-db authored and cloud-fan committed Apr 22, 2024
1 parent 61dc9d9 commit 256fc51
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,21 @@ public static StringSearch getStringSearch(
final UTF8String targetUTF8String,
final UTF8String patternUTF8String,
final int collationId) {
String pattern = patternUTF8String.toString();
CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString());
return getStringSearch(targetUTF8String.toString(), patternUTF8String.toString(), collationId);
}

/**
* Returns a StringSearch object for the given pattern and target strings, under collation
* rules corresponding to the given collationId. The external ICU library StringSearch object can
* be used to find occurrences of the pattern in the target string, while respecting collation.
*/
public static StringSearch getStringSearch(
final String targetString,
final String patternString,
final int collationId) {
CharacterIterator target = new StringCharacterIterator(targetString);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
return new StringSearch(pattern, target, (RuleBasedCollator) collator);
return new StringSearch(patternString, target, (RuleBasedCollator) collator);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,76 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
}
}

public static class FindInSet {
public static int exec(final UTF8String word, final UTF8String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(word, set);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(word, set);
} else {
return execICU(word, set, collationId);
}
}
public static String genCode(final String word, final String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.FindInSet.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", word, set);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", word, set);
} else {
return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId);
}
}
public static int execBinary(final UTF8String word, final UTF8String set) {
return set.findInSet(word);
}
public static int execLowercase(final UTF8String word, final UTF8String set) {
return set.toLowerCase().findInSet(word.toLowerCase());
}
public static int execICU(final UTF8String word, final UTF8String set,
final int collationId) {
return CollationAwareUTF8String.findInSet(word, set, collationId);
}
}

public static class StringInstr {
public static int exec(final UTF8String string, final UTF8String substring,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(string, substring);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(string, substring);
} else {
return execICU(string, substring, collationId);
}
}
public static String genCode(final String string, final String substring,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.StringInstr.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", string, substring);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", string, substring);
} else {
return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId);
}
}
public static int execBinary(final UTF8String string, final UTF8String substring) {
return string.indexOf(substring, 0);
}
public static int execLowercase(final UTF8String string, final UTF8String substring) {
return string.toLowerCase().indexOf(substring.toLowerCase(), 0);
}
public static int execICU(final UTF8String string, final UTF8String substring,
final int collationId) {
return CollationAwareUTF8String.indexOf(string, substring, 0, collationId);
}
}

// TODO: Add more collation-aware string expressions.

/**
Expand All @@ -164,6 +234,48 @@ public static boolean execICU(final UTF8String l, final UTF8String r,

private static class CollationAwareUTF8String {

private static int findInSet(final UTF8String match, final UTF8String set, int collationId) {
if (match.contains(UTF8String.fromString(","))) {
return 0;
}

String setString = set.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(),
collationId);

int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if (isValidStart && isValidEnd) {
int pos = 0;
for (int i = 0; i < setString.length() && i < wordStart; i++) {
if (setString.charAt(i) == ',') {
pos++;
}
}

return pos + 1;
}
}

return 0;
}

private static int indexOf(final UTF8String target, final UTF8String pattern,
final int start, final int collationId) {
if (pattern.numBytes() == 0) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
stringSearch.setIndex(start);

return stringSearch.next();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,88 @@ public void testEndsWith() throws SparkException {
assertEndsWith("The i̇o", "İo", "UNICODE_CI", true);
}

private void assertStringInstr(String string, String substring, String collationName,
Integer expected) throws SparkException {
UTF8String str = UTF8String.fromString(string);
UTF8String substr = UTF8String.fromString(substring);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected, CollationSupport.StringInstr.exec(str, substr, collationId) + 1);
}

@Test
public void testStringInstr() throws SparkException {
assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0);
assertStringInstr("aaaDs", "de", "UTF8_BINARY", 0);
assertStringInstr("aaads", "ds", "UTF8_BINARY", 4);
assertStringInstr("xxxx", "", "UTF8_BINARY", 1);
assertStringInstr("", "xxxx", "UTF8_BINARY", 0);
assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY", 5);
assertStringInstr("test大千世界X大千世界", "界X", "UTF8_BINARY", 8);
assertStringInstr("aaads", "Aa", "UTF8_BINARY_LCASE", 1);
assertStringInstr("aaaDs", "de", "UTF8_BINARY_LCASE", 0);
assertStringInstr("aaaDs", "ds", "UTF8_BINARY_LCASE", 4);
assertStringInstr("xxxx", "", "UTF8_BINARY_LCASE", 1);
assertStringInstr("", "xxxx", "UTF8_BINARY_LCASE", 0);
assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY_LCASE", 5);
assertStringInstr("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8);
assertStringInstr("aaads", "Aa", "UNICODE", 0);
assertStringInstr("aaads", "aa", "UNICODE", 1);
assertStringInstr("aaads", "de", "UNICODE", 0);
assertStringInstr("xxxx", "", "UNICODE", 1);
assertStringInstr("", "xxxx", "UNICODE", 0);
assertStringInstr("test大千世界X大千世界", "界x", "UNICODE", 0);
assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8);
assertStringInstr("aaads", "AD", "UNICODE_CI", 3);
assertStringInstr("aaads", "dS", "UNICODE_CI", 4);
assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0);
assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8);
assertStringInstr("abİo12", "i̇o", "UNICODE_CI", 3);
assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3);
}

private void assertFindInSet(String word, String set, String collationName,
Integer expected) throws SparkException {
UTF8String w = UTF8String.fromString(word);
UTF8String s = UTF8String.fromString(set);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected, CollationSupport.FindInSet.exec(w, s, collationId));
}

@Test
public void testFindInSet() throws SparkException {
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1);
assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("c", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4);
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3);
assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 1);
assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("XX", "xx", "UTF8_BINARY_LCASE", 1);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_BINARY_LCASE", 4);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3);
assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("xx", "xx", "UNICODE", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4);
assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5);
assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("XX", "xx", "UNICODE_CI", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4);
assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("i̇o", "ab,İo,12", "UNICODE_CI", 2);
assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2);
}

// TODO: Test more collation-aware string expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,15 +978,19 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

override protected def nullSafeEval(word: Any, set: Any): Any =
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override protected def nullSafeEval(word: Any, set: Any): Any = {
CollationSupport.FindInSet.
exec(word.asInstanceOf[UTF8String], set.asInstanceOf[UTF8String], collationId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (word, set) =>
s"${ev.value} = $set.findInSet($word);"
)
defineCodeGen(ctx, ev, (word, set) => CollationSupport.FindInSet.
genCode(word, set, collationId))
}

override def dataType: DataType = IntegerType
Expand Down Expand Up @@ -1350,20 +1354,24 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non
case class StringInstr(str: Expression, substr: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

override def left: Expression = str
override def right: Expression = substr
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override def nullSafeEval(string: Any, sub: Any): Any = {
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1
CollationSupport.StringInstr.
exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], collationId) + 1
}

override def prettyName: String = "instr"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (l, r) =>
s"($l).indexOf($r, 0) + 1")
defineCodeGen(ctx, ev, (string, substring) =>
CollationSupport.StringInstr.genCode(string, substring, collationId) + " + 1")
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,68 @@ class CollationStringExpressionsSuite
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support StringInStr string expression with collation") {
case class StringInStrTestCase[R](string: String, substring: String, c: String, result: R)
val testCases = Seq(
// scalastyle:off
StringInStrTestCase("test大千世界X大千世界", "大千", "UTF8_BINARY", 5),
StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8),
StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0),
StringInStrTestCase("test大千世界X大千世界", "界y", "UNICODE_CI", 0),
StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8),
StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3)
// scalastyle:on
)
testCases.foreach(t => {
val query = s"SELECT instr(collate('${t.string}','${t.c}')," +
s"collate('${t.substring}','${t.c}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
// Implicit casting
checkAnswer(sql(s"SELECT instr(collate('${t.string}','${t.c}')," +
s"'${t.substring}')"), Row(t.result))
checkAnswer(sql(s"SELECT instr('${t.string}'," +
s"collate('${t.substring}','${t.c}'))"), Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql(s"SELECT instr(collate('aaads','UTF8_BINARY'), collate('Aa','UTF8_BINARY_LCASE'))")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support FindInSet string expression with collation") {
case class FindInSetTestCase[R](word: String, set: String, c: String, result: R)
val testCases = Seq(
FindInSetTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0),
FindInSetTestCase("C", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4),
FindInSetTestCase("d,ef", "abc,b,ab,c,def", "UNICODE", 0),
// scalastyle:off
FindInSetTestCase("i̇o", "ab,İo,12", "UNICODE_CI", 2),
FindInSetTestCase("İo", "ab,i̇o,12", "UNICODE_CI", 2)
// scalastyle:on
)
testCases.foreach(t => {
val query = s"SELECT find_in_set(collate('${t.word}', '${t.c}')," +
s"collate('${t.set}', '${t.c}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
// Implicit casting
checkAnswer(sql(s"SELECT find_in_set(collate('${t.word}', '${t.c}')," +
s"'${t.set}')"), Row(t.result))
checkAnswer(sql(s"SELECT find_in_set('${t.word}'," +
s"collate('${t.set}', '${t.c}'))"), Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql(s"SELECT find_in_set(collate('AB','UTF8_BINARY')," +
s"collate('ab,xyz,fgh','UTF8_BINARY_LCASE'))")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support StartsWith string expression with collation") {
// Supported collations
case class StartsWithTestCase[R](l: String, r: String, c: String, result: R)
Expand Down

0 comments on commit 256fc51

Please sign in to comment.