diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index a7e9bd7f129dc..482c3a3091f86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -395,8 +395,9 @@ case class BroadcastNestedLoopJoinExec( } } - override def supportCodegen: Boolean = { - joinType.isInstanceOf[InnerLike] + override def supportCodegen: Boolean = (joinType, buildSide) match { + case (_: InnerLike, _) | (LeftSemi | LeftAnti, BuildRight) => true + case _ => false } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -410,29 +411,33 @@ case class BroadcastNestedLoopJoinExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - joinType match { - case _: InnerLike => codegenInner(ctx, input) + (joinType, buildSide) match { + case (_: InnerLike, _) => codegenInner(ctx, input) + case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true) + case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false) case _ => throw new IllegalArgumentException( - s"BroadcastNestedLoopJoin code-gen should not take $joinType as the JoinType") + s"BroadcastNestedLoopJoin code-gen should not take neither $joinType as the JoinType " + + s"nor $buildSide as the BuildSide") } } /** - * Returns the variable name for [[Broadcast]] side. + * Returns a tuple of [[Broadcast]] side and the variable name for it. */ - private def prepareBroadcast(ctx: CodegenContext): String = { + private def prepareBroadcast(ctx: CodegenContext): (Array[InternalRow], String) = { // Create a name for broadcast side val broadcastArray = broadcast.executeBroadcast[Array[InternalRow]]() val broadcastTerm = ctx.addReferenceObj("broadcastTerm", broadcastArray) // Inline mutable state since not many join operations in a task - ctx.addMutableState("InternalRow[]", "buildRowArray", + val arrayTerm = ctx.addMutableState("InternalRow[]", "buildRowArray", v => s"$v = (InternalRow[]) $broadcastTerm.value();", forceInline = true) + (broadcastArray.value, arrayTerm) } private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val buildRowArrayTerm = prepareBroadcast(ctx) + val (_, buildRowArrayTerm) = prepareBroadcast(ctx) val (buildRow, checkCondition, buildVars) = getJoinCondition(ctx, input, streamed, broadcast) val resultVars = buildSide match { @@ -452,4 +457,50 @@ case class BroadcastNestedLoopJoinExec( |} """.stripMargin } + + private def codegenLeftExistence( + ctx: CodegenContext, + input: Seq[ExprCode], + exists: Boolean): String = { + val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx) + val numOutput = metricTerm(ctx, "numOutputRows") + + if (condition.isEmpty) { + if (buildRowArray.nonEmpty == exists) { + // Return streamed side if join condition is empty and + // 1. build side is non-empty for LeftSemi join + // or + // 2. build side is empty for LeftAnti join. + s""" + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + // Return nothing if join condition is empty and + // 1. build side is empty for LeftSemi join + // or + // 2. build side is non-empty for LeftAnti join. + "" + } + } else { + val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) + val foundMatch = ctx.freshName("foundMatch") + val arrayIndex = ctx.freshName("arrayIndex") + + s""" + |boolean $foundMatch = false; + |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { + | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | $checkCondition { + | $foundMatch = true; + | break; + | } + |} + |if ($foundMatch == $exists) { + | $numOutput.add(1); + | ${consume(ctx, input)} + |} + """.stripMargin + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index dc76edd1fc2ce..8246bca1893a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -211,6 +211,42 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } + test("Left semi/anti BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") { + val df1 = spark.range(4).select($"id".as("k1")) + val df2 = spark.range(3).select($"id".as("k2")) + val df3 = spark.range(2).select($"id".as("k3")) + + Seq(true, false).foreach { codegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) { + // test left semi join + val semiJoinDF = df1.join(df2, $"k1" + 1 <= $"k2", "left_semi") + var hasJoinInCodegen = semiJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(ProjectExec(_, _ : BroadcastNestedLoopJoinExec)) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(semiJoinDF, Seq(Row(0), Row(1))) + + // test left anti join + val antiJoinDF = df1.join(df2, $"k1" + 1 <= $"k2", "left_anti") + hasJoinInCodegen = antiJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(ProjectExec(_, _ : BroadcastNestedLoopJoinExec)) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(antiJoinDF, Seq(Row(2), Row(3))) + + // test a combination of left semi and left anti joins + val twoJoinsDF = df1.join(df2, $"k1" < $"k2", "left_semi") + .join(df3, $"k1" > $"k3", "left_anti") + hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(ProjectExec(_, BroadcastNestedLoopJoinExec( + _: BroadcastNestedLoopJoinExec, _, _, _, _))) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(twoJoinsDF, Seq(Row(0))) + } + } + } + test("Sort should be included in WholeStageCodegen") { val df = spark.range(3, 0, -1).toDF().sort(col("id")) val plan = df.queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 8d4ee102e4fbc..93c5d05da2e8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -175,7 +175,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession { } } - test(s"$testName using BroadcastNestedLoopJoin build right") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build right") { _ => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements.apply(