From 91cfd3375f4603e94944952f2b6bc2b8c5d4468e Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 25 Mar 2014 15:20:43 +0800 Subject: [PATCH] add implementation for rlike/like Conflicts: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala --- .../apache/spark/sql/catalyst/SqlParser.scala | 4 + .../expressions/stringOperations.scala | 110 +++++++++++++++++- .../ExpressionEvaluationSuite.scala | 97 +++++++++++++-- .../org/apache/spark/sql/hive/HiveQl.scala | 6 +- 4 files changed, 200 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 9dec4e3d9e4c2..83b836f94f7cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers { protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") protected val OR = Keyword("OR") + protected val LIKE = Keyword("LIKE") + protected val RLIKE = Keyword("RLIKE") protected val ORDER = Keyword("ORDER") protected val OUTER = Keyword("OUTER") protected val RIGHT = Keyword("RIGHT") @@ -267,6 +269,8 @@ class SqlParser extends StandardTokenParsers { termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | + termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { case e1 ~ _ ~ _ ~ e2 => In(e1, e2) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 7584fe03cf745..04e066237e2ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,10 +20,114 @@ package catalyst package expressions import org.apache.spark.sql.catalyst.types.BooleanType +import java.util.regex.Pattern -case class Like(left: Expression, right: Expression) extends BinaryExpression { - def dataType = BooleanType - def nullable = left.nullable // Right cannot be null. +import catalyst.types.StringType +import catalyst.types.BooleanType +import catalyst.trees.TreeNode + +import catalyst.errors.`package`.TreeNodeException +import org.apache.spark.sql.catalyst.types.DataType + + +/** + * Thrown when an invalid RegEx string is found. + */ +class InvalidRegExException[TreeType <: TreeNode[_]](tree: TreeType, reason: String) extends + errors.TreeNodeException(tree, s"$reason", null) + +trait StringRegexExpression { + self: BinaryExpression => + + type EvaluatedType = Any + + def escape(v: String): String + def nullable: Boolean = true + def dataType: DataType = BooleanType + + // try cache the pattern for Literal + private lazy val cache: Pattern = right match { + case x @ Literal(value: String, StringType) => compile(value) + case _ => null + } + + protected def compile(str: Any): Pattern = str match { + // TODO or let it be null if couldn't compile the regex? + case x: String if(x != null) => Pattern.compile(escape(x)) + case x: String => null + case _ => throw new InvalidRegExException(this, "$str can not be compiled to regex pattern") + } + + protected def pattern(str: String) = if(cache == null) compile(str) else cache + + protected def filter: PartialFunction[(Row, (String, String)), Any] = { + case (row, (null, r)) => { false } + case (row, (l, null)) => { false } + case (row, (l, r)) => { + val regex = pattern(r) + if(regex == null) { + null + } else { + regex.matcher(l).matches + } + } + } + + override def apply(input: Row): Any = { + val l = left.apply(input) + if(l == null) { + null + } else { + val r = right.apply(input) + if(r == null) { + null + } else { + filter.lift(input, (l.asInstanceOf[String], r.asInstanceOf[String])).get + } + } + } +} + +/** + * Simple RegEx pattern matching function + */ +case class Like(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression { + def symbol = "LIKE" + + // replace the _ with .{1} exactly match 1 time of any character + // replace the % with .*, match 0 or more times with any character + override def escape(v: String) = { + val sb = new StringBuilder() + var i = 0; + while (i < v.length) { + // Make a special case for "\\_" and "\\%" + val n = v.charAt(i); + if (n == '\\' && i + 1 < v.length && (v.charAt(i + 1) == '_' || v.charAt(i + 1) == '%')) { + sb.append(v.charAt(i + 1)) + i += 1 + } else { + if (n == '_') { + sb.append("."); + } else if (n == '%') { + sb.append(".*"); + } else { + sb.append(Pattern.quote(Character.toString(n))); + } + } + + i += 1 + } + + sb.toString() + } } +case class RLike(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression { + + def symbol = "RLIKE" + + override def escape(v: String) = v +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 0a684855fa332..6a108b1639463 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -26,6 +26,11 @@ import org.apache.spark.sql.catalyst.types._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import types._ +import expressions._ +import dsl._ +import dsl.expressions._ + /** * Root class of expression evaluation test @@ -50,7 +55,7 @@ trait ExpressionEvaluationSuite extends FunSuite { */ def executor(exprs: Seq[Expression]): ExprEvalTest - val data: Row = new GenericRow(Array(1, null, 1.0, true, 4, 5, null, "abcccd")) + val data: Row = new GenericRow(Array(1, null, 1.0, true, 4, 5, null, "abcccd", "a%")) // TODO add to DSL val c1 = BoundReference(0, AttributeReference("a", IntegerType)()) @@ -61,37 +66,72 @@ trait ExpressionEvaluationSuite extends FunSuite { val c6 = BoundReference(5, AttributeReference("f", IntegerType)()) val c7 = BoundReference(6, AttributeReference("g", StringType)()) val c8 = BoundReference(7, AttributeReference("h", StringType)()) + val c9 = BoundReference(8, AttributeReference("i", StringType)()) - def verify(expected: Seq[(Boolean, Any)], result: Row, input: Row) { + /** + * Compare each of the field if it equals the expected value. + * + * expected is a sequence of (Any, Any), + * and the first element indicates: + * true: the expected value is field is null + * false: the expected value is not null + * Exception Class: the expected exception class while computing the value + * the second element is the real value when first element equals false(not null) + */ + def verify(expected: Seq[(Any, Any)], result: Row, input: Row) { Seq.tabulate(expected.size) { i => expected(i) match { case (false, expected) => { - assert(result.isNullAt(i) == false, s"Input:($input), Output field:$i shouldn't be null") + assert(result.isNullAt(i) == false, + s"Input:($input), Output field:$i shouldn't be null") + val real = result.apply(i) - assert(real == expected, s"Input:($input), Output field:$i is expected as $expected, but got $real") + assert(real == expected, + s"Input:($input), Output field:$i is expected as $expected, but got $real") } case (true, _) => { - assert(result.isNullAt(i), s"Input:($input), Output field:$i is expected as null") + assert(result.isNullAt(i) == true, s"Input:($input), Output field:$i is expected as null") + } + case (exception: Class[_], _) => { + assert(result.isNullAt(i) == false, + s"Input:($input), Output field:$i should be exception") + + val real = result.apply(i).getClass.getName + val expect = exception.getName + assert(real == expect, + s"Input:($input), Output field:$i expect exception $expect, but got $real") } } } } - def verify(expecteds: Seq[Seq[(Boolean, Any)]], results: Seq[Row], inputs: Seq[Row]) { + def verify(expecteds: Seq[Seq[(Any, Any)]], results: Seq[Row], inputs: Seq[Row]) { Range(0, expecteds.length).foreach { i => verify(expecteds(i), results(i), inputs(i)) } } - def run(exprs: Seq[Expression], expected: Seq[(Boolean, Any)], input: Row) { + def proc(tester: ExprEvalTest, input: Row): Row = { + try { + tester.engine.apply(input) + } catch { + case x: Any => { + println(x.printStackTrace()) + new GenericRow(Array(x.asInstanceOf[Any])) + } + } + } + + def run(exprs: Seq[Expression], expected: Seq[(Any, Any)], input: Row) { val tester = executor(exprs) - verify(expected, tester.engine.apply(input), input) + + verify(expected, proc(tester,input), input) } - def run(exprs: Seq[Expression], expecteds: Seq[Seq[(Boolean, Any)]], inputs: Seq[Row]) { + def run(exprs: Seq[Expression], expecteds: Seq[Seq[(Any, Any)]], inputs: Seq[Row]) { val tester = executor(exprs) - verify(expecteds, inputs.map(tester.engine.apply(_)), inputs) + verify(expecteds, inputs.map(proc(tester,_)), inputs) } test("logical") { @@ -133,6 +173,43 @@ trait ExpressionEvaluationSuite extends FunSuite { run(exprs, expecteds, data) } + test("string like / rlike") { + val exprs = Seq( + Like(c7, Literal("a", StringType)), + Like(c7, Literal(null, StringType)), + Like(c8, Literal(null, StringType)), + Like(c8, Literal("a_c", StringType)), + Like(c8, Literal("a%c", StringType)), + Like(c8, Literal("a%d", StringType)), + Like(c8, Literal("a\\%d", StringType)), // to escape the % + Like(c8, c9), + RLike(c7, Literal("a+", StringType)), + RLike(c7, Literal(null, StringType)), + RLike(c8, Literal(null, StringType)), + RLike(c8, Literal("a.*", StringType)) + ) + + val expecteds = Seq( + (true, false), + (true, false), + (true, false), + (false, false), + (false, false), + (false, true), + (false, false), + (false, true), + (true, false), + (true, false), + (true, false), + (false, true)) + + run(exprs, expecteds, data) + + val expr = Seq(RLike(c8, Literal("[a.(*])", StringType))) + val expected = Seq((classOf[java.util.regex.PatternSyntaxException], false)) + run(expr, expected, data) + } + test("literals") { assert((Literal(1) + Literal(1)).apply(null) === 2) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 8e76a7348e957..7a2ecb165ff91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -848,10 +848,8 @@ object HiveQl { case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("LIKE", left :: right:: Nil) => - UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right))) - case Token("RLIKE", left :: right:: Nil) => - UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("LIKE", left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) + case Token("RLIKE", left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) case Token("REGEXP", left :: right:: Nil) => UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right))) case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>