From b12bbe5687901fd447a11a09a8724e0336d3d8cf Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 16 Apr 2024 09:59:39 +0200 Subject: [PATCH 01/16] Lowercase collation-aware regexp expressions --- .../sql/catalyst/util/CollationSupport.java | 14 +++ .../expressions/regexpExpressions.scala | 68 +++++++--- .../sql/CollationRegexpExpressionsSuite.scala | 118 ++++++++---------- 3 files changed, 111 insertions(+), 89 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index fe1952921b7fb..51d7e549ca275 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -171,4 +171,18 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern } + private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); + + public static UTF8String collationAwareRegex(final UTF8String regex, final int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + return regex; + } else { + return lowercaseRegex(regex); + } + } + + public static UTF8String lowercaseRegex(final UTF8String regex) { + return UTF8String.concat(lowercaseRegexPrefix, regex); + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b33de303b5d55..96e464d60f5f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} +import org.apache.spark.sql.catalyst.util.{CollationFactory, CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +45,10 @@ abstract class StringRegexExpression extends BinaryExpression def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId // try cache foldable pattern private lazy val cache: Pattern = right match { @@ -58,7 +62,11 @@ abstract class StringRegexExpression extends BinaryExpression } else { // Let it raise exception if couldn't compile the regex string try { - Pattern.compile(escape(str)) + var patternFlags: Int = 0 + if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + patternFlags = Pattern.CASE_INSENSITIVE + } + Pattern.compile(escape(str), patternFlags) } catch { case e: PatternSyntaxException => throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern, e) @@ -258,7 +266,8 @@ case class ILike( def this(left: Expression, right: Expression) = this(left, right, '\\') - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { @@ -543,17 +552,21 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress case class StringSplit(str: Expression, regex: Expression, limit: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = ArrayType(StringType, containsNull = false) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def dataType: DataType = ArrayType(str.dataType, containsNull = false) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) override def first: Expression = str override def second: Expression = regex override def third: Expression = limit + final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId + def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)) override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { val strings = string.asInstanceOf[UTF8String].split( - regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int]) + CollationSupport.collationAwareRegex(regex.asInstanceOf[UTF8String], collationId), + limit.asInstanceOf[Int]) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } @@ -561,7 +574,8 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + s"""${ev.value} = new $arrayClass($str.split(CollationSupport.collationAwareRegex($regex, + |$collationId),$limit));""".stripMargin }) } @@ -657,8 +671,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { - if (!p.equals(lastRegex)) { - val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName) + var regex: UTF8String = p.asInstanceOf[UTF8String] + if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + regex = CollationSupport.lowercaseRegex(regex) + } + if (!regex.equals(lastRegex)) { + val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName) pattern = patternAndRegex._1 lastRegex = patternAndRegex._2 } @@ -683,9 +701,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } } - override def dataType: DataType = StringType + override def dataType: DataType = subject.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringType, StringType, StringType, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeAnyCollation, StringTypeBinaryLcase, IntegerType) + final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -771,15 +790,22 @@ abstract class RegExpExtractBase final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx + final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId + protected def getLastMatcher(s: Any, p: Any): Matcher = { - if (p != lastRegex) { + var regex: UTF8String = p.asInstanceOf[UTF8String] + if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + regex = CollationSupport.lowercaseRegex(regex) + } + if (regex != lastRegex) { // regex value changed - val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName) + val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName) pattern = patternAndRegex._1 lastRegex = patternAndRegex._2 } @@ -848,7 +874,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } } - override def dataType: DataType = StringType + override def dataType: DataType = subject.dataType override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -947,7 +973,7 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]]) } - override def dataType: DataType = ArrayType(StringType) + override def dataType: DataType = ArrayType(subject.dataType) override def prettyName: String = "regexp_extract_all" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1020,7 +1046,8 @@ case class RegExpCount(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpCount = @@ -1053,13 +1080,14 @@ case class RegExpSubStr(left: Expression, right: Expression) override lazy val replacement: Expression = new NullIf( RegExpExtract(subject = left, regexp = right, idx = Literal(0)), - Literal("")) + Literal.create("", left.dataType)) override def prettyName: String = "regexp_substr" override def children: Seq[Expression] = Seq(left, right) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpSubStr = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 0876425847bbb..fc0a6821c8923 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -34,288 +34,268 @@ class CollationRegexpExpressionsSuite // Supported collations case class LikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - LikeTestCase("ABC", "%B%", "UTF8_BINARY", true) + LikeTestCase("ABC", "%B%", "UTF8_BINARY", true), + LikeTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", true), + LikeTestCase("ABC", "%b%", "UNICODE", false) ) testCases.foreach(t => { - val query = s"SELECT like(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT like(collate('${t.l}', '${t.c}'), '${t.r}')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class LikeTestFail(l: String, r: String, c: String) val failCases = Seq( - LikeTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), - LikeTestFail("ABC", "%B%", "UNICODE"), LikeTestFail("ABC", "%b%", "UNICODE_CI") ) failCases.foreach(t => { - val query = s"SELECT like(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT like(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support ILike string expression with collation") { // Supported collations case class ILikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - ILikeTestCase("ABC", "%b%", "UTF8_BINARY", true) + ILikeTestCase("ABC", "%b%", "UTF8_BINARY", true), + ILikeTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", true), + ILikeTestCase("ABC", "%b%", "UNICODE", true) ) testCases.foreach(t => { val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class ILikeTestFail(l: String, r: String, c: String) val failCases = Seq( - ILikeTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"), - ILikeTestFail("ABC", "%b%", "UNICODE"), ILikeTestFail("ABC", "%b%", "UNICODE_CI") ) failCases.foreach(t => { - val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RLike string expression with collation") { // Supported collations case class RLikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - RLikeTestCase("ABC", ".B.", "UTF8_BINARY", true) + RLikeTestCase("ABC", ".B.", "UTF8_BINARY", true), + RLikeTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", true), + RLikeTestCase("ABC", ".b.", "UNICODE", false) ) testCases.foreach(t => { - val query = s"SELECT rlike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT rlike(collate('${t.l}', '${t.c}'), '${t.r}')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class RLikeTestFail(l: String, r: String, c: String) val failCases = Seq( - RLikeTestFail("ABC", ".b.", "UTF8_BINARY_LCASE"), - RLikeTestFail("ABC", ".B.", "UNICODE"), RLikeTestFail("ABC", ".b.", "UNICODE_CI") ) failCases.foreach(t => { - val query = s"SELECT rlike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT rlike(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support StringSplit string expression with collation") { // Supported collations case class StringSplitTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")) + StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C")), + StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")) ) testCases.foreach(t => { - val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class StringSplitTestFail(l: String, r: String, c: String) val failCases = Seq( - StringSplitTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"), - StringSplitTestFail("ABC", "[B]", "UNICODE"), StringSplitTestFail("ABC", "[b]", "UNICODE_CI") ) failCases.foreach(t => { - val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpReplace string expression with collation") { // Supported collations case class RegExpReplaceTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - RegExpReplaceTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") + RegExpReplaceTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE"), + RegExpReplaceTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", "AFFFE"), + RegExpReplaceTestCase("ABCDE", ".c.", "UNICODE", "ABCDE") ) testCases.foreach(t => { val query = - s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 'FFF')" + s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), '${t.r}', 'FFF')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class RegExpReplaceTestFail(l: String, r: String, c: String) val failCases = Seq( - RegExpReplaceTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - RegExpReplaceTestFail("ABCDE", ".C.", "UNICODE"), RegExpReplaceTestFail("ABCDE", ".c.", "UNICODE_CI") ) failCases.foreach(t => { val query = - s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 'FFF')" + s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), '${t.r}', 'FFF')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpExtract string expression with collation") { // Supported collations case class RegExpExtractTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - RegExpExtractTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + RegExpExtractTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), + RegExpExtractTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", "BCD"), + RegExpExtractTestCase("ABCDE", ".c.", "UNICODE", "") ) testCases.foreach(t => { val query = - s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), '${t.r}', 0)" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class RegExpExtractTestFail(l: String, r: String, c: String) val failCases = Seq( - RegExpExtractTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - RegExpExtractTestFail("ABCDE", ".C.", "UNICODE"), RegExpExtractTestFail("ABCDE", ".c.", "UNICODE_CI") ) failCases.foreach(t => { val query = - s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), '${t.r}', 0)" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpExtractAll string expression with collation") { // Supported collations case class RegExpExtractAllTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - RegExpExtractAllTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")) + RegExpExtractAllTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")), + RegExpExtractAllTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", Seq("BCD")), + RegExpExtractAllTestCase("ABCDE", ".c.", "UNICODE", Seq()) ) testCases.foreach(t => { val query = - s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'), '${t.r}', 0)" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class RegExpExtractAllTestFail(l: String, r: String, c: String) val failCases = Seq( - RegExpExtractAllTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - RegExpExtractAllTestFail("ABCDE", ".C.", "UNICODE"), RegExpExtractAllTestFail("ABCDE", ".c.", "UNICODE_CI") ) failCases.foreach(t => { val query = - s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'), 0)" + s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'), '${t.r}', 0)" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpCount string expression with collation") { // Supported collations case class RegExpCountTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - RegExpCountTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + RegExpCountTestCase("ABCDE", ".C.", "UTF8_BINARY", 1), + RegExpCountTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 1), + RegExpCountTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { - val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'), '${t.r}')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class RegExpCountTestFail(l: String, r: String, c: String) val failCases = Seq( - RegExpCountTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - RegExpCountTestFail("ABCDE", ".C.", "UNICODE"), RegExpCountTestFail("ABCDE", ".c.", "UNICODE_CI") ) failCases.foreach(t => { - val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpSubStr string expression with collation") { // Supported collations case class RegExpSubStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - RegExpSubStrTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + RegExpSubStrTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), + RegExpSubStrTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", "BCD"), + RegExpSubStrTestCase("ABCDE", ".c.", "UNICODE", null) ) testCases.foreach(t => { - val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'), '${t.r}')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class RegExpSubStrTestFail(l: String, r: String, c: String) val failCases = Seq( - RegExpSubStrTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - RegExpSubStrTestFail("ABCDE", ".C.", "UNICODE"), RegExpSubStrTestFail("ABCDE", ".c.", "UNICODE_CI") ) failCases.foreach(t => { - val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpInStr string expression with collation") { // Supported collations case class RegExpInStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( - RegExpInStrTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) + RegExpInStrTestCase("ABCDE", ".C.", "UTF8_BINARY", 2), + RegExpInStrTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 2), + RegExpInStrTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { - val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'), '${t.r}')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class RegExpInStrTestFail(l: String, r: String, c: String) val failCases = Seq( - RegExpInStrTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"), - RegExpInStrTestFail("ABCDE", ".C.", "UNICODE"), RegExpInStrTestFail("ABCDE", ".c.", "UNICODE_CI") ) failCases.foreach(t => { - val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } } From c125be5cb3b0cfd60b8e4d5eac7b541fda65ef07 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 16 Apr 2024 10:18:42 +0200 Subject: [PATCH 02/16] Uniform lowercase collation handling in nullSafeEval --- .../sql/catalyst/util/CollationSupport.java | 20 +++++-------------- .../expressions/regexpExpressions.scala | 11 +++++----- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 51d7e549ca275..17d77938ee8f4 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -143,7 +143,11 @@ public static boolean execICU(final UTF8String l, final UTF8String r, * Collation-aware regexp expressions. */ - // TODO: Add more collation-aware regexp expressions. + private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); + + public static UTF8String lowercaseRegex(final UTF8String regex) { + return UTF8String.concat(lowercaseRegexPrefix, regex); + } /** * Other collation-aware expressions. @@ -171,18 +175,4 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern } - private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); - - public static UTF8String collationAwareRegex(final UTF8String regex, final int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - return regex; - } else { - return lowercaseRegex(regex); - } - } - - public static UTF8String lowercaseRegex(final UTF8String regex) { - return UTF8String.concat(lowercaseRegexPrefix, regex); - } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 96e464d60f5f0..310be8fb07076 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -564,9 +564,11 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)) override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split( - CollationSupport.collationAwareRegex(regex.asInstanceOf[UTF8String], collationId), - limit.asInstanceOf[Int]) + var pattern = regex.asInstanceOf[UTF8String] + if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + pattern = CollationSupport.lowercaseRegex(pattern) + } + val strings = string.asInstanceOf[UTF8String].split(pattern, limit.asInstanceOf[Int]) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } @@ -574,8 +576,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split(CollationSupport.collationAwareRegex($regex, - |$collationId),$limit));""".stripMargin + s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin }) } From 281f1b2b55f81a971c72e286ff9ea18013852310 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 16 Apr 2024 10:53:03 +0200 Subject: [PATCH 03/16] No need to use "u" flag, "i" seems to be enough --- .../org/apache/spark/sql/catalyst/util/CollationSupport.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 17d77938ee8f4..95ba6268fd65e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -143,7 +143,7 @@ public static boolean execICU(final UTF8String l, final UTF8String r, * Collation-aware regexp expressions. */ - private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); + private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?i)"); public static UTF8String lowercaseRegex(final UTF8String regex) { return UTF8String.concat(lowercaseRegexPrefix, regex); From 2e19e3e93b3face3104beefa87356eba1a482575 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 16 Apr 2024 14:06:17 +0200 Subject: [PATCH 04/16] Need to use "u" flag in conjunction with "i" flag --- .../sql/catalyst/util/CollationSupport.java | 2 +- .../expressions/regexpExpressions.scala | 2 +- .../sql/CollationRegexpExpressionsSuite.scala | 20 +++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 95ba6268fd65e..17d77938ee8f4 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -143,7 +143,7 @@ public static boolean execICU(final UTF8String l, final UTF8String r, * Collation-aware regexp expressions. */ - private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?i)"); + private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); public static UTF8String lowercaseRegex(final UTF8String regex) { return UTF8String.concat(lowercaseRegexPrefix, regex); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 310be8fb07076..6c72614186222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -64,7 +64,7 @@ abstract class StringRegexExpression extends BinaryExpression try { var patternFlags: Int = 0 if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { - patternFlags = Pattern.CASE_INSENSITIVE + patternFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE } Pattern.compile(escape(str), patternFlags) } catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index fc0a6821c8923..9774a5df29521 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -35,7 +35,7 @@ class CollationRegexpExpressionsSuite case class LikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( LikeTestCase("ABC", "%B%", "UTF8_BINARY", true), - LikeTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", true), + LikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), LikeTestCase("ABC", "%b%", "UNICODE", false) ) testCases.foreach(t => { @@ -61,7 +61,7 @@ class CollationRegexpExpressionsSuite case class ILikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( ILikeTestCase("ABC", "%b%", "UTF8_BINARY", true), - ILikeTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", true), + ILikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), ILikeTestCase("ABC", "%b%", "UNICODE", true) ) testCases.foreach(t => { @@ -87,7 +87,7 @@ class CollationRegexpExpressionsSuite case class RLikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RLikeTestCase("ABC", ".B.", "UTF8_BINARY", true), - RLikeTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", true), + RLikeTestCase("AḂC", ".ḃ.", "UTF8_BINARY_LCASE", true), RLikeTestCase("ABC", ".b.", "UNICODE", false) ) testCases.foreach(t => { @@ -113,7 +113,7 @@ class CollationRegexpExpressionsSuite case class StringSplitTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")), - StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C")), + StringSplitTestCase("AḂC", "[ḃ]", "UTF8_BINARY_LCASE", Seq("A", "C")), StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")) ) testCases.foreach(t => { @@ -139,7 +139,7 @@ class CollationRegexpExpressionsSuite case class RegExpReplaceTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpReplaceTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE"), - RegExpReplaceTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", "AFFFE"), + RegExpReplaceTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "AFFFE"), RegExpReplaceTestCase("ABCDE", ".c.", "UNICODE", "ABCDE") ) testCases.foreach(t => { @@ -167,7 +167,7 @@ class CollationRegexpExpressionsSuite case class RegExpExtractTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpExtractTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), - RegExpExtractTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", "BCD"), + RegExpExtractTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), RegExpExtractTestCase("ABCDE", ".c.", "UNICODE", "") ) testCases.foreach(t => { @@ -195,7 +195,7 @@ class CollationRegexpExpressionsSuite case class RegExpExtractAllTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpExtractAllTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")), - RegExpExtractAllTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", Seq("BCD")), + RegExpExtractAllTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", Seq("BĆD")), RegExpExtractAllTestCase("ABCDE", ".c.", "UNICODE", Seq()) ) testCases.foreach(t => { @@ -223,7 +223,7 @@ class CollationRegexpExpressionsSuite case class RegExpCountTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpCountTestCase("ABCDE", ".C.", "UTF8_BINARY", 1), - RegExpCountTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 1), + RegExpCountTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 1), RegExpCountTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { @@ -249,7 +249,7 @@ class CollationRegexpExpressionsSuite case class RegExpSubStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpSubStrTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), - RegExpSubStrTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", "BCD"), + RegExpSubStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), RegExpSubStrTestCase("ABCDE", ".c.", "UNICODE", null) ) testCases.foreach(t => { @@ -275,7 +275,7 @@ class CollationRegexpExpressionsSuite case class RegExpInStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpInStrTestCase("ABCDE", ".C.", "UTF8_BINARY", 2), - RegExpInStrTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 2), + RegExpInStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 2), RegExpInStrTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { From 957b46d6cbe3ed2b1c43ec31f0cb883c43636ee2 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:51:27 +0200 Subject: [PATCH 05/16] Update CollationRegexpExpressionsSuite.scala --- .../sql/CollationRegexpExpressionsSuite.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 9774a5df29521..0fcf1456290a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -35,7 +35,7 @@ class CollationRegexpExpressionsSuite case class LikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( LikeTestCase("ABC", "%B%", "UTF8_BINARY", true), - LikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), + LikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), // scalastyle:ignore LikeTestCase("ABC", "%b%", "UNICODE", false) ) testCases.foreach(t => { @@ -61,7 +61,7 @@ class CollationRegexpExpressionsSuite case class ILikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( ILikeTestCase("ABC", "%b%", "UTF8_BINARY", true), - ILikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), + ILikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), // scalastyle:ignore ILikeTestCase("ABC", "%b%", "UNICODE", true) ) testCases.foreach(t => { @@ -87,7 +87,7 @@ class CollationRegexpExpressionsSuite case class RLikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RLikeTestCase("ABC", ".B.", "UTF8_BINARY", true), - RLikeTestCase("AḂC", ".ḃ.", "UTF8_BINARY_LCASE", true), + RLikeTestCase("AḂC", ".ḃ.", "UTF8_BINARY_LCASE", true), // scalastyle:ignore RLikeTestCase("ABC", ".b.", "UNICODE", false) ) testCases.foreach(t => { @@ -113,7 +113,7 @@ class CollationRegexpExpressionsSuite case class StringSplitTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")), - StringSplitTestCase("AḂC", "[ḃ]", "UTF8_BINARY_LCASE", Seq("A", "C")), + StringSplitTestCase("AḂC", "[ḃ]", "UTF8_BINARY_LCASE", Seq("A", "C")), // scalastyle:ignore StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")) ) testCases.foreach(t => { @@ -139,7 +139,7 @@ class CollationRegexpExpressionsSuite case class RegExpReplaceTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpReplaceTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE"), - RegExpReplaceTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "AFFFE"), + RegExpReplaceTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "AFFFE"), // scalastyle:ignore RegExpReplaceTestCase("ABCDE", ".c.", "UNICODE", "ABCDE") ) testCases.foreach(t => { @@ -167,7 +167,7 @@ class CollationRegexpExpressionsSuite case class RegExpExtractTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpExtractTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), - RegExpExtractTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), + RegExpExtractTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), // scalastyle:ignore RegExpExtractTestCase("ABCDE", ".c.", "UNICODE", "") ) testCases.foreach(t => { @@ -195,7 +195,7 @@ class CollationRegexpExpressionsSuite case class RegExpExtractAllTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpExtractAllTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")), - RegExpExtractAllTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", Seq("BĆD")), + RegExpExtractAllTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", Seq("BĆD")), // scalastyle:ignore RegExpExtractAllTestCase("ABCDE", ".c.", "UNICODE", Seq()) ) testCases.foreach(t => { @@ -223,7 +223,7 @@ class CollationRegexpExpressionsSuite case class RegExpCountTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpCountTestCase("ABCDE", ".C.", "UTF8_BINARY", 1), - RegExpCountTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 1), + RegExpCountTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 1), // scalastyle:ignore RegExpCountTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { @@ -249,7 +249,7 @@ class CollationRegexpExpressionsSuite case class RegExpSubStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpSubStrTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), - RegExpSubStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), + RegExpSubStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), // scalastyle:ignore RegExpSubStrTestCase("ABCDE", ".c.", "UNICODE", null) ) testCases.foreach(t => { @@ -275,7 +275,7 @@ class CollationRegexpExpressionsSuite case class RegExpInStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpInStrTestCase("ABCDE", ".C.", "UTF8_BINARY", 2), - RegExpInStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 2), + RegExpInStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 2), // scalastyle:ignore RegExpInStrTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { From fe910f75a1ad8b1b3d720b6fdecd668c5a79c3a2 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 17 Apr 2024 08:48:55 +0200 Subject: [PATCH 06/16] Implement codegen --- .../sql/catalyst/util/CollationSupport.java | 17 +++++- .../expressions/regexpExpressions.scala | 57 +++++++++---------- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 17d77938ee8f4..869053eccea2e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -20,6 +20,8 @@ import org.apache.spark.unsafe.types.UTF8String; +import java.util.regex.Pattern; + /** * Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and * other expressions that require custom collation support), as well as private utility methods for @@ -143,11 +145,24 @@ public static boolean execICU(final UTF8String l, final UTF8String r, * Collation-aware regexp expressions. */ - private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); + public static boolean supportsLowercaseRegex(final int collationId) { + // for regex, only Unicode case-insensitive matching is possible, + // so UTF8_BINARY_LCASE is treated as UNICODE_CI in this context + return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality; + } + private static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; + public static int collationAwareRegexFlags(final int collationId) { + return supportsLowercaseRegex(collationId) ? lowercaseRegexFlags : 0; + } + + private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); public static UTF8String lowercaseRegex(final UTF8String regex) { return UTF8String.concat(lowercaseRegexPrefix, regex); } + public static UTF8String collationAwareRegex(final UTF8String regex, final int collationId) { + return supportsLowercaseRegex(collationId) ? lowercaseRegex(regex) : regex; + } /** * Other collation-aware expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 6c72614186222..e3aa863709154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} -import org.apache.spark.sql.catalyst.util.{CollationFactory, CollationSupport, GenericArrayData, StringUtils} +import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ @@ -62,11 +62,7 @@ abstract class StringRegexExpression extends BinaryExpression } else { // Let it raise exception if couldn't compile the regex string try { - var patternFlags: Int = 0 - if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { - patternFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE - } - Pattern.compile(escape(str), patternFlags) + Pattern.compile(escape(str), CollationSupport.collationAwareRegexFlags(collationId)) } catch { case e: PatternSyntaxException => throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern, e) @@ -166,7 +162,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) val pattern = ctx.addMutableState(patternClass, "patternLike", - v => s"""$v = $patternClass.compile("$regexStr");""") + v => + s"""$v = $patternClass.compile("$regexStr", + |CollationSupport.collationAwareRegexFlags($collationId));""".stripMargin) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -194,7 +192,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) s""" String $rightStr = $eval2.toString(); $patternClass $pattern = $patternClass.compile( - $escapeFunc($rightStr, '$escapedEscapeChar')); + $escapeFunc($rightStr, '$escapedEscapeChar'), + CollationSupport.collationAwareRegexFlags($collationId)); ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) @@ -484,7 +483,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = $patternClass.compile("$regexStr");""") + v => s"""$v = $patternClass.compile("$regexStr", + |CollationSupport.collationAwareRegexFlags($collationId));""".stripMargin) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -508,7 +508,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = $eval2.toString(); - $patternClass $pattern = $patternClass.compile($rightStr); + $patternClass $pattern = $patternClass.compile($rightStr, + CollationSupport.collationAwareRegexFlags($collationId)); ${ev.value} = $pattern.matcher($eval1.toString()).find(0); """ }) @@ -564,10 +565,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)) override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { - var pattern = regex.asInstanceOf[UTF8String] - if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { - pattern = CollationSupport.lowercaseRegex(pattern) - } + val pattern = CollationSupport.collationAwareRegex(regex.asInstanceOf[UTF8String], collationId) val strings = string.asInstanceOf[UTF8String].split(pattern, limit.asInstanceOf[Int]) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } @@ -576,7 +574,8 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + s"""${ev.value} = new $arrayClass($str.split( + |CollationSupport.collationAwareRegex($regex, $collationId),$limit));""".stripMargin }) } @@ -672,10 +671,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { - var regex: UTF8String = p.asInstanceOf[UTF8String] - if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { - regex = CollationSupport.lowercaseRegex(regex) - } + val regex = CollationSupport.collationAwareRegex(p.asInstanceOf[UTF8String], collationId) if (!regex.equals(lastRegex)) { val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName) pattern = patternAndRegex._1 @@ -728,7 +724,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, rep, pos) => { s""" - ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)} + ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} if (!$rep.equals($termLastReplacementInUTF8)) { // replacement string changed $termLastReplacementInUTF8 = $rep.clone(); @@ -800,10 +796,7 @@ abstract class RegExpExtractBase final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId protected def getLastMatcher(s: Any, p: Any): Matcher = { - var regex: UTF8String = p.asInstanceOf[UTF8String] - if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { - regex = CollationSupport.lowercaseRegex(regex) - } + val regex = CollationSupport.collationAwareRegex(p.asInstanceOf[UTF8String], collationId) if (regex != lastRegex) { // regex value changed val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName) @@ -890,7 +883,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)} + ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} if ($matcher.find()) { java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); $classNameRegExpExtractBase.checkGroupIndex("$prettyName", $matchResult.groupCount(), $idx); @@ -990,7 +983,7 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres } nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - | ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)} + | ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} | java.util.ArrayList $matchResults = new java.util.ArrayList(); | while ($matcher.find()) { | java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); @@ -1156,7 +1149,8 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) s""" |try { | $setEvNotNull - | ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName)} + | ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, + collationId)} | if ($matcher.find()) { | ${ev.value} = $matcher.toMatchResult().start() + 1; | } else { @@ -1180,16 +1174,19 @@ object RegExpUtils { subject: String, regexp: String, matcher: String, - prettyName: String): String = { + prettyName: String, + collationId: Int): String = { val classNamePattern = classOf[Pattern].getCanonicalName val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") val termPattern = ctx.addMutableState(classNamePattern, "pattern") + val collAwareRegexp = ctx.freshName("collAwareRegexp") s""" - |if (!$regexp.equals($termLastRegex)) { + |UTF8String $collAwareRegexp = CollationSupport.collationAwareRegex($regexp, $collationId); + |if (!$collAwareRegexp.equals($termLastRegex)) { | // regex value changed | try { - | UTF8String r = $regexp.clone(); + | UTF8String r = $collAwareRegexp.clone(); | $termPattern = $classNamePattern.compile(r.toString()); | $termLastRegex = r; | } catch (java.util.regex.PatternSyntaxException e) { From f0bcfe7899ee78e61456d8459bfb002ae65bd2d6 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 17 Apr 2024 09:34:10 +0200 Subject: [PATCH 07/16] scalastyle fix --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index e3aa863709154..75a975b9a6305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -983,7 +983,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres } nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - | ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} + | ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, + collationId)} | java.util.ArrayList $matchResults = new java.util.ArrayList(); | while ($matcher.find()) { | java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); From 011177046e9475d89ce8cab7ec207925fb295cae Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:09:11 +0200 Subject: [PATCH 08/16] MultiLikeBase collation support --- .../expressions/regexpExpressions.scala | 6 ++- .../CollationExpressionSuite.scala | 37 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 75a975b9a6305..9d524b2355a73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -281,7 +281,8 @@ sealed abstract class MultiLikeBase protected def isNotSpecified: Boolean - override def inputTypes: Seq[DataType] = StringType :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeBinaryLcase :: Nil + final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId override def nullable: Boolean = true @@ -290,7 +291,8 @@ sealed abstract class MultiLikeBase protected lazy val hasNull: Boolean = patterns.contains(null) protected lazy val cache = patterns.filterNot(_ == null) - .map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'))) + .map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'), + CollationSupport.collationAwareRegexFlags(collationId))) protected lazy val matchFunc = if (isNotSpecified) { (p: Pattern, inputValue: String) => !p.matcher(inputValue).matches() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index 537bac9aae9b4..69fa95e2b545b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ @@ -161,4 +162,40 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayExcept(left, right), out) } } + + test("MultiLikeBase regexp expressions with collated strings") { + // Supported collations (StringTypeBinaryLcase) + val binaryCollation = StringType(CollationFactory.collationNameToId("UTF8_BINARY")) + val lowercaseCollation = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")) + val unicodeCollation = StringType(CollationFactory.collationNameToId("UNICODE")) + // LikeAll + checkEvaluation(Literal.create("foo", binaryCollation).likeAll("%foo%", "%oo"), true) + checkEvaluation(Literal.create("foo", binaryCollation).likeAll("%foo%", "%bar%"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAll("%foo%", "%oo"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAll("%foo%", "%bar%"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%foo%", "%oo"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%foo%", "%bar%"), false) + // NotLikeAll + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAll("%foo%", "%oo"), false) + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAll("%goo%", "%bar%"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAll("%foo%", "%oo"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAll("%goo%", "%bar%"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%foo%", "%oo"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%goo%", "%bar%"), true) + // LikeAny + checkEvaluation(Literal.create("foo", binaryCollation).likeAny("%goo%", "%hoo"), false) + checkEvaluation(Literal.create("foo", binaryCollation).likeAny("%foo%", "%bar%"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAny("%goo%", "%hoo"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAny("%foo%", "%bar%"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%goo%", "%hoo"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%foo%", "%bar%"), true) + // NotLikeAny + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAny("%foo%", "%hoo"), true) + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAny("%foo%", "%oo%"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAny("%Foo%", "%hoo"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAny("%foo%", "%oo%"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%Foo%", "%hoo"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%foo%", "%oo%"), false) + } + } From 542fb40ad0201e5c9cfe847fe8c9319d83db34f4 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:55:58 +0200 Subject: [PATCH 09/16] Add sql tests for MultiLikeBase expressions --- .../sql/CollationRegexpExpressionsSuite.scala | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 0fcf1456290a4..3133b9b446696 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -82,6 +82,118 @@ class CollationRegexpExpressionsSuite }) } + test("Support LikeAll string expression with collation") { + // Supported collations + case class LikeAllTestCase[R](s: String, p: Seq[String], c: String, result: R) + val testCases = Seq( + LikeAllTestCase("foo", Seq("%foo%", "%oo"), "UTF8_BINARY", true), + LikeAllTestCase("Foo", Seq("%foo%", "%oo"), "UTF8_BINARY_LCASE", true), + LikeAllTestCase("foo", Seq("%foo%", "%bar%"), "UNICODE", false) + ) + testCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') LIKE ALL ('${t.p.mkString("','")}')" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + }) + // Unsupported collations + case class LikeAllTestFail(s: String, p: Seq[String], c: String) + val failCases = Seq( + LikeAllTestFail("Foo", Seq("%foo%", "%oo"), "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') LIKE ALL ('${t.p.mkString("','")}')" + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + } + + test("Support NotLikeAll string expression with collation") { + // Supported collations + case class NotLikeAllTestCase[R](s: String, p: Seq[String], c: String, result: R) + val testCases = Seq( + NotLikeAllTestCase("foo", Seq("%foo%", "%oo"), "UTF8_BINARY", false), + NotLikeAllTestCase("Foo", Seq("%foo%", "%oo"), "UTF8_BINARY_LCASE", false), + NotLikeAllTestCase("foo", Seq("%goo%", "%bar%"), "UNICODE", true) + ) + testCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') NOT LIKE ALL ('${t.p.mkString("','")}')" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + }) + // Unsupported collations + case class NotLikeAllTestFail(s: String, p: Seq[String], c: String) + val failCases = Seq( + NotLikeAllTestFail("Foo", Seq("%foo%", "%oo"), "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') NOT LIKE ALL ('${t.p.mkString("','")}')" + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + } + + test("Support LikeAny string expression with collation") { + // Supported collations + case class LikeAnyTestCase[R](s: String, p: Seq[String], c: String, result: R) + val testCases = Seq( + LikeAnyTestCase("foo", Seq("%foo%", "%bar"), "UTF8_BINARY", true), + LikeAnyTestCase("Foo", Seq("%foo%", "%bar"), "UTF8_BINARY_LCASE", true), + LikeAnyTestCase("foo", Seq("%goo%", "%hoo%"), "UNICODE", false) + ) + testCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') LIKE ANY ('${t.p.mkString("','")}')" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + }) + // Unsupported collations + case class LikeAnyTestFail(s: String, p: Seq[String], c: String) + val failCases = Seq( + LikeAnyTestFail("Foo", Seq("%foo%", "%oo"), "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') LIKE ANY ('${t.p.mkString("','")}')" + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + } + + test("Support NotLikeAny string expression with collation") { + // Supported collations + case class NotLikeAnyTestCase[R](s: String, p: Seq[String], c: String, result: R) + val testCases = Seq( + NotLikeAnyTestCase("foo", Seq("%foo%", "%hoo"), "UTF8_BINARY", true), + NotLikeAnyTestCase("Foo", Seq("%foo%", "%hoo"), "UTF8_BINARY_LCASE", true), + NotLikeAnyTestCase("foo", Seq("%foo%", "%oo%"), "UNICODE", false) + ) + testCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') NOT LIKE ANY ('${t.p.mkString("','")}')" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) + }) + // Unsupported collations + case class NotLikeAnyTestFail(s: String, p: Seq[String], c: String) + val failCases = Seq( + NotLikeAnyTestFail("Foo", Seq("%foo%", "%oo"), "UNICODE_CI") + ) + failCases.foreach(t => { + val query = s"SELECT collate('${t.s}', '${t.c}') NOT LIKE ANY ('${t.p.mkString("','")}')" + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } + assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + } + test("Support RLike string expression with collation") { // Supported collations case class RLikeTestCase[R](l: String, r: String, c: String, result: R) From 2a7ad8aae2b622f33e842311fcfbddd0ec00133c Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 18 Apr 2024 17:18:11 +0200 Subject: [PATCH 10/16] Implicit cast for RegExpReplace --- .../sql/catalyst/analysis/CollationTypeCasts.scala | 7 ++++++- .../spark/sql/CollationRegexpExpressionsSuite.scala | 12 +++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 1a14b4227de8f..d95bafd174a19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, RegExpReplace} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType} @@ -45,6 +45,11 @@ object CollationTypeCasts extends TypeCoercionRule { caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e)) CaseWhen(newBranches, newElseValue) + case regExpReplace: RegExpReplace => + val singleType = collateToSingleType(Seq(regExpReplace.subject, regExpReplace.rep)) + val newChildren = Seq(singleType.head, regExpReplace.regexp, singleType(1), regExpReplace.pos) + regExpReplace.withNewChildren(newChildren) + case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 3133b9b446696..367ea6eab8479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -256,11 +256,21 @@ class CollationRegexpExpressionsSuite ) testCases.foreach(t => { val query = - s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), '${t.r}', 'FFF')" + s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), '${t.r}', collate('FFF', '${t.c}'))" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + // Implicit casting + checkAnswer(sql(s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), '${t.r}', 'FFF')"), + Row(t.result)) + checkAnswer(sql(s"SELECT regexp_replace('${t.l}', '${t.r}', collate('FFF', '${t.c}'))"), + Row(t.result)) }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT regexp_replace(collate('ABCDE','UTF8_BINARY'), '.c.', collate('FFF','UNICODE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") // Unsupported collations case class RegExpReplaceTestFail(l: String, r: String, c: String) val failCases = Seq( From 599afef70f701d6cd1037e4bd76f53bb75f89dc4 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:29:25 +0200 Subject: [PATCH 11/16] Fix tests and casting --- .../analysis/CollationTypeCasts.scala | 4 +- .../CollationExpressionSuite.scala | 36 ----- .../CollationRegexpExpressionSuite.scala | 147 ++++++++++++++++++ .../sql/CollationRegexpExpressionsSuite.scala | 60 ++++--- 4 files changed, 189 insertions(+), 58 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 20b5dca416ccf..e5f88b791ca53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -53,8 +53,8 @@ object CollationTypeCasts extends TypeCoercionRule { ++ Seq(overlay.pos, overlay.len)) case regExpReplace: RegExpReplace => - val singleType = collateToSingleType(Seq(regExpReplace.subject, regExpReplace.rep)) - val newChildren = Seq(singleType.head, regExpReplace.regexp, singleType(1), regExpReplace.pos) + val Seq(subject, rep) = collateToSingleType(Seq(regExpReplace.subject, regExpReplace.rep)) + val newChildren = Seq(subject, regExpReplace.regexp, rep, regExpReplace.pos) regExpReplace.withNewChildren(newChildren) case otherExpr @ ( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index 69fa95e2b545b..f74f237d33da9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ @@ -163,39 +162,4 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("MultiLikeBase regexp expressions with collated strings") { - // Supported collations (StringTypeBinaryLcase) - val binaryCollation = StringType(CollationFactory.collationNameToId("UTF8_BINARY")) - val lowercaseCollation = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")) - val unicodeCollation = StringType(CollationFactory.collationNameToId("UNICODE")) - // LikeAll - checkEvaluation(Literal.create("foo", binaryCollation).likeAll("%foo%", "%oo"), true) - checkEvaluation(Literal.create("foo", binaryCollation).likeAll("%foo%", "%bar%"), false) - checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAll("%foo%", "%oo"), true) - checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAll("%foo%", "%bar%"), false) - checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%foo%", "%oo"), true) - checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%foo%", "%bar%"), false) - // NotLikeAll - checkEvaluation(Literal.create("foo", binaryCollation).notLikeAll("%foo%", "%oo"), false) - checkEvaluation(Literal.create("foo", binaryCollation).notLikeAll("%goo%", "%bar%"), true) - checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAll("%foo%", "%oo"), false) - checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAll("%goo%", "%bar%"), true) - checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%foo%", "%oo"), false) - checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%goo%", "%bar%"), true) - // LikeAny - checkEvaluation(Literal.create("foo", binaryCollation).likeAny("%goo%", "%hoo"), false) - checkEvaluation(Literal.create("foo", binaryCollation).likeAny("%foo%", "%bar%"), true) - checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAny("%goo%", "%hoo"), false) - checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAny("%foo%", "%bar%"), true) - checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%goo%", "%hoo"), false) - checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%foo%", "%bar%"), true) - // NotLikeAny - checkEvaluation(Literal.create("foo", binaryCollation).notLikeAny("%foo%", "%hoo"), true) - checkEvaluation(Literal.create("foo", binaryCollation).notLikeAny("%foo%", "%oo%"), false) - checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAny("%Foo%", "%hoo"), true) - checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAny("%foo%", "%oo%"), false) - checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%Foo%", "%hoo"), true) - checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%foo%", "%oo%"), false) - } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala new file mode 100644 index 0000000000000..200159281537c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types._ + +class CollationRegexpExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Like/ILike/RLike expressions with collated strings") { + case class LikeTestCase[R](l: String, regexLike: String, regexRLike: String, collation: String, + expectedLike: R, expectedILike: R, expectedRLike: R) + val testCases = Seq( + LikeTestCase("AbC", "%AbC%", ".b.", "UTF8_BINARY", true, true, true), + LikeTestCase("AbC", "%ABC%", ".B.", "UTF8_BINARY", false, true, false), + LikeTestCase("AbC", "%abc%", ".b.", "UTF8_BINARY_LCASE", true, true, true), + LikeTestCase("", "", "", "UTF8_BINARY_LCASE", true, true, true), + LikeTestCase("Foo", "", "", "UTF8_BINARY_LCASE", false, false, true), + LikeTestCase("", "%foo%", ".o.", "UTF8_BINARY_LCASE", false, false, false), + LikeTestCase("AbC", "%ABC%", ".B.", "UNICODE", false, true, false), + LikeTestCase(null, "%foo%", ".o.", "UNICODE", null, null, null), + LikeTestCase("Foo", null, null, "UNICODE", null, null, null), + LikeTestCase(null, null, null, "UNICODE", null, null, null) + ) + testCases.foreach(t => { + // Like + checkEvaluation(Like( + Literal.create(t.l, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.regexLike, StringType), '\\'), t.expectedLike) + // ILike + checkEvaluation(ILike( + Literal.create(t.l, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.regexLike, StringType), '\\').replacement, t.expectedILike) + // RLike + checkEvaluation(RLike( + Literal.create(t.l, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.regexRLike, StringType)), t.expectedRLike) + }) + } + + test("Regexp expressions with collated strings") { + case class RegexpTestCase[R](l: String, r: String, collation: String, + expectedExtract: R, expectedExtractAll: R, expectedCount: R) + val testCases = Seq( + RegexpTestCase("AbC-aBc", ".b.", "UTF8_BINARY", "AbC", Seq("AbC"), 1), + RegexpTestCase("AbC-abc", ".b.", "UTF8_BINARY", "AbC", Seq("AbC", "abc"), 2), + RegexpTestCase("AbC-aBc", ".b.", "UTF8_BINARY_LCASE", "AbC", Seq("AbC", "aBc"), 2), + RegexpTestCase("ABC-abc", ".b.", "UTF8_BINARY_LCASE", "ABC", Seq("ABC", "abc"), 2), + RegexpTestCase("", "", "UTF8_BINARY_LCASE", "", Seq(""), 1), + RegexpTestCase("Foo", "", "UTF8_BINARY_LCASE", "", Seq("", "", "", ""), 4), + RegexpTestCase("", ".o.", "UTF8_BINARY_LCASE", "", Seq(), 0), + RegexpTestCase("Foo", ".O.", "UNICODE", "", Seq(), 0), + RegexpTestCase(null, ".O.", "UNICODE", null, null, null), + RegexpTestCase("Foo", null, "UNICODE", null, null, null), + RegexpTestCase(null, null, "UNICODE", null, null, null) + ) + testCases.foreach(t => { + // RegExpExtract + checkEvaluation(RegExpExtract( + Literal.create(t.l, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.r, StringType), 0), t.expectedExtract) + // RegExpExtractAll + checkEvaluation(RegExpExtractAll( + Literal.create(t.l, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.r, StringType), 0), t.expectedExtractAll) + // RegExpCount + checkEvaluation(RegExpCount( + Literal.create(t.l, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.r, StringType)).replacement, t.expectedCount) + // RegExpInStr + def expectedInStr(count: Any): Any = count match { + case null => null + case 0 => 0 + case n: Int if n >= 1 => 1 + } + checkEvaluation(RegExpInStr( + Literal.create(t.l, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.r, StringType), 0), expectedInStr(t.expectedCount)) + }) + } + + test("MultiLikeBase regexp expressions with collated strings") { + val nullStr = Literal.create(null, StringType) + // Supported collations (StringTypeBinaryLcase) + val binaryCollation = StringType(CollationFactory.collationNameToId("UTF8_BINARY")) + val lowercaseCollation = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")) + val unicodeCollation = StringType(CollationFactory.collationNameToId("UNICODE")) + // LikeAll + checkEvaluation(Literal.create("foo", binaryCollation).likeAll("%foo%", "%oo"), true) + checkEvaluation(Literal.create("foo", binaryCollation).likeAll("%foo%", "%bar%"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAll("%foo%", "%oo"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAll("%foo%", "%bar%"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%foo%", "%oo"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%foo%", "%bar%"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%foo%", nullStr), null) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAll("%feo%", nullStr), false) + checkEvaluation(Literal.create(null, unicodeCollation).likeAll("%foo%", "%oo"), null) + // NotLikeAll + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAll("%foo%", "%oo"), false) + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAll("%goo%", "%bar%"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAll("%foo%", "%oo"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAll("%goo%", "%bar%"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%foo%", "%oo"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%goo%", "%bar%"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%foo%", nullStr), false) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAll("%feo%", nullStr), null) + checkEvaluation(Literal.create(null, unicodeCollation).notLikeAll("%foo%", "%oo"), null) + // LikeAny + checkEvaluation(Literal.create("foo", binaryCollation).likeAny("%goo%", "%hoo"), false) + checkEvaluation(Literal.create("foo", binaryCollation).likeAny("%foo%", "%bar%"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAny("%goo%", "%hoo"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).likeAny("%foo%", "%bar%"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%goo%", "%hoo"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%foo%", "%bar%"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%foo%", nullStr), true) + checkEvaluation(Literal.create("foo", unicodeCollation).likeAny("%feo%", nullStr), null) + checkEvaluation(Literal.create(null, unicodeCollation).likeAny("%foo%", "%oo"), null) + // NotLikeAny + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAny("%foo%", "%hoo"), true) + checkEvaluation(Literal.create("foo", binaryCollation).notLikeAny("%foo%", "%oo%"), false) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAny("%Foo%", "%hoo"), true) + checkEvaluation(Literal.create("Foo", lowercaseCollation).notLikeAny("%foo%", "%oo%"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%Foo%", "%hoo"), true) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%foo%", "%oo%"), false) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%foo%", nullStr), null) + checkEvaluation(Literal.create("foo", unicodeCollation).notLikeAny("%feo%", nullStr), true) + checkEvaluation(Literal.create(null, unicodeCollation).notLikeAny("%foo%", "%oo"), null) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 367ea6eab8479..e56ff7c801ec1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType, StringType} +// scalastyle:off nonascii class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession @@ -35,7 +36,7 @@ class CollationRegexpExpressionsSuite case class LikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( LikeTestCase("ABC", "%B%", "UTF8_BINARY", true), - LikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), // scalastyle:ignore + LikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), LikeTestCase("ABC", "%b%", "UNICODE", false) ) testCases.foreach(t => { @@ -51,7 +52,9 @@ class CollationRegexpExpressionsSuite ) failCases.foreach(t => { val query = s"SELECT like(collate('${t.l}', '${t.c}'), '${t.r}')" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -61,11 +64,11 @@ class CollationRegexpExpressionsSuite case class ILikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( ILikeTestCase("ABC", "%b%", "UTF8_BINARY", true), - ILikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), // scalastyle:ignore + ILikeTestCase("AḂC", "%ḃ%", "UTF8_BINARY_LCASE", true), ILikeTestCase("ABC", "%b%", "UNICODE", true) ) testCases.foreach(t => { - val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), '${t.r}')" // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) @@ -77,7 +80,9 @@ class CollationRegexpExpressionsSuite ) failCases.foreach(t => { val query = s"SELECT ilike(collate('${t.l}', '${t.c}'), '${t.r}')" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -199,7 +204,7 @@ class CollationRegexpExpressionsSuite case class RLikeTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RLikeTestCase("ABC", ".B.", "UTF8_BINARY", true), - RLikeTestCase("AḂC", ".ḃ.", "UTF8_BINARY_LCASE", true), // scalastyle:ignore + RLikeTestCase("AḂC", ".ḃ.", "UTF8_BINARY_LCASE", true), RLikeTestCase("ABC", ".b.", "UNICODE", false) ) testCases.foreach(t => { @@ -215,7 +220,9 @@ class CollationRegexpExpressionsSuite ) failCases.foreach(t => { val query = s"SELECT rlike(collate('${t.l}', '${t.c}'), '${t.r}')" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -225,7 +232,7 @@ class CollationRegexpExpressionsSuite case class StringSplitTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")), - StringSplitTestCase("AḂC", "[ḃ]", "UTF8_BINARY_LCASE", Seq("A", "C")), // scalastyle:ignore + StringSplitTestCase("AḂC", "[ḃ]", "UTF8_BINARY_LCASE", Seq("A", "C")), StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")) ) testCases.foreach(t => { @@ -241,7 +248,9 @@ class CollationRegexpExpressionsSuite ) failCases.foreach(t => { val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -251,7 +260,7 @@ class CollationRegexpExpressionsSuite case class RegExpReplaceTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpReplaceTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE"), - RegExpReplaceTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "AFFFE"), // scalastyle:ignore + RegExpReplaceTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "AFFFE"), RegExpReplaceTestCase("ABCDE", ".c.", "UNICODE", "ABCDE") ) testCases.foreach(t => { @@ -279,7 +288,9 @@ class CollationRegexpExpressionsSuite failCases.foreach(t => { val query = s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), '${t.r}', 'FFF')" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -289,7 +300,7 @@ class CollationRegexpExpressionsSuite case class RegExpExtractTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpExtractTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), - RegExpExtractTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), // scalastyle:ignore + RegExpExtractTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), RegExpExtractTestCase("ABCDE", ".c.", "UNICODE", "") ) testCases.foreach(t => { @@ -307,7 +318,9 @@ class CollationRegexpExpressionsSuite failCases.foreach(t => { val query = s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), '${t.r}', 0)" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -317,7 +330,7 @@ class CollationRegexpExpressionsSuite case class RegExpExtractAllTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpExtractAllTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD")), - RegExpExtractAllTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", Seq("BĆD")), // scalastyle:ignore + RegExpExtractAllTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", Seq("BĆD")), RegExpExtractAllTestCase("ABCDE", ".c.", "UNICODE", Seq()) ) testCases.foreach(t => { @@ -335,7 +348,9 @@ class CollationRegexpExpressionsSuite failCases.foreach(t => { val query = s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'), '${t.r}', 0)" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -345,7 +360,7 @@ class CollationRegexpExpressionsSuite case class RegExpCountTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpCountTestCase("ABCDE", ".C.", "UTF8_BINARY", 1), - RegExpCountTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 1), // scalastyle:ignore + RegExpCountTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 1), RegExpCountTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { @@ -361,7 +376,9 @@ class CollationRegexpExpressionsSuite ) failCases.foreach(t => { val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'), '${t.r}')" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -371,7 +388,7 @@ class CollationRegexpExpressionsSuite case class RegExpSubStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpSubStrTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD"), - RegExpSubStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), // scalastyle:ignore + RegExpSubStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", "BĆD"), RegExpSubStrTestCase("ABCDE", ".c.", "UNICODE", null) ) testCases.foreach(t => { @@ -387,7 +404,9 @@ class CollationRegexpExpressionsSuite ) failCases.foreach(t => { val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'), '${t.r}')" - val unsupportedCollation = intercept[AnalysisException] { sql(query) } + val unsupportedCollation = intercept[AnalysisException] { + sql(query) + } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) } @@ -397,7 +416,7 @@ class CollationRegexpExpressionsSuite case class RegExpInStrTestCase[R](l: String, r: String, c: String, result: R) val testCases = Seq( RegExpInStrTestCase("ABCDE", ".C.", "UTF8_BINARY", 2), - RegExpInStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 2), // scalastyle:ignore + RegExpInStrTestCase("ABĆDE", ".ć.", "UTF8_BINARY_LCASE", 2), RegExpInStrTestCase("ABCDE", ".c.", "UNICODE", 0) ) testCases.foreach(t => { @@ -421,6 +440,7 @@ class CollationRegexpExpressionsSuite } } +// scalastyle:on nonascii class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { override protected def sparkConf: SparkConf = From f7a271041fa5e5b42d8e5a47ed2f86fbf5cf11d5 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:30:33 +0200 Subject: [PATCH 12/16] Remove unwanted changes --- .../sql/catalyst/expressions/CollationExpressionSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index f74f237d33da9..537bac9aae9b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -161,5 +161,4 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayExcept(left, right), out) } } - } From b05bf3f307723fcc4dcdd8bb991e94607e7deabb Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 19 Apr 2024 14:20:28 +0200 Subject: [PATCH 13/16] Fix regex compilation and add tests --- .../expressions/regexpExpressions.scala | 24 +++++++++---------- .../CollationRegexpExpressionSuite.scala | 23 ++++++++++++++++++ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9d524b2355a73..63715a22f602b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -673,9 +673,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { - val regex = CollationSupport.collationAwareRegex(p.asInstanceOf[UTF8String], collationId) - if (!regex.equals(lastRegex)) { - val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName) + if (!p.equals(lastRegex)) { + val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName, collationId) pattern = patternAndRegex._1 lastRegex = patternAndRegex._2 } @@ -798,10 +797,9 @@ abstract class RegExpExtractBase final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId protected def getLastMatcher(s: Any, p: Any): Matcher = { - val regex = CollationSupport.collationAwareRegex(p.asInstanceOf[UTF8String], collationId) - if (regex != lastRegex) { + if (p != lastRegex) { // regex value changed - val patternAndRegex = RegExpUtils.getPatternAndLastRegex(regex, prettyName) + val patternAndRegex = RegExpUtils.getPatternAndLastRegex(p, prettyName, collationId) pattern = patternAndRegex._1 lastRegex = patternAndRegex._2 } @@ -1182,15 +1180,14 @@ object RegExpUtils { val classNamePattern = classOf[Pattern].getCanonicalName val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") val termPattern = ctx.addMutableState(classNamePattern, "pattern") - val collAwareRegexp = ctx.freshName("collAwareRegexp") s""" - |UTF8String $collAwareRegexp = CollationSupport.collationAwareRegex($regexp, $collationId); - |if (!$collAwareRegexp.equals($termLastRegex)) { + |if (!$regexp.equals($termLastRegex)) { | // regex value changed | try { - | UTF8String r = $collAwareRegexp.clone(); - | $termPattern = $classNamePattern.compile(r.toString()); + | UTF8String r = $regexp.clone(); + | $termPattern = $classNamePattern.compile(r.toString(), + | CollationSupport.collationAwareRegexFlags($collationId)); | $termLastRegex = r; | } catch (java.util.regex.PatternSyntaxException e) { | throw QueryExecutionErrors.invalidPatternError("$prettyName", e.getPattern(), e); @@ -1200,10 +1197,11 @@ object RegExpUtils { |""".stripMargin } - def getPatternAndLastRegex(p: Any, prettyName: String): (Pattern, UTF8String) = { + def getPatternAndLastRegex(p: Any, prettyName: String, collationId: Int): (Pattern, UTF8String) = + { val r = p.asInstanceOf[UTF8String].clone() val pattern = try { - Pattern.compile(r.toString) + Pattern.compile(r.toString, CollationSupport.collationAwareRegexFlags(collationId)) } catch { case e: PatternSyntaxException => throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern, e) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala index 200159281537c..f53dd77caade8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala @@ -55,6 +55,29 @@ class CollationRegexpExpressionSuite extends SparkFunSuite with ExpressionEvalHe }) } + test("StringSplit expression with collated strings") { + case class StringSplitTestCase[R](s: String, r: String, collation: String, expected: R) + val testCases = Seq( + StringSplitTestCase("1A2B3C", "[ABC]", "UTF8_BINARY", Seq("1", "2", "3", "")), + StringSplitTestCase("1A2B3C", "[abc]", "UTF8_BINARY", Seq("1A2B3C")), + StringSplitTestCase("1A2B3C", "[ABC]", "UTF8_BINARY_LCASE", Seq("1", "2", "3", "")), + StringSplitTestCase("1A2B3C", "[abc]", "UTF8_BINARY_LCASE", Seq("1", "2", "3", "")), + StringSplitTestCase("1A2B3C", "[1-9]+", "UNICODE", Seq("", "A", "B", "C")), + StringSplitTestCase("", "", "UNICODE", Seq("")), + StringSplitTestCase("1A2B3C", "", "UNICODE", Seq("1", "A", "2", "B", "3", "C")), + StringSplitTestCase("", "[1-9]+", "UNICODE", Seq("")), + StringSplitTestCase(null, "[1-9]+", "UNICODE", null), + StringSplitTestCase("1A2B3C", null, "UNICODE", null), + StringSplitTestCase(null, null, "UNICODE", null) + ) + testCases.foreach(t => { + // StringSplit + checkEvaluation(StringSplit( + Literal.create(t.s, StringType(CollationFactory.collationNameToId(t.collation))), + Literal.create(t.r, StringType), -1), t.expected) + }) + } + test("Regexp expressions with collated strings") { case class RegexpTestCase[R](l: String, r: String, collation: String, expectedExtract: R, expectedExtractAll: R, expectedCount: R) From 3c9f208f07ee20fc42564344b17f14d7a928ec3f Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:29:13 +0200 Subject: [PATCH 14/16] Fix style --- .../catalyst/expressions/CollationRegexpExpressionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala index f53dd77caade8..f45d85dabcbff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala @@ -26,7 +26,7 @@ class CollationRegexpExpressionSuite extends SparkFunSuite with ExpressionEvalHe test("Like/ILike/RLike expressions with collated strings") { case class LikeTestCase[R](l: String, regexLike: String, regexRLike: String, collation: String, - expectedLike: R, expectedILike: R, expectedRLike: R) + expectedLike: R, expectedILike: R, expectedRLike: R) val testCases = Seq( LikeTestCase("AbC", "%AbC%", ".b.", "UTF8_BINARY", true, true, true), LikeTestCase("AbC", "%ABC%", ".B.", "UTF8_BINARY", false, true, false), @@ -80,7 +80,7 @@ class CollationRegexpExpressionSuite extends SparkFunSuite with ExpressionEvalHe test("Regexp expressions with collated strings") { case class RegexpTestCase[R](l: String, r: String, collation: String, - expectedExtract: R, expectedExtractAll: R, expectedCount: R) + expectedExtract: R, expectedExtractAll: R, expectedCount: R) val testCases = Seq( RegexpTestCase("AbC-aBc", ".b.", "UTF8_BINARY", "AbC", Seq("AbC"), 1), RegexpTestCase("AbC-abc", ".b.", "UTF8_BINARY", "AbC", Seq("AbC", "abc"), 2), From da474b0af37ec5888ec6acdff2afa692846b2d81 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:28:35 +0200 Subject: [PATCH 15/16] Rename test files --- ...e.scala => CollationRegexpExpressionsSuite.scala} | 2 +- ...ionsSuite.scala => CollationSQLRegexpSuite.scala} | 12 +----------- 2 files changed, 2 insertions(+), 12 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/{CollationRegexpExpressionSuite.scala => CollationRegexpExpressionsSuite.scala} (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{CollationRegexpExpressionsSuite.scala => CollationSQLRegexpSuite.scala} (97%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionsSuite.scala similarity index 99% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionsSuite.scala index f45d85dabcbff..cc50aebf589e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationRegexpExpressionsSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ -class CollationRegexpExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { +class CollationRegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Like/ILike/RLike expressions with collated strings") { case class LikeTestCase[R](l: String, regexLike: String, regexRLike: String, collation: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index c519efa93081b..739b000492c55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType, StringType} // scalastyle:off nonascii -class CollationRegexpExpressionsSuite +class CollationSQLRegexpSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper { @@ -439,11 +437,3 @@ class CollationRegexpExpressionsSuite } // scalastyle:on nonascii - -class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { - override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.ANSI_ENABLED, true) - - // TODO: If needed, add more tests for other regexp expressions (with ANSI mode enabled) - -} From 05b3bd480b1d444e00d8d8358eeaada927a48539 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 24 Apr 2024 13:43:53 +0200 Subject: [PATCH 16/16] Regex flags as static field --- .../expressions/regexpExpressions.scala | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 63715a22f602b..297c709c6d7d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -49,6 +49,7 @@ abstract class StringRegexExpression extends BinaryExpression Seq(StringTypeBinaryLcase, StringTypeAnyCollation) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId) // try cache foldable pattern private lazy val cache: Pattern = right match { @@ -62,7 +63,7 @@ abstract class StringRegexExpression extends BinaryExpression } else { // Let it raise exception if couldn't compile the regex string try { - Pattern.compile(escape(str), CollationSupport.collationAwareRegexFlags(collationId)) + Pattern.compile(escape(str), collationRegexFlags) } catch { case e: PatternSyntaxException => throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern, e) @@ -163,8 +164,7 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) val pattern = ctx.addMutableState(patternClass, "patternLike", v => - s"""$v = $patternClass.compile("$regexStr", - |CollationSupport.collationAwareRegexFlags($collationId));""".stripMargin) + s"""$v = $patternClass.compile("$regexStr", $collationRegexFlags);""".stripMargin) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -192,8 +192,7 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) s""" String $rightStr = $eval2.toString(); $patternClass $pattern = $patternClass.compile( - $escapeFunc($rightStr, '$escapedEscapeChar'), - CollationSupport.collationAwareRegexFlags($collationId)); + $escapeFunc($rightStr, '$escapedEscapeChar'), $collationRegexFlags); ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) @@ -283,6 +282,7 @@ sealed abstract class MultiLikeBase override def inputTypes: Seq[AbstractDataType] = StringTypeBinaryLcase :: Nil final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId + final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId) override def nullable: Boolean = true @@ -290,9 +290,8 @@ sealed abstract class MultiLikeBase protected lazy val hasNull: Boolean = patterns.contains(null) - protected lazy val cache = patterns.filterNot(_ == null) - .map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'), - CollationSupport.collationAwareRegexFlags(collationId))) + protected lazy val cache = patterns.filterNot(_ == null).map(s => + Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'), collationRegexFlags)) protected lazy val matchFunc = if (isNotSpecified) { (p: Pattern, inputValue: String) => !p.matcher(inputValue).matches() @@ -485,8 +484,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = $patternClass.compile("$regexStr", - |CollationSupport.collationAwareRegexFlags($collationId));""".stripMargin) + v => s"""$v = $patternClass.compile("$regexStr", $collationRegexFlags);""".stripMargin) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -510,8 +508,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = $eval2.toString(); - $patternClass $pattern = $patternClass.compile($rightStr, - CollationSupport.collationAwareRegexFlags($collationId)); + $patternClass $pattern = $patternClass.compile($rightStr, $collationRegexFlags); ${ev.value} = $pattern.matcher($eval1.toString()).find(0); """ }) @@ -1180,14 +1177,14 @@ object RegExpUtils { val classNamePattern = classOf[Pattern].getCanonicalName val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") val termPattern = ctx.addMutableState(classNamePattern, "pattern") + val collationRegexFlags = CollationSupport.collationAwareRegexFlags(collationId) s""" |if (!$regexp.equals($termLastRegex)) { | // regex value changed | try { | UTF8String r = $regexp.clone(); - | $termPattern = $classNamePattern.compile(r.toString(), - | CollationSupport.collationAwareRegexFlags($collationId)); + | $termPattern = $classNamePattern.compile(r.toString(), $collationRegexFlags); | $termLastRegex = r; | } catch (java.util.regex.PatternSyntaxException e) { | throw QueryExecutionErrors.invalidPatternError("$prettyName", e.getPattern(), e);