Skip to content

Commit

Permalink
[SPARK-44000][SQL] Add hint to disable broadcasting and replicating o…
Browse files Browse the repository at this point in the history
…ne side of join

### What changes were proposed in this pull request?

This PR adds a new internal join hint to disable broadcasting and replicating one side of join.

### Why are the changes needed?

These changes are needed to disable broadcasting and replicating one side of join when it is not permitted, such as the cardinality check in MERGE operations in PR apache#41448.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR comes with tests. More tests are in apache#41448.

Closes apache#41499 from aokolnychyi/spark-44000.

Authored-by: aokolnychyi <aokolnychyi@apple.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
aokolnychyi authored and dongjoon-hyun committed Jun 8, 2023
1 parent caf905d commit d88633a
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,16 @@ trait JoinSelectionHelper {
)
}

def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint): Option[BuildSide] = {
if (hintToNotBroadcastAndReplicateLeft(hint)) {
Some(BuildRight)
} else if (hintToNotBroadcastAndReplicateRight(hint)) {
Some(BuildLeft)
} else {
None
}
}

def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = {
if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
}
Expand Down Expand Up @@ -413,11 +423,19 @@ trait JoinSelectionHelper {
}

def hintToNotBroadcastLeft(hint: JoinHint): Boolean = {
hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH))
hint.leftHint.flatMap(_.strategy).exists {
case NO_BROADCAST_HASH => true
case NO_BROADCAST_AND_REPLICATION => true
case _ => false
}
}

def hintToNotBroadcastRight(hint: JoinHint): Boolean = {
hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH))
hint.rightHint.flatMap(_.strategy).exists {
case NO_BROADCAST_HASH => true
case NO_BROADCAST_AND_REPLICATION => true
case _ => false
}
}

def hintToShuffleHashJoinLeft(hint: JoinHint): Boolean = {
Expand Down Expand Up @@ -454,6 +472,18 @@ trait JoinSelectionHelper {
hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
}

def hintToNotBroadcastAndReplicate(hint: JoinHint): Boolean = {
hintToNotBroadcastAndReplicateLeft(hint) || hintToNotBroadcastAndReplicateRight(hint)
}

def hintToNotBroadcastAndReplicateLeft(hint: JoinHint): Boolean = {
hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_AND_REPLICATION))
}

def hintToNotBroadcastAndReplicateRight(hint: JoinHint): Boolean = {
hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_AND_REPLICATION))
}

private def getBuildSide(
canBuildLeft: Boolean,
canBuildRight: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ case object PREFER_SHUFFLE_HASH extends JoinStrategyHint {
override def hintAliases: Set[String] = Set.empty
}

/**
* An internal hint to prohibit broadcasting and replicating one side of a join. This hint is used
* by some rules where broadcasting or replicating a particular side of the join is not permitted,
* such as the cardinality check in MERGE operations.
*/
case object NO_BROADCAST_AND_REPLICATION extends JoinStrategyHint {
override def displayName: String = "no_broadcast_and_replication"
override def hintAliases: Set[String] = Set.empty
}

/**
* The callback for implementing customized strategies of handling hint errors.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

def createCartesianProduct() = {
if (joinType.isInstanceOf[InnerLike]) {
if (joinType.isInstanceOf[InnerLike] && !hintToNotBroadcastAndReplicate(hint)) {
// `CartesianProductExec` can't implicitly evaluate equal join condition, here we should
// pass the original condition which includes both equal and non-equal conditions.
Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), j.condition)))
Expand All @@ -288,7 +288,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
.orElse(createCartesianProduct())
.getOrElse {
// This join could be very slow or OOM
val buildSide = getSmallerSide(left, right)
val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint)
val buildSide = requiredBuildSide.getOrElse(getSmallerSide(left, right))
Seq(joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, j.condition))
}
Expand Down Expand Up @@ -336,7 +337,19 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if (canBuildBroadcastLeft(joinType)) BuildLeft else BuildRight
}

def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = {
def createBroadcastNLJoin(onlyLookingAtHint: Boolean) = {
val buildLeft = if (onlyLookingAtHint) {
hintToBroadcastLeft(hint)
} else {
canBroadcastBySize(left, conf) && !hintToNotBroadcastAndReplicateLeft(hint)
}

val buildRight = if (onlyLookingAtHint) {
hintToBroadcastRight(hint)
} else {
canBroadcastBySize(right, conf) && !hintToNotBroadcastAndReplicateRight(hint)
}

val maybeBuildSide = if (buildLeft && buildRight) {
Some(desiredBuildSide)
} else if (buildLeft) {
Expand All @@ -354,27 +367,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

def createCartesianProduct() = {
if (joinType.isInstanceOf[InnerLike]) {
if (joinType.isInstanceOf[InnerLike] && !hintToNotBroadcastAndReplicate(hint)) {
Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)))
} else {
None
}
}

def createJoinWithoutHint() = {
createBroadcastNLJoin(canBroadcastBySize(left, conf), canBroadcastBySize(right, conf))
createBroadcastNLJoin(false)
.orElse(createCartesianProduct())
.getOrElse {
// This join could be very slow or OOM
val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint)
val buildSide = requiredBuildSide.getOrElse(desiredBuildSide)
Seq(joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), desiredBuildSide, joinType, condition))
planLater(left), planLater(right), buildSide, joinType, condition))
}
}

if (hint.isEmpty) {
createJoinWithoutHint()
} else {
createBroadcastNLJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint))
createBroadcastNLJoin(true)
.orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None }
.getOrElse(createJoinWithoutHint())
}
Expand Down
64 changes: 63 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION}
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand Down Expand Up @@ -92,6 +93,67 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
operators.head
}

test("NO_BROADCAST_AND_REPLICATION hint is respected in cross joins") {
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
val noBroadcastAndReplicationHint = JoinHint(
leftHint = None,
rightHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))))

val join = testData.crossJoin(testData2).queryExecution.optimizedPlan.asInstanceOf[Join]
val joinWithHint = join.copy(hint = noBroadcastAndReplicationHint)

val planned = spark.sessionState.planner.JoinSelection(join)
assert(planned.size == 1)
assert(planned.head.isInstanceOf[CartesianProductExec])

val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint)
assert(plannedWithHint.size == 1)
assert(plannedWithHint.head.isInstanceOf[BroadcastNestedLoopJoinExec])
assert(plannedWithHint.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildLeft)
}
}

test("NO_BROADCAST_AND_REPLICATION hint disables broadcast hash joins") {
sql("CACHE TABLE testData")
sql("CACHE TABLE testData2")

val noBroadcastAndReplicationHint = JoinHint(
leftHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))),
rightHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))))

val ds = sql("SELECT * FROM testData JOIN testData2 ON key = a")
val join = ds.queryExecution.optimizedPlan.asInstanceOf[Join]
val joinWithHint = join.copy(hint = noBroadcastAndReplicationHint)

val planned = spark.sessionState.planner.JoinSelection(join)
assert(planned.size == 1)
assert(planned.head.isInstanceOf[BroadcastHashJoinExec])

val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint)
assert(plannedWithHint.size == 1)
assert(plannedWithHint.head.isInstanceOf[SortMergeJoinExec])
}

test("NO_BROADCAST_AND_REPLICATION controls build side in BNLJ") {
val noBroadcastAndReplicationHint = JoinHint(
leftHint = None,
rightHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))))

val ds = testData.join(testData2, $"key" === 1, "left_outer")
val join = ds.queryExecution.optimizedPlan.asInstanceOf[Join]
val joinWithHint = join.copy(hint = noBroadcastAndReplicationHint)

val planned = spark.sessionState.planner.JoinSelection(join)
assert(planned.size == 1)
assert(planned.head.isInstanceOf[BroadcastNestedLoopJoinExec])
assert(planned.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildRight)

val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint)
assert(plannedWithHint.size == 1)
assert(plannedWithHint.head.isInstanceOf[BroadcastNestedLoopJoinExec])
assert(plannedWithHint.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildLeft)
}

test("join operator selection") {
spark.sharedState.cacheManager.clearCache()

Expand Down

0 comments on commit d88633a

Please sign in to comment.