diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 9786c559da44b..93691e28c692b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -196,10 +196,21 @@ public static StringSearch getStringSearch( final UTF8String targetUTF8String, final UTF8String patternUTF8String, final int collationId) { - String pattern = patternUTF8String.toString(); - CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString()); + return getStringSearch(targetUTF8String.toString(), patternUTF8String.toString(), collationId); + } + + /** + * Returns a StringSearch object for the given pattern and target strings, under collation + * rules corresponding to the given collationId. The external ICU library StringSearch object can + * be used to find occurrences of the pattern in the target string, while respecting collation. + */ + public static StringSearch getStringSearch( + final String targetString, + final String patternString, + final int collationId) { + CharacterIterator target = new StringCharacterIterator(targetString); Collator collator = CollationFactory.fetchCollation(collationId).collator; - return new StringSearch(pattern, target, (RuleBasedCollator) collator); + return new StringSearch(patternString, target, (RuleBasedCollator) collator); } /** diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index f54e6b162a933..d54e297413f49 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -144,6 +144,76 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } } + public static class FindInSet { + public static int exec(final UTF8String word, final UTF8String set, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(word, set); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(word, set); + } else { + return execICU(word, set, collationId); + } + } + public static String genCode(final String word, final String set, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.FindInSet.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", word, set); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", word, set); + } else { + return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId); + } + } + public static int execBinary(final UTF8String word, final UTF8String set) { + return set.findInSet(word); + } + public static int execLowercase(final UTF8String word, final UTF8String set) { + return set.toLowerCase().findInSet(word.toLowerCase()); + } + public static int execICU(final UTF8String word, final UTF8String set, + final int collationId) { + return CollationAwareUTF8String.findInSet(word, set, collationId); + } + } + + public static class StringInstr { + public static int exec(final UTF8String string, final UTF8String substring, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(string, substring); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(string, substring); + } else { + return execICU(string, substring, collationId); + } + } + public static String genCode(final String string, final String substring, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringInstr.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", string, substring); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", string, substring); + } else { + return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); + } + } + public static int execBinary(final UTF8String string, final UTF8String substring) { + return string.indexOf(substring, 0); + } + public static int execLowercase(final UTF8String string, final UTF8String substring) { + return string.toLowerCase().indexOf(substring.toLowerCase(), 0); + } + public static int execICU(final UTF8String string, final UTF8String substring, + final int collationId) { + return CollationAwareUTF8String.indexOf(string, substring, 0, collationId); + } + } + // TODO: Add more collation-aware string expressions. /** @@ -164,6 +234,48 @@ public static boolean execICU(final UTF8String l, final UTF8String r, private static class CollationAwareUTF8String { + private static int findInSet(final UTF8String match, final UTF8String set, int collationId) { + if (match.contains(UTF8String.fromString(","))) { + return 0; + } + + String setString = set.toString(); + StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(), + collationId); + + int wordStart = 0; + while ((wordStart = stringSearch.next()) != StringSearch.DONE) { + boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; + boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() + || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; + + if (isValidStart && isValidEnd) { + int pos = 0; + for (int i = 0; i < setString.length() && i < wordStart; i++) { + if (setString.charAt(i) == ',') { + pos++; + } + } + + return pos + 1; + } + } + + return 0; + } + + private static int indexOf(final UTF8String target, final UTF8String pattern, + final int start, final int collationId) { + if (pattern.numBytes() == 0) { + return 0; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + stringSearch.setIndex(start); + + return stringSearch.next(); + } + } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 3c0d999089e7e..36acf1c9b7a66 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -261,6 +261,88 @@ public void testEndsWith() throws SparkException { assertEndsWith("The i̇o", "İo", "UNICODE_CI", true); } + private void assertStringInstr(String string, String substring, String collationName, + Integer expected) throws SparkException { + UTF8String str = UTF8String.fromString(string); + UTF8String substr = UTF8String.fromString(substring); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected, CollationSupport.StringInstr.exec(str, substr, collationId) + 1); + } + + @Test + public void testStringInstr() throws SparkException { + assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0); + assertStringInstr("aaaDs", "de", "UTF8_BINARY", 0); + assertStringInstr("aaads", "ds", "UTF8_BINARY", 4); + assertStringInstr("xxxx", "", "UTF8_BINARY", 1); + assertStringInstr("", "xxxx", "UTF8_BINARY", 0); + assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY", 5); + assertStringInstr("test大千世界X大千世界", "界X", "UTF8_BINARY", 8); + assertStringInstr("aaads", "Aa", "UTF8_BINARY_LCASE", 1); + assertStringInstr("aaaDs", "de", "UTF8_BINARY_LCASE", 0); + assertStringInstr("aaaDs", "ds", "UTF8_BINARY_LCASE", 4); + assertStringInstr("xxxx", "", "UTF8_BINARY_LCASE", 1); + assertStringInstr("", "xxxx", "UTF8_BINARY_LCASE", 0); + assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY_LCASE", 5); + assertStringInstr("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8); + assertStringInstr("aaads", "Aa", "UNICODE", 0); + assertStringInstr("aaads", "aa", "UNICODE", 1); + assertStringInstr("aaads", "de", "UNICODE", 0); + assertStringInstr("xxxx", "", "UNICODE", 1); + assertStringInstr("", "xxxx", "UNICODE", 0); + assertStringInstr("test大千世界X大千世界", "界x", "UNICODE", 0); + assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8); + assertStringInstr("aaads", "AD", "UNICODE_CI", 3); + assertStringInstr("aaads", "dS", "UNICODE_CI", 4); + assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0); + assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8); + assertStringInstr("abİo12", "i̇o", "UNICODE_CI", 3); + assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3); + } + + private void assertFindInSet(String word, String set, String collationName, + Integer expected) throws SparkException { + UTF8String w = UTF8String.fromString(word); + UTF8String s = UTF8String.fromString(set); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected, CollationSupport.FindInSet.exec(w, s, collationId)); + } + + @Test + public void testFindInSet() throws SparkException { + assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0); + assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1); + assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5); + assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0); + assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0); + assertFindInSet("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("c", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4); + assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3); + assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 1); + assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("XX", "xx", "UTF8_BINARY_LCASE", 1); + assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_BINARY_LCASE", 4); + assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0); + assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3); + assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0); + assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0); + assertFindInSet("xx", "xx", "UNICODE", 1); + assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0); + assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5); + assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0); + assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4); + assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5); + assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0); + assertFindInSet("XX", "xx", "UNICODE_CI", 1); + assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4); + assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5); + assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5); + assertFindInSet("i̇o", "ab,İo,12", "UNICODE_CI", 2); + assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2); + } + // TODO: Test more collation-aware string expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index bd2c3baf4fe85..2b7703ed82b37 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -978,15 +978,19 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac case class FindInSet(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId - override protected def nullSafeEval(word: Any, set: Any): Any = - set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) + + override protected def nullSafeEval(word: Any, set: Any): Any = { + CollationSupport.FindInSet. + exec(word.asInstanceOf[UTF8String], set.asInstanceOf[UTF8String], collationId) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (word, set) => - s"${ev.value} = $set.findInSet($word);" - ) + defineCodeGen(ctx, ev, (word, set) => CollationSupport.FindInSet. + genCode(word, set, collationId)) } override def dataType: DataType = IntegerType @@ -1350,20 +1354,24 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non case class StringInstr(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + override def left: Expression = str override def right: Expression = substr override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def nullSafeEval(string: Any, sub: Any): Any = { - string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 + CollationSupport.StringInstr. + exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], collationId) + 1 } override def prettyName: String = "instr" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (l, r) => - s"($l).indexOf($r, 0) + 1") + defineCodeGen(ctx, ev, (string, substring) => + CollationSupport.StringInstr.genCode(string, substring, collationId) + " + 1") } override protected def withNewChildrenInternal( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 07be8d48e8697..35f63ce010a90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -115,6 +115,68 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("Support StringInStr string expression with collation") { + case class StringInStrTestCase[R](string: String, substring: String, c: String, result: R) + val testCases = Seq( + // scalastyle:off + StringInStrTestCase("test大千世界X大千世界", "大千", "UTF8_BINARY", 5), + StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8), + StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0), + StringInStrTestCase("test大千世界X大千世界", "界y", "UNICODE_CI", 0), + StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8), + StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3) + // scalastyle:on + ) + testCases.foreach(t => { + val query = s"SELECT instr(collate('${t.string}','${t.c}')," + + s"collate('${t.substring}','${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + // Implicit casting + checkAnswer(sql(s"SELECT instr(collate('${t.string}','${t.c}')," + + s"'${t.substring}')"), Row(t.result)) + checkAnswer(sql(s"SELECT instr('${t.string}'," + + s"collate('${t.substring}','${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql(s"SELECT instr(collate('aaads','UTF8_BINARY'), collate('Aa','UTF8_BINARY_LCASE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + + test("Support FindInSet string expression with collation") { + case class FindInSetTestCase[R](word: String, set: String, c: String, result: R) + val testCases = Seq( + FindInSetTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0), + FindInSetTestCase("C", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4), + FindInSetTestCase("d,ef", "abc,b,ab,c,def", "UNICODE", 0), + // scalastyle:off + FindInSetTestCase("i̇o", "ab,İo,12", "UNICODE_CI", 2), + FindInSetTestCase("İo", "ab,i̇o,12", "UNICODE_CI", 2) + // scalastyle:on + ) + testCases.foreach(t => { + val query = s"SELECT find_in_set(collate('${t.word}', '${t.c}')," + + s"collate('${t.set}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + // Implicit casting + checkAnswer(sql(s"SELECT find_in_set(collate('${t.word}', '${t.c}')," + + s"'${t.set}')"), Row(t.result)) + checkAnswer(sql(s"SELECT find_in_set('${t.word}'," + + s"collate('${t.set}', '${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql(s"SELECT find_in_set(collate('AB','UTF8_BINARY')," + + s"collate('ab,xyz,fgh','UTF8_BINARY_LCASE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + test("Support StartsWith string expression with collation") { // Supported collations case class StartsWithTestCase[R](l: String, r: String, c: String, result: R)