diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 291ab29f1b3ba..1f5120b23d26d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -354,7 +354,7 @@ case class SortMergeJoinExec( } private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match { - case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys)) + case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys)) case RightOuter => ((right, rightKeys), (left, leftKeys)) case x => throw new IllegalArgumentException( @@ -365,7 +365,7 @@ case class SortMergeJoinExec( private lazy val bufferedOutput = bufferedPlan.output override def supportCodegen: Boolean = joinType match { - case _: InnerLike | LeftOuter | RightOuter => true + case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true case _ => false } @@ -424,8 +424,18 @@ case class SortMergeJoinExec( // A list to hold all matched rows from buffered side. val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName + // Flag to only buffer first matched row, to avoid buffering unnecessary rows. + val onlyBufferFirstMatchedRow = (joinType, condition) match { + case (LeftSemi, None) => true + case _ => false + } + val inMemoryThreshold = + if (onlyBufferFirstMatchedRow) { + 1 + } else { + getInMemoryThreshold + } val spillThreshold = getSpillThreshold - val inMemoryThreshold = getInMemoryThreshold // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", @@ -435,7 +445,7 @@ case class SortMergeJoinExec( // Handle the case when streamed rows has any NULL keys. val handleStreamedAnyNull = joinType match { - case _: InnerLike => + case _: InnerLike | LeftSemi => // Skip streamed row. s""" |$streamedRow = null; @@ -457,7 +467,7 @@ case class SortMergeJoinExec( // Handle the case when streamed keys has no match with buffered side. val handleStreamedWithoutMatch = joinType match { - case _: InnerLike => + case _: InnerLike | LeftSemi => // Skip streamed row. s"$streamedRow = null;" case LeftOuter | RightOuter => @@ -468,6 +478,17 @@ case class SortMergeJoinExec( s"SortMergeJoin.genScanner should not take $x as the JoinType") } + val addRowToBuffer = + if (onlyBufferFirstMatchedRow) { + s""" + |if ($matches.isEmpty()) { + | $matches.add((UnsafeRow) $bufferedRow); + |} + """.stripMargin + } else { + s"$matches.add((UnsafeRow) $bufferedRow);" + } + // Generate a function to scan both streamed and buffered sides to find a match. // Return whether a match is found. // @@ -483,17 +504,18 @@ case class SortMergeJoinExec( // The function has the following step: // - Step 1: Find the next `streamedRow` with non-null join keys. // For `streamedRow` with null join keys (`handleStreamedAnyNull`): - // 1. Inner join: skip the row. `matches` will be cleared later when hitting the - // next `streamedRow` with non-null join keys. + // 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when + // hitting the next `streamedRow` with non-null join + // keys. // 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row, // and return false. // // - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`. // Clear `matches` if we hit a new `streamedRow`, as we need to find new matches. // Use `bufferedRow` to iterate buffered side to put all matched rows into - // `matches`. Return true when getting all matched rows. + // `matches` (`addRowToBuffer`). Return true when getting all matched rows. // For `streamedRow` without `matches` (`handleStreamedWithoutMatch`): - // 1. Inner join: skip the row. + // 1. Inner and Left Semi join: skip the row. // 2. Left/Right Outer join: keep the row and return false (with `matches` being // empty). ctx.addNewFunction("findNextJoinRows", @@ -543,7 +565,7 @@ case class SortMergeJoinExec( | $handleStreamedWithoutMatch | } | } else { - | $matches.add((UnsafeRow) $bufferedRow); + | $addRowToBuffer | $bufferedRow = null; | } | } while ($streamedRow != null); @@ -639,6 +661,8 @@ case class SortMergeJoinExec( streamedVars ++ bufferedVars case RightOuter => bufferedVars ++ streamedVars + case LeftSemi => + streamedVars case x => throw new IllegalArgumentException( s"SortMergeJoin.doProduce should not take $x as the JoinType") @@ -650,8 +674,9 @@ case class SortMergeJoinExec( val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars) val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars) // Generate code for condition - ctx.currentVars = resultVars - val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) + ctx.currentVars = streamedVars ++ bufferedVars + val cond = BindReferences.bindReference( + condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" @@ -724,9 +749,32 @@ case class SortMergeJoinExec( """.stripMargin } + lazy val semiJoin = { + val hasOutputRow = ctx.freshName("hasOutputRow") + s""" + |while (findNextJoinRows($streamedInput, $bufferedInput)) { + | ${streamedVarDecl.mkString("\n")} + | ${beforeLoop.trim} + | scala.collection.Iterator $iterator = $matches.generateIterator(); + | boolean $hasOutputRow = false; + | + | while (!$hasOutputRow && $iterator.hasNext()) { + | InternalRow $bufferedRow = (InternalRow) $iterator.next(); + | ${condCheck.trim} + | $hasOutputRow = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | if (shouldStop()) return; + |} + |$eagerCleanup + """.stripMargin + } + joinType match { case _: InnerLike => innerJoin case LeftOuter | RightOuter => outerJoin + case LeftSemi => semiJoin case x => throw new IllegalArgumentException( s"SortMergeJoin.doProduce should not take $x as the JoinType") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9199a5e51e669..f019e34b60118 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -203,6 +203,28 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession Row(null, null, 6), Row(null, null, 7), Row(null, null, 8), Row(null, null, 9))) } + test("Left Semi SortMergeJoin should be included in WholeStageCodegen") { + val df1 = spark.range(10).select($"id".as("k1")) + val df2 = spark.range(4).select($"id".as("k2")) + val df3 = spark.range(6).select($"id".as("k3")) + + // test one left semi sort merge join + val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi") + assert(oneJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true + }.size === 1) + checkAnswer(oneJoinDF, Seq(Row(0), Row(1), Row(2), Row(3))) + + // test two left semi sort merge joins + val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi") + .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi") + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) | + WholeStageCodegenExec(_ : SortMergeJoinExec) => true + }.size === 2) + checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3))) + } + test("Inner/Cross BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") { val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 93c5d05da2e8b..3588b9dda90d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -141,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession { } } - test(s"$testName using SortMergeJoin") { + testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ => extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>