Skip to content

Commit

Permalink
Reduce shuffles for successive full outer join (apache#63)
Browse files Browse the repository at this point in the history
* Modify existed Partioning & Distribution to reduce shuffles for full outer join

* Refactor and test
  • Loading branch information
Guo Chenzhao authored and plusplusjiajia committed Jan 15, 2019
1 parent 471fd59 commit 61bd1c9
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,22 @@ case class ClusteredDistribution(
}

/**
* Represents data where tuples have been clustered according to the hash of the given
* `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
* If exceptNull == false: Represents data where tuples have been clustered according to the hash of
* the given `expressions`.
* If exceptNull == true: Represents data where tuples have been clustered according to the hash of
* the given `expressions` except NULL, it means NULL can distribute in any partitions. This is
* often used in conditions of Join, where NULL's distribution is not cared about due to NULL will
* be considered not equal to any value
* The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
* [[HashPartitioning]] can satisfy this distribution.
*
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
* number of partitions, this distribution strictly requires which partition the tuple should be in.
*/
case class HashClusteredDistribution(
expressions: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution {
requiredNumPartitions: Option[Int] = None,
exceptNull: Boolean = false) extends Distribution {
require(
expressions != Nil,
"The expressions for hash of a HashClusteredDistribution should not be Nil. " +
Expand All @@ -112,7 +118,7 @@ case class HashClusteredDistribution(
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(expressions, numPartitions)
HashPartitioning(expressions, numPartitions, exceptNull)
}
}

Expand Down Expand Up @@ -207,12 +213,16 @@ case object SinglePartition extends Partitioning {
}

/**
* Represents a partitioning where rows are split up across partitions based on the hash
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* If exceptNull == false: Represents a partitioning where rows are split up across partitions based
* on the hash of `expressions`.
* If exceptNull == true: Represents a partitioning where rows are split up across partitions based
* on the hash of `expressions` except null, which is the only key not co-partitioned.
* All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression with Partitioning with Unevaluable {
case class HashPartitioning(
expressions: Seq[Expression], numPartitions: Int, exceptNull: Boolean = false)
extends Expression with Partitioning with Unevaluable {

override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
Expand All @@ -222,8 +232,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
super.satisfies0(required) || {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
expressions.length == h.expressions.length && (h.exceptNull || !exceptNull) &&
expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, _) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -163,13 +162,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
case HashPartitioning(leftExpressions, _, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
case HashPartitioning(rightExpressions, _, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ object ShuffleExchangeExec {
serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
case HashPartitioning(_, n, _) =>
new Partitioner {
override def numPartitions: Int = n
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,27 @@ case class SortMergeJoinExec(
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case FullOuter =>
// The output of Full Outer Join is similar to pure HashPartioning, except for NULL, which
// is the only key not co-partitioned
(left.outputPartitioning, right.outputPartitioning) match {
case (l: HashPartitioning, r: HashPartitioning) =>
PartitioningCollection(Seq(l.copy(exceptNull = true), r.copy(exceptNull = true)))
case _ => UnknownPartitioning(left.outputPartitioning.numPartitions)
}
case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
}

override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
override def requiredChildDistribution: Seq[Distribution] = joinType match {
case Inner | LeftOuter | RightOuter | FullOuter =>
HashClusteredDistribution(leftKeys, exceptNull = true) ::
HashClusteredDistribution(rightKeys, exceptNull = true) :: Nil
case _ => HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
}


override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.{QueryTest, Row, execution}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._

class PlannerSuite extends SharedSQLContext {
class PlannerSuite extends QueryTest with SharedSQLContext {
import testImplicits._

setupTestData()
Expand Down Expand Up @@ -683,6 +683,41 @@ class PlannerSuite extends SharedSQLContext {
case _ => fail()
}
}
test("EnsureRequirements doesn't add shuffle between 2 successive full outer joins on the same " +
"key") {
val df1 = spark.range(1, 100, 1, 2).filter(_ % 2 == 0).selectExpr("id as a1")
val df2 = spark.range(1, 100, 1, 2).selectExpr("id as b2")
val df3 = spark.range(1, 100, 1, 2).selectExpr("id as a3")
val fullOuterJoins = df1
.join(df2, col("a1") === col("b2"), "full_outer")
.join(df3, col("a1") === col("a3"), "full_outer")
assert(
fullOuterJoins.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }
.length === 3)
val expected = (1 until 100).filter(_ % 2 == 0).map(i => Row(i, i, i)) ++
(1 until 100).filterNot(_ % 2 == 0).map(Row(null, _, null)) ++
(1 until 100).filterNot(_ % 2 == 0).map(Row(null, null, _))
checkAnswer(fullOuterJoins, expected)
}

test("EnsureRequirements still adds shuffle for non-successive full outer joins on the same key")
{
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
val df1 = spark.range(1, 100).selectExpr("id as a1")
val df2 = spark.range(1, 100).selectExpr("id as b2")
val df3 = spark.range(1, 100).selectExpr("id as a3")
val df4 = spark.range(1, 100).selectExpr("id as a4")

val fullOuterJoins = df1
.join(df2, col("a1") === col("b2"), "full_outer")
.join(df3, col("a1") === col("a3"), "left_outer")
.join(df4, col("a3") === col("a4"), "full_outer")
fullOuterJoins.explain(true)
assert(
fullOuterJoins.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }
.length === 5)
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down

0 comments on commit 61bd1c9

Please sign in to comment.