diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index fe1952921b7fb..a09ccf7e23d93 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -143,6 +143,41 @@ public static boolean execICU(final UTF8String l, final UTF8String r, * Collation-aware regexp expressions. */ + public static class StringSplit { + public static UTF8String[] exec(final UTF8String string, final UTF8String regex, + final int limit, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(string, regex, limit); + } else { + assert(collation.supportsLowercaseEquality); + return execLowercase(string, regex, limit); + } + } + public static String genCode(final String string, final String regex, final String limit, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringSplit.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s, %s)", string, regex, limit); + } else { + return String.format(expr + "Lowercase(%s, %s, %s)", string, regex, limit); + } + } + public static UTF8String[] execBinary(final UTF8String string, final UTF8String regex, + final int limit) { + return string.split(regex, limit); + } + public static UTF8String[] execLowercase(final UTF8String string, final UTF8String regex, + final int limit) { + if (string.numBytes() != 0 && regex.numBytes() == 0) { + return string.split(regex, limit); + } else { + return string.split(CollationAwareUTF8String.getLowercaseRegex(regex), limit); + } + } + } + // TODO: Add more collation-aware regexp expressions. /** @@ -169,6 +204,13 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern pos, pos + pattern.numChars()), pattern, collationId).last() == 0; } + // ui flags toggle unicode case-insensitive matching + private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); + + private static UTF8String getLowercaseRegex(UTF8String regex) { + return UTF8String.concat(lowercaseRegexPrefix, regex); + } + } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index bfb696c35fff6..a390b8410e314 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -16,6 +16,8 @@ */ package org.apache.spark.unsafe.types; +import java.util.Arrays; + import org.apache.spark.SparkException; import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.sql.catalyst.util.CollationSupport; @@ -255,6 +257,82 @@ public void testEndsWith() throws SparkException { * Collation-aware regexp expressions. */ + @Test + public void testStringSplit() throws SparkException { + // binary equality + assertStringSplit("ABC", "[B]", "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", "UTF8_BINARY", new String[]{"ABC"}); + assertStringSplit("aaaa", "", "UTF8_BINARY", new String[]{"a", "a", "a", "a"}); + assertStringSplit("aaaa", "[a-z]", "UTF8_BINARY", new String[]{"", "", "", "", ""}); + assertStringSplit("aaaa", "[0-9]", "UTF8_BINARY", new String[]{"aaaa"}); + assertStringSplit("a1b2", "[a-z0-9]", "UTF8_BINARY", new String[]{"", "", "", "", ""}); + assertStringSplit("ABC", "[B]", "UNICODE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", "UNICODE", new String[]{"ABC"}); + assertStringSplit("aaaa", "", "UNICODE", new String[]{"a", "a", "a", "a"}); + assertStringSplit("aaaa", "[a-z]", "UNICODE", new String[]{"", "", "", "", ""}); + assertStringSplit("aaaa", "[0-9]", "UNICODE", new String[]{"aaaa"}); + assertStringSplit("a1b2", "[a-z0-9]", "UNICODE", new String[]{"", "", "", "", ""}); + // non-binary equality (lowercase) + assertStringSplit("ABC", "[B]", "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("aaaa", "", "UTF8_BINARY_LCASE", new String[]{"a", "a", "a", "a"}); + assertStringSplit("aaaa", "[a-z]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("aaaa", "[0-9]", "UTF8_BINARY_LCASE", new String[]{"aaaa"}); + assertStringSplit("a1b2", "[a-z0-9]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("AAA", "[a]", "UTF8_BINARY_LCASE", new String[]{"", "", "", ""}); + assertStringSplit("AAA", "[b]", "UTF8_BINARY_LCASE", new String[]{"AAA"}); + assertStringSplit("aAbB", "[ab]", "UTF8_BINARY_LCASE",new String[]{"", "", "", "", ""}); + assertStringSplit("", "", "UTF8_BINARY_LCASE", new String[]{""}); + assertStringSplit("", "[a]", "UTF8_BINARY_LCASE", new String[]{""}); + assertStringSplit("xAxBxaxbx", "[AB]", "UTF8_BINARY_LCASE", + new String[]{"x", "x", "x", "x", "x"}); + assertStringSplit("ABC", "", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C"}); + // special characters + assertStringSplit("ä", "", "UTF8_BINARY", new String[]{"ä"}); + assertStringSplit("ääää", "", "UTF8_BINARY", new String[]{"ä", "ä", "ä", "ä"}); + assertStringSplit("äbćδ", "", "UTF8_BINARY", new String[]{"ä", "b", "ć", "δ"}); + assertStringSplit("äbćδ", "[äbćδ]", "UTF8_BINARY", new String[]{"", "", "", "", ""}); + assertStringSplit("ä", "", "UTF8_BINARY_LCASE", new String[]{"ä"}); + assertStringSplit("ääää", "", "UTF8_BINARY_LCASE", new String[]{"ä", "ä", "ä", "ä"}); + assertStringSplit("äbćδ", "", "UTF8_BINARY_LCASE", new String[]{"ä", "b", "ć", "δ"}); + assertStringSplit("äbćδ", "[äbćδ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("äbćδ", "[ÄBĆΔ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("äbćδ", "[äBćΔ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("ääää", "Ä", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("AäBÄCä", "Ä", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C", ""}); + assertStringSplit("AäBÄCäD", "Ä", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C", "D"}); + assertStringSplit("ä", "", "UNICODE", new String[]{"ä"}); + assertStringSplit("ääää", "", "UNICODE", new String[]{"ä", "ä", "ä", "ä"}); + assertStringSplit("äbćδ", "", "UNICODE", new String[]{"ä", "b", "ć", "δ"}); + assertStringSplit("äbćδ", "[äbćδ]", "UNICODE", new String[]{"", "", "", "", ""}); + // set limit + assertStringSplit("ABC", "[B]", 0, "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 1, "UTF8_BINARY", new String[]{"ABC"}); + assertStringSplit("ABC", "[B]", 2, "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 3, "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", 0, "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", 1, "UTF8_BINARY_LCASE", new String[]{"ABC"}); + assertStringSplit("ABC", "[b]", 2, "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", 3, "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 0, "UNICODE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 1, "UNICODE", new String[]{"ABC"}); + assertStringSplit("ABC", "[B]", 2, "UNICODE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 3, "UNICODE", new String[]{"A", "C"}); + } + + private void assertStringSplit(String string, String regex, int limit, String collationName, + String[] value) throws SparkException { + UTF8String[] result = CollationSupport.StringSplit.exec(UTF8String.fromString(string), + UTF8String.fromString(regex), limit, CollationFactory.collationNameToId(collationName)); + String[] actual = Arrays.stream(result).map(UTF8String::toString).toArray(String[]::new); + assertArrayEquals(value, actual); + } + + private void assertStringSplit(String string, String regex, String collationName, + String[] value) throws SparkException { + assertStringSplit(string, regex, -1, collationName, value); + } + // TODO: Test more collation-aware regexp expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b33de303b5d55..a91cc74276df9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} +import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -543,25 +544,28 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress case class StringSplit(str: Expression, regex: Expression, limit: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = ArrayType(StringType, containsNull = false) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def dataType: DataType = ArrayType(str.dataType, containsNull = false) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) override def first: Expression = str override def second: Expression = regex override def third: Expression = limit + final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId + def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)) override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split( - regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int]) + val strings = CollationSupport.StringSplit.exec(string.asInstanceOf[UTF8String], + regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int], collationId) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, regex, limit) => { + defineCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + s"new $arrayClass(${CollationSupport.StringSplit.genCode(str, regex, limit, collationId)})" }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 0876425847bbb..0a5a5055e6fa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -116,30 +116,31 @@ class CollationRegexpExpressionsSuite test("Support StringSplit string expression with collation") { // Supported collations - case class StringSplitTestCase[R](l: String, r: String, c: String, result: R) + case class StringSplitTestCase[R](l: String, r: String, c: String, result: R, limit: Int = -1) val testCases = Seq( - StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")) + StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC")), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C")), + StringSplitTestCase("AAA", "[a]", "UTF8_BINARY_LCASE", Seq("", "", "", "")), + StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC"), 1), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("ABC"), 1), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 2) ) testCases.foreach(t => { - val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}', ${t.limit})" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class StringSplitTestFail(l: String, r: String, c: String) - val failCases = Seq( - StringSplitTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"), - StringSplitTestFail("ABC", "[B]", "UNICODE"), - StringSplitTestFail("ABC", "[b]", "UNICODE_CI") - ) + val failCases = Seq(StringSplitTestFail("ABC", "[b]", "UNICODE_CI")) failCases.foreach(t => { - val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpReplace string expression with collation") {