Skip to content

Commit

Permalink
[SPARK-47414][SQL] Lowercase collation support for regexp expressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Introduce collation awareness for regexp expressions: like, ilike, like all, not like all, like any, not like any, rlike, split, regexp_replace, regexp_extract, regexp_extract_all, regexp_count, regexp_substr, regexp_instr. Note: collation support is only enabled for binary (UTF8_BINARY, UNICODE) & lowercase (UTF8_BINARY_LCASE) collation.

### Why are the changes needed?
Add collation support for built-in regexp functions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for built-in regexp functions: like, ilike, like all, not like all, like any, not like any, rlike, split, regexp_replace, regexp_extract, regexp_extract_all, regexp_count, regexp_substr, regexp_instr.

### How was this patch tested?
Unit regexp expression tests and e2e sql tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46077 from uros-db/SPARK-47414.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed Apr 25, 2024
1 parent c6aaa18 commit b4624bf
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import org.apache.spark.unsafe.types.UTF8String;

import java.util.regex.Pattern;

/**
* Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and
* other expressions that require custom collation support), as well as private utility methods for
Expand Down Expand Up @@ -310,7 +312,24 @@ public static int execICU(final UTF8String string, final UTF8String substring,
* Collation-aware regexp expressions.
*/

// TODO: Add more collation-aware regexp expressions.
public static boolean supportsLowercaseRegex(final int collationId) {
// for regex, only Unicode case-insensitive matching is possible,
// so UTF8_BINARY_LCASE is treated as UNICODE_CI in this context
return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality;
}

private static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE;
public static int collationAwareRegexFlags(final int collationId) {
return supportsLowercaseRegex(collationId) ? lowercaseRegexFlags : 0;
}

private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)");
public static UTF8String lowercaseRegex(final UTF8String regex) {
return UTF8String.concat(lowercaseRegexPrefix, regex);
}
public static UTF8String collationAwareRegex(final UTF8String regex, final int collationId) {
return supportsLowercaseRegex(collationId) ? lowercaseRegex(regex) : regex;
}

/**
* Other collation-aware expressions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import javax.annotation.Nullable
import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay, StringLPad, StringRPad}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay, RegExpReplace, StringLPad, StringRPad}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
Expand Down Expand Up @@ -52,6 +52,11 @@ object CollationTypeCasts extends TypeCoercionRule {
overlayExpr.withNewChildren(collateToSingleType(Seq(overlayExpr.input, overlayExpr.replace))
++ Seq(overlayExpr.pos, overlayExpr.len))

case regExpReplace: RegExpReplace =>
val Seq(subject, rep) = collateToSingleType(Seq(regExpReplace.subject, regExpReplace.rep))
val newChildren = Seq(subject, regExpReplace.regexp, rep, regExpReplace.pos)
regExpReplace.withNewChildren(newChildren)

case stringPadExpr @ (_: StringRPad | _: StringLPad) =>
val Seq(str, len, pad) = stringPadExpr.children
val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -44,7 +45,11 @@ abstract class StringRegexExpression extends BinaryExpression
def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean

override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId
final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId)

// try cache foldable pattern
private lazy val cache: Pattern = right match {
Expand All @@ -58,7 +63,7 @@ abstract class StringRegexExpression extends BinaryExpression
} else {
// Let it raise exception if couldn't compile the regex string
try {
Pattern.compile(escape(str))
Pattern.compile(escape(str), collationRegexFlags)
} catch {
case e: PatternSyntaxException =>
throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern, e)
Expand Down Expand Up @@ -158,7 +163,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
val pattern = ctx.addMutableState(patternClass, "patternLike",
v => s"""$v = $patternClass.compile("$regexStr");""")
v =>
s"""$v = $patternClass.compile("$regexStr", $collationRegexFlags);""".stripMargin)

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand Down Expand Up @@ -186,7 +192,7 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
s"""
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile(
$escapeFunc($rightStr, '$escapedEscapeChar'));
$escapeFunc($rightStr, '$escapedEscapeChar'), $collationRegexFlags);
${ev.value} = $pattern.matcher($eval1.toString()).matches();
"""
})
Expand Down Expand Up @@ -258,7 +264,8 @@ case class ILike(
def this(left: Expression, right: Expression) =
this(left, right, '\\')

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Expression = {
Expand All @@ -273,16 +280,18 @@ sealed abstract class MultiLikeBase

protected def isNotSpecified: Boolean

override def inputTypes: Seq[DataType] = StringType :: Nil
override def inputTypes: Seq[AbstractDataType] = StringTypeBinaryLcase :: Nil
final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId
final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId)

override def nullable: Boolean = true

final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)

protected lazy val hasNull: Boolean = patterns.contains(null)

protected lazy val cache = patterns.filterNot(_ == null)
.map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\')))
protected lazy val cache = patterns.filterNot(_ == null).map(s =>
Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'), collationRegexFlags))

protected lazy val matchFunc = if (isNotSpecified) {
(p: Pattern, inputValue: String) => !p.matcher(inputValue).matches()
Expand Down Expand Up @@ -475,7 +484,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
val regexStr =
StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString())
val pattern = ctx.addMutableState(patternClass, "patternRLike",
v => s"""$v = $patternClass.compile("$regexStr");""")
v => s"""$v = $patternClass.compile("$regexStr", $collationRegexFlags);""".stripMargin)

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand All @@ -499,7 +508,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile($rightStr);
$patternClass $pattern = $patternClass.compile($rightStr, $collationRegexFlags);
${ev.value} = $pattern.matcher($eval1.toString()).find(0);
"""
})
Expand Down Expand Up @@ -543,25 +552,29 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = ArrayType(StringType, containsNull = false)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def dataType: DataType = ArrayType(str.dataType, containsNull = false)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType)
override def first: Expression = str
override def second: Expression = regex
override def third: Expression = limit

final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId

def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1))

override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int])
val pattern = CollationSupport.collationAwareRegex(regex.asInstanceOf[UTF8String], collationId)
val strings = string.asInstanceOf[UTF8String].split(pattern, limit.asInstanceOf[Int])
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, regex, limit) => {
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin
s"""${ev.value} = new $arrayClass($str.split(
|CollationSupport.collationAwareRegex($regex, $collationId),$limit));""".stripMargin
})
}

Expand Down Expand Up @@ -658,7 +671,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio

override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = {
if (!p.equals(lastRegex)) {
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName)
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName, collationId)
pattern = patternAndRegex._1
lastRegex = patternAndRegex._2
}
Expand All @@ -683,9 +696,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
}
}

override def dataType: DataType = StringType
override def dataType: DataType = subject.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, StringType, IntegerType)
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, StringTypeBinaryLcase, IntegerType)
final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId
override def prettyName: String = "regexp_replace"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -708,7 +722,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, rep, pos) => {
s"""
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)}
if (!$rep.equals($termLastReplacementInUTF8)) {
// replacement string changed
$termLastReplacementInUTF8 = $rep.clone();
Expand Down Expand Up @@ -771,15 +785,18 @@ abstract class RegExpExtractBase

final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY)

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType)
override def first: Expression = subject
override def second: Expression = regexp
override def third: Expression = idx

final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId

protected def getLastMatcher(s: Any, p: Any): Matcher = {
if (p != lastRegex) {
// regex value changed
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName)
val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName, collationId)
pattern = patternAndRegex._1
lastRegex = patternAndRegex._2
}
Expand Down Expand Up @@ -848,7 +865,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
}
}

override def dataType: DataType = StringType
override def dataType: DataType = subject.dataType
override def prettyName: String = "regexp_extract"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -863,7 +880,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)}
if ($matcher.find()) {
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
$classNameRegExpExtractBase.checkGroupIndex("$prettyName", $matchResult.groupCount(), $idx);
Expand Down Expand Up @@ -947,7 +964,7 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres
new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]])
}

override def dataType: DataType = ArrayType(StringType)
override def dataType: DataType = ArrayType(subject.dataType)
override def prettyName: String = "regexp_extract_all"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -963,7 +980,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres
}
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName,
collationId)}
| java.util.ArrayList $matchResults = new java.util.ArrayList<UTF8String>();
| while ($matcher.find()) {
| java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
Expand Down Expand Up @@ -1020,7 +1038,8 @@ case class RegExpCount(left: Expression, right: Expression)

override def children: Seq[Expression] = Seq(left, right)

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): RegExpCount =
Expand Down Expand Up @@ -1053,13 +1072,14 @@ case class RegExpSubStr(left: Expression, right: Expression)
override lazy val replacement: Expression =
new NullIf(
RegExpExtract(subject = left, regexp = right, idx = Literal(0)),
Literal(""))
Literal.create("", left.dataType))

override def prettyName: String = "regexp_substr"

override def children: Seq[Expression] = Seq(left, right)

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): RegExpSubStr =
Expand Down Expand Up @@ -1127,7 +1147,8 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression)
s"""
|try {
| $setEvNotNull
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)}
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName,
collationId)}
| if ($matcher.find()) {
| ${ev.value} = $matcher.toMatchResult().start() + 1;
| } else {
Expand All @@ -1151,17 +1172,19 @@ object RegExpUtils {
subject: String,
regexp: String,
matcher: String,
prettyName: String): String = {
prettyName: String,
collationId: Int): String = {
val classNamePattern = classOf[Pattern].getCanonicalName
val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
val termPattern = ctx.addMutableState(classNamePattern, "pattern")
val collationRegexFlags = CollationSupport.collationAwareRegexFlags(collationId)

s"""
|if (!$regexp.equals($termLastRegex)) {
| // regex value changed
| try {
| UTF8String r = $regexp.clone();
| $termPattern = $classNamePattern.compile(r.toString());
| $termPattern = $classNamePattern.compile(r.toString(), $collationRegexFlags);
| $termLastRegex = r;
| } catch (java.util.regex.PatternSyntaxException e) {
| throw QueryExecutionErrors.invalidPatternError("$prettyName", e.getPattern(), e);
Expand All @@ -1171,10 +1194,11 @@ object RegExpUtils {
|""".stripMargin
}

def getPatternAndLastRegex(p: Any, prettyName: String): (Pattern, UTF8String) = {
def getPatternAndLastRegex(p: Any, prettyName: String, collationId: Int): (Pattern, UTF8String) =
{
val r = p.asInstanceOf[UTF8String].clone()
val pattern = try {
Pattern.compile(r.toString)
Pattern.compile(r.toString, CollationSupport.collationAwareRegexFlags(collationId))
} catch {
case e: PatternSyntaxException =>
throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern, e)
Expand Down
Loading

0 comments on commit b4624bf

Please sign in to comment.