Skip to content

Commit

Permalink
improvement according to Michael
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Jun 3, 2014
1 parent 8d4a121 commit 4c726e5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
// All predicates can be evaluated for left semi join (those that are in the WHERE
// clause can only from left table, so they can all be pushed down.)
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
case join @ Join(left, right, joinType, condition) =>
logger.debug(s"Considering hash join on: $condition")
splitPredicates(condition.toSeq, join)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find leftsemi joins where at least some predicates can be evaluated by matching hash keys
// using the HashFilteredJoin pattern.
// Find left semi joins where at least some predicates can be evaluated by matching hash
// keys using the HashFilteredJoin pattern.
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = execution.LeftSemiJoinHash(
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
Expand Down
49 changes: 18 additions & 31 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,29 +142,23 @@ case class HashJoin(

/**
* :: DeveloperApi ::
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
*/
@DeveloperApi
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
}

val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
val (buildPlan, streamedPlan) = (right, left)
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)

def output = left.output

Expand All @@ -175,24 +169,18 @@ case class LeftSemiJoinHash(
def execute() = {

buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
// TODO: Use Spark's HashMap implementation.
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
val hashTable = new java.util.HashSet[Row]()
var currentRow: Row = null

// Create a mapping of buildKeys -> rows
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
existingMatchList
val keyExists = hashTable.contains(rowKey)
if (!keyExists) {
hashTable.add(rowKey)
}
matchList += currentRow.copy()
}
}

Expand Down Expand Up @@ -220,7 +208,7 @@ case class LeftSemiJoinHash(
while (!currentHashMatched && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatched = true
currentHashMatched = hashTable.contains(joinKeys.currentValue)
}
}
currentHashMatched
Expand All @@ -232,6 +220,8 @@ case class LeftSemiJoinHash(

/**
* :: DeveloperApi ::
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
* for hash join.
*/
@DeveloperApi
case class LeftSemiJoinBNL(
Expand Down Expand Up @@ -261,26 +251,23 @@ case class LeftSemiJoinBNL(
def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow

streamedIter.filter(streamedRow => {
var i = 0
var matched = false

while (i < broadcastedRelation.value.size && !matched) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matched = true
}
i += 1
}
matched
}).map(streamedRow => (streamedRow, null))
})
}

streamedPlusMatches.map(_._1)
}
}

Expand Down

0 comments on commit 4c726e5

Please sign in to comment.