From 645c70b24762d14d88e19d708e53a384a3daa2a1 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 10 Apr 2015 03:44:52 -0700 Subject: [PATCH] address comments using sort --- .../plans/physical/partitioning.scala | 49 ------------------- .../apache/spark/sql/execution/Exchange.scala | 38 +++++--------- .../spark/sql/execution/SparkPlan.scala | 6 +++ .../sql/execution/joins/SortMergeJoin.scala | 13 ++++- .../SortMergeCompatibilitySuite.scala | 13 +++++ 5 files changed, 42 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e0f981ef37960..288c11f69fe22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -75,21 +75,6 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } -/** - * Represents data where tuples have been ordered according to the `clustering` - * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted - * by the expressions. - */ -case class ClusteredOrderedDistribution(clustering: Seq[Expression]) - extends Distribution { - require( - clustering != Nil, - "The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " + - "An AllTuples should be used to represent a distribution that only has " + - "a single partition.") -} - sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -177,40 +162,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. And rows within the same partition are sorted by the expressions. - */ -case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression - with Partitioning { - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: DataType = IntegerType - - private[this] lazy val clusteringSet = expressions.toSet - - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredOrderedDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case _ => false - } - - override def compatibleWith(other: Partitioning) = other match { - case BroadcastPartitioning => true - case h: HashSortedPartitioning if h == this => true - case _ => false - } - - override def eval(input: Row = null): EvaluatedType = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") -} - /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner the following diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 63d2e92697c04..c89b2a068351b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -72,29 +72,6 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) - case HashSortedPartitioning(expressions, numPartitions) => - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - iter.map(r => (hashExpressions(r).copy(), r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[Row, Row]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) - } - } - val sortingExpressions = expressions.zipWithIndex.map { - case (exp, index) => - new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending) - } - val ordering = new RowOrdering(sortingExpressions, child.output) - val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - shuffled.map(_._2) - case RangePartitioning(sortingExpressions, numPartitions) => val rdd = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} @@ -184,6 +161,11 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + // Check if the partitioning we want to ensure is the same as the child's output + // partitioning. If so, we do not need to add the Exchange operator. + def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan = + if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child + if (meetsRequirements && compatible) { operator } else { @@ -195,14 +177,18 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), child) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (ClusteredOrderedDistribution(clustering), child) => - addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child) case (OrderedDistribution(ordering), child) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, child) => child case (dist, _) => sys.error(s"Don't know how to ensure $dist") } - operator.withNewChildren(repartitionedChildren) + val reorderedChildren = operator.requiredInPartitionOrdering.zip(repartitionedChildren).map { + case (Nil, child) => + child + case (ordering, child) => + addSortIfNecessary(ordering, child) + } + operator.withNewChildren(reorderedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index d239637cd4b4e..41d80969e376b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + /** Specifies how data is ordered in each partition. */ + def outputOrdering: Seq[SortOrder] = Nil + + /** Specifies sort order for each partition requirements on the input data for this operator. */ + def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** * Runs this query returning the result as an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 259e7ab264e29..fd65320d55139 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -39,16 +39,25 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = HashSortedPartitioning(leftKeys, 0) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = - ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map { case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending) } private val ordering: RowOrdering = new RowOrdering(orders, left.output) + private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Seq[SortOrder] = keys.map { + k => SortOrder(BindReferences.bindReference(k, side.output, allowFailures = false), Ascending) + } + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys, left) + + override def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys, left) :: requiredOrders(rightKeys, right) :: Nil + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index f49555c9142b1..3e08a0ce8c003 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -70,6 +70,19 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { "auto_sortmerge_join_7", "auto_sortmerge_join_8", "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", "join0", "join1", "join10",