Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47411][SQL] Support StringInstr & FindInSet functions to work with collated strings #45643

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a4d3592
Add support for instr and unit test in CollationStringExpressionsSuit…
miland-db Mar 21, 2024
eb2d7c5
Correct code style
miland-db Mar 21, 2024
9340831
Remove blank line from CollationStringExpressionsSuite.scala
miland-db Mar 21, 2024
465e814
Correct comment indentation
miland-db Mar 21, 2024
f3f30d8
Add unit tests for INSTR operation
miland-db Mar 22, 2024
9cb92d3
Add doGenCode for FindInSet
miland-db Mar 22, 2024
834be70
Rewrite unit tests for INSTR and FIND_IN_SET
miland-db Mar 22, 2024
db2453a
Correct return value when substr is not found in INSTR method
miland-db Mar 22, 2024
91b648a
Update unit tests for StringInStr and FindInSet
miland-db Mar 25, 2024
1062521
Remove tests on non-explicit default collation
miland-db Mar 25, 2024
42700c7
Merge branch 'apache:master' into miland-db/substr-functions
miland-db Mar 25, 2024
427ea25
Improve signature of testInStr
miland-db Mar 26, 2024
546b3b0
Merge branch 'master' into miland-db/substr-functions
miland-db Mar 26, 2024
108d707
Remove E2E test for collation mismatch. This will be added in Implici…
miland-db Mar 26, 2024
822ecd2
Resolve merge problems with master
miland-db Mar 26, 2024
f730d05
Improve scala style
miland-db Mar 26, 2024
de7b591
Solve whitespace scala style problem
miland-db Mar 27, 2024
f0ee8fd
Add lazy val collationId
miland-db Apr 1, 2024
4ac6885
Remove repeated code
miland-db Apr 1, 2024
b931333
Improve test format
miland-db Apr 1, 2024
0fd51d5
Improve indexOf method
miland-db Apr 1, 2024
28fa7f0
Remove checks in return statement of collatedIndexOf method
miland-db Apr 1, 2024
bab96ac
Merge branch 'master' into substr-functions
miland-db Apr 2, 2024
4666aff
Add branch for collated findInSet
miland-db Apr 2, 2024
ca8a37c
Add branch for collation check in StringInstr
miland-db Apr 2, 2024
4ffab78
Improve naming of collation aware methods
miland-db Apr 2, 2024
037d6be
Improve java style
miland-db Apr 3, 2024
0a22909
Improve collationAwareIndexOf performance
miland-db Apr 3, 2024
b3be85d
Fix indentation
miland-db Apr 3, 2024
c4c0fe7
Add more tests for instr
miland-db Apr 3, 2024
5b29f76
Add more tests
miland-db Apr 3, 2024
877828e
Remove collation match type checks
miland-db Apr 4, 2024
038a071
Merge branch 'master' into substr-functions
miland-db Apr 4, 2024
b35a8ac
Merge with the latest master
miland-db Apr 4, 2024
8b06014
Remove checkInputDataTypes
miland-db Apr 4, 2024
2c454af
Merge branch 'master' into substr-functions
miland-db Apr 11, 2024
4dbc26e
Merge branch 'master' into substr-functions
miland-db Apr 11, 2024
fbd1c00
Refactor code and move it to CollationSupport
miland-db Apr 12, 2024
960af54
Improve codegen and run tests
miland-db Apr 12, 2024
05cd6c4
Unify collationAwareIndexOf for return value to have same semantics a…
miland-db Apr 12, 2024
053efa0
Break line at 100 chars
miland-db Apr 13, 2024
c65d68e
Add new version of getStringSearch
miland-db Apr 15, 2024
ae33a38
Rename StringInstr params and class in CollationSupport
miland-db Apr 15, 2024
be1b52c
Go from nullSafeCodeGen to defineCodeGen
miland-db Apr 15, 2024
5894d2f
Refactor testing
miland-db Apr 15, 2024
75dc0bd
Remove empty lines
miland-db Apr 15, 2024
c712b4b
Improve CollationAware indexOf to have the same semantics as UTF8Stri…
miland-db Apr 15, 2024
3c37f35
Add new e2e test
miland-db Apr 15, 2024
cd860b9
Revert unused import deletion
miland-db Apr 15, 2024
3fa2502
Fix codegen
miland-db Apr 16, 2024
b12f176
Merge branch 'master' into substr-functions
miland-db Apr 17, 2024
b35d718
Remove unused import
miland-db Apr 17, 2024
1ee5ad6
Add new tests
miland-db Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,47 @@ public int findInSet(UTF8String match) {
return 0;
}

public int findInSet(UTF8String match, int collationId) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.findInSet(match);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().findInSet(match.toLowerCase());
}
return collatedFindInSet(match, collationId);
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

private int collatedFindInSet(UTF8String match, int collationId) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if (match.contains(COMMA_UTF8)) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, match, collationId);

String setString = this.toString();
int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if(isValidStart && isValidEnd) {
int pos = 0;
for(int i = 0; i < setString.length() && i < wordStart; i++) {
if(setString.charAt(i) == ',') {
pos++;
}
}

return pos + 1;
}
}
}

return 0;
}

/**
* Copy the bytes from the current UTF8String, and make a new UTF8String.
* @param start the start position of the current UTF8String in bytes.
Expand Down Expand Up @@ -835,6 +876,33 @@ public int indexOf(UTF8String v, int start) {
return -1;
}

public int indexOf(UTF8String substring, int start, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.indexOf(substring, start);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().indexOf(substring.toLowerCase(), start);
}
return collatedIndexOf(substring, collationId);
}

private int collatedIndexOf(UTF8String substring, int collationId) {
if (substring.numBytes == 0) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId);

int pos = 0;
while ((pos = stringSearch.next()) != StringSearch.DONE) {
if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) {
return pos;
}
}

return 0;
}

/**
* Find the `str` from left to right.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1002,15 +1002,32 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override protected def nullSafeEval(word: Any, set: Any): Any =
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
override protected def nullSafeEval(word: Any, set: Any): Any = {
val collationId = left.dataType.asInstanceOf[StringType].collationId
miland-db marked this conversation as resolved.
Show resolved Hide resolved
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (word, set) =>
s"${ev.value} = $set.findInSet($word);"
)
val collationId = left.dataType.asInstanceOf[StringType].collationId

if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);")
} else {
nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word, $collationId);")
}
}

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
return defaultCheck
}

val collationId = left.dataType.asInstanceOf[StringType].collationId
CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType))
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}

override def dataType: DataType = IntegerType
Expand Down Expand Up @@ -1377,17 +1394,34 @@ case class StringInstr(str: Expression, substr: Expression)
override def left: Expression = str
override def right: Expression = substr
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override def nullSafeEval(string: Any, sub: Any): Any = {
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1
val collationId = left.dataType.asInstanceOf[StringType].collationId
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1
}

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
return defaultCheck
}

val collationId = left.dataType.asInstanceOf[StringType].collationId
CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType))
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}

override def prettyName: String = "instr"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (l, r) =>
s"($l).indexOf($r, 0) + 1")
val collationId = left.dataType.asInstanceOf[StringType].collationId

if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1")
} else {
defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0, $collationId) + 1")
}
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ import scala.collection.immutable.Seq

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.{Collate, ExpressionEvalHelper, FindInSet, Literal, StringInstr}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType

class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession {
class CollationStringExpressionsSuite extends QueryTest
with SharedSparkSession with ExpressionEvalHelper {

case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R)
case class CollationTestFail[R](s1: String, s2: String, collation: String)
Expand Down Expand Up @@ -70,6 +74,148 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession
})
}

test("INSTR check result on non-explicit default collation") {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
checkEvaluation(StringInstr(Literal("aAads"), Literal("Aa")), 2)
}

test("INSTR check result on explicitly collated strings") {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
// UTF8_BINARY_LCASE
checkEvaluation(StringInstr(Literal.create("aaads", StringType(1)),
Literal.create("Aa", StringType(1))), 1)
checkEvaluation(StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"),
miland-db marked this conversation as resolved.
Show resolved Hide resolved
Collate(Literal("Aa"), "UTF8_BINARY_LCASE")), 1)
// UNICODE
checkEvaluation(StringInstr(Literal.create("aaads", StringType(2)),
Literal.create("Aa", StringType(2))), 0)
checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE"),
Collate(Literal("Aa"), "UNICODE")), 0)
// UNICODE_CI
checkEvaluation(StringInstr(Literal.create("aaads", StringType(3)),
Literal.create("de", StringType(3))), 0)
checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE_CI"),
Collate(Literal("Aa"), "UNICODE_CI")), 0)
}

test("INSTR fail mismatched collation types") {
// UNICODE and UNICODE_CI
val expr1 = StringInstr(Collate(Literal("aaads"), "UNICODE"),
Collate(Literal("Aa"), "UNICODE_CI"))
assert(expr1.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> "UNICODE",
"collationNameRight" -> "UNICODE_CI"
)
)
)
// DEFAULT(UTF8_BINARY) and UTF8_BINARY_LCASE
val expr2 = StringInstr(Literal("aaads"),
Collate(Literal("Aa"), "UTF8_BINARY_LCASE"))
assert(expr2.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> "UTF8_BINARY",
"collationNameRight" -> "UTF8_BINARY_LCASE"
)
)
)
// UTF8_BINARY_LCASE and UNICODE_CI
val expr3 = StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"),
Collate(Literal("Aa"), "UNICODE_CI"))
assert(expr3.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> "UTF8_BINARY_LCASE",
"collationNameRight" -> "UNICODE_CI"
)
)
)
}

test("FIND_IN_SET check result on non-explicit default collation") {
checkEvaluation(FindInSet(Literal("def"), Literal("abc,b,ab,c,def")), 5)
miland-db marked this conversation as resolved.
Show resolved Hide resolved
checkEvaluation(FindInSet(Literal("defg"), Literal("abc,b,ab,c,def")), 0)
}

test("FIND_IN_SET check result on explicitly collated strings") {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
// UTF8_BINARY
checkEvaluation(FindInSet(Collate(Literal("a"), "UTF8_BINARY"),
Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0)
checkEvaluation(FindInSet(Collate(Literal("c"), "UTF8_BINARY"),
Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 4)
checkEvaluation(FindInSet(Collate(Literal("AB"), "UTF8_BINARY"),
Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0)
checkEvaluation(FindInSet(Collate(Literal("abcd"), "UTF8_BINARY"),
Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0)
// UTF8_BINARY_LCASE
checkEvaluation(FindInSet(Collate(Literal("aB"), "UTF8_BINARY_LCASE"),
Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 3)
checkEvaluation(FindInSet(Collate(Literal("a"), "UTF8_BINARY_LCASE"),
Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 0)
checkEvaluation(FindInSet(Collate(Literal("abc"), "UTF8_BINARY_LCASE"),
Collate(Literal("aBc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 1)
checkEvaluation(FindInSet(Collate(Literal("abcd"), "UTF8_BINARY_LCASE"),
Collate(Literal("aBc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 0)
// UNICODE
checkEvaluation(FindInSet(Collate(Literal("a"), "UNICODE"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 0)
checkEvaluation(FindInSet(Collate(Literal("ab"), "UNICODE"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 3)
checkEvaluation(FindInSet(Collate(Literal("Ab"), "UNICODE"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 0)
// UNICODE_CI
checkEvaluation(FindInSet(Collate(Literal("a"), "UNICODE_CI"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0)
checkEvaluation(FindInSet(Collate(Literal("C"), "UNICODE_CI"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 4)
checkEvaluation(FindInSet(Collate(Literal("DeF"), "UNICODE_CI"),
Collate(Literal("abc,b,ab,c,dEf"), "UNICODE_CI")), 5)
checkEvaluation(FindInSet(Collate(Literal("DEFG"), "UNICODE_CI"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0)
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

test("FIND_IN_SET fail mismatched collation types") {
// UNICODE and UNICODE_CI
val expr1 = FindInSet(Collate(Literal("a"), "UNICODE"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI"))
assert(expr1.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> "UNICODE",
"collationNameRight" -> "UNICODE_CI"
)
)
)
// DEFAULT(UTF8_BINARY) and UTF8_BINARY_LCASE
val expr2 = FindInSet(Collate(Literal("a"), "UTF8_BINARY"),
Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE"))
assert(expr2.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> "UTF8_BINARY",
"collationNameRight" -> "UTF8_BINARY_LCASE"
)
)
)
// UTF8_BINARY_LCASE and UNICODE_CI
val expr3 = FindInSet(Collate(Literal("a"), "UTF8_BINARY_LCASE"),
Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI"))
assert(expr3.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> "UTF8_BINARY_LCASE",
"collationNameRight" -> "UNICODE_CI"
)
)
)
}
miland-db marked this conversation as resolved.
Show resolved Hide resolved

// TODO: Add more tests for other string expressions

}
Expand Down