Skip to content

Commit

Permalink
sort merge join for spark sql
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Apr 3, 2015
1 parent 5db8912 commit 880d8e9
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
def clustering: Set[Expression] = ordering.map(_.child).toSet
}

/**
* Represents data where tuples have been ordered according to the `clustering`
* [[Expression Expressions]]. This is a strictly stronger guarantee than
* [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted
* by the expressions.
*/
case class ClusteredOrderedDistribution(clustering: Seq[Expression])
extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
}

sealed trait Partitioning {
/** Returns the number of partitions that the data is split across */
val numPartitions: Int
Expand Down Expand Up @@ -162,6 +177,40 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

/**
* Represents a partitioning where rows are split up across partitions based on the hash
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition. And rows within the same partition are sorted by the expressions.
*/
case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression
with Partitioning {

override def children = expressions
override def nullable = false
override def dataType = IntegerType

private[this] lazy val clusteringSet = expressions.toSet

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredOrderedDistribution(requiredClustering) =>
clusteringSet.subsetOf(requiredClustering.toSet)
case ClusteredDistribution(requiredClustering) =>
clusteringSet.subsetOf(requiredClustering.toSet)
case _ => false
}

override def compatibleWith(other: Partitioning) = other match {
case BroadcastPartitioning => true
case h: HashSortedPartitioning if h == this => true
case _ => false
}

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

/**
* Represents a partitioning where rows are split across partitions based on some total ordering of
* the expressions specified in `ordering`. When data is partitioned in this manner the following
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private[spark] object SQLConf {
val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
val AUTO_SORTMERGEJOIN = "spark.sql.autoSortMergeJoin"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val CODEGEN_ENABLED = "spark.sql.codegen"
Expand Down Expand Up @@ -143,6 +144,12 @@ private[sql] class SQLConf extends Serializable {
private[spark] def autoBroadcastJoinThreshold: Int =
getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt

/**
* By default it will choose sort merge join.
*/
private[spark] def autoSortMergeJoin: Boolean =
getConf(AUTO_SORTMERGEJOIN, true.toString).toBoolean

/**
* The default size in bytes to assign to a logical operator's estimation statistics. By default,
* it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
Expand Down Expand Up @@ -73,6 +72,26 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case HashSortedPartitioning(expressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}
val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending))
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case RangePartitioning(sortingExpressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
Expand Down Expand Up @@ -173,6 +192,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), child) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
case (ClusteredOrderedDistribution(clustering), child) =>
addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child)
case (OrderedDistribution(ordering), child) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, child) => child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)

// for now let's support inner join first, then add outer join
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoSortMergeJoin =>
val mergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, Inner, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.collection.CompactBuffer

/**
* :: DeveloperApi ::
* Performs an sort merge join of two child relations.
*/
@DeveloperApi
case class SortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def output: Seq[Attribute] = left.output ++ right.output

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[ClusteredOrderedDistribution] =
ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil

private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending))
private val ordering: RowOrdering = new RowOrdering(orders, left.output)

@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)

override def execute() = {
val leftResults = left.execute().map(_.copy())
val rightResults = right.execute().map(_.copy())

leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
new Iterator[Row] {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow5
private[this] var leftElement: Row = _
private[this] var rightElement: Row = _
private[this] var leftKey: Row = _
private[this] var rightKey: Row = _
private[this] var read: Boolean = false
private[this] var currentlMatches: CompactBuffer[Row] = _
private[this] var currentrMatches: CompactBuffer[Row] = _
private[this] var currentlPosition: Int = -1
private[this] var currentrPosition: Int = -1

override final def hasNext: Boolean =
(currentlPosition != -1 && currentlPosition < currentlMatches.size) ||
(leftIter.hasNext && rightIter.hasNext && nextMatchingPair)

override final def next(): Row = {
val joinedRow =
joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition))
currentrPosition += 1
if (currentrPosition >= currentrMatches.size) {
currentlPosition += 1
currentrPosition = 0
}
joinedRow
}

/**
* Searches the left/right iterator for the next rows that matches.
*
* @return true if the search is successful, and false if the left/right iterator runs out
* of tuples.
*/
private def nextMatchingPair(): Boolean = {
currentlPosition = -1
currentlMatches = null
if (rightElement == null) {
rightElement = rightIter.next()
rightKey = rightKeyGenerator(rightElement)
}
while (currentlMatches == null && leftIter.hasNext) {
if (!read) {
leftElement = leftIter.next()
leftKey = leftKeyGenerator(leftElement)
}
while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) {
rightElement = rightIter.next()
rightKey = rightKeyGenerator(rightElement)
}
currentrMatches = new CompactBuffer[Row]()
while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) {
currentrMatches += rightElement
rightElement = rightIter.next()
rightKey = rightKeyGenerator(rightElement)
}
if (ordering.compare(leftKey, rightKey) == 0) {
currentrMatches += rightElement
}
if (currentrMatches.size > 0) {
// there exists rows match in right table, should search left table
currentlMatches = new CompactBuffer[Row]()
val leftMatch = leftKey.copy()
while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) {
currentlMatches += leftElement
leftElement = leftIter.next()
leftKey = leftKeyGenerator(leftElement)
}
if (ordering.compare(leftKey, leftMatch) == 0) {
currentlMatches += leftElement
} else {
read = true
}
}
}

if (currentlMatches == null) {
false
} else {
currentlPosition = 0
currentrPosition = 0
true
}
}
}
}
}
}
7 changes: 4 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
case j: SortMergeJoin => j
}

assert(operators.size === 1)
Expand All @@ -75,9 +76,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[HashOuterJoin]),
Expand Down

0 comments on commit 880d8e9

Please sign in to comment.