diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..147b48047f530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -75,6 +75,21 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Represents data where tuples have been ordered according to the `clustering` + * [[Expression Expressions]]. This is a strictly stronger guarantee than + * [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted + * by the expressions. + */ +case class ClusteredOrderedDistribution(clustering: Seq[Expression]) + extends Distribution { + require( + clustering != Nil, + "The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") +} + sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -162,6 +177,40 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. And rows within the same partition are sorted by the expressions. + */ +case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression + with Partitioning { + + override def children = expressions + override def nullable = false + override def dataType = IntegerType + + private[this] lazy val clusteringSet = expressions.toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case ClusteredOrderedDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case ClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case h: HashSortedPartitioning if h == this => true + case _ => false + } + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} + /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner the following diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4815620c6fe57..e2e07c1b804a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -27,6 +27,7 @@ private[spark] object SQLConf { val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" + val AUTO_SORTMERGEJOIN = "spark.sql.autoSortMergeJoin" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" @@ -143,6 +144,12 @@ private[sql] class SQLConf extends Serializable { private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt + /** + * By default it will choose sort merge join. + */ + private[spark] def autoSortMergeJoin: Boolean = + getConf(AUTO_SORTMERGEJOIN, true.toString).toBoolean + /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, * it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..6c01dee9a8969 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -73,6 +72,26 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) + case HashSortedPartitioning(expressions, numPartitions) => + val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + iter.map(r => (hashExpressions(r).copy(), r.copy())) + } + } else { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } + } + val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + val part = new HashPartitioner(numPartitions) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.map(_._2) + case RangePartitioning(sortingExpressions, numPartitions) => val rdd = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} @@ -173,6 +192,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), child) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) + case (ClusteredOrderedDistribution(clustering), child) => + addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child) case (OrderedDistribution(ordering), child) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, child) => child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f754fa770d1b5..72f41e4bd7685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,6 +90,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.autoSortMergeJoin => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, Inner, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala new file mode 100644 index 0000000000000..3c0ab080e7f4d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + */ +@DeveloperApi +case class SortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[ClusteredOrderedDistribution] = + ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil + + private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending)) + private val ordering: RowOrdering = new RowOrdering(orders, left.output) + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var read: Boolean = false + private[this] var currentlMatches: CompactBuffer[Row] = _ + private[this] var currentrMatches: CompactBuffer[Row] = _ + private[this] var currentlPosition: Int = -1 + private[this] var currentrPosition: Int = -1 + + override final def hasNext: Boolean = + (currentlPosition != -1 && currentlPosition < currentlMatches.size) || + (leftIter.hasNext && rightIter.hasNext && nextMatchingPair) + + override final def next(): Row = { + val joinedRow = + joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition)) + currentrPosition += 1 + if (currentrPosition >= currentrMatches.size) { + currentlPosition += 1 + currentrPosition = 0 + } + joinedRow + } + + /** + * Searches the left/right iterator for the next rows that matches. + * + * @return true if the search is successful, and false if the left/right iterator runs out + * of tuples. + */ + private def nextMatchingPair(): Boolean = { + currentlPosition = -1 + currentlMatches = null + if (rightElement == null) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } + while (currentlMatches == null && leftIter.hasNext) { + if (!read) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } + while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } + currentrMatches = new CompactBuffer[Row]() + while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) { + currentrMatches += rightElement + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } + if (ordering.compare(leftKey, rightKey) == 0) { + currentrMatches += rightElement + } + if (currentrMatches.size > 0) { + // there exists rows match in right table, should search left table + currentlMatches = new CompactBuffer[Row]() + val leftMatch = leftKey.copy() + while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) { + currentlMatches += leftElement + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } + if (ordering.compare(leftKey, leftMatch) == 0) { + currentlMatches += leftElement + } else { + read = true + } + } + } + + if (currentlMatches == null) { + false + } else { + currentlPosition = 0 + currentrPosition = 0 + true + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e4dee87849fd4..bba2f223c55dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j } assert(operators.size === 1) @@ -75,9 +76,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]),