Skip to content

Commit

Permalink
fix several errors
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Apr 3, 2015
1 parent 95db7ad commit 303b6da
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
}
}
val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending))
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
val ordering = new RowOrdering(sortingExpressions, child.output)
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,21 @@ case class SortMergeJoin(
private[this] var currentlPosition: Int = -1
private[this] var currentrPosition: Int = -1

override final def hasNext: Boolean =
(currentlPosition != -1 && currentlPosition < currentlMatches.size) ||
nextMatchingPair
override final def hasNext: Boolean = currentlPosition != -1 || nextMatchingPair

override final def next(): Row = {
if (!hasNext) {
return null
}
val joinedRow =
joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition))
currentrPosition += 1
if (currentrPosition >= currentrMatches.size) {
currentlPosition += 1
currentrPosition = 0
if (currentlPosition >= currentlMatches.size) {
currentlPosition = -1
}
}
joinedRow
}
Expand All @@ -100,13 +104,13 @@ case class SortMergeJoin(
}
}

// initialize iterator
private def initialize() = {
private def fetchFirst() = {
fetchLeft()
fetchRight()
currentrPosition = 0
}

initialize()
// initialize iterator
fetchFirst()

/**
* Searches the left/right iterator for the next rows that matches.
Expand All @@ -115,49 +119,49 @@ case class SortMergeJoin(
* of tuples.
*/
private def nextMatchingPair(): Boolean = {
currentlPosition = -1
currentlMatches = null
var stop: Boolean = false
while (!stop && leftElement != null && rightElement != null) {
if (ordering.compare(leftKey, rightKey) > 0) {
fetchRight()
} else if (ordering.compare(leftKey, rightKey) < 0) {
fetchLeft()
} else {
stop = true
if (currentlPosition > -1) {
true
} else {
currentlPosition = -1
currentlMatches = null
var stop: Boolean = false
while (!stop && leftElement != null && rightElement != null) {
if (ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull) {
stop = true
} else if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) {
fetchRight()
} else { //if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull)
fetchLeft()
}
}
}
currentrMatches = new CompactBuffer[Row]()
while (stop && rightElement != null) {
if (!rightKey.anyNull) {
currentrMatches = new CompactBuffer[Row]()
while (stop && rightElement != null) {
currentrMatches += rightElement
fetchRight()
if (ordering.compare(leftKey, rightKey) != 0) {
stop = false
}
}
fetchRight()
if (ordering.compare(leftKey, rightKey) != 0) {
if (currentrMatches.size > 0) {
stop = false
}
}
if (currentrMatches.size > 0) {
stop = false
currentlMatches = new CompactBuffer[Row]()
val leftMatch = leftKey.copy()
while (!stop && leftElement != null) {
if (!leftKey.anyNull) {
currentlMatches = new CompactBuffer[Row]()
val leftMatch = leftKey.copy()
while (!stop && leftElement != null) {
currentlMatches += leftElement
}
fetchLeft()
if (ordering.compare(leftKey, leftMatch) != 0) {
stop = true
fetchLeft()
if (ordering.compare(leftKey, leftMatch) != 0) {
stop = true
}
}
}
}

if (currentlMatches == null) {
false
} else {
currentlPosition = 0
currentrPosition = 0
true
if (currentlMatches == null) {
false
} else {
currentlPosition = 0
currentrPosition = 0
true
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
expectedAnswer: Seq[Row],
ct: ClassTag[_]) = {
before()
conf.setConf("spark.sql.autoSortMergeJoin", "false")

var df = sql(query)

Expand Down Expand Up @@ -178,6 +179,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""")
}

conf.setConf("spark.sql.autoSortMergeJoin", "true")
after()
}

Expand Down

0 comments on commit 303b6da

Please sign in to comment.