Skip to content

Commit

Permalink
[SPARK-32282][SQL] Improve EnsureRquirement.reorderJoinKeys to handle…
Browse files Browse the repository at this point in the history
… more scenarios such as PartitioningCollection

### What changes were proposed in this pull request?

This PR proposes to improve  `EnsureRquirement.reorderJoinKeys` to handle the following scenarios:
1. If the keys cannot be reordered to match the left-side `HashPartitioning`, consider the right-side `HashPartitioning`.
2. Handle `PartitioningCollection`, which may contain `HashPartitioning`

### Why are the changes needed?

1. For the scenario 1), the current behavior matches either the left-side `HashPartitioning` or the right-side `HashPartitioning`. This means that if both sides are `HashPartitioning`, it will try to match only the left side.
The following will not consider the right-side `HashPartitioning`:
```
val df1 = (0 until 10).map(i => (i % 5, i % 13)).toDF("i1", "j1")
val df2 = (0 until 10).map(i => (i % 7, i % 11)).toDF("i2", "j2")
df1.write.format("parquet").bucketBy(4, "i1", "j1").saveAsTable("t1")df2.write.format("parquet").bucketBy(4, "i2", "j2").saveAsTable("t2")
val t1 = spark.table("t1")
val t2 = spark.table("t2")
val join = t1.join(t2, t1("i1") === t2("j2") && t1("i1") === t2("i2"))
 join.explain

== Physical Plan ==
*(5) SortMergeJoin [i1#26, i1#26], [j2#31, i2#30], Inner
:- *(2) Sort [i1#26 ASC NULLS FIRST, i1#26 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(i1#26, i1#26, 4), true, [id=#69]
:     +- *(1) Project [i1#26, j1#27]
:        +- *(1) Filter isnotnull(i1#26)
:           +- *(1) ColumnarToRow
:              +- FileScan parquet default.t1[i1#26,j1#27] Batched: true, DataFilters: [isnotnull(i1#26)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(i1)], ReadSchema: struct<i1:int,j1:int>, SelectedBucketsCount: 4 out of 4
+- *(4) Sort [j2#31 ASC NULLS FIRST, i2#30 ASC NULLS FIRST], false, 0.
   +- Exchange hashpartitioning(j2#31, i2#30, 4), true, [id=#79].       <===== This can be removed
      +- *(3) Project [i2#30, j2#31]
         +- *(3) Filter (((j2#31 = i2#30) AND isnotnull(j2#31)) AND isnotnull(i2#30))
            +- *(3) ColumnarToRow
               +- FileScan parquet default.t2[i2#30,j2#31] Batched: true, DataFilters: [(j2#31 = i2#30), isnotnull(j2#31), isnotnull(i2#30)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(j2), IsNotNull(i2)], ReadSchema: struct<i2:int,j2:int>, SelectedBucketsCount: 4 out of 4

```

2.  For the scenario 2), the current behavior does not handle `PartitioningCollection`:
```
val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
val df2 = (0 until 100).map(i => (i % 7, i % 11)).toDF("i2", "j2")
val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3")
val join = df1.join(df2, df1("i1") === df2("i2") && df1("j1") === df2("j2")) // PartitioningCollection
val join2 = join.join(df3, join("j1") === df3("j3") && join("i1") === df3("i3"))
join2.explain

== Physical Plan ==
*(9) SortMergeJoin [j1#8, i1#7], [j3#30, i3#29], Inner
:- *(6) Sort [j1#8 ASC NULLS FIRST, i1#7 ASC NULLS FIRST], false, 0.       <===== This can be removed
:  +- Exchange hashpartitioning(j1#8, i1#7, 5), true, [id=#58]             <===== This can be removed
:     +- *(5) SortMergeJoin [i1#7, j1#8], [i2#18, j2#19], Inner
:        :- *(2) Sort [i1#7 ASC NULLS FIRST, j1#8 ASC NULLS FIRST], false, 0
:        :  +- Exchange hashpartitioning(i1#7, j1#8, 5), true, [id=#45]
:        :     +- *(1) Project [_1#2 AS i1#7, _2#3 AS j1#8]
:        :        +- *(1) LocalTableScan [_1#2, _2#3]
:        +- *(4) Sort [i2#18 ASC NULLS FIRST, j2#19 ASC NULLS FIRST], false, 0
:           +- Exchange hashpartitioning(i2#18, j2#19, 5), true, [id=#51]
:              +- *(3) Project [_1#13 AS i2#18, _2#14 AS j2#19]
:                 +- *(3) LocalTableScan [_1#13, _2#14]
+- *(8) Sort [j3#30 ASC NULLS FIRST, i3#29 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(j3#30, i3#29, 5), true, [id=#64]
      +- *(7) Project [_1#24 AS i3#29, _2#25 AS j3#30]
         +- *(7) LocalTableScan [_1#24, _2#25]
```
### Does this PR introduce _any_ user-facing change?

Yes, now from the above examples, the shuffle/sort nodes pointed by `This can be removed` are now removed:
1. Senario 1):
```
== Physical Plan ==
*(4) SortMergeJoin [i1#26, i1#26], [i2#30, j2#31], Inner
:- *(2) Sort [i1#26 ASC NULLS FIRST, i1#26 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(i1#26, i1#26, 4), true, [id=#67]
:     +- *(1) Project [i1#26, j1#27]
:        +- *(1) Filter isnotnull(i1#26)
:           +- *(1) ColumnarToRow
:              +- FileScan parquet default.t1[i1#26,j1#27] Batched: true, DataFilters: [isnotnull(i1#26)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(i1)], ReadSchema: struct<i1:int,j1:int>, SelectedBucketsCount: 4 out of 4
+- *(3) Sort [i2#30 ASC NULLS FIRST, j2#31 ASC NULLS FIRST], false, 0
   +- *(3) Project [i2#30, j2#31]
      +- *(3) Filter (((j2#31 = i2#30) AND isnotnull(j2#31)) AND isnotnull(i2#30))
         +- *(3) ColumnarToRow
            +- FileScan parquet default.t2[i2#30,j2#31] Batched: true, DataFilters: [(j2#31 = i2#30), isnotnull(j2#31), isnotnull(i2#30)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(j2), IsNotNull(i2)], ReadSchema: struct<i2:int,j2:int>, SelectedBucketsCount: 4 out of 4
```
2. Scenario 2):
```
== Physical Plan ==
*(8) SortMergeJoin [i1#7, j1#8], [i3#29, j3#30], Inner
:- *(5) SortMergeJoin [i1#7, j1#8], [i2#18, j2#19], Inner
:  :- *(2) Sort [i1#7 ASC NULLS FIRST, j1#8 ASC NULLS FIRST], false, 0
:  :  +- Exchange hashpartitioning(i1#7, j1#8, 5), true, [id=#43]
:  :     +- *(1) Project [_1#2 AS i1#7, _2#3 AS j1#8]
:  :        +- *(1) LocalTableScan [_1#2, _2#3]
:  +- *(4) Sort [i2#18 ASC NULLS FIRST, j2#19 ASC NULLS FIRST], false, 0
:     +- Exchange hashpartitioning(i2#18, j2#19, 5), true, [id=#49]
:        +- *(3) Project [_1#13 AS i2#18, _2#14 AS j2#19]
:           +- *(3) LocalTableScan [_1#13, _2#14]
+- *(7) Sort [i3#29 ASC NULLS FIRST, j3#30 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(i3#29, j3#30, 5), true, [id=#58]
      +- *(6) Project [_1#24 AS i3#29, _2#25 AS j3#30]
         +- *(6) LocalTableScan [_1#24, _2#25]
```

### How was this patch tested?

Added tests.

Closes #29074 from imback82/reorder_keys.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
imback82 authored and cloud-fan committed Oct 8, 2020
1 parent bbc887b commit 1c781a4
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftKeys: IndexedSeq[Expression],
rightKeys: IndexedSeq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = {
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
return (leftKeys, rightKeys)
return None
}

// Check if the current order already satisfies the expected order.
if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) {
return Some(leftKeys, rightKeys)
}

// Build a lookup between an expression and the positions its holds in the current key seq.
Expand All @@ -164,10 +169,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
rightKeysBuffer += rightKeys(index)
case _ =>
// The expression cannot be found, or we have exhausted all indices for that expression.
return (leftKeys, rightKeys)
return None
}
}
(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq)
Some(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq)
}

private def reorderJoinKeys(
Expand All @@ -176,19 +181,48 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
case (_, HashPartitioning(rightExpressions, _)) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
case _ =>
(leftKeys, rightKeys)
}
reorderJoinKeysRecursively(
leftKeys,
rightKeys,
Some(leftPartitioning),
Some(rightPartitioning))
.getOrElse((leftKeys, rightKeys))
} else {
(leftKeys, rightKeys)
}
}

/**
* Recursively reorders the join keys based on partitioning. It starts reordering the
* join keys to match HashPartitioning on either side, followed by PartitioningCollection.
*/
private def reorderJoinKeysRecursively(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Option[Partitioning],
rightPartitioning: Option[Partitioning]): Option[(Seq[Expression], Seq[Expression])] = {
(leftPartitioning, rightPartitioning) match {
case (Some(HashPartitioning(leftExpressions, _)), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(HashPartitioning(rightExpressions, _))) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(PartitioningCollection(partitionings)), _) =>
partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) =>
res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning))
}.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(PartitioningCollection(partitionings))) =>
partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) =>
res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p)))
}.orElse(None)
case _ =>
None
}
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.exchange

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.test.SharedSparkSession

class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
private val exprB = Literal(2)
private val exprC = Literal(3)

test("reorder should handle PartitioningCollection") {
val plan1 = DummySparkPlan(
outputPartitioning = PartitioningCollection(Seq(
HashPartitioning(exprA :: exprB :: Nil, 5),
HashPartitioning(exprA :: Nil, 5))))
val plan2 = DummySparkPlan()

// Test PartitioningCollection on the left side of join.
val smjExec1 = SortMergeJoinExec(
exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2)
EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _),
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) =>
assert(leftKeys === Seq(exprA, exprB))
assert(rightKeys === Seq(exprB, exprA))
case other => fail(other.toString)
}

// Test PartitioningCollection on the right side of join.
val smjExec2 = SortMergeJoinExec(
exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1)
EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
assert(leftKeys === Seq(exprB, exprA))
assert(rightKeys === Seq(exprA, exprB))
case other => fail(other.toString)
}

// Both sides are PartitioningCollection, but left side cannot be reorderd to match
// and it should fall back to the right side.
val smjExec3 = SortMergeJoinExec(
exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1)
EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
assert(leftKeys === Seq(exprC, exprA))
assert(rightKeys === Seq(exprA, exprB))
case other => fail(other.toString)
}
}

test("reorder should fallback to the other side partitioning") {
val plan1 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5))
val plan2 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5))

// Test fallback to the right side, which has HashPartitioning.
val smjExec1 = SortMergeJoinExec(
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2)
EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) =>
assert(leftKeys === Seq(exprB, exprA))
assert(rightKeys === Seq(exprB, exprC))
case other => fail(other.toString)
}

// Test fallback to the right side, which has PartitioningCollection.
val plan3 = DummySparkPlan(
outputPartitioning = PartitioningCollection(Seq(HashPartitioning(exprB :: exprC :: Nil, 5))))
val smjExec2 = SortMergeJoinExec(
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan3)
EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
assert(leftKeys === Seq(exprB, exprA))
assert(rightKeys === Seq(exprB, exprC))
case other => fail(other.toString)
}

// The right side has HashPartitioning, so it is matched first, but no reordering match is
// found, and it should fall back to the left side, which has a PartitioningCollection.
val smjExec3 = SortMergeJoinExec(
exprC :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, plan3, plan1)
EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _),
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) =>
assert(leftKeys === Seq(exprB, exprC))
assert(rightKeys === Seq(exprB, exprA))
case other => fail(other.toString)
}
}
}

0 comments on commit 1c781a4

Please sign in to comment.