From 3af6ba546b016d8bd2fa8d9a729a7bc9993e8e50 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sat, 11 Apr 2015 21:15:48 +0800 Subject: [PATCH] use buffer for only one side --- .../sql/execution/joins/SortMergeJoin.scala | 49 ++++++++----------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index fd65320d55139..7e7b692d401cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -73,22 +73,23 @@ case class SortMergeJoin( private[this] var rightElement: Row = _ private[this] var leftKey: Row = _ private[this] var rightKey: Row = _ - private[this] var leftMatches: CompactBuffer[Row] = _ private[this] var rightMatches: CompactBuffer[Row] = _ - private[this] var leftPosition: Int = -1 private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ - override final def hasNext: Boolean = leftPosition != -1 || nextMatchingPair + override final def hasNext: Boolean = nextMatchingPair() override final def next(): Row = { if (hasNext) { - val joinedRow = joinRow(leftMatches(leftPosition), rightMatches(rightPosition)) + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) rightPosition += 1 if (rightPosition >= rightMatches.size) { - leftPosition += 1 rightPosition = 0 - if (leftPosition >= leftMatches.size) { - leftPosition = -1 + fetchLeft() + if (leftElement == null || ordering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null } } joinedRow @@ -130,9 +131,7 @@ case class SortMergeJoin( * of tuples. */ private def nextMatchingPair(): Boolean = { - if (leftPosition == -1) { - leftMatches = null - var stop: Boolean = false + if (!stop && rightElement != null) { while (!stop && leftElement != null && rightElement != null) { stop = ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) { @@ -142,27 +141,21 @@ case class SortMergeJoin( } } rightMatches = new CompactBuffer[Row]() - while (stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - // exit loop when run out of right matches - stop = ordering.compare(leftKey, rightKey) == 0 - } - if (rightMatches.size > 0) { - leftMatches = new CompactBuffer[Row]() - val leftMatch = leftKey.copy() - while (ordering.compare(leftKey, leftMatch) == 0 && leftElement != null) { - leftMatches += leftElement - fetchLeft() + if (stop) { + stop = false + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + // exit loop when run out of right matches + stop = ordering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey } - } - - if (leftMatches != null) { - leftPosition = 0 - rightPosition = 0 } } - leftPosition > -1 + rightMatches != null && rightMatches.size > 0 } } }