Skip to content

Commit

Permalink
add support for left semi join
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed May 20, 2014
1 parent 753b04d commit 14cff80
Show file tree
Hide file tree
Showing 27 changed files with 197 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ case object Inner extends JoinType
case object LeftOuter extends JoinType
case object RightOuter extends JoinType
case object FullOuter extends JoinType
case object LeftSemi extends JoinType
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
TakeOrdered ::
PartialAggregation ::
LeftSemiJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>

object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find leftsemi joins where at least some predicates can be evaluated by matching hash keys
// using the HashFilteredJoin pattern.
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin =
execution.LeftSemiJoinHash(leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), LeftSemi, condition)(sparkContext) :: Nil
case _ => Nil
}
}

object HashJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys
Expand Down
144 changes: 144 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,150 @@ case class HashJoin(
}
}

/**
* :: DeveloperApi ::
*/
@DeveloperApi
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
}

val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}

def output = left.output

@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)

def execute() = {

buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
// TODO: Use Spark's HashMap implementation.
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
var currentRow: Row = null

// Create a mapping of buildKeys -> rows
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += currentRow.copy()
}
}

new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatched: Boolean = false

private[this] val joinKeys = streamSideKeyGenerator()

override final def hasNext: Boolean =
streamIter.hasNext && fetchNext()

override final def next() = {
currentStreamedRow
}

/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatched = false
while (!currentHashMatched && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatched = true
}
}
currentHashMatched
}
}
}
}
}

/**
* :: DeveloperApi ::
*/
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
(@transient sc: SparkContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.

override def outputPartitioning: Partitioning = streamed.outputPartitioning

override def otherCopyArgs = sc :: Nil

def output = left.output

/** The Streamed Relation */
def left = streamed
/** The Broadcast relation */
def right = broadcast

@transient lazy val boundCondition =
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))


def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow

streamedIter.filter(streamedRow => {
var i = 0
var matched = false

while (i < broadcastedRelation.value.size && !matched) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matched = true
}
i += 1
}
matched
}).map(streamedRow => (streamedRow, null))
}

streamedPlusMatches.map(_._1)
}
}

/**
* :: DeveloperApi ::
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
DataSinks,
Scripts,
PartialAggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
CartesianProduct,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ private[hive] object HiveQl {
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
}
assert(other.size <= 1, "Unhandled join clauses.")
Join(nodeToRelation(relation1),
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"lateral_view_cp",
"lateral_view_outer",
"lateral_view_ppd",
"leftsemijoin",
"leftsemijoin_mr",
"lineage1",
"literal_double",
"literal_ints",
Expand Down

0 comments on commit 14cff80

Please sign in to comment.