Skip to content

Commit

Permalink
Left semi sort merge join code-gen
Browse files Browse the repository at this point in the history
  • Loading branch information
c21 committed May 12, 2021
1 parent ae0579a commit 8eb55c3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ case class SortMergeJoinExec(
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys))
case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
throw new IllegalArgumentException(
Expand All @@ -365,7 +365,7 @@ case class SortMergeJoinExec(
private lazy val bufferedOutput = bufferedPlan.output

override def supportCodegen: Boolean = joinType match {
case _: InnerLike | LeftOuter | RightOuter => true
case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true
case _ => false
}

Expand Down Expand Up @@ -424,8 +424,18 @@ case class SortMergeJoinExec(
// A list to hold all matched rows from buffered side.
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName

// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
val onlyBufferFirstMatchedRow = (joinType, condition) match {
case (LeftSemi, None) => true
case _ => false
}
val inMemoryThreshold =
if (onlyBufferFirstMatchedRow) {
1
} else {
getInMemoryThreshold
}
val spillThreshold = getSpillThreshold
val inMemoryThreshold = getInMemoryThreshold

// Inline mutable state since not many join operations in a task
val matches = ctx.addMutableState(clsName, "matches",
Expand All @@ -435,7 +445,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed rows has any NULL keys.
val handleStreamedAnyNull = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"""
|$streamedRow = null;
Expand All @@ -457,7 +467,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed keys has no match with buffered side.
val handleStreamedWithoutMatch = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"$streamedRow = null;"
case LeftOuter | RightOuter =>
Expand All @@ -468,6 +478,17 @@ case class SortMergeJoinExec(
s"SortMergeJoin.genScanner should not take $x as the JoinType")
}

val addRowToBuffer =
if (onlyBufferFirstMatchedRow) {
s"""
|if ($matches.isEmpty()) {
| $matches.add((UnsafeRow) $bufferedRow);
|}
""".stripMargin
} else {
s"$matches.add((UnsafeRow) $bufferedRow);"
}

// Generate a function to scan both streamed and buffered sides to find a match.
// Return whether a match is found.
//
Expand All @@ -483,17 +504,18 @@ case class SortMergeJoinExec(
// The function has the following step:
// - Step 1: Find the next `streamedRow` with non-null join keys.
// For `streamedRow` with null join keys (`handleStreamedAnyNull`):
// 1. Inner join: skip the row. `matches` will be cleared later when hitting the
// next `streamedRow` with non-null join keys.
// 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
// hitting the next `streamedRow` with non-null join
// keys.
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row,
// and return false.
//
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
// Use `bufferedRow` to iterate buffered side to put all matched rows into
// `matches`. Return true when getting all matched rows.
// `matches` (`addRowToBuffer`). Return true when getting all matched rows.
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
// 1. Inner join: skip the row.
// 1. Inner and Left Semi join: skip the row.
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
// empty).
ctx.addNewFunction("findNextJoinRows",
Expand Down Expand Up @@ -543,7 +565,7 @@ case class SortMergeJoinExec(
| $handleStreamedWithoutMatch
| }
| } else {
| $matches.add((UnsafeRow) $bufferedRow);
| $addRowToBuffer
| $bufferedRow = null;
| }
| } while ($streamedRow != null);
Expand Down Expand Up @@ -639,6 +661,8 @@ case class SortMergeJoinExec(
streamedVars ++ bufferedVars
case RightOuter =>
bufferedVars ++ streamedVars
case LeftSemi =>
streamedVars
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
Expand All @@ -650,8 +674,9 @@ case class SortMergeJoinExec(
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
// Generate code for condition
ctx.currentVars = resultVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
ctx.currentVars = streamedVars ++ bufferedVars
val cond = BindReferences.bindReference(
condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before =
s"""
Expand Down Expand Up @@ -724,9 +749,32 @@ case class SortMergeJoinExec(
""".stripMargin
}

lazy val semiJoin = {
val hasOutputRow = ctx.freshName("hasOutputRow")
s"""
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| boolean $hasOutputRow = false;
|
| while (!$hasOutputRow && $iterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $hasOutputRow = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

joinType match {
case _: InnerLike => innerJoin
case LeftOuter | RightOuter => outerJoin
case LeftSemi => semiJoin
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,28 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
Row(null, null, 6), Row(null, null, 7), Row(null, null, 8), Row(null, null, 9)))
}

test("Left Semi SortMergeJoin should be included in WholeStageCodegen") {
val df1 = spark.range(10).select($"id".as("k1"))
val df2 = spark.range(4).select($"id".as("k2"))
val df3 = spark.range(6).select($"id".as("k3"))

// test one left semi sort merge join
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0), Row(1), Row(2), Row(3)))

// test two left semi sort merge joins
val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi")
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
WholeStageCodegenExec(_ : SortMergeJoinExec) => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3)))
}

test("Inner/Cross BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
}
}

test(s"$testName using SortMergeJoin") {
testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
Expand Down

0 comments on commit 8eb55c3

Please sign in to comment.