diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 3c0ab080e7f4d..1bf3baa75ace0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -41,7 +41,7 @@ case class SortMergeJoin( override def outputPartitioning: Partitioning = left.outputPartitioning - override def requiredChildDistribution: Seq[ClusteredOrderedDistribution] = + override def requiredChildDistribution: Seq[Distribution] = ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending)) @@ -62,7 +62,6 @@ case class SortMergeJoin( private[this] var rightElement: Row = _ private[this] var leftKey: Row = _ private[this] var rightKey: Row = _ - private[this] var read: Boolean = false private[this] var currentlMatches: CompactBuffer[Row] = _ private[this] var currentrMatches: CompactBuffer[Row] = _ private[this] var currentlPosition: Int = -1 @@ -70,7 +69,7 @@ case class SortMergeJoin( override final def hasNext: Boolean = (currentlPosition != -1 && currentlPosition < currentlMatches.size) || - (leftIter.hasNext && rightIter.hasNext && nextMatchingPair) + nextMatchingPair override final def next(): Row = { val joinedRow = @@ -83,6 +82,32 @@ case class SortMergeJoin( joinedRow } + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + // initialize iterator + private def initialize() = { + fetchLeft() + fetchRight() + } + + initialize() + /** * Searches the left/right iterator for the next rows that matches. * @@ -92,42 +117,33 @@ case class SortMergeJoin( private def nextMatchingPair(): Boolean = { currentlPosition = -1 currentlMatches = null - if (rightElement == null) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + var stop: Boolean = false + while (!stop && leftElement != null && rightElement != null) { + if (ordering.compare(leftKey, rightKey) > 0) + fetchRight() + else if (ordering.compare(leftKey, rightKey) < 0) + fetchLeft() + else + stop = true } - while (currentlMatches == null && leftIter.hasNext) { - if (!read) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } - while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } - currentrMatches = new CompactBuffer[Row]() - while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) { + currentrMatches = new CompactBuffer[Row]() + while (stop && rightElement != null) { + if (!rightKey.anyNull) currentrMatches += rightElement - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } - if (ordering.compare(leftKey, rightKey) == 0) { - currentrMatches += rightElement - } - if (currentrMatches.size > 0) { - // there exists rows match in right table, should search left table - currentlMatches = new CompactBuffer[Row]() - val leftMatch = leftKey.copy() - while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) { - currentlMatches += leftElement - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } - if (ordering.compare(leftKey, leftMatch) == 0) { + fetchRight() + if (ordering.compare(leftKey, rightKey) != 0) + stop = false + } + if (currentrMatches.size > 0) { + stop = false + currentlMatches = new CompactBuffer[Row]() + val leftMatch = leftKey.copy() + while (!stop && leftElement != null) { + if (!leftKey.anyNull) currentlMatches += leftElement - } else { - read = true - } + fetchLeft() + if (ordering.compare(leftKey, leftMatch) != 0) + stop = true } }