diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index a2719c5060496..ec6a05542659d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -62,28 +62,29 @@ case class SortMergeJoin( private[this] var rightElement: Row = _ private[this] var leftKey: Row = _ private[this] var rightKey: Row = _ - private[this] var currentlMatches: CompactBuffer[Row] = _ - private[this] var currentrMatches: CompactBuffer[Row] = _ - private[this] var currentlPosition: Int = -1 - private[this] var currentrPosition: Int = -1 + private[this] var leftMatches: CompactBuffer[Row] = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var leftPosition: Int = -1 + private[this] var rightPosition: Int = -1 - override final def hasNext: Boolean = currentlPosition != -1 || nextMatchingPair + override final def hasNext: Boolean = leftPosition != -1 || nextMatchingPair override final def next(): Row = { - if (!hasNext) { - return null - } - val joinedRow = - joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition)) - currentrPosition += 1 - if (currentrPosition >= currentrMatches.size) { - currentlPosition += 1 - currentrPosition = 0 - if (currentlPosition >= currentlMatches.size) { - currentlPosition = -1 + if (hasNext) { + val joinedRow = joinRow(leftMatches(leftPosition), rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + leftPosition += 1 + rightPosition = 0 + if (leftPosition >= leftMatches.size) { + leftPosition = -1 + } } + joinedRow + } else { + // according to Scala doc, this is undefined + null } - joinedRow } private def fetchLeft() = { @@ -104,13 +105,12 @@ case class SortMergeJoin( } } - private def fetchFirst() = { + private def initialize() = { fetchLeft() fetchRight() - currentrPosition = 0 } // initialize iterator - fetchFirst() + initialize() /** * Searches the left/right iterator for the next rows that matches. @@ -119,50 +119,42 @@ case class SortMergeJoin( * of tuples. */ private def nextMatchingPair(): Boolean = { - if (currentlPosition > -1) { - true - } else { - currentlPosition = -1 - currentlMatches = null + if (leftPosition == -1) { + leftMatches = null var stop: Boolean = false while (!stop && leftElement != null && rightElement != null) { - if (ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull) { - stop = true - } else if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { + stop = ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull + if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { fetchRight() - } else { // if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) + } else if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) { fetchLeft() } } - currentrMatches = new CompactBuffer[Row]() + rightMatches = new CompactBuffer[Row]() while (stop && rightElement != null) { - currentrMatches += rightElement + rightMatches += rightElement fetchRight() - if (ordering.compare(leftKey, rightKey) != 0) { - stop = false - } + // exit loop when run out of right matches + stop = ordering.compare(leftKey, rightKey) == 0 } - if (currentrMatches.size > 0) { + if (rightMatches.size > 0) { stop = false - currentlMatches = new CompactBuffer[Row]() + leftMatches = new CompactBuffer[Row]() val leftMatch = leftKey.copy() while (!stop && leftElement != null) { - currentlMatches += leftElement + leftMatches += leftElement fetchLeft() - if (ordering.compare(leftKey, leftMatch) != 0) { - stop = true - } + // exit loop when run out of left matches + stop = ordering.compare(leftKey, leftMatch) != 0 } } - if (currentlMatches == null) { - false - } else { - currentlPosition = 0 - currentrPosition = 0 - true + if (leftMatches != null) { + leftPosition = 0 + rightPosition = 0 } } + leftPosition > -1 } } }