Skip to content

Commit

Permalink
[SPARK-29231][SQL] Constraints should be inferred from cast equality …
Browse files Browse the repository at this point in the history
…constraint

### What changes were proposed in this pull request?

This PR add support infer constraints from cast equality constraint. For example:
```scala
scala> spark.sql("create table spark_29231_1(c1 bigint, c2 bigint)")
res0: org.apache.spark.sql.DataFrame = []

scala> spark.sql("create table spark_29231_2(c1 int, c2 bigint)")
res1: org.apache.spark.sql.DataFrame = []

scala> spark.sql("select t1.* from spark_29231_1 t1 join spark_29231_2 t2 on (t1.c1 = t2.c1 and t1.c1 = 1)").explain
== Physical Plan ==
*(2) Project [c1#5L, c2#6L]
+- *(2) BroadcastHashJoin [c1#5L], [cast(c1#7 as bigint)], Inner, BuildRight
   :- *(2) Project [c1#5L, c2#6L]
   :  +- *(2) Filter (isnotnull(c1#5L) AND (c1#5L = 1))
   :     +- *(2) ColumnarToRow
   :        +- FileScan parquet default.spark_29231_1[c1#5L,c2#6L] Batched: true, DataFilters: [isnotnull(c1#5L), (c1#5L = 1)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-preview2-bin-hadoop2.7/spark-warehouse/spark_29231_1], PartitionFilters: [], PushedFilters: [IsNotNull(c1), EqualTo(c1,1)], ReadSchema: struct<c1:bigint,c2:bigint>
   +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#209]
      +- *(1) Project [c1#7]
         +- *(1) Filter isnotnull(c1#7)
            +- *(1) ColumnarToRow
               +- FileScan parquet default.spark_29231_2[c1#7] Batched: true, DataFilters: [isnotnull(c1#7)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-preview2-bin-hadoop2.7/spark-warehouse/spark_29231_2], PartitionFilters: [], PushedFilters: [IsNotNull(c1)], ReadSchema: struct<c1:int>
```

After this PR:
```scala
scala> spark.sql("select t1.* from spark_29231_1 t1 join spark_29231_2 t2 on (t1.c1 = t2.c1 and t1.c1 = 1)").explain
== Physical Plan ==
*(2) Project [c1#0L, c2#1L]
+- *(2) BroadcastHashJoin [c1#0L], [cast(c1#2 as bigint)], Inner, BuildRight
   :- *(2) Project [c1#0L, c2#1L]
   :  +- *(2) Filter (isnotnull(c1#0L) AND (c1#0L = 1))
   :     +- *(2) ColumnarToRow
   :        +- FileScan parquet default.spark_29231_1[c1#0L,c2#1L] Batched: true, DataFilters: [isnotnull(c1#0L), (c1#0L = 1)], Format: Parquet, Location: InMemoryFileIndex[file:/root/opensource/spark/spark-warehouse/spark_29231_1], PartitionFilters: [], PushedFilters: [IsNotNull(c1), EqualTo(c1,1)], ReadSchema: struct<c1:bigint,c2:bigint>
   +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#99]
      +- *(1) Project [c1#2]
         +- *(1) Filter ((cast(c1#2 as bigint) = 1) AND isnotnull(c1#2))
            +- *(1) ColumnarToRow
               +- FileScan parquet default.spark_29231_2[c1#2] Batched: true, DataFilters: [(cast(c1#2 as bigint) = 1), isnotnull(c1#2)], Format: Parquet, Location: InMemoryFileIndex[file:/root/opensource/spark/spark-warehouse/spark_29231_2], PartitionFilters: [], PushedFilters: [IsNotNull(c1)], ReadSchema: struct<c1:int>
```

### Why are the changes needed?

Improve query performance.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Unit test.

Closes #27252 from wangyum/SPARK-29231.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wangyum authored and cloud-fan committed Feb 13, 2020
1 parent 04604b9 commit fb0e07b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ trait ConstraintHelper {
*/
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
predicates.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = constraints - eq
val candidateConstraints = predicates - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) =>
inferredConstraints ++= replaceConstraints(predicates - eq, r, l)
case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
inferredConstraints ++= replaceConstraints(predicates - eq, l, r)
case _ => // No inference
}
inferredConstraints -- constraints
Expand All @@ -75,7 +81,7 @@ trait ConstraintHelper {
private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
destination: Expression): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, LongType}

class InferFiltersFromConstraintsSuite extends PlanTest {

Expand All @@ -46,8 +47,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
y: LogicalPlan,
expectedLeft: LogicalPlan,
expectedRight: LogicalPlan,
joinType: JoinType) = {
val condition = Some("x.a".attr === "y.a".attr)
joinType: JoinType,
condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = {
val originalQuery = x.join(y, joinType, condition).analyze
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
val optimized = Optimize.execute(originalQuery)
Expand Down Expand Up @@ -263,4 +264,56 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
}

test("Constraints should be inferred from cast equality constraint(filter higher data type)") {
val testRelation1 = LocalRelation('a.int)
val testRelation2 = LocalRelation('b.long)
val originalLeft = testRelation1.subquery('left)
val originalRight = testRelation2.where('b === 1L).subquery('right)

val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left)
val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right)

Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
}

Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
testConstraintsAfterJoin(
originalLeft,
originalRight,
testRelation1.where(IsNotNull('a)).subquery('left),
right,
Inner,
condition)
}
}

test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") {
val testRelation1 = LocalRelation('a.int)
val testRelation2 = LocalRelation('b.long)
val originalLeft = testRelation1.where('a === 1).subquery('left)
val originalRight = testRelation2.subquery('right)

val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left)
val right = testRelation2.where(IsNotNull('b)).subquery('right)

Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
}

Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
testConstraintsAfterJoin(
originalLeft,
originalRight,
left,
testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right),
Inner,
condition)
}
}
}

0 comments on commit fb0e07b

Please sign in to comment.