diff --git a/docs/compatibility.md b/docs/compatibility.md index 5b925aa36c3..4aaa4db1b17 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -473,7 +473,7 @@ Here are some examples of regular expression patterns that are not supported on - Empty groups: `()` - Regular expressions containing null characters (unless the pattern is a simple literal string) - Beginning-of-line and end-of-line anchors (`^` and `$`) are not supported in some contexts, such as when combined - with a choice (`^|a`) or when used anywhere in `regexp_replace` patterns. +- with a choice (`^|a`). In addition to these cases that can be detected, there are also known issues that can cause incorrect results: diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 7a7843b6357..65f9042b9c2 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -336,6 +336,13 @@ def test_re_replace(): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( 'REGEXP_REPLACE(a, "TEST", "PROD")', + 'REGEXP_REPLACE(a, "^TEST", "PROD")', + 'REGEXP_REPLACE(a, "^TEST$", "PROD")', + 'REGEXP_REPLACE(a, "TEST$", "PROD")', + 'REGEXP_REPLACE(a, "$TEST", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\$", "PROD")', + 'REGEXP_REPLACE(a, "\\^TEST$", "PROD")', + 'REGEXP_REPLACE(a, "\\^TEST\\$", "PROD")', 'REGEXP_REPLACE(a, "TEST", "")', 'REGEXP_REPLACE(a, "TEST", "%^[]\ud720")', 'REGEXP_REPLACE(a, "TEST", NULL)'), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 43ac51cb56d..4a71ce1d51e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -167,7 +167,7 @@ class RegexParser(pattern: String) { peek() match { case None => throw new RegexUnsupportedException( - s"unexpected EOF while parsing escaped character", Some(pos)) + s"Unclosed character class", Some(pos)) case Some(ch) => // typically an escaped metacharacter ('\\', '^', '-', ']', '+') // within the character class, but could be any escaped character @@ -203,7 +203,7 @@ class RegexParser(pattern: String) { } if (!characterClassComplete) { throw new RegexUnsupportedException( - s"unexpected EOF while parsing character class", Some(pos)) + s"Unclosed character class", Some(pos)) } characterClass } @@ -440,10 +440,6 @@ class CudfRegexTranspiler(replace: Boolean) { case '.' => // workaround for https://github.com/rapidsai/cudf/issues/9619 RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n'))) - case '^' | '$' if replace => - // this is a bit extreme and it would be good to replace with finer-grained - // rules - throw new RegexUnsupportedException("regexp_replace on GPU does not support ^ or $") case '$' => RegexSequence(ListBuffer( RegexRepetition( @@ -552,9 +548,21 @@ class CudfRegexTranspiler(replace: Boolean) { // falling back to CPU throw new RegexUnsupportedException(nothingToRepeat) } - if (replace && parts.length == 1 && (isRegexChar(parts.head, '^') - || isRegexChar(parts.head, '$'))) { - throw new RegexUnsupportedException("regexp_replace on GPU does not support ^ or $") + def isBeginOrEndLineAnchor(regex: RegexAST): Boolean = regex match { + case RegexSequence(parts) => parts.nonEmpty && parts.forall(isBeginOrEndLineAnchor) + case RegexGroup(_, term) => isBeginOrEndLineAnchor(term) + case RegexChoice(l, r) => isBeginOrEndLineAnchor(l) && isBeginOrEndLineAnchor(r) + case RegexRepetition(term, _) => isBeginOrEndLineAnchor(term) + case RegexChar(ch) => ch == '^' || ch == '$' + case _ => false + } + if (parts.forall(isBeginOrEndLineAnchor)) { + throw new RegexUnsupportedException( + "sequences that only contain '^' or '$' are not supported") + } + if (isRegexChar(parts.last, '^')) { + throw new RegexUnsupportedException( + "sequences that end with '^' are not supported") } RegexSequence(parts.map(rewrite)) @@ -773,7 +781,7 @@ class RegexUnsupportedException(message: String, index: Option[Int] = None) extends SQLException { override def getMessage: String = { index match { - case Some(i) => s"$message at index $index" + case Some(i) => s"$message near index $i" case _ => message } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 4836c8f15cb..2d31835cf55 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -94,6 +94,13 @@ class RegularExpressionParserSuite extends FunSuite { ListBuffer(RegexChar('a'))), RegexChar(']')))) } + test("unclosed character class") { + val e = intercept[RegexUnsupportedException] { + parse("[ab") + } + assert(e.getMessage === "Unclosed character class near index 3") + } + test("hex digit") { assert(parse(raw"\xFF") === RegexSequence(ListBuffer(RegexHexDigit("FF")))) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala index db3a1791317..d1dd5e67a26 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala @@ -70,17 +70,13 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { frame => frame.selectExpr("regexp_replace(strings,'[a-z]+','D')") } - testGpuFallback("String regexp_replace regex 3 cpu fall back", - "RegExpReplace", - nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { + testSparkResultsAreEqual("String regexp_replace regex 3", + nullableStringsFromCsv, conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'foo$','D')") } - testGpuFallback("String regexp_replace regex 4 cpu fall back", - "RegExpReplace", - nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { + testSparkResultsAreEqual("String regexp_replace regex 4", + nullableStringsFromCsv, conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'^foo','D')") } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 73295d02ecb..3ffa7f804ac 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -36,7 +36,10 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "a*+", "\t+|a", "(\t+|a)Dc$1", - "(?d)" + "(?d)", + "$|$[^\n]2]}|B", + "a^|b", + "w$|b" ) // data is not relevant because we are checking for compilation errors val inputs = Seq("a") @@ -70,7 +73,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support choice with nothing to repeat") { val patterns = Seq("b+|^\t") patterns.foreach(pattern => - assertUnsupported(pattern, "nothing to repeat") + assertUnsupported(pattern, replace = false, "nothing to repeat") ) } @@ -94,14 +97,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support possessive quantifier") { val patterns = Seq("a*+", "a|(a?|a*+)") patterns.foreach(pattern => - assertUnsupported(pattern, "nothing to repeat") + assertUnsupported(pattern, replace = false, "nothing to repeat") ) } test("cuDF does not support empty sequence") { val patterns = Seq("", "a|", "()") patterns.foreach(pattern => - assertUnsupported(pattern, "empty sequence not supported") + assertUnsupported(pattern, replace = false, "empty sequence not supported") ) } @@ -109,27 +112,23 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // note that we could choose to transpile and escape the '{' and '}' characters val patterns = Seq("{1,2}", "{1,}", "{1}", "{2,1}") patterns.foreach(pattern => - assertUnsupported(pattern, "nothing to repeat") + assertUnsupported(pattern, replace = false, "nothing to repeat") ) } test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") patterns.foreach(pattern => { - assertUnsupported(pattern, "nothing to repeat") + assertUnsupported(pattern, replace = false, + "nothing to repeat") }) } test("cuDF does not support null in pattern") { val patterns = Seq("\u0000", "a\u0000b", "a(\u0000)b", "a[a-b][\u0000]") patterns.foreach(pattern => - assertUnsupported(pattern, "cuDF does not support null characters in regular expressions")) - } - - test("nothing to repeat") { - val patterns = Seq("$*", "^+") - patterns.foreach(pattern => - assertUnsupported(pattern, "nothing to repeat")) + assertUnsupported(pattern, replace = false, + "cuDF does not support null characters in regular expressions")) } test("end of line anchor with strings ending in valid newline") { @@ -248,6 +247,24 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpReplace(patterns, inputs) } + test("compare CPU and GPU: regexp replace BOL / EOL supported use cases") { + val inputs = Seq("a", "b", "c", "cat", "", "^", "$", "^a", "t$") + val patterns = Seq("^a", "a$", "^a$", "(^a|t$)", "(^a)|(t$)", "^[ac]$", "^^^a$$$", + "[\\^\\$]") + assertCpuGpuMatchesRegexpReplace(patterns, inputs) + } + + test("cuDF does not support some uses of BOL/EOL in regexp_replace") { + Seq("^$", "^", "$", "(^)($)", "(((^^^)))$", "^*", "$*", "^+", "$+").foreach(pattern => + assertUnsupported(pattern, replace = true, + "sequences that only contain '^' or '$' are not supported") + ) + Seq("^|$", "^^|$$").foreach(pattern => + assertUnsupported(pattern, replace = true, + "nothing to repeat") + ) + } + test("compare CPU and GPU: regexp replace negated character class") { val inputs = Seq("a", "b", "a\nb", "a\r\nb\n\rc\rd") val patterns = Seq("[^z]", "[^\r]", "[^\n]", "[^\r]", "[^\r\n]", @@ -405,11 +422,11 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { new CudfRegexTranspiler(replace).transpile(pattern) } - private def assertUnsupported(pattern: String, message: String): Unit = { + private def assertUnsupported(pattern: String, replace: Boolean, message: String): Unit = { val e = intercept[RegexUnsupportedException] { - transpile(pattern, replace = false) + transpile(pattern, replace) } - assert(e.getMessage.startsWith(message)) + assert(e.getMessage.startsWith(message), pattern) } private def parse(pattern: String): RegexAST = new RegexParser(pattern).parse()