Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47411][SQL] Support StringInstr & FindInSet functions to work with collated strings #45643

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a4d3592
Add support for instr and unit test in CollationStringExpressionsSuit…
miland-db Mar 21, 2024
eb2d7c5
Correct code style
miland-db Mar 21, 2024
9340831
Remove blank line from CollationStringExpressionsSuite.scala
miland-db Mar 21, 2024
465e814
Correct comment indentation
miland-db Mar 21, 2024
f3f30d8
Add unit tests for INSTR operation
miland-db Mar 22, 2024
9cb92d3
Add doGenCode for FindInSet
miland-db Mar 22, 2024
834be70
Rewrite unit tests for INSTR and FIND_IN_SET
miland-db Mar 22, 2024
db2453a
Correct return value when substr is not found in INSTR method
miland-db Mar 22, 2024
91b648a
Update unit tests for StringInStr and FindInSet
miland-db Mar 25, 2024
1062521
Remove tests on non-explicit default collation
miland-db Mar 25, 2024
42700c7
Merge branch 'apache:master' into miland-db/substr-functions
miland-db Mar 25, 2024
427ea25
Improve signature of testInStr
miland-db Mar 26, 2024
546b3b0
Merge branch 'master' into miland-db/substr-functions
miland-db Mar 26, 2024
108d707
Remove E2E test for collation mismatch. This will be added in Implici…
miland-db Mar 26, 2024
822ecd2
Resolve merge problems with master
miland-db Mar 26, 2024
f730d05
Improve scala style
miland-db Mar 26, 2024
de7b591
Solve whitespace scala style problem
miland-db Mar 27, 2024
f0ee8fd
Add lazy val collationId
miland-db Apr 1, 2024
4ac6885
Remove repeated code
miland-db Apr 1, 2024
b931333
Improve test format
miland-db Apr 1, 2024
0fd51d5
Improve indexOf method
miland-db Apr 1, 2024
28fa7f0
Remove checks in return statement of collatedIndexOf method
miland-db Apr 1, 2024
bab96ac
Merge branch 'master' into substr-functions
miland-db Apr 2, 2024
4666aff
Add branch for collated findInSet
miland-db Apr 2, 2024
ca8a37c
Add branch for collation check in StringInstr
miland-db Apr 2, 2024
4ffab78
Improve naming of collation aware methods
miland-db Apr 2, 2024
037d6be
Improve java style
miland-db Apr 3, 2024
0a22909
Improve collationAwareIndexOf performance
miland-db Apr 3, 2024
b3be85d
Fix indentation
miland-db Apr 3, 2024
c4c0fe7
Add more tests for instr
miland-db Apr 3, 2024
5b29f76
Add more tests
miland-db Apr 3, 2024
877828e
Remove collation match type checks
miland-db Apr 4, 2024
038a071
Merge branch 'master' into substr-functions
miland-db Apr 4, 2024
b35a8ac
Merge with the latest master
miland-db Apr 4, 2024
8b06014
Remove checkInputDataTypes
miland-db Apr 4, 2024
2c454af
Merge branch 'master' into substr-functions
miland-db Apr 11, 2024
4dbc26e
Merge branch 'master' into substr-functions
miland-db Apr 11, 2024
fbd1c00
Refactor code and move it to CollationSupport
miland-db Apr 12, 2024
960af54
Improve codegen and run tests
miland-db Apr 12, 2024
05cd6c4
Unify collationAwareIndexOf for return value to have same semantics a…
miland-db Apr 12, 2024
053efa0
Break line at 100 chars
miland-db Apr 13, 2024
c65d68e
Add new version of getStringSearch
miland-db Apr 15, 2024
ae33a38
Rename StringInstr params and class in CollationSupport
miland-db Apr 15, 2024
be1b52c
Go from nullSafeCodeGen to defineCodeGen
miland-db Apr 15, 2024
5894d2f
Refactor testing
miland-db Apr 15, 2024
75dc0bd
Remove empty lines
miland-db Apr 15, 2024
c712b4b
Improve CollationAware indexOf to have the same semantics as UTF8Stri…
miland-db Apr 15, 2024
3c37f35
Add new e2e test
miland-db Apr 15, 2024
cd860b9
Revert unused import deletion
miland-db Apr 15, 2024
3fa2502
Fix codegen
miland-db Apr 16, 2024
b12f176
Merge branch 'master' into substr-functions
miland-db Apr 17, 2024
b35d718
Remove unused import
miland-db Apr 17, 2024
1ee5ad6
Add new tests
miland-db Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,51 @@ public int findInSet(UTF8String match) {
return 0;
}

public int findInSet(UTF8String match, int collationId) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
return this.findInSet(match);
}
if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbatomic do we have a general principle for this special collation? Always lower-case first and then reuse existing UTF8String functions? Will we have more collations like this in the future?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to expect more special collations besides UTF8_BINARY_LCASE. Everything else will be based on ICU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, I also don't like current direction of pushing everything into UTF8String. Let me see if we can come up with some cleaner approach.

Copy link
Contributor Author

@miland-db miland-db Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can solve this in the same way we did it here: STRING_LOCATE. In the next PRs we updated CollationFactory to be able to return "lowercase collator" so we can unify the way we deal with collated strings. If you prefer, I can update this PR with that change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"lowercase collator" SGTM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - Let's put on hold string expression development until we provide proper design. @uros-db, @miland-db and I will follow up on this today.

return this.toLowerCase().findInSet(match.toLowerCase());
}
return collationAwareFindInSet(match, collationId);
}

/*
* Works on Strings with collationId other than UTF8_BINARY_COLLATION_ID. Returns the index
* of the string `match` in this String. This string has to be a comma separated
* list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String,
* 0 will be returned, else the index of match (1-based index)
*/
private int collationAwareFindInSet(UTF8String match, int collationId) {
if (match.contains(COMMA_UTF8)) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, match, collationId);

String setString = this.toString();
int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if (isValidStart && isValidEnd) {
int pos = 0;
for (int i = 0; i < setString.length() && i < wordStart; i++) {
if (setString.charAt(i) == ',') {
pos++;
}
}

return pos + 1;
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}
}

return 0;
}

/**
* Copy the bytes from the current UTF8String, and make a new UTF8String.
* @param start the start position of the current UTF8String in bytes.
Expand Down Expand Up @@ -835,6 +880,27 @@ public int indexOf(UTF8String v, int start) {
return -1;
}

public int indexOf(UTF8String substring, int start, int collationId) {
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
return this.indexOf(substring, start);
}
if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
return this.toLowerCase().indexOf(substring.toLowerCase(), start);
}
return collationAwareIndexOf(substring, start, collationId);
}

private int collationAwareIndexOf(UTF8String substring, int start, int collationId) {
if (substring.numBytes == 0) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId);
stringSearch.setIndex(start);

return stringSearch.next();
}

/**
* Find the `str` from left to right.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,15 +994,25 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override protected def nullSafeEval(word: Any, set: Any): Any =
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
override protected def nullSafeEval(word: Any, set: Any): Any = {
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
} else {
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (word, set) =>
s"${ev.value} = $set.findInSet($word);"
)
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);")
} else {
nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word, $collationId);")
}
}

override def dataType: DataType = IntegerType
Expand Down Expand Up @@ -1366,20 +1376,30 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non
case class StringInstr(str: Expression, substr: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

override def left: Expression = str
override def right: Expression = substr
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override def nullSafeEval(string: Any, sub: Any): Any = {
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1
} else {
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1
}
}

override def prettyName: String = "instr"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (l, r) =>
s"($l).indexOf($r, 0) + 1")
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1")
} else {
defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0, $collationId) + 1")
}
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.immutable.Seq

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, Literal, StringRepeat}
import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, FindInSet, Literal, StringInstr, StringRepeat}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -76,17 +76,128 @@ class CollationStringExpressionsSuite
)
}

test("INSTR check result on explicitly collated strings") {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
miland-db marked this conversation as resolved.
Show resolved Hide resolved
def testInStr(str: String, substr: String, collationId: Integer, expected: Integer): Unit = {
val string = Literal.create(str, StringType(collationId))
val substring = Literal.create(substr, StringType(collationId))

checkEvaluation(StringInstr(string, substring), expected)
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

var collationId = CollationFactory.collationNameToId("UTF8_BINARY")
testInStr("aaads", "Aa", collationId, 0)
testInStr("aaaDs", "de", collationId, 0)
testInStr("aaads", "ds", collationId, 4)
testInStr("xxxx", "", collationId, 1)
testInStr("", "xxxx", collationId, 0)
// scalastyle:off
testInStr("test大千世界X大千世界", "大千", collationId, 5)
testInStr("test大千世界X大千世界", "界X", collationId, 8)
// scalastyle:on

collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE")
testInStr("aaads", "Aa", collationId, 1)
testInStr("aaaDs", "de", collationId, 0)
testInStr("aaaDs", "ds", collationId, 4)
testInStr("xxxx", "", collationId, 1)
testInStr("", "xxxx", collationId, 0)
// scalastyle:off
testInStr("test大千世界X大千世界", "大千", collationId, 5)
testInStr("test大千世界X大千世界", "界x", collationId, 8)
// scalastyle:on

collationId = CollationFactory.collationNameToId("UNICODE")
testInStr("aaads", "Aa", collationId, 0)
testInStr("aaads", "aa", collationId, 1)
testInStr("aaads", "de", collationId, 0)
testInStr("xxxx", "", collationId, 1)
testInStr("", "xxxx", collationId, 0)
// scalastyle:off
testInStr("test大千世界X大千世界", "界x", collationId, 0)
testInStr("test大千世界X大千世界", "界X", collationId, 8)
// scalastyle:on

collationId = CollationFactory.collationNameToId("UNICODE_CI")
testInStr("aaads", "AD", collationId, 3)
testInStr("aaads", "dS", collationId, 4)
// scalastyle:off
testInStr("test大千世界X大千世界", "界x", collationId, 8)
// scalastyle:on
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}

test("FIND_IN_SET check result on explicitly collated strings") {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
def testFindInSet(word: String, set: String, collationId: Integer, expected: Integer): Unit = {
val w = Literal.create(word, StringType(collationId))
val s = Literal.create(set, StringType(collationId))

checkEvaluation(FindInSet(w, s), expected)
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

var collationId = CollationFactory.collationNameToId("UTF8_BINARY")
testFindInSet("AB", "abc,b,ab,c,def", collationId, 0)
testFindInSet("abc", "abc,b,ab,c,def", collationId, 1)
testFindInSet("def", "abc,b,ab,c,def", collationId, 5)
testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0)
testFindInSet("", "abc,b,ab,c,def", collationId, 0)

collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE")
testFindInSet("a", "abc,b,ab,c,def", collationId, 0)
testFindInSet("c", "abc,b,ab,c,def", collationId, 4)
testFindInSet("AB", "abc,b,ab,c,def", collationId, 3)
testFindInSet("AbC", "abc,b,ab,c,def", collationId, 1)
testFindInSet("abcd", "abc,b,ab,c,def", collationId, 0)
testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0)
testFindInSet("XX", "xx", collationId, 1)
testFindInSet("", "abc,b,ab,c,def", collationId, 0)
// scalastyle:off
testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 4)
// scalastyle:on

collationId = CollationFactory.collationNameToId("UNICODE")
testFindInSet("a", "abc,b,ab,c,def", collationId, 0)
testFindInSet("ab", "abc,b,ab,c,def", collationId, 3)
testFindInSet("Ab", "abc,b,ab,c,def", collationId, 0)
testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0)
testFindInSet("xx", "xx", collationId, 1)
// scalastyle:off
testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 0)
testFindInSet("大", "test,大千,世,界X,大,千,世界", collationId, 5)
// scalastyle:on

collationId = CollationFactory.collationNameToId("UNICODE_CI")
testFindInSet("a", "abc,b,ab,c,def", collationId, 0)
testFindInSet("C", "abc,b,ab,c,def", collationId, 4)
testFindInSet("DeF", "abc,b,ab,c,dEf", collationId, 5)
testFindInSet("DEFG", "abc,b,ab,c,def", collationId, 0)
testFindInSet("XX", "xx", collationId, 1)
// scalastyle:off
testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 4)
testFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", collationId, 5)
testFindInSet("大", "test,大千,世,界X,大,千,世界", collationId, 5)
// scalastyle:on
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

test("REPEAT check output type on explicitly collated string") {
def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = {
def testRepeat(input: String, n: Int, collationId: Int, expected: String): Unit = {
val s = Literal.create(input, StringType(collationId))

checkEvaluation(Collation(StringRepeat(s, Literal.create(n))).replacement, expected)
}

testRepeat("UTF8_BINARY", 0, "abc", 2)
testRepeat("UTF8_BINARY_LCASE", 1, "abc", 2)
testRepeat("UNICODE", 2, "abc", 2)
testRepeat("UNICODE_CI", 3, "abc", 2)
// Not important for this test
val repeatNum = 2;

var collationId = CollationFactory.collationNameToId("UTF8_BINARY")
testRepeat("abc", repeatNum, collationId, "UTF8_BINARY")

collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE")
testRepeat("abc", repeatNum, collationId, "UTF8_BINARY_LCASE")

collationId = CollationFactory.collationNameToId("UNICODE")
testRepeat("abc", repeatNum, collationId, "UNICODE")

collationId = CollationFactory.collationNameToId("UNICODE_CI")
testRepeat("abc", repeatNum, collationId, "UNICODE_CI")
}

// TODO: Add more tests for other string expressions
Expand Down