From 171001fea71669a009e440dfb6caf927aecd924b Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sat, 11 Apr 2015 19:26:01 +0800 Subject: [PATCH] change default outputordering --- .../apache/spark/sql/execution/Aggregate.scala | 2 ++ .../apache/spark/sql/execution/Exchange.scala | 4 ++-- .../apache/spark/sql/execution/SparkPlan.scala | 1 + .../spark/sql/execution/basicOperators.scala | 8 ++++++++ .../scala/org/apache/spark/sql/JoinSuite.scala | 17 ++++++++++------- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 18b1ba4c5c4b9..296c71df6a11e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -60,6 +60,8 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def outputOrdering: Seq[SortOrder] = Nil + /** * An aggregate that needs to be computed for each row in a group. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index dc7e3a5c41070..e6ac2926320f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -161,8 +161,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. + // Check if the ordering we want to ensure is the same as the child's output + // ordering. If so, we do not need to add the Sort operator. def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan = if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 41d80969e376b..b3252da2df201 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -183,6 +183,7 @@ private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { self: Product => override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1f5251a20376f..e13a3699318cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -70,6 +70,8 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: override def execute(): RDD[Row] = { child.execute().map(_.copy()).sample(withReplacement, fraction, seed) } + + override def outputOrdering: Seq[SortOrder] = Nil } /** @@ -146,6 +148,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -171,6 +175,8 @@ case class Sort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -201,6 +207,8 @@ case class ExternalSort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a2c9778883389..826db143a9211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -95,13 +95,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - conf.setConf("spark.sql.autoSortMergeJoin", "true") - Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString) + try { + conf.setConf("spark.sql.autoSortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString) + } } test("broadcasted hash join operator selection") {