Skip to content

Commit

Permalink
[SPARK-35350][SQL] Add code-gen for left semi sort merge join
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

As title. This PR is to add code-gen support for LEFT SEMI sort merge join. The main change is to add `semiJoin` code path in `SortMergeJoinExec.doProduce()` and introduce `onlyBufferFirstMatchedRow` in `SortMergeJoinExec.genScanner()`. The latter is for left semi sort merge join without condition. For this kind of query, we don't need to buffer all matched rows, but only the first one (this is same as non-code-gen code path).

Example query:

```
val df1 = spark.range(10).select($"id".as("k1"))
val df2 = spark.range(4).select($"id".as("k2"))
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi")
```

Example of generated code for the query:

```
== Subtree 5 / 5 (maxMethodCodeSize:302; maxConstantPoolSize:156(0.24% used); numInnerClasses:0) ==
*(5) Project [id#0L AS k1#2L]
+- *(5) SortMergeJoin [id#0L], [k2#6L], LeftSemi
   :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(id#0L, 5), ENSURE_REQUIREMENTS, [id=#27]
   :     +- *(1) Range (0, 10, step=1, splits=2)
   +- *(4) Sort [k2#6L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(k2#6L, 5), ENSURE_REQUIREMENTS, [id=#33]
         +- *(3) Project [id#4L AS k2#6L]
            +- *(3) Range (0, 4, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage5(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=5
/* 006 */ final class GeneratedIteratorForCodegenStage5 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator smj_streamedInput_0;
/* 010 */   private scala.collection.Iterator smj_bufferedInput_0;
/* 011 */   private InternalRow smj_streamedRow_0;
/* 012 */   private InternalRow smj_bufferedRow_0;
/* 013 */   private long smj_value_2;
/* 014 */   private org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray smj_matches_0;
/* 015 */   private long smj_value_3;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] smj_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage5(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */     smj_streamedInput_0 = inputs[0];
/* 026 */     smj_bufferedInput_0 = inputs[1];
/* 027 */
/* 028 */     smj_matches_0 = new org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray(1, 2147483647);
/* 029 */     smj_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     smj_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */
/* 032 */   }
/* 033 */
/* 034 */   private boolean findNextJoinRows(
/* 035 */     scala.collection.Iterator streamedIter,
/* 036 */     scala.collection.Iterator bufferedIter) {
/* 037 */     smj_streamedRow_0 = null;
/* 038 */     int comp = 0;
/* 039 */     while (smj_streamedRow_0 == null) {
/* 040 */       if (!streamedIter.hasNext()) return false;
/* 041 */       smj_streamedRow_0 = (InternalRow) streamedIter.next();
/* 042 */       long smj_value_0 = smj_streamedRow_0.getLong(0);
/* 043 */       if (false) {
/* 044 */         smj_streamedRow_0 = null;
/* 045 */         continue;
/* 046 */
/* 047 */       }
/* 048 */       if (!smj_matches_0.isEmpty()) {
/* 049 */         comp = 0;
/* 050 */         if (comp == 0) {
/* 051 */           comp = (smj_value_0 > smj_value_3 ? 1 : smj_value_0 < smj_value_3 ? -1 : 0);
/* 052 */         }
/* 053 */
/* 054 */         if (comp == 0) {
/* 055 */           return true;
/* 056 */         }
/* 057 */         smj_matches_0.clear();
/* 058 */       }
/* 059 */
/* 060 */       do {
/* 061 */         if (smj_bufferedRow_0 == null) {
/* 062 */           if (!bufferedIter.hasNext()) {
/* 063 */             smj_value_3 = smj_value_0;
/* 064 */             return !smj_matches_0.isEmpty();
/* 065 */           }
/* 066 */           smj_bufferedRow_0 = (InternalRow) bufferedIter.next();
/* 067 */           long smj_value_1 = smj_bufferedRow_0.getLong(0);
/* 068 */           if (false) {
/* 069 */             smj_bufferedRow_0 = null;
/* 070 */             continue;
/* 071 */           }
/* 072 */           smj_value_2 = smj_value_1;
/* 073 */         }
/* 074 */
/* 075 */         comp = 0;
/* 076 */         if (comp == 0) {
/* 077 */           comp = (smj_value_0 > smj_value_2 ? 1 : smj_value_0 < smj_value_2 ? -1 : 0);
/* 078 */         }
/* 079 */
/* 080 */         if (comp > 0) {
/* 081 */           smj_bufferedRow_0 = null;
/* 082 */         } else if (comp < 0) {
/* 083 */           if (!smj_matches_0.isEmpty()) {
/* 084 */             smj_value_3 = smj_value_0;
/* 085 */             return true;
/* 086 */           } else {
/* 087 */             smj_streamedRow_0 = null;
/* 088 */           }
/* 089 */         } else {
/* 090 */           if (smj_matches_0.isEmpty()) {
/* 091 */             smj_matches_0.add((UnsafeRow) smj_bufferedRow_0);
/* 092 */           }
/* 093 */
/* 094 */           smj_bufferedRow_0 = null;
/* 095 */         }
/* 096 */       } while (smj_streamedRow_0 != null);
/* 097 */     }
/* 098 */     return false; // unreachable
/* 099 */   }
/* 100 */
/* 101 */   protected void processNext() throws java.io.IOException {
/* 102 */     while (findNextJoinRows(smj_streamedInput_0, smj_bufferedInput_0)) {
/* 103 */       long smj_value_4 = -1L;
/* 104 */       smj_value_4 = smj_streamedRow_0.getLong(0);
/* 105 */       scala.collection.Iterator<UnsafeRow> smj_iterator_0 = smj_matches_0.generateIterator();
/* 106 */       boolean smj_hasOutputRow_0 = false;
/* 107 */
/* 108 */       while (!smj_hasOutputRow_0 && smj_iterator_0.hasNext()) {
/* 109 */         InternalRow smj_bufferedRow_1 = (InternalRow) smj_iterator_0.next();
/* 110 */
/* 111 */         smj_hasOutputRow_0 = true;
/* 112 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 113 */
/* 114 */         // common sub-expressions
/* 115 */
/* 116 */         smj_mutableStateArray_0[1].reset();
/* 117 */
/* 118 */         smj_mutableStateArray_0[1].write(0, smj_value_4);
/* 119 */         append((smj_mutableStateArray_0[1].getRow()).copy());
/* 120 */
/* 121 */       }
/* 122 */       if (shouldStop()) return;
/* 123 */     }
/* 124 */     ((org.apache.spark.sql.execution.joins.SortMergeJoinExec) references[1] /* plan */).cleanupResources();
/* 125 */   }
/* 126 */
/* 127 */ }
```

### Why are the changes needed?

Improve query CPU performance. Test with one query:

```
 def sortMergeJoin(): Unit = {
    val N = 2 << 20
    codegenBenchmark("left semi sort merge join", N) {
      val df1 = spark.range(N).selectExpr(s"id * 2 as k1")
      val df2 = spark.range(N).selectExpr(s"id * 3 as k2")
      val df = df1.join(df2, col("k1") === col("k2"), "left_semi")
      assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
      df.noop()
    }
  }
```

Seeing 30% of run-time improvement:

```
Running benchmark: left semi sort merge join
  Running case: left semi sort merge join code-gen off
  Stopped after 2 iterations, 1369 ms
  Running case: left semi sort merge join code-gen on
  Stopped after 5 iterations, 2743 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
left semi sort merge join:                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
left semi sort merge join code-gen off              676            685          13          3.1         322.2       1.0X
left semi sort merge join code-gen on               524            549          32          4.0         249.7       1.3X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added unit test in `WholeStageCodegenSuite.scala` and `ExistenceJoinSuite.scala`.

Closes apache#32528 from c21/smj-left-semi.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
c21 authored and cloud-fan committed May 13, 2021
1 parent 5181543 commit c1e995a
Show file tree
Hide file tree
Showing 47 changed files with 1,797 additions and 1,587 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,18 @@ case class SortMergeJoinExec(
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
}

// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
private val onlyBufferFirstMatchedRow = (joinType, condition) match {
case (LeftExistence(_), None) => true
case _ => false
}

private def getInMemoryThreshold: Int = {
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
if (onlyBufferFirstMatchedRow) {
1
} else {
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
}
}

protected override def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -236,7 +246,7 @@ case class SortMergeJoinExec(
inMemoryThreshold,
spillThreshold,
cleanupResources,
condition.isEmpty
onlyBufferFirstMatchedRow
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -273,7 +283,7 @@ case class SortMergeJoinExec(
inMemoryThreshold,
spillThreshold,
cleanupResources,
condition.isEmpty
onlyBufferFirstMatchedRow
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -317,7 +327,7 @@ case class SortMergeJoinExec(
inMemoryThreshold,
spillThreshold,
cleanupResources,
condition.isEmpty
onlyBufferFirstMatchedRow
)
private[this] val joinRow = new JoinedRow

Expand Down Expand Up @@ -354,7 +364,7 @@ case class SortMergeJoinExec(
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys))
case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
throw new IllegalArgumentException(
Expand All @@ -365,7 +375,7 @@ case class SortMergeJoinExec(
private lazy val bufferedOutput = bufferedPlan.output

override def supportCodegen: Boolean = joinType match {
case _: InnerLike | LeftOuter | RightOuter => true
case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true
case _ => false
}

Expand Down Expand Up @@ -435,7 +445,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed rows has any NULL keys.
val handleStreamedAnyNull = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"""
|$streamedRow = null;
Expand All @@ -457,7 +467,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed keys has no match with buffered side.
val handleStreamedWithoutMatch = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"$streamedRow = null;"
case LeftOuter | RightOuter =>
Expand All @@ -468,6 +478,17 @@ case class SortMergeJoinExec(
s"SortMergeJoin.genScanner should not take $x as the JoinType")
}

val addRowToBuffer =
if (onlyBufferFirstMatchedRow) {
s"""
|if ($matches.isEmpty()) {
| $matches.add((UnsafeRow) $bufferedRow);
|}
""".stripMargin
} else {
s"$matches.add((UnsafeRow) $bufferedRow);"
}

// Generate a function to scan both streamed and buffered sides to find a match.
// Return whether a match is found.
//
Expand All @@ -483,17 +504,18 @@ case class SortMergeJoinExec(
// The function has the following step:
// - Step 1: Find the next `streamedRow` with non-null join keys.
// For `streamedRow` with null join keys (`handleStreamedAnyNull`):
// 1. Inner join: skip the row. `matches` will be cleared later when hitting the
// next `streamedRow` with non-null join keys.
// 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
// hitting the next `streamedRow` with non-null join
// keys.
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row,
// and return false.
//
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
// Use `bufferedRow` to iterate buffered side to put all matched rows into
// `matches`. Return true when getting all matched rows.
// `matches` (`addRowToBuffer`). Return true when getting all matched rows.
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
// 1. Inner join: skip the row.
// 1. Inner and Left Semi join: skip the row.
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
// empty).
ctx.addNewFunction("findNextJoinRows",
Expand Down Expand Up @@ -543,7 +565,7 @@ case class SortMergeJoinExec(
| $handleStreamedWithoutMatch
| }
| } else {
| $matches.add((UnsafeRow) $bufferedRow);
| $addRowToBuffer
| $bufferedRow = null;
| }
| } while ($streamedRow != null);
Expand Down Expand Up @@ -639,19 +661,22 @@ case class SortMergeJoinExec(
streamedVars ++ bufferedVars
case RightOuter =>
bufferedVars ++ streamedVars
case LeftSemi =>
streamedVars
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
}

val (beforeLoop, condCheck) = if (condition.isDefined) {
val (streamedBeforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
val loaded = ctx.freshName("loaded")
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
// Generate code for condition
ctx.currentVars = resultVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
ctx.currentVars = streamedVars ++ bufferedVars
val cond = BindReferences.bindReference(
condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before =
s"""
Expand All @@ -674,65 +699,129 @@ case class SortMergeJoinExec(
|}
|$bufferedAfter
""".stripMargin
(before, checking)
(before, checking.trim)
} else {
(evaluateVariables(streamedVars), "")
}

val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

lazy val innerJoin =
val beforeLoop =
s"""
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| while ($iterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin

lazy val outerJoin = {
val hasOutputRow = ctx.freshName("hasOutputRow")
|${streamedVarDecl.mkString("\n")}
|${streamedBeforeLoop.trim}
|scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
""".stripMargin
val outputRow =
s"""
|while ($streamedInput.hasNext()) {
| findNextJoinRows($streamedInput, $bufferedInput);
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| boolean $hasOutputRow = false;
|
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
| while ($iterator.hasNext() || !$hasOutputRow) {
| InternalRow $bufferedRow = $iterator.hasNext() ?
| (InternalRow) $iterator.next() : null;
| ${condCheck.trim}
| $hasOutputRow = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
|$numOutput.add(1);
|${consume(ctx, resultVars)}
""".stripMargin
}
val findNextJoinRows = s"findNextJoinRows($streamedInput, $bufferedInput)"
val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

joinType match {
case _: InnerLike => innerJoin
case LeftOuter | RightOuter => outerJoin
case _: InnerLike =>
codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow,
eagerCleanup)
case LeftOuter | RightOuter =>
codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
case LeftSemi =>
codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
}
}

/**
* Generates the code for Inner join.
*/
private def codegenInner(
findNextJoinRows: String,
beforeLoop: String,
matchIterator: String,
bufferedRow: String,
conditionCheck: String,
outputRow: String,
eagerCleanup: String): String = {
s"""
|while ($findNextJoinRows) {
| $beforeLoop
| while ($matchIterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
| $conditionCheck
| $outputRow
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

/**
* Generates the code for Left or Right Outer join.
*/
private def codegenOuter(
streamedInput: String,
findNextJoinRows: String,
beforeLoop: String,
matchIterator: String,
bufferedRow: String,
conditionCheck: String,
hasOutputRow: String,
outputRow: String,
eagerCleanup: String): String = {
s"""
|while ($streamedInput.hasNext()) {
| $findNextJoinRows;
| $beforeLoop
| boolean $hasOutputRow = false;
|
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
| while ($matchIterator.hasNext() || !$hasOutputRow) {
| InternalRow $bufferedRow = $matchIterator.hasNext() ?
| (InternalRow) $matchIterator.next() : null;
| $conditionCheck
| $hasOutputRow = true;
| $outputRow
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

/**
* Generates the code for Left Semi join.
*/
private def codegenSemi(
findNextJoinRows: String,
beforeLoop: String,
matchIterator: String,
bufferedRow: String,
conditionCheck: String,
hasOutputRow: String,
outputRow: String,
eagerCleanup: String): String = {
s"""
|while ($findNextJoinRows) {
| $beforeLoop
| boolean $hasOutputRow = false;
|
| while (!$hasOutputRow && $matchIterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
| $conditionCheck
| $hasOutputRow = true;
| $outputRow
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
copy(left = newLeft, right = newRight)
Expand Down Expand Up @@ -783,8 +872,7 @@ private[joins] class SortMergeJoinScanner(
private[this] var matchJoinKey: InternalRow = _
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray =
new ExternalAppendOnlyUnsafeRowArray(if (onlyBufferFirstMatch) 1 else inMemoryThreshold,
spillThreshold)
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)

// Initialization (note: do _not_ want to advance streamed here).
advancedBufferedToRowWithNullFreeJoinKey()
Expand Down
Loading

0 comments on commit c1e995a

Please sign in to comment.