Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35350][SQL] Add code-gen for left semi sort merge join #32528

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,18 @@ case class SortMergeJoinExec(
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
}

// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
private val onlyBufferFirstMatchedRow = (joinType, condition) match {
case (LeftExistence(_), None) => true
case _ => false
}

private def getInMemoryThreshold: Int = {
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
if (onlyBufferFirstMatchedRow) {
1
} else {
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
}
}

protected override def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -236,7 +246,7 @@ case class SortMergeJoinExec(
inMemoryThreshold,
spillThreshold,
cleanupResources,
condition.isEmpty
onlyBufferFirstMatchedRow
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -273,7 +283,7 @@ case class SortMergeJoinExec(
inMemoryThreshold,
spillThreshold,
cleanupResources,
condition.isEmpty
onlyBufferFirstMatchedRow
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -317,7 +327,7 @@ case class SortMergeJoinExec(
inMemoryThreshold,
spillThreshold,
cleanupResources,
condition.isEmpty
onlyBufferFirstMatchedRow
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -354,7 +364,7 @@ case class SortMergeJoinExec(
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys))
case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
throw new IllegalArgumentException(
Expand All @@ -365,7 +375,7 @@ case class SortMergeJoinExec(
private lazy val bufferedOutput = bufferedPlan.output

override def supportCodegen: Boolean = joinType match {
case _: InnerLike | LeftOuter | RightOuter => true
case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true
case _ => false
}

Expand Down Expand Up @@ -435,7 +445,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed rows has any NULL keys.
val handleStreamedAnyNull = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"""
|$streamedRow = null;
Expand All @@ -457,7 +467,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed keys has no match with buffered side.
val handleStreamedWithoutMatch = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"$streamedRow = null;"
case LeftOuter | RightOuter =>
Expand All @@ -468,6 +478,17 @@ case class SortMergeJoinExec(
s"SortMergeJoin.genScanner should not take $x as the JoinType")
}

val addRowToBuffer =
if (onlyBufferFirstMatchedRow) {
s"""
|if ($matches.isEmpty()) {
| $matches.add((UnsafeRow) $bufferedRow);
|}
""".stripMargin
} else {
s"$matches.add((UnsafeRow) $bufferedRow);"
}

// Generate a function to scan both streamed and buffered sides to find a match.
// Return whether a match is found.
//
Expand All @@ -483,17 +504,18 @@ case class SortMergeJoinExec(
// The function has the following step:
// - Step 1: Find the next `streamedRow` with non-null join keys.
// For `streamedRow` with null join keys (`handleStreamedAnyNull`):
// 1. Inner join: skip the row. `matches` will be cleared later when hitting the
// next `streamedRow` with non-null join keys.
// 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
// hitting the next `streamedRow` with non-null join
// keys.
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row,
// and return false.
//
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
// Use `bufferedRow` to iterate buffered side to put all matched rows into
// `matches`. Return true when getting all matched rows.
// `matches` (`addRowToBuffer`). Return true when getting all matched rows.
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
// 1. Inner join: skip the row.
// 1. Inner and Left Semi join: skip the row.
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
// empty).
ctx.addNewFunction("findNextJoinRows",
Expand Down Expand Up @@ -543,7 +565,7 @@ case class SortMergeJoinExec(
| $handleStreamedWithoutMatch
| }
| } else {
| $matches.add((UnsafeRow) $bufferedRow);
| $addRowToBuffer
| $bufferedRow = null;
| }
| } while ($streamedRow != null);
Expand Down Expand Up @@ -639,19 +661,22 @@ case class SortMergeJoinExec(
streamedVars ++ bufferedVars
case RightOuter =>
bufferedVars ++ streamedVars
case LeftSemi =>
streamedVars
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
}

val (beforeLoop, condCheck) = if (condition.isDefined) {
val (streamedBeforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
val loaded = ctx.freshName("loaded")
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
// Generate code for condition
ctx.currentVars = resultVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
ctx.currentVars = streamedVars ++ bufferedVars
val cond = BindReferences.bindReference(
condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before =
s"""
Expand All @@ -674,65 +699,129 @@ case class SortMergeJoinExec(
|}
|$bufferedAfter
""".stripMargin
(before, checking)
(before, checking.trim)
} else {
(evaluateVariables(streamedVars), "")
}

val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

lazy val innerJoin =
val beforeLoop =
s"""
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| while ($iterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin

lazy val outerJoin = {
val hasOutputRow = ctx.freshName("hasOutputRow")
|${streamedVarDecl.mkString("\n")}
|${streamedBeforeLoop.trim}
|scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
""".stripMargin
val outputRow =
s"""
|while ($streamedInput.hasNext()) {
| findNextJoinRows($streamedInput, $bufferedInput);
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| boolean $hasOutputRow = false;
|
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
| while ($iterator.hasNext() || !$hasOutputRow) {
| InternalRow $bufferedRow = $iterator.hasNext() ?
| (InternalRow) $iterator.next() : null;
| ${condCheck.trim}
| $hasOutputRow = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
|$numOutput.add(1);
|${consume(ctx, resultVars)}
""".stripMargin
}
val findNextJoinRows = s"findNextJoinRows($streamedInput, $bufferedInput)"
val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

joinType match {
case _: InnerLike => innerJoin
case LeftOuter | RightOuter => outerJoin
case _: InnerLike =>
codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we pass beforeLoop.trim so that we don't need to do it in all the 3 methods?

Copy link
Contributor Author

@c21 c21 May 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually after double checking, we do not need to do beforeLoop.trim as beforeLoop already has stripMargin, and has no trailing spaces. Also updated to avoid repeated conditionCheck.trim

eagerCleanup)
case LeftOuter | RightOuter =>
codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
case LeftSemi =>
codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
}
}

/**
* Generates the code for Inner join.
*/
private def codegenInner(
findNextJoinRows: String,
beforeLoop: String,
matchIterator: String,
bufferedRow: String,
conditionCheck: String,
outputRow: String,
eagerCleanup: String): String = {
s"""
|while ($findNextJoinRows) {
| $beforeLoop
| while ($matchIterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
| $conditionCheck
| $outputRow
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

/**
* Generates the code for Left or Right Outer join.
*/
private def codegenOuter(
streamedInput: String,
findNextJoinRows: String,
beforeLoop: String,
matchIterator: String,
bufferedRow: String,
conditionCheck: String,
hasOutputRow: String,
outputRow: String,
eagerCleanup: String): String = {
s"""
|while ($streamedInput.hasNext()) {
| $findNextJoinRows;
| $beforeLoop
| boolean $hasOutputRow = false;
|
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
| while ($matchIterator.hasNext() || !$hasOutputRow) {
| InternalRow $bufferedRow = $matchIterator.hasNext() ?
| (InternalRow) $matchIterator.next() : null;
| $conditionCheck
| $hasOutputRow = true;
| $outputRow
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

/**
* Generates the code for Left Semi join.
*/
private def codegenSemi(
findNextJoinRows: String,
beforeLoop: String,
matchIterator: String,
bufferedRow: String,
conditionCheck: String,
hasOutputRow: String,
outputRow: String,
eagerCleanup: String): String = {
s"""
|while ($findNextJoinRows) {
| $beforeLoop
| boolean $hasOutputRow = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this flag if we are sure matchIterator has at most one element?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - matchIterator will only has at most one element if join condition is empty. So yes we don't need this if join condition is empty. But consider the extra code is just a while loop check on hasOutputRow, and set value of hasOutputRow, I don't see much value to specialize another code-gen for left semi join without join condition. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let's keep it

|
| while (!$hasOutputRow && $matchIterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
| $conditionCheck
| $hasOutputRow = true;
| $outputRow
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
copy(left = newLeft, right = newRight)
Expand Down Expand Up @@ -783,8 +872,7 @@ private[joins] class SortMergeJoinScanner(
private[this] var matchJoinKey: InternalRow = _
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray =
new ExternalAppendOnlyUnsafeRowArray(if (onlyBufferFirstMatch) 1 else inMemoryThreshold,
spillThreshold)
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)

// Initialization (note: do _not_ want to advance streamed here).
advancedBufferedToRowWithNullFreeJoinKey()
Expand Down
Loading