From 303b6da2ef427d08e367ab85e29fd18a6572c90a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 31 Mar 2015 00:48:35 -0700 Subject: [PATCH] fix several errors --- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../sql/execution/joins/SortMergeJoin.scala | 88 ++++++++++--------- .../spark/sql/hive/StatisticsSuite.scala | 2 + 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6c01dee9a8969..58c62997a843e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -86,7 +86,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) - implicit val ordering = new RowOrdering(sortingExpressions, child.output) + val ordering = new RowOrdering(sortingExpressions, child.output) val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index c241a7ae69cde..7048f91f80eab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -67,17 +67,21 @@ case class SortMergeJoin( private[this] var currentlPosition: Int = -1 private[this] var currentrPosition: Int = -1 - override final def hasNext: Boolean = - (currentlPosition != -1 && currentlPosition < currentlMatches.size) || - nextMatchingPair + override final def hasNext: Boolean = currentlPosition != -1 || nextMatchingPair override final def next(): Row = { + if (!hasNext) { + return null + } val joinedRow = joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition)) currentrPosition += 1 if (currentrPosition >= currentrMatches.size) { currentlPosition += 1 currentrPosition = 0 + if (currentlPosition >= currentlMatches.size) { + currentlPosition = -1 + } } joinedRow } @@ -100,13 +104,13 @@ case class SortMergeJoin( } } - // initialize iterator - private def initialize() = { + private def fetchFirst() = { fetchLeft() fetchRight() + currentrPosition = 0 } - - initialize() + // initialize iterator + fetchFirst() /** * Searches the left/right iterator for the next rows that matches. @@ -115,49 +119,49 @@ case class SortMergeJoin( * of tuples. */ private def nextMatchingPair(): Boolean = { - currentlPosition = -1 - currentlMatches = null - var stop: Boolean = false - while (!stop && leftElement != null && rightElement != null) { - if (ordering.compare(leftKey, rightKey) > 0) { - fetchRight() - } else if (ordering.compare(leftKey, rightKey) < 0) { - fetchLeft() - } else { - stop = true + if (currentlPosition > -1) { + true + } else { + currentlPosition = -1 + currentlMatches = null + var stop: Boolean = false + while (!stop && leftElement != null && rightElement != null) { + if (ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull) { + stop = true + } else if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { + fetchRight() + } else { //if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) + fetchLeft() + } } - } - currentrMatches = new CompactBuffer[Row]() - while (stop && rightElement != null) { - if (!rightKey.anyNull) { + currentrMatches = new CompactBuffer[Row]() + while (stop && rightElement != null) { currentrMatches += rightElement + fetchRight() + if (ordering.compare(leftKey, rightKey) != 0) { + stop = false + } } - fetchRight() - if (ordering.compare(leftKey, rightKey) != 0) { + if (currentrMatches.size > 0) { stop = false - } - } - if (currentrMatches.size > 0) { - stop = false - currentlMatches = new CompactBuffer[Row]() - val leftMatch = leftKey.copy() - while (!stop && leftElement != null) { - if (!leftKey.anyNull) { + currentlMatches = new CompactBuffer[Row]() + val leftMatch = leftKey.copy() + while (!stop && leftElement != null) { currentlMatches += leftElement - } - fetchLeft() - if (ordering.compare(leftKey, leftMatch) != 0) { - stop = true + fetchLeft() + if (ordering.compare(leftKey, leftMatch) != 0) { + stop = true + } } } - } - if (currentlMatches == null) { - false - } else { - currentlPosition = 0 - currentrPosition = 0 - true + if (currentlMatches == null) { + false + } else { + currentlPosition = 0 + currentrPosition = 0 + true + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ccd0e5aa51f95..dc1d9fbd299e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -144,6 +144,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { expectedAnswer: Seq[Row], ct: ClassTag[_]) = { before() + conf.setConf("spark.sql.autoSortMergeJoin", "false") var df = sql(query) @@ -178,6 +179,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") } + conf.setConf("spark.sql.autoSortMergeJoin", "true") after() }