Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-35349][SQL] Add code-gen for left/right outer sort merge join
### What changes were proposed in this pull request? This PR is to add code-gen support for LEFT OUTER / RIGHT OUTER sort merge join. Currently sort merge join only supports inner join type (https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala#L374 ). There's no fundamental reason why we cannot support code-gen for other join types. Here we add code-gen for LEFT OUTER / RIGHT OUTER join. Will submit followup PRs to add LEFT SEMI, LEFT ANTI and FULL OUTER code-gen separately. The change is to extend current sort merge join logic to work with LEFT OUTER and RIGHT OUTER (should work with LEFT SEMI/ANTI as well, but FULL OUTER join needs some other more code change). Replace left/right with streamed/buffered to make code extendable to other join types besides inner join. Example query: ``` val df1 = spark.range(10).select($"id".as("k1"), $"id".as("k3")) val df2 = spark.range(4).select($"id".as("k2"), $"id".as("k4")) df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2" && $"k3" + 1 < $"k4", "left_outer").explain("codegen") ``` Example generated code: ``` == Subtree 5 / 5 (maxMethodCodeSize:396; maxConstantPoolSize:159(0.24% used); numInnerClasses:0) == *(5) SortMergeJoin [k1#2L], [k2#8L], LeftOuter, ((k3#3L + 1) < k4#9L) :- *(2) Sort [k1#2L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(k1#2L, 5), ENSURE_REQUIREMENTS, [id=#26] : +- *(1) Project [id#0L AS k1#2L, id#0L AS k3#3L] : +- *(1) Range (0, 10, step=1, splits=2) +- *(4) Sort [k2#8L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(k2#8L, 5), ENSURE_REQUIREMENTS, [id=#32] +- *(3) Project [id#6L AS k2#8L, id#6L AS k4#9L] +- *(3) Range (0, 4, step=1, splits=2) Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage5(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=5 /* 006 */ final class GeneratedIteratorForCodegenStage5 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator smj_streamedInput_0; /* 010 */ private scala.collection.Iterator smj_bufferedInput_0; /* 011 */ private InternalRow smj_streamedRow_0; /* 012 */ private InternalRow smj_bufferedRow_0; /* 013 */ private long smj_value_2; /* 014 */ private org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray smj_matches_0; /* 015 */ private long smj_value_3; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] smj_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1]; /* 017 */ /* 018 */ public GeneratedIteratorForCodegenStage5(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ smj_streamedInput_0 = inputs[0]; /* 026 */ smj_bufferedInput_0 = inputs[1]; /* 027 */ /* 028 */ smj_matches_0 = new org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray(2147483632, 2147483647); /* 029 */ smj_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(4, 0); /* 030 */ /* 031 */ } /* 032 */ /* 033 */ private boolean findNextJoinRows( /* 034 */ scala.collection.Iterator streamedIter, /* 035 */ scala.collection.Iterator bufferedIter) { /* 036 */ smj_streamedRow_0 = null; /* 037 */ int comp = 0; /* 038 */ while (smj_streamedRow_0 == null) { /* 039 */ if (!streamedIter.hasNext()) return false; /* 040 */ smj_streamedRow_0 = (InternalRow) streamedIter.next(); /* 041 */ long smj_value_0 = smj_streamedRow_0.getLong(0); /* 042 */ if (false) { /* 043 */ if (!smj_matches_0.isEmpty()) { /* 044 */ smj_matches_0.clear(); /* 045 */ } /* 046 */ return false; /* 047 */ /* 048 */ } /* 049 */ if (!smj_matches_0.isEmpty()) { /* 050 */ comp = 0; /* 051 */ if (comp == 0) { /* 052 */ comp = (smj_value_0 > smj_value_3 ? 1 : smj_value_0 < smj_value_3 ? -1 : 0); /* 053 */ } /* 054 */ /* 055 */ if (comp == 0) { /* 056 */ return true; /* 057 */ } /* 058 */ smj_matches_0.clear(); /* 059 */ } /* 060 */ /* 061 */ do { /* 062 */ if (smj_bufferedRow_0 == null) { /* 063 */ if (!bufferedIter.hasNext()) { /* 064 */ smj_value_3 = smj_value_0; /* 065 */ return !smj_matches_0.isEmpty(); /* 066 */ } /* 067 */ smj_bufferedRow_0 = (InternalRow) bufferedIter.next(); /* 068 */ long smj_value_1 = smj_bufferedRow_0.getLong(0); /* 069 */ if (false) { /* 070 */ smj_bufferedRow_0 = null; /* 071 */ continue; /* 072 */ } /* 073 */ smj_value_2 = smj_value_1; /* 074 */ } /* 075 */ /* 076 */ comp = 0; /* 077 */ if (comp == 0) { /* 078 */ comp = (smj_value_0 > smj_value_2 ? 1 : smj_value_0 < smj_value_2 ? -1 : 0); /* 079 */ } /* 080 */ /* 081 */ if (comp > 0) { /* 082 */ smj_bufferedRow_0 = null; /* 083 */ } else if (comp < 0) { /* 084 */ if (!smj_matches_0.isEmpty()) { /* 085 */ smj_value_3 = smj_value_0; /* 086 */ return true; /* 087 */ } else { /* 088 */ return false; /* 089 */ } /* 090 */ } else { /* 091 */ smj_matches_0.add((UnsafeRow) smj_bufferedRow_0); /* 092 */ smj_bufferedRow_0 = null; /* 093 */ } /* 094 */ } while (smj_streamedRow_0 != null); /* 095 */ } /* 096 */ return false; // unreachable /* 097 */ } /* 098 */ /* 099 */ protected void processNext() throws java.io.IOException { /* 100 */ while (smj_streamedInput_0.hasNext()) { /* 101 */ findNextJoinRows(smj_streamedInput_0, smj_bufferedInput_0); /* 102 */ long smj_value_4 = -1L; /* 103 */ long smj_value_5 = -1L; /* 104 */ boolean smj_loaded_0 = false; /* 105 */ smj_value_5 = smj_streamedRow_0.getLong(1); /* 106 */ scala.collection.Iterator<UnsafeRow> smj_iterator_0 = smj_matches_0.generateIterator(); /* 107 */ boolean smj_foundMatch_0 = false; /* 108 */ /* 109 */ // the last iteration of this loop is to emit an empty row if there is no matched rows. /* 110 */ while (smj_iterator_0.hasNext() || !smj_foundMatch_0) { /* 111 */ InternalRow smj_bufferedRow_1 = smj_iterator_0.hasNext() ? /* 112 */ (InternalRow) smj_iterator_0.next() : null; /* 113 */ boolean smj_isNull_5 = true; /* 114 */ long smj_value_9 = -1L; /* 115 */ if (smj_bufferedRow_1 != null) { /* 116 */ long smj_value_8 = smj_bufferedRow_1.getLong(1); /* 117 */ smj_isNull_5 = false; /* 118 */ smj_value_9 = smj_value_8; /* 119 */ } /* 120 */ if (smj_bufferedRow_1 != null) { /* 121 */ boolean smj_isNull_6 = true; /* 122 */ boolean smj_value_10 = false; /* 123 */ long smj_value_11 = -1L; /* 124 */ /* 125 */ smj_value_11 = smj_value_5 + 1L; /* 126 */ /* 127 */ if (!smj_isNull_5) { /* 128 */ smj_isNull_6 = false; // resultCode could change nullability. /* 129 */ smj_value_10 = smj_value_11 < smj_value_9; /* 130 */ /* 131 */ } /* 132 */ if (smj_isNull_6 || !smj_value_10) { /* 133 */ continue; /* 134 */ } /* 135 */ } /* 136 */ if (!smj_loaded_0) { /* 137 */ smj_loaded_0 = true; /* 138 */ smj_value_4 = smj_streamedRow_0.getLong(0); /* 139 */ } /* 140 */ boolean smj_isNull_3 = true; /* 141 */ long smj_value_7 = -1L; /* 142 */ if (smj_bufferedRow_1 != null) { /* 143 */ long smj_value_6 = smj_bufferedRow_1.getLong(0); /* 144 */ smj_isNull_3 = false; /* 145 */ smj_value_7 = smj_value_6; /* 146 */ } /* 147 */ smj_foundMatch_0 = true; /* 148 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 149 */ /* 150 */ smj_mutableStateArray_0[0].reset(); /* 151 */ /* 152 */ smj_mutableStateArray_0[0].zeroOutNullBytes(); /* 153 */ /* 154 */ smj_mutableStateArray_0[0].write(0, smj_value_4); /* 155 */ /* 156 */ smj_mutableStateArray_0[0].write(1, smj_value_5); /* 157 */ /* 158 */ if (smj_isNull_3) { /* 159 */ smj_mutableStateArray_0[0].setNullAt(2); /* 160 */ } else { /* 161 */ smj_mutableStateArray_0[0].write(2, smj_value_7); /* 162 */ } /* 163 */ /* 164 */ if (smj_isNull_5) { /* 165 */ smj_mutableStateArray_0[0].setNullAt(3); /* 166 */ } else { /* 167 */ smj_mutableStateArray_0[0].write(3, smj_value_9); /* 168 */ } /* 169 */ append((smj_mutableStateArray_0[0].getRow()).copy()); /* 170 */ /* 171 */ } /* 172 */ if (shouldStop()) return; /* 173 */ } /* 174 */ ((org.apache.spark.sql.execution.joins.SortMergeJoinExec) references[1] /* plan */).cleanupResources(); /* 175 */ } /* 176 */ /* 177 */ } ``` ### Why are the changes needed? Improve query CPU performance. Example micro benchmark below showed 10% run-time improvement. ``` def sortMergeJoinWithDuplicates(): Unit = { val N = 2 << 20 codegenBenchmark("sort merge join with duplicates", N) { val df1 = spark.range(N) .selectExpr(s"(id * 15485863) % ${N*10} as k1", "id as k3") val df2 = spark.range(N) .selectExpr(s"(id * 15485867) % ${N*10} as k2", "id as k4") val df = df1.join(df2, col("k1") === col("k2") && col("k3") * 3 < col("k4"), "left_outer") assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) df.noop() } } ``` ``` Running benchmark: sort merge join with duplicates Running case: sort merge join with duplicates outer-smj-codegen off Stopped after 2 iterations, 2696 ms Running case: sort merge join with duplicates outer-smj-codegen on Stopped after 5 iterations, 6058 ms Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz sort merge join with duplicates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- sort merge join with duplicates outer-smj-codegen off 1333 1348 21 1.6 635.7 1.0X sort merge join with duplicates outer-smj-codegen on 1169 1212 47 1.8 557.4 1.1X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test in `WholeStageCodegenSuite.scala` and `WholeStageCodegenSuite.scala`. Closes #32476 from c21/smj-outer-codegen. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
- Loading branch information