diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ccb0df113c063..3ed83f04039b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -58,6 +58,20 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi "a single partition.") } +/** + * Represents data where tuples that share the same values for the `clustering` + * [[Expression Expressions]] will be co-located. Based on the context, this + * can mean such tuples are either co-located in the same partition or they will be contiguous + * within a single partition. + */ +case class ClusteredOrderedDistribution(clustering: Seq[Expression]) extends Distribution { + require( + clustering != Nil, + "The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") +} + /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than @@ -162,6 +176,37 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. In each partition, the keys are sorted according to expressions + */ +case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression + with Partitioning { + + override def children = expressions + override def nullable = false + override def dataType = IntegerType + + private[this] lazy val clusteringSet = expressions.toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case ClusteredOrderedDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case h: HashSortedPartitioning if h == this => true + case _ => false + } + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner the following diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 927f40063e47e..443266725ea05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -57,11 +58,33 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(hashExpressions(r), r)) } } + val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) + case HashSortedPartitioning(expressions, numPartitions) => + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => + val hashExpressions = newProjection(expressions, child.output) + iter.map(r => (hashExpressions(r), r.copy())) + } + } else { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } + } + + val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + val part = new HashPartitioner(numPartitions) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.map(_._2) + case RangePartitioning(sortingExpressions, numPartitions) => val rdd = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} @@ -158,6 +181,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), child) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) + case (ClusteredOrderedDistribution(clustering), child) => + addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child) case (OrderedDistribution(ordering), child) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, child) => child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cc7e0c05ffc70..03db6d1126519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -84,15 +84,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - val hashJoin = joins.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) - condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + val mergeJoin = joins.MergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + + // case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + // val buildSide = + // if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + // joins.BuildRight + // } else { + // joins.BuildLeft + // } + // val hashJoin = joins.ShuffledHashJoin( + // leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + // condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.HashOuterJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala new file mode 100644 index 0000000000000..2ca9541ccc307 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer +/* : Developer Api : + * Performs sort-merge join of two child relations by first shuffling the data using the join + * keys. Also, when shuffling the data, sort the data by join keys. +*/ +@DeveloperApi +case class MergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan +) extends BinaryNode { + // Implementation: the tricky part is handling duplicate join keys. + // To handle duplicate keys, we use a buffer to store all matching elements + // in right iterator for a certain join key. The buffer is used for + // generating join tuples when the join key of the next left element is + // the same as the current join key. + // TODO: add outer join support + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def output = left.output ++ right.output + + private val orders = leftKeys.map(s => SortOrder(s, Ascending)) + + override def requiredChildDistribution = + ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil + + @transient protected lazy val leftKeyGenerator: Projection = + newProjection(leftKeys, left.output) + + @transient protected lazy val rightKeyGenerator: Projection = + newProjection(rightKeys, right.output) + + private val ordering = new RowOrdering(orders, left.output) + + override def execute() = { + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + new Iterator[Row] { + private[this] val joinRow = new JoinedRow2 + private[this] var leftElement:Row = _ + private[this] var rightElement:Row = _ + private[this] var leftKey:Row = _ + private[this] var rightKey:Row = _ + private[this] var buffer:CompactBuffer[Row] = _ + private[this] var index = -1 + private[this] var last = false + + // initialize iterator + private def initialize() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + last = true + } + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + last = true + } + } + + initialize() + + override final def hasNext: Boolean = { + // Two cases that hasNext returns true + // 1. We are iterating the buffer + // 2. We can find tuple pairs that have matching join key + // + // hasNext is stateless as nextMatchingPair() is called when + // index == -1 and will set index to 0 when nextMatchingPair() + // returns true. Muptiple calls to hasNext modifies iterator + // state at most once. + if (index != -1) return true + if (last) return false + return nextMatchingPair() + } + + override final def next(): Row = { + if (index == -1) { + // We need this becasue the client of the join iterator may + // call next() without calling hasNext + if (!hasNext) return null + } + val joinedRow = joinRow(leftElement, buffer(index)) + index += 1 + if (index == buffer.size) { + // finished iterating the buffer, fetch + // next element from left iterator + if (leftIter.hasNext) { + // fetch next element + val leftElem = leftElement + val leftK = leftKeyGenerator(leftElem) + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + if (ordering.compare(leftKey,leftK) == 0) { + // need to go over the buffer again + // as we have the same join key for + // next left element + index = 0 + } else { + // need to find a matching element from + // right iterator + index = -1 + } + } else { + // no next left element, we are done + index = -1 + last = true + } + } + joinedRow + } + + // find the next pair of left/right tuples that have a + // matching join key + private def nextMatchingPair(): Boolean = { + while (ordering.compare(leftKey, rightKey) != 0) { + if (ordering.compare(leftKey, rightKey) < 0) { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + last = true + return false + } + } else { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + last = true + return false + } + } + } + // outer == inner + index = 0 + buffer = null + buffer = new CompactBuffer[Row]() + buffer += rightElement + val rightElem = rightElement + val rightK = rightKeyGenerator(rightElem) + while(rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + if (ordering.compare(rightKey,rightK) == 0) { + buffer += rightElement + } else { + return true + } + } + true + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8b4cf5bac0187..22c1dd482c36c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j + case j: MergeJoin => j } assert(operators.size === 1) @@ -72,9 +73,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[MergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[MergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[MergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]),