Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47414][SQL] Lowercase collation support for regexp expressions #46077

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import org.apache.spark.unsafe.types.UTF8String;

import java.util.regex.Pattern;

/**
* Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and
* other expressions that require custom collation support), as well as private utility methods for
Expand Down Expand Up @@ -143,7 +145,24 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
* Collation-aware regexp expressions.
*/

// TODO: Add more collation-aware regexp expressions.
public static boolean supportsLowercaseRegex(final int collationId) {
uros-db marked this conversation as resolved.
Show resolved Hide resolved
// for regex, only Unicode case-insensitive matching is possible,
// so UTF8_BINARY_LCASE is treated as UNICODE_CI in this context
return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality;
}

private static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE;
public static int collationAwareRegexFlags(final int collationId) {
return supportsLowercaseRegex(collationId) ? lowercaseRegexFlags : 0;
}

private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)");
public static UTF8String lowercaseRegex(final UTF8String regex) {
return UTF8String.concat(lowercaseRegexPrefix, regex);
}
public static UTF8String collationAwareRegex(final UTF8String regex, final int collationId) {
return supportsLowercaseRegex(collationId) ? lowercaseRegex(regex) : regex;
}

/**
* Other collation-aware expressions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import javax.annotation.Nullable
import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, RegExpReplace}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
Expand All @@ -45,6 +45,11 @@ object CollationTypeCasts extends TypeCoercionRule {
caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e))
CaseWhen(newBranches, newElseValue)

case regExpReplace: RegExpReplace =>
val singleType = collateToSingleType(Seq(regExpReplace.subject, regExpReplace.rep))
val newChildren = Seq(singleType.head, regExpReplace.regexp, singleType(1), regExpReplace.pos)
uros-db marked this conversation as resolved.
Show resolved Hide resolved
regExpReplace.withNewChildren(newChildren)

case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least |
_: Coalesce | _: BinaryExpression | _: ConcatWs) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -44,7 +45,10 @@ abstract class StringRegexExpression extends BinaryExpression
def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean

override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

// try cache foldable pattern
private lazy val cache: Pattern = right match {
Expand All @@ -58,7 +62,7 @@ abstract class StringRegexExpression extends BinaryExpression
} else {
// Let it raise exception if couldn't compile the regex string
try {
Pattern.compile(escape(str))
Pattern.compile(escape(str), CollationSupport.collationAwareRegexFlags(collationId))
} catch {
case e: PatternSyntaxException =>
throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern, e)
Expand Down Expand Up @@ -158,7 +162,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
val pattern = ctx.addMutableState(patternClass, "patternLike",
v => s"""$v = $patternClass.compile("$regexStr");""")
v =>
s"""$v = $patternClass.compile("$regexStr",
|CollationSupport.collationAwareRegexFlags($collationId));""".stripMargin)
uros-db marked this conversation as resolved.
Show resolved Hide resolved

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand Down Expand Up @@ -186,7 +192,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
s"""
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile(
$escapeFunc($rightStr, '$escapedEscapeChar'));
$escapeFunc($rightStr, '$escapedEscapeChar'),
CollationSupport.collationAwareRegexFlags($collationId));
uros-db marked this conversation as resolved.
Show resolved Hide resolved
${ev.value} = $pattern.matcher($eval1.toString()).matches();
"""
})
Expand Down Expand Up @@ -258,7 +265,8 @@ case class ILike(
def this(left: Expression, right: Expression) =
this(left, right, '\\')

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Expression = {
Expand All @@ -273,7 +281,8 @@ sealed abstract class MultiLikeBase

protected def isNotSpecified: Boolean

override def inputTypes: Seq[DataType] = StringType :: Nil
override def inputTypes: Seq[AbstractDataType] = StringTypeBinaryLcase :: Nil
final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId

override def nullable: Boolean = true

Expand All @@ -282,7 +291,8 @@ sealed abstract class MultiLikeBase
protected lazy val hasNull: Boolean = patterns.contains(null)

protected lazy val cache = patterns.filterNot(_ == null)
.map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\')))
.map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'),
CollationSupport.collationAwareRegexFlags(collationId)))

protected lazy val matchFunc = if (isNotSpecified) {
(p: Pattern, inputValue: String) => !p.matcher(inputValue).matches()
Expand Down Expand Up @@ -475,7 +485,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
val regexStr =
StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString())
val pattern = ctx.addMutableState(patternClass, "patternRLike",
v => s"""$v = $patternClass.compile("$regexStr");""")
v => s"""$v = $patternClass.compile("$regexStr",
|CollationSupport.collationAwareRegexFlags($collationId));""".stripMargin)
uros-db marked this conversation as resolved.
Show resolved Hide resolved

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand All @@ -499,7 +510,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile($rightStr);
$patternClass $pattern = $patternClass.compile($rightStr,
CollationSupport.collationAwareRegexFlags($collationId));
uros-db marked this conversation as resolved.
Show resolved Hide resolved
${ev.value} = $pattern.matcher($eval1.toString()).find(0);
"""
})
Expand Down Expand Up @@ -543,25 +555,29 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = ArrayType(StringType, containsNull = false)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def dataType: DataType = ArrayType(str.dataType, containsNull = false)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType)
override def first: Expression = str
override def second: Expression = regex
override def third: Expression = limit

final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId

def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1))

override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int])
val pattern = CollationSupport.collationAwareRegex(regex.asInstanceOf[UTF8String], collationId)
val strings = string.asInstanceOf[UTF8String].split(pattern, limit.asInstanceOf[Int])
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, regex, limit) => {
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin
s"""${ev.value} = new $arrayClass($str.split(
|CollationSupport.collationAwareRegex($regex, $collationId),$limit));""".stripMargin
})
}

Expand Down Expand Up @@ -657,8 +673,9 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE)

override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = {
if (!p.equals(lastRegex)) {
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName)
val regex = CollationSupport.collationAwareRegex(p.asInstanceOf[UTF8String], collationId)
if (!regex.equals(lastRegex)) {
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName)
pattern = patternAndRegex._1
lastRegex = patternAndRegex._2
}
Expand All @@ -683,9 +700,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
}
}

override def dataType: DataType = StringType
override def dataType: DataType = subject.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, StringType, IntegerType)
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, StringTypeBinaryLcase, IntegerType)
final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId
override def prettyName: String = "regexp_replace"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -708,7 +726,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, rep, pos) => {
s"""
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)}
if (!$rep.equals($termLastReplacementInUTF8)) {
// replacement string changed
$termLastReplacementInUTF8 = $rep.clone();
Expand Down Expand Up @@ -771,15 +789,19 @@ abstract class RegExpExtractBase

final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY)

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType)
override def first: Expression = subject
override def second: Expression = regexp
override def third: Expression = idx

final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId

protected def getLastMatcher(s: Any, p: Any): Matcher = {
if (p != lastRegex) {
val regex = CollationSupport.collationAwareRegex(p.asInstanceOf[UTF8String], collationId)
if (regex != lastRegex) {
// regex value changed
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName)
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName)
pattern = patternAndRegex._1
lastRegex = patternAndRegex._2
}
Expand Down Expand Up @@ -848,7 +870,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
}
}

override def dataType: DataType = StringType
override def dataType: DataType = subject.dataType
override def prettyName: String = "regexp_extract"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -863,7 +885,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)}
if ($matcher.find()) {
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
$classNameRegExpExtractBase.checkGroupIndex("$prettyName", $matchResult.groupCount(), $idx);
Expand Down Expand Up @@ -947,7 +969,7 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres
new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]])
}

override def dataType: DataType = ArrayType(StringType)
override def dataType: DataType = ArrayType(subject.dataType)
override def prettyName: String = "regexp_extract_all"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -963,7 +985,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres
}
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName,
collationId)}
| java.util.ArrayList $matchResults = new java.util.ArrayList<UTF8String>();
| while ($matcher.find()) {
| java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
Expand Down Expand Up @@ -1020,7 +1043,8 @@ case class RegExpCount(left: Expression, right: Expression)

override def children: Seq[Expression] = Seq(left, right)

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): RegExpCount =
Expand Down Expand Up @@ -1053,13 +1077,14 @@ case class RegExpSubStr(left: Expression, right: Expression)
override lazy val replacement: Expression =
new NullIf(
RegExpExtract(subject = left, regexp = right, idx = Literal(0)),
Literal(""))
Literal.create("", left.dataType))

override def prettyName: String = "regexp_substr"

override def children: Seq[Expression] = Seq(left, right)

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): RegExpSubStr =
Expand Down Expand Up @@ -1127,7 +1152,8 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression)
s"""
|try {
| $setEvNotNull
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName,
collationId)}
| if ($matcher.find()) {
| ${ev.value} = $matcher.toMatchResult().start() + 1;
| } else {
Expand All @@ -1151,16 +1177,19 @@ object RegExpUtils {
subject: String,
regexp: String,
matcher: String,
prettyName: String): String = {
prettyName: String,
collationId: Int): String = {
val classNamePattern = classOf[Pattern].getCanonicalName
val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
val termPattern = ctx.addMutableState(classNamePattern, "pattern")
val collAwareRegexp = ctx.freshName("collAwareRegexp")

s"""
|if (!$regexp.equals($termLastRegex)) {
|UTF8String $collAwareRegexp = CollationSupport.collationAwareRegex($regexp, $collationId);
|if (!$collAwareRegexp.equals($termLastRegex)) {
| // regex value changed
| try {
| UTF8String r = $regexp.clone();
| UTF8String r = $collAwareRegexp.clone();
| $termPattern = $classNamePattern.compile(r.toString());
| $termLastRegex = r;
| } catch (java.util.regex.PatternSyntaxException e) {
Expand Down
Loading