Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQL] Add collations support to split regex expression #45856

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.util;

import java.util.regex.Pattern;

import com.ibm.icu.text.StringSearch;

import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -143,6 +145,36 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
* Collation-aware regexp expressions.
*/

public static class StringSplit {
public static UTF8String[] exec(final UTF8String l, final UTF8String r, final int limit,
nikolamand-db marked this conversation as resolved.
Show resolved Hide resolved
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(l, r, limit);
} else {
return execLowercase(l, r, limit);
}
}
public static String genCode(final String l, final String r, final String limit,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.StringSplit.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s, %s)", l, r, limit);
} else {
return String.format(expr + "Lowercase(%s, %s, %s)", l, r, limit);
}
}
public static UTF8String[] execBinary(final UTF8String l, final UTF8String r,
final int limit) {
return l.split(r, limit);
}
public static UTF8String[] execLowercase(final UTF8String l, final UTF8String r,
final int limit) {
return l.split(r, limit, Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE);
}
}

// TODO: Add more collation-aware regexp expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
return fromBytes(result);
}

public UTF8String[] split(UTF8String pattern, int limit) {
public UTF8String[] split(UTF8String pattern, int limit, int regexFlags) {
// For the empty `pattern` a `split` function ignores trailing empty strings unless original
// string is empty.
if (numBytes() != 0 && pattern.numBytes() == 0) {
Expand All @@ -1044,7 +1044,11 @@ public UTF8String[] split(UTF8String pattern, int limit) {
}
return result;
}
return split(pattern.toString(), limit);
return split(pattern.toString(), limit, regexFlags);
}

public UTF8String[] split(UTF8String pattern, int limit) {
return split(pattern, limit, 0); // Pattern without regex flags
}

public UTF8String[] splitSQL(UTF8String delimiter, int limit) {
Expand All @@ -1061,21 +1065,31 @@ public UTF8String[] splitSQL(UTF8String delimiter, int limit) {
}
}

private UTF8String[] split(String delimiter, int limit) {
private UTF8String[] split(String delimiter, int limit, int regexFlags) {
// Java String's split method supports "ignore empty string" behavior when the limit is 0
// whereas other languages do not. To avoid this java specific behavior, we fall back to
// -1 when the limit is 0.
if (limit == 0) {
limit = -1;
}
String[] splits = toString().split(delimiter, limit);
String[] splits;
if (regexFlags == 0) {
nikolamand-db marked this conversation as resolved.
Show resolved Hide resolved
// Pattern without regex flags
splits = toString().split(delimiter, limit);
} else {
splits = Pattern.compile(delimiter, regexFlags).split(toString(), limit);
}
UTF8String[] res = new UTF8String[splits.length];
for (int i = 0; i < res.length; i++) {
res[i] = fromString(splits[i]);
}
return res;
}

private UTF8String[] split(String delimiter, int limit) {
return split(delimiter, limit, 0); // Pattern without regex flags
}

public UTF8String replace(UTF8String search, UTF8String replace) {
// This implementation is loosely based on commons-lang3's StringUtils.replace().
if (numBytes == 0 || search.numBytes == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -543,25 +544,29 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = ArrayType(StringType, containsNull = false)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def dataType: DataType = ArrayType(str.dataType, containsNull = false)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean for "regex" to be of type StringTypeAnyCollation, as it doesn't seem to me that collation is respected/needed for this parameter to begin with? for example, consider: [,] with UNICODE_CI collation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, why would we allow any collation for the regex string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support session-level default collation. If the user changes it and passes regex string literal, it will be interpreted as collated string. We don't want to throw exception in such cases.

override def first: Expression = str
override def second: Expression = regex
override def third: Expression = limit

final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId

def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1))

override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int])
val strings = CollationSupport.StringSplit.exec(string.asInstanceOf[UTF8String],
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int], collationId)
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, regex, limit) => {
nikolamand-db marked this conversation as resolved.
Show resolved Hide resolved
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin
val genCode = CollationSupport.StringSplit.genCode(str, regex, limit, collationId)
s"""${ev.value} = new $arrayClass($genCode);""".stripMargin
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,37 @@ class CollationRegexpExpressionsSuite

test("Support StringSplit string expression with collation") {
// Supported collations
case class StringSplitTestCase[R](l: String, r: String, c: String, result: R)
case class StringSplitTestCase[R](l: String, r: String, c: String, result: R, limit: Int = -1)
val testCases = Seq(
StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C"))
StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC")),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C")),
StringSplitTestCase("AAA", "[a]", "UTF8_BINARY_LCASE", Seq("", "", "", "")),
StringSplitTestCase("AAA", "[b]", "UTF8_BINARY_LCASE", Seq("AAA")),
StringSplitTestCase("aAbB", "[ab]", "UTF8_BINARY_LCASE", Seq("", "", "", "", "")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please write unit tests for all corner cases, instead of end-to-end tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewritten tests, please check.

StringSplitTestCase("", "", "UTF8_BINARY_LCASE", Seq("")),
StringSplitTestCase("", "[a]", "UTF8_BINARY_LCASE", Seq("")),
StringSplitTestCase("xAxBxaxbx", "[AB]", "UTF8_BINARY_LCASE", Seq("x", "x", "x", "x", "x")),
StringSplitTestCase("ABC", "", "UTF8_BINARY_LCASE", Seq("A", "B", "C")),
// test split with limit
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("ABC"), 1),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 2),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 3),
StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")),
StringSplitTestCase("ABC", "[b]", "UNICODE", Seq("ABC"))
)
testCases.foreach(t => {
val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))"
val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}', ${t.limit})"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c))))
// TODO: Implicit casting (not currently supported)
nikolamand-db marked this conversation as resolved.
Show resolved Hide resolved
})
// Unsupported collations
case class StringSplitTestFail(l: String, r: String, c: String)
val failCases = Seq(
StringSplitTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"),
StringSplitTestFail("ABC", "[B]", "UNICODE"),
StringSplitTestFail("ABC", "[b]", "UNICODE_CI")
)
val failCases = Seq(StringSplitTestFail("ABC", "[b]", "UNICODE_CI"))
failCases.foreach(t => {
val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))"
val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')"
val unsupportedCollation = intercept[AnalysisException] { sql(query) }
assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
})
Expand Down