Skip to content

Commit

Permalink
Correct code style
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Mar 21, 2024
1 parent a4d3592 commit eb2d7c5
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,48 @@ public int findInSet(UTF8String match) {
return 0;
}

/**
public int findInSet(UTF8String match, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.findInSet(match);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().findInSet(match.toLowerCase());
}
return collatedFindInSet(match, collationId);
}

private int collatedFindInSet(UTF8String match, int collationId) {
if (match.contains(COMMA_UTF8)) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, match, collationId);

String setString = this.toString();
int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if(isValidStart && isValidEnd) {
int pos = 0;
for(int i = 0; i < setString.length() && i < wordStart; i++) {
if(setString.charAt(i) == ',') {
pos++;
}
}

return pos + 1;
}
}
}

return 0;
}

/**
* Copy the bytes from the current UTF8String, and make a new UTF8String.
* @param start the start position of the current UTF8String in bytes.
* @param end the end position of the current UTF8String in bytes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1002,17 +1002,30 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override protected def nullSafeEval(word: Any, set: Any): Any =
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
override protected def nullSafeEval(word: Any, set: Any): Any = {
val collationId = left.dataType.asInstanceOf[StringType].collationId
set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (word, set) =>
s"${ev.value} = $set.findInSet($word);"
)
}

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
return defaultCheck
}

val collationId = left.dataType.asInstanceOf[StringType].collationId
CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType))
}

override def dataType: DataType = IntegerType

override def prettyName: String = "find_in_set"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,51 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession
})
}

test("Support FindInSet with Collation") {
// Supported collations
val checks = Seq(
CollationTestCase("a", "abc,b,ab,c,def", "UTF8_BINARY", 0),
CollationTestCase("c", "abc,b,ab,c,def", "UTF8_BINARY", 4),
CollationTestCase("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1),
CollationTestCase("ab", "abc,b,ab,c,def", "UTF8_BINARY", 3),
CollationTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0),
CollationTestCase("Ab", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3),
CollationTestCase("ab", "abc,b,ab,c,def", "UNICODE", 3),
CollationTestCase("aB", "abc,b,ab,c,def", "UNICODE", 0),
CollationTestCase("AB", "abc,b,ab,c,def", "UNICODE_CI", 3)
)
checks.foreach(ct => {
checkAnswer(sql(s"SELECT find_in_set(collate('${ct.s1}', '${ct.collation}'), " +
s"collate('${ct.s2}', '${ct.collation}'))"),
Row(ct.expectedResult))
})
// Unsupported collation pairs
val fails = Seq(
SubstringIndexTestFail("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", "UTF8_BINARY"),
SubstringIndexTestFail("a", "abc,b,ab,c,def", "UNICODE_CI", "UNICODE")
)
fails.foreach(ct => {
val expr = s"find_in_set(collate('${ct.s1}', '${ct.c1}'), collate('${ct.s2}', '${ct.c2}'))"
checkError(
exception = intercept[ExtendedAnalysisException] {
sql(s"SELECT $expr")
},
errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH",
sqlState = "42K09",
parameters = Map(
"sqlExpr" -> s"\"find_in_set(collate(${ct.s1}), collate(${ct.s2}))\"",
"collationNameLeft" -> s"${ct.c1}",
"collationNameRight" -> s"${ct.c2}"
),
context = ExpectedContext(
fragment = s"$expr",
start = 7,
stop = 51 + ct.s1.length + ct.c1.length + ct.s2.length + ct.c2.length
)
)
})
}

// TODO: Add more tests for other string expressions

}
Expand Down

0 comments on commit eb2d7c5

Please sign in to comment.