Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2213][SQL] Sort Merge Join #3173

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi
"a single partition.")
}

/**
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located. Based on the context, this
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems you want to update this comment.

case class ClusteredOrderedDistribution(clustering: Seq[Expression]) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
}

/**
* Represents data where tuples have been ordered according to the `ordering`
* [[Expression Expressions]]. This is a strictly stronger guarantee than
Expand Down Expand Up @@ -162,6 +176,37 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

/**
* Represents a partitioning where rows are split up across partitions based on the hash
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition. In each partition, the keys are sorted according to expressions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the keys are sorted => the rows are sorted?

*/
case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression
with Partitioning {

override def children = expressions
override def nullable = false
override def dataType = IntegerType

private[this] lazy val clusteringSet = expressions.toSet

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredOrderedDistribution(requiredClustering) =>
clusteringSet.subsetOf(requiredClustering.toSet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ClusteredOrderedDistribution is a special case of ClusteredDistribution, I think you can also have

case ClusteredDistribution(requiredClustering) =>
       clusteringSet.subsetOf(requiredClustering.toSet)

case _ => false
}

override def compatibleWith(other: Partitioning) = other match {
case BroadcastPartitioning => true
case h: HashSortedPartitioning if h == this => true
case _ => false
}

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
/**
* Represents a partitioning where rows are split across partitions based on some total ordering of
* the expressions specified in `ordering`. When data is partitioned in this manner the following
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
Expand Down Expand Up @@ -57,11 +58,33 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}

val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case HashSortedPartitioning(expressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter =>
val hashExpressions = newProjection(expressions, child.output)
iter.map(r => (hashExpressions(r), r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}

val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending))
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case RangePartitioning(sortingExpressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
Expand Down Expand Up @@ -158,6 +181,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), child) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
case (ClusteredOrderedDistribution(clustering), child) =>
addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child)
case (OrderedDistribution(ordering), child) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, child) => child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
joins.BuildRight
} else {
joins.BuildLeft
}
val hashJoin = joins.ShuffledHashJoin(
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
val mergeJoin = joins.MergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil

// case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
// val buildSide =
// if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
// joins.BuildRight
// } else {
// joins.BuildLeft
// }
// val hashJoin = joins.ShuffledHashJoin(
// leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
// condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
joins.HashOuterJoin(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.collection.CompactBuffer
/* : Developer Api :
* Performs sort-merge join of two child relations by first shuffling the data using the join
* keys. Also, when shuffling the data, sort the data by join keys.
*/
@DeveloperApi
case class MergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan
) extends BinaryNode {
// Implementation: the tricky part is handling duplicate join keys.
// To handle duplicate keys, we use a buffer to store all matching elements
// in right iterator for a certain join key. The buffer is used for
// generating join tuples when the join key of the next left element is
// the same as the current join key.
// TODO: add outer join support
override def outputPartitioning: Partitioning = left.outputPartitioning

override def output = left.output ++ right.output

private val orders = leftKeys.map(s => SortOrder(s, Ascending))

override def requiredChildDistribution =
ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil

@transient protected lazy val leftKeyGenerator: Projection =
newProjection(leftKeys, left.output)

@transient protected lazy val rightKeyGenerator: Projection =
newProjection(rightKeys, right.output)

private val ordering = new RowOrdering(orders, left.output)

override def execute() = {

left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
new Iterator[Row] {
private[this] val joinRow = new JoinedRow2
private[this] var leftElement:Row = _
private[this] var rightElement:Row = _
private[this] var leftKey:Row = _
private[this] var rightKey:Row = _
private[this] var buffer:CompactBuffer[Row] = _
private[this] var index = -1
private[this] var last = false

// initialize iterator
private def initialize() = {
if (leftIter.hasNext) {
leftElement = leftIter.next()
leftKey = leftKeyGenerator(leftElement)
} else {
last = true
}
if (rightIter.hasNext) {
rightElement = rightIter.next()
rightKey = rightKeyGenerator(rightElement)
} else {
last = true
}
}

initialize()

override final def hasNext: Boolean = {
// Two cases that hasNext returns true
// 1. We are iterating the buffer
// 2. We can find tuple pairs that have matching join key
//
// hasNext is stateless as nextMatchingPair() is called when
// index == -1 and will set index to 0 when nextMatchingPair()
// returns true. Muptiple calls to hasNext modifies iterator
// state at most once.
if (index != -1) return true
if (last) return false
return nextMatchingPair()
}

override final def next(): Row = {
if (index == -1) {
// We need this becasue the client of the join iterator may
// call next() without calling hasNext
if (!hasNext) return null
}
val joinedRow = joinRow(leftElement, buffer(index))
index += 1
if (index == buffer.size) {
// finished iterating the buffer, fetch
// next element from left iterator
if (leftIter.hasNext) {
// fetch next element
val leftElem = leftElement
val leftK = leftKeyGenerator(leftElem)
leftElement = leftIter.next()
leftKey = leftKeyGenerator(leftElement)
if (ordering.compare(leftKey,leftK) == 0) {
// need to go over the buffer again
// as we have the same join key for
// next left element
index = 0
} else {
// need to find a matching element from
// right iterator
index = -1
}
} else {
// no next left element, we are done
index = -1
last = true
}
}
joinedRow
}

// find the next pair of left/right tuples that have a
// matching join key
private def nextMatchingPair(): Boolean = {
while (ordering.compare(leftKey, rightKey) != 0) {
if (ordering.compare(leftKey, rightKey) < 0) {
if (leftIter.hasNext) {
leftElement = leftIter.next()
leftKey = leftKeyGenerator(leftElement)
} else {
last = true
return false
}
} else {
if (rightIter.hasNext) {
rightElement = rightIter.next()
rightKey = rightKeyGenerator(rightElement)
} else {
last = true
return false
}
}
}
// outer == inner
index = 0
buffer = null
buffer = new CompactBuffer[Row]()
buffer += rightElement
val rightElem = rightElement
val rightK = rightKeyGenerator(rightElem)
while(rightIter.hasNext) {
rightElement = rightIter.next()
rightKey = rightKeyGenerator(rightElement)
if (ordering.compare(rightKey,rightK) == 0) {
buffer += rightElement
} else {
return true
}
}
true
}
}
}
}
}
7 changes: 4 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: MergeJoin => j
}

assert(operators.size === 1)
Expand All @@ -72,9 +73,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[MergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[MergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[MergeJoin]),
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[HashOuterJoin]),
Expand Down