Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34708][SQL] Code-gen for left semi/anti broadcast nested loop join (build right side) #31874

Closed
wants to merge 2 commits into from

Conversation

c21
Copy link
Contributor

@c21 c21 commented Mar 18, 2021

What changes were proposed in this pull request?

This PR is to add code-gen support for left semi / left anti BroadcastNestedLoopJoin (build side is right side). The execution code path for build left side cannot fit into whole stage code-gen framework, so only add the code-gen for build right side here.

Reference: the iterator (non-code-gen) code path is BroadcastNestedLoopJoinExec.leftExistenceJoin() with BuildRight.

Why are the changes needed?

Improve query CPU performance.
Tested with a simple query:

val N = 20 << 20
val M = 1 << 4

val dim = broadcast(spark.range(M).selectExpr("id as k2"))
codegenBenchmark("left semi broadcast nested loop join", N) {
  park.range(N).selectExpr(s"id as k1").join(
    dim, col("k1") + 1 <= col("k2"), "left_semi")
}

Seeing 5x run time improvement:

Running benchmark: left semi broadcast nested loop join
  Running case: left semi broadcast nested loop join codegen off
  Stopped after 2 iterations, 6958 ms
  Running case: left semi broadcast nested loop join codegen on
  Stopped after 5 iterations, 3383 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7
Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
left semi broadcast nested loop join:             Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
--------------------------------------------------------------------------------------------------------------------------------
left semi broadcast nested loop join codegen off           3434           3479          65          6.1         163.7       1.0X
left semi broadcast nested loop join codegen on             672            677           5         31.2          32.1       5.1X

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Changed existing unit test in ExistenceJoinSuite.scala to cover all code paths:

  • left semi/anti + empty right side + empty condition
  • left semi/anti + non-empty right side + empty condition
  • left semi/anti + right side + non-empty condition

Added unit test in WholeStageCodegenSuite.scala to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well.

Example query:

val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 <= $"k2", "left_semi").explain("codegen")

Example generated code (bnlj_doConsume_0 method):
This is for left semi join. The generated code for left anti join is mostly to be same as here, except L55 to be if (bnlj_findMatchedRow_0 == false) {.

== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) ==
*(2) Project [id#0L AS k1#2L]
+- *(2) BroadcastNestedLoopJoin BuildRight, LeftSemi, ((id#0L + 1) <= k2#6L)
   :- *(2) Range (0, 4, step=1, splits=2)
   +- BroadcastExchange IdentityBroadcastMode, [id=#23]
      +- *(1) Project [id#4L AS k2#6L]
         +- *(1) Range (0, 3, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_nextIndex_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private InternalRow[] bnlj_buildRowArray_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 031 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 032 */     range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */     boolean bnlj_findMatchedRow_0 = false;
/* 038 */     for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 039 */       UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 040 */
/* 041 */       long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 042 */
/* 043 */       long bnlj_value_3 = -1L;
/* 044 */
/* 045 */       bnlj_value_3 = bnlj_expr_0_0 + 1L;
/* 046 */
/* 047 */       boolean bnlj_value_2 = false;
/* 048 */       bnlj_value_2 = bnlj_value_3 <= bnlj_value_1;
/* 049 */       if (!(false || !bnlj_value_2))
/* 050 */       {
/* 051 */         bnlj_findMatchedRow_0 = true;
/* 052 */         break;
/* 053 */       }
/* 054 */     }
/* 055 */     if (bnlj_findMatchedRow_0 == true) {
/* 056 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 057 */
/* 058 */       // common sub-expressions
/* 059 */
/* 060 */       range_mutableStateArray_0[3].reset();
/* 061 */
/* 062 */       range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 063 */       append((range_mutableStateArray_0[3].getRow()).copy());
/* 064 */
/* 065 */     }
/* 066 */
/* 067 */   }
/* 068 */
/* 069 */   private void initRange(int idx) {
/* 070 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 071 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 072 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 073 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 074 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 075 */     long partitionEnd;
/* 076 */
/* 077 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 078 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 079 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 080 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 081 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 082 */     } else {
/* 083 */       range_nextIndex_0 = st.longValue();
/* 084 */     }
/* 085 */     range_batchEnd_0 = range_nextIndex_0;
/* 086 */
/* 087 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 088 */     .multiply(step).add(start);
/* 089 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 090 */       partitionEnd = Long.MAX_VALUE;
/* 091 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 092 */       partitionEnd = Long.MIN_VALUE;
/* 093 */     } else {
/* 094 */       partitionEnd = end.longValue();
/* 095 */     }
/* 096 */
/* 097 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 098 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 099 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 100 */     if (range_numElementsTodo_0 < 0) {
/* 101 */       range_numElementsTodo_0 = 0;
/* 102 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 103 */       range_numElementsTodo_0++;
/* 104 */     }
/* 105 */   }
/* 106 */
/* 107 */   protected void processNext() throws java.io.IOException {
/* 108 */     // initialize Range
/* 109 */     if (!range_initRange_0) {
/* 110 */       range_initRange_0 = true;
/* 111 */       initRange(partitionIndex);
/* 112 */     }
/* 113 */
/* 114 */     while (true) {
/* 115 */       if (range_nextIndex_0 == range_batchEnd_0) {
/* 116 */         long range_nextBatchTodo_0;
/* 117 */         if (range_numElementsTodo_0 > 1000L) {
/* 118 */           range_nextBatchTodo_0 = 1000L;
/* 119 */           range_numElementsTodo_0 -= 1000L;
/* 120 */         } else {
/* 121 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 122 */           range_numElementsTodo_0 = 0;
/* 123 */           if (range_nextBatchTodo_0 == 0) break;
/* 124 */         }
/* 125 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 126 */       }
/* 127 */
/* 128 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 129 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 130 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 131 */
/* 132 */         bnlj_doConsume_0(range_value_0);
/* 133 */
/* 134 */         if (shouldStop()) {
/* 135 */           range_nextIndex_0 = range_value_0 + 1L;
/* 136 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 137 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 138 */           return;
/* 139 */         }
/* 140 */
/* 141 */       }
/* 142 */       range_nextIndex_0 = range_batchEnd_0;
/* 143 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 144 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 145 */       range_taskContext_0.killTaskIfInterrupted();
/* 146 */     }
/* 147 */   }
/* 148 */
/* 149 */ }

@c21
Copy link
Contributor Author

c21 commented Mar 18, 2021

cc @cloud-fan and @maropu could you help take a look when you have time, thanks.

@github-actions github-actions bot added the SQL label Mar 18, 2021
@c21
Copy link
Contributor Author

c21 commented Mar 18, 2021

retest this please

@c21 c21 closed this Mar 18, 2021
@c21 c21 reopened this Mar 18, 2021
@c21
Copy link
Contributor Author

c21 commented Mar 18, 2021

Closed & reopened the PR to trigger test rerun.

val arrayIndex = ctx.freshName("arrayIndex")

s"""
|boolean $findMatchedRow = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
|boolean $findMatchedRow = false;
|boolean $foundMatch = false;

Maybe we can be consistent with non-codegen naming in BroadcastNestedLoopJoinExec?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, updated.

override def supportCodegen: Boolean = {
joinType.isInstanceOf[InnerLike]
override def supportCodegen: Boolean = (joinType, buildSide) match {
case (_: InnerLike, _) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (LeftSemi | LeftAnti, BuildRight)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - updated.

@c21
Copy link
Contributor Author

c21 commented Mar 19, 2021

Just FYI - the unit test failure here is legit, and I found there's a bug around LIMIT codegen - https://issues.apache.org/jira/browse/SPARK-34796 . Will submit a PR to fix that first.

Comment on lines +469 to +484
if (buildRowArray.nonEmpty == exists) {
// Return streamed side if join condition is empty and
// 1. build side is non-empty for LeftSemi join
// or
// 2. build side is empty for LeftAnti join.
s"""
|$numOutput.add(1);
|${consume(ctx, input)}
""".stripMargin
} else {
// Return nothing if join condition is empty and
// 1. build side is empty for LeftSemi join
// or
// 2. build side is non-empty for LeftAnti join.
""
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this case, we don't need to add a mutable state for buildRowArrayTerm ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already added inside prepareBroadcast

Copy link
Member

@maropu maropu Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my qestion is wrong. I meant, is buildRowArrayTerm referenced in this case? Ah, nvm, it looks fine.

Copy link
Contributor Author

@c21 c21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed comments and the PR is ready to review again, thanks.

override def supportCodegen: Boolean = {
joinType.isInstanceOf[InnerLike]
override def supportCodegen: Boolean = (joinType, buildSide) match {
case (_: InnerLike, _) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => true
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - updated.

val arrayIndex = ctx.freshName("arrayIndex")

s"""
|boolean $findMatchedRow = false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, updated.

@c21 c21 force-pushed the code-semi-anti branch from e50a751 to f2e846d Compare March 21, 2021 06:09
@SparkQA
Copy link

SparkQA commented Mar 21, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40882/

@SparkQA
Copy link

SparkQA commented Mar 21, 2021

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40882/

@SparkQA
Copy link

SparkQA commented Mar 21, 2021

Test build #136300 has finished for PR 31874 at commit f2e846d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@AmplabJenkins
Copy link

Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/136300/

@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in f8838fe Mar 22, 2021
@c21
Copy link
Contributor Author

c21 commented Mar 22, 2021

Thank you @cloud-fan, @maropu and @linzebing for review!

@c21 c21 deleted the code-semi-anti branch March 22, 2021 17:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants