diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 6a32244bd03b3..d45230a02e501 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -234,7 +234,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { } object RowOrdering { - def getOrderingFromDataTypes(dataTypes: Seq[DataType]): RowOrdering = + def forSchema(dataTypes: Seq[DataType]): RowOrdering = new RowOrdering(dataTypes.zipWithIndex.map { case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..fb4217a44807b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -94,6 +94,9 @@ sealed trait Partitioning { * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean + + /** Returns the expressions that are used to key the partitioning. */ + def keyExpressions: Seq[Expression] } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case UnknownPartitioning(_) => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object SinglePartition extends Partitioning { @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object BroadcastPartitioning extends Partitioning { @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } /** @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = expressions + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def eval(input: Row): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 39dd14e796f06..f35b1cfb9ac43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1080,7 +1080,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: Nil + Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil } protected[sql] def openSession(): SQLSession = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d67e5c40a4727..a56f6a100a515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,21 +28,30 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair +object Exchange { + /** Returns true when the ordering expressions are a subset of the key. */ + def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { + desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) + } +} + /** - * Shuffle data according to a new partition rule, and sort inside each partition if necessary. - * @param newPartitioning The new partitioning way that required by parent - * @param sort Whether we will sort inside each partition - * @param child Child operator + * :: DeveloperApi :: + * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each + * resulting partition based on expressions from the partition key. It is invalid to construct an + * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. */ @DeveloperApi case class Exchange( newPartitioning: Partitioning, - sort: Boolean, + newOrdering: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning + override def outputOrdering: Seq[SortOrder] = newOrdering + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ @@ -51,6 +60,20 @@ case class Exchange( private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val keyOrdering = { + if (newOrdering.nonEmpty) { + val key = newPartitioning.keyExpressions + val boundOrdering = newOrdering.map { o => + val ordinal = key.indexOf(o.child) + if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") + o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + } + new RowOrdering(boundOrdering) + } else { + null // Ordering will not be used + } + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => @@ -62,7 +85,9 @@ case class Exchange( // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) { + val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold + + val rdd = if (willMergeSort || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -75,21 +100,17 @@ case class Exchange( } } val part = new HashPartitioner(numPartitions) - val shuffled = sort match { - case false => new ShuffledRDD[Row, Row, Row](rdd, part) - case true => - val sortingExpressions = expressions.zipWithIndex.map { - case (exp, index) => - new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending) - } - val ordering = new RowOrdering(sortingExpressions, child.output) - new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) - } + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Row, Row](rdd, part) + } shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - val rdd = if (sortBasedShuffleOn) { + val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { child.execute().mapPartitions { iter => @@ -102,7 +123,12 @@ case class Exchange( implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Null, Null](rdd, part) + } shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -135,27 +161,35 @@ case class Exchange( * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. + * each operator by inserting [[Exchange]] Operators where required. Also ensure that the + * required input partition ordering requirements are met. */ -private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // Check if every child's outputPartitioning satisfies the corresponding + // True iff every child's outputPartitioning satisfies the corresponding // required data distribution. def meetsRequirements: Boolean = - !operator.requiredChildDistribution.zip(operator.children).map { + operator.requiredChildDistribution.zip(operator.children).forall { case (required, child) => val valid = child.outputPartitioning.satisfies(required) logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid - }.exists(!_) + } + + // True iff any of the children are incorrectly sorted. + def needsAnySort: Boolean = + operator.requiredChildOrdering.zip(operator.children).exists { + case (required, child) => required.nonEmpty && required != child + } + - // Check if outputPartitionings of children are compatible with each other. + // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, // A dataset is range partitioned by "a.asc" (RangePartitioning) and another @@ -172,40 +206,61 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case Seq(a,b) => a compatibleWith b }.exists(!_) - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary( + // Adds Exchange or Sort operators as required + def addOperatorsIfNecessary( partitioning: Partitioning, - child: SparkPlan, - rowOrdering: Option[Ordering[Row]] = None): SparkPlan = { - val needSort = child.outputOrdering != rowOrdering - if (child.outputPartitioning != partitioning || needSort) { - // TODO: if only needSort, we need only sort each partition instead of an Exchange - Exchange(partitioning, sort = needSort, child) + rowOrdering: Seq[SortOrder], + child: SparkPlan): SparkPlan = { + val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering + val needsShuffle = child.outputPartitioning != partitioning + val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) + + if (needSort && needsShuffle && canSortWithShuffle) { + Exchange(partitioning, rowOrdering, child) } else { - child + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) + } else { + child + } + + val withSort = if (needSort) { + Sort(rowOrdering, global = false, withShuffle) + } else { + withShuffle + } + + withSort } } - if (meetsRequirements && compatible) { + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip( - operator.children.zip(operator.requiredChildOrdering) - ).map { - case (AllTuples, (child, _)) => - addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), (child, rowOrdering)) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering) - case (OrderedDistribution(ordering), (child, None)) => - addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, (child, _)) => child - case (dist, _) => sys.error(s"Don't know how to ensure $dist") + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), Nil, child) + + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + Sort(rowOrdering, global = false, child) + + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(repartitionedChildren) + + operator.withNewChildren(fixedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 748c478a5a839..9c5ecb984245b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ Seq.fill(children.size)(UnspecifiedDistribution) /** Specifies how data is ordered in each partition. */ - def outputOrdering: Option[Ordering[Row]] = None + def outputOrdering: Seq[SortOrder] = Nil /** Specifies sort order for each partition requirements on the input data for this operator. */ - def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None) + def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) /** * Runs this query returning the result as an RDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c6ff8c30c24e7..e680b169c5975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -308,7 +308,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange( - HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil + HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index cbfcca1ea1546..a13c7f7c8611b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -42,12 +42,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends iter.map(resuableProjection) } - /** - * outputOrdering of Project is not always same with child's outputOrdering if the certain - * key is pruned, however, if the key is pruned then we must not require child using this - * ordering from upper layer, so it is fine to keep it to avoid some unnecessary sorting. - */ - override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -63,7 +58,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { iter.filter(conditionEvaluator) } - override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -111,7 +106,7 @@ case class Limit(limit: Int, child: SparkPlan) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition - override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def executeCollect(): Array[Row] = child.executeTake(limit) @@ -158,7 +153,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) - override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -185,7 +180,7 @@ case class Sort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -217,7 +212,7 @@ case class ExternalSort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder)) + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 8262a3c14b0b3..e2393ba3bb8f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -48,19 +48,19 @@ case class SortMergeJoin( // this is to manually construct an ordering that can be used to compare keys from both sides private val keyOrdering: RowOrdering = - RowOrdering.getOrderingFromDataTypes(leftKeys.map(_.dataType)) + RowOrdering.forSchema(leftKeys.map(_.dataType)) - private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] = - newOrdering(keys.map(SortOrder(_, Ascending)), side.output) + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) - override def outputOrdering: Option[Ordering[Row]] = Some(requiredOrders(leftKeys, left)) - - override def requiredChildOrdering: Seq[Option[Ordering[Row]]] = - Some(requiredOrders(leftKeys, left)) :: Some(requiredOrders(rightKeys, right)) :: Nil + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + override def execute(): RDD[Row] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy())