Skip to content

Commit

Permalink
Add more rewrite rule and UT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jul 10, 2023
1 parent 91b2a06 commit dc57bfa
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.json4s.CustomSerializer
import org.json4s.JsonAST.JString
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.Predicate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction

Expand Down Expand Up @@ -54,6 +55,9 @@ trait FlintSparkSkippingStrategy {
* new filtering condition on index data or empty if index not applicable
*/
def rewritePredicate(predicate: Predicate): Option[Predicate]

// Convert a column to a predicate
protected def convertToPredicate(col: Column): Predicate = col.expr.asInstanceOf[Predicate]
}

object FlintSparkSkippingStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ package org.opensearch.flint.spark.skipping.minmax
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MinMax, SkippingKind}

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, Predicate}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Or, Predicate}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Max, Min}
import org.apache.spark.sql.functions.col

Expand All @@ -33,10 +32,20 @@ case class MinMaxSkippingStrategy(
Seq(Min(col(columnName).expr), Max(col(columnName).expr))

override def rewritePredicate(predicate: Predicate): Option[Predicate] =
predicate.collect { case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
rewriteTo(col(minColName) <= value && col(maxColName) >= value)
predicate.collect {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
convertToPredicate(col(minColName) <= value && col(maxColName) >= value)
case LessThan(AttributeReference(`columnName`, _, _, _), value: Literal) =>
convertToPredicate(col(minColName) < value)
case LessThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) =>
convertToPredicate(col(minColName) <= value)
case GreaterThan(AttributeReference(`columnName`, _, _, _), value: Literal) =>
convertToPredicate(col(maxColName) > value)
case GreaterThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) =>
convertToPredicate(col(maxColName) >= value)
case In(AttributeReference(`columnName`, _, _, _), values: Seq[Literal]) =>
values
.map(value => convertToPredicate(col(minColName) <= value && col(maxColName) >= value))
.reduceLeft(Or)
}.headOption

// Convert a column to predicate
private def rewriteTo(col: Column): Predicate = col.expr.asInstanceOf[Predicate]
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ case class PartitionSkippingStrategy(
override def rewritePredicate(predicate: Predicate): Option[Predicate] = {
// Column has same name in index data, so just rewrite to the same equation
predicate.collect { case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
EqualTo(col(columnName).expr, value)
convertToPredicate(col(columnName) === value)
}.headOption
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.minmax

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal, Or}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.IntegerType

class MinMaxSkippingStrategySuite extends SparkFunSuite with Matchers {

private val strategy = MinMaxSkippingStrategy(columnName = "age", columnType = "integer")

private val indexCol = AttributeReference("age", IntegerType, nullable = false)()

private val minCol = col("MinMax_age_0").expr
private val maxCol = col("MinMax_age_1").expr

test("should rewrite EqualTo(<indexCol>, <value>)") {
strategy.rewritePredicate(EqualTo(indexCol, Literal(30))) shouldBe Some(
And(LessThanOrEqual(minCol, Literal(30)), GreaterThanOrEqual(maxCol, Literal(30))))
}

test("should rewrite LessThan(<indexCol>, <value>)") {
strategy.rewritePredicate(LessThan(indexCol, Literal(30))) shouldBe Some(
LessThan(minCol, Literal(30)))
}

test("should rewrite LessThanOrEqual(<indexCol>, <value>)") {
strategy.rewritePredicate(LessThanOrEqual(indexCol, Literal(30))) shouldBe Some(
LessThanOrEqual(minCol, Literal(30)))
}

test("should rewrite GreaterThan(<indexCol>, <value>)") {
strategy.rewritePredicate(GreaterThan(indexCol, Literal(30))) shouldBe Some(
GreaterThan(maxCol, Literal(30)))
}

test("should rewrite GreaterThanOrEqual(<indexCol>, <value>)") {
strategy.rewritePredicate(GreaterThanOrEqual(indexCol, Literal(30))) shouldBe Some(
GreaterThanOrEqual(maxCol, Literal(30)))
}

test("should rewrite In(<indexCol>, <value1, value2 ...>") {
strategy.rewritePredicate(In(indexCol, Seq(Literal(25), Literal(30)))) shouldBe Some(
Or(
And(LessThanOrEqual(minCol, Literal(25)), GreaterThanOrEqual(maxCol, Literal(25))),
And(LessThanOrEqual(minCol, Literal(30)), GreaterThanOrEqual(maxCol, Literal(30)))))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.partition

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThan, Literal}
import org.apache.spark.sql.types.IntegerType

class PartitionSkippingStrategySuite extends SparkFunSuite with Matchers {

private val strategy = PartitionSkippingStrategy(columnName = "year", columnType = "int")

private val indexCol = AttributeReference("year", IntegerType, nullable = false)()

test("should rewrite EqualTo(<indexCol>, <value>)") {
strategy.rewritePredicate(EqualTo(indexCol, Literal(2023))) shouldBe Some(
EqualTo(UnresolvedAttribute("year"), Literal(2023)))
}

test("should not rewrite predicate with other column)") {
val predicate =
EqualTo(AttributeReference("month", IntegerType, nullable = false)(), Literal(4))

strategy.rewritePredicate(predicate) shouldBe empty
}

test("should not rewrite GreaterThan(<indexCol>, <value>)") {
strategy.rewritePredicate(GreaterThan(indexCol, Literal(2023))) shouldBe empty
}

test("should only rewrite EqualTo(<indexCol>, <value>) in conjunction") {
val predicate =
And(EqualTo(indexCol, Literal(2023)), GreaterThan(indexCol, Literal(2023)))

strategy.rewritePredicate(predicate) shouldBe Some(
EqualTo(UnresolvedAttribute("year"), Literal(2023)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.valueset

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThan, Literal}
import org.apache.spark.sql.types.StringType

class ValueSetSkippingStrategySuite extends SparkFunSuite with Matchers {

private val strategy = ValueSetSkippingStrategy(columnName = "name", columnType = "string")

private val indexCol = AttributeReference("name", StringType, nullable = false)()

test("should rewrite EqualTo(<indexCol>, <value>)") {
strategy.rewritePredicate(EqualTo(indexCol, Literal("hello"))) shouldBe Some(
EqualTo(UnresolvedAttribute("name"), Literal("hello")))
}

test("should not rewrite predicate with other column") {
val predicate =
EqualTo(AttributeReference("address", StringType, nullable = false)(), Literal("hello"))

strategy.rewritePredicate(predicate) shouldBe empty
}

test("should not rewrite GreaterThan(<indexCol>, <value>)") {
strategy.rewritePredicate(GreaterThan(indexCol, Literal("hello"))) shouldBe empty
}

test("should only rewrite EqualTo(<indexCol>, <value>) in conjunction") {
val predicate =
And(EqualTo(indexCol, Literal("hello")), GreaterThan(indexCol, Literal(2023)))

strategy.rewritePredicate(predicate) shouldBe Some(
EqualTo(UnresolvedAttribute("name"), Literal("hello")))
}
}

0 comments on commit dc57bfa

Please sign in to comment.