Skip to content

Commit

Permalink
Support DPP + AQE when find the broadcast exchange can reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Apr 15, 2021
1 parent 31555f7 commit 3bc4baf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -94,7 +94,7 @@ case class AdaptiveSparkPlanExec(
// A list of physical optimizer rules to be applied to a new stage before its execution. These
// optimizations should be stage-independent.
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
PlanAdaptiveDynamicPruningFilters(context.stageCache),
PlanAdaptiveDynamicPruningFilters(inputPlan),
ReuseAdaptiveSubquery(context.subqueryCache),
CoalesceShufflePartitions(context.session),
// The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
Expand Down Expand Up @@ -310,6 +310,11 @@ case class AdaptiveSparkPlanExec(
rdd
}

override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
val broadcastPlan = getFinalPhysicalPlan()
broadcastPlan.doExecuteBroadcast()
}

protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan")

override def generateTreeString(
Expand Down Expand Up @@ -476,7 +481,7 @@ case class AdaptiveSparkPlanExec(
throw new IllegalStateException(
"Custom columnar rules cannot transform shuffle node to something else.")
}
ShuffleQueryStageExec(currentStageId, newShuffle, s.canonicalized)
ShuffleQueryStageExec(currentStageId, newShuffle, s.child.canonicalized)
case b: BroadcastExchangeLike =>
val newBroadcast = applyPhysicalRules(
b.withNewChildren(Seq(optimizedPlan)),
Expand All @@ -486,7 +491,7 @@ case class AdaptiveSparkPlanExec(
throw new IllegalStateException(
"Custom columnar rules cannot transform broadcast node to something else.")
}
BroadcastQueryStageExec(currentStageId, newBroadcast, b.canonicalized)
BroadcastQueryStageExec(currentStageId, newBroadcast, b.child.canonicalized)
}
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@

package org.apache.spark.sql.execution.adaptive

import scala.collection.concurrent.TrieMap

import org.apache.spark.sql.catalyst.expressions.{BindReferences, DynamicPruningExpression, Literal}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{HashedRelationBroadcastMode, HashJoin}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelationBroadcastMode, HashJoin}

/**
* A rule to insert dynamic pruning predicates in order to reuse the results of broadcast.
*/
case class PlanAdaptiveDynamicPruningFilters(
stageCache: TrieMap[SparkPlan, QueryStageExec]) extends Rule[SparkPlan] {
originalPlan: SparkPlan) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (!conf.dynamicPartitionPruningEnabled) {
return plan
Expand All @@ -41,15 +40,26 @@ case class PlanAdaptiveDynamicPruningFilters(
adaptivePlan: AdaptiveSparkPlanExec), exprId, _)) =>
val packedKeys = BindReferences.bindReferences(
HashJoin.rewriteKeyExpr(buildKeys), adaptivePlan.executedPlan.output)
val mode = HashedRelationBroadcastMode(packedKeys)
// plan a broadcast exchange of the build side of the join
val exchange = BroadcastExchangeExec(mode, adaptivePlan.executedPlan)
val existingStage = stageCache.get(exchange.canonicalized)
if (existingStage.nonEmpty && conf.exchangeReuseEnabled) {
val name = s"dynamicpruning#${exprId.id}"
val reuseQueryStage = existingStage.get.newReuseInstance(0, exchange.output)
val broadcastValues =
SubqueryBroadcastExec(name, index, buildKeys, reuseQueryStage)

val canReuseExchange = conf.exchangeReuseEnabled && buildKeys.nonEmpty &&
originalPlan.find {
case BroadcastHashJoinExec(_, _, _, BuildLeft, _, left, _, _) =>
left.sameResult(adaptivePlan.executedPlan)
case BroadcastHashJoinExec(_, _, _, BuildRight, _, _, right, _) =>
right.sameResult(adaptivePlan.executedPlan)
case _ => false
}.isDefined

if(canReuseExchange) {
val mode = HashedRelationBroadcastMode(packedKeys)
// plan a broadcast exchange of the build side of the join
val exchange = BroadcastExchangeExec(mode, adaptivePlan.executedPlan)
exchange.setLogicalLink(adaptivePlan.executedPlan.logicalLink.get)
val newAdaptivePlan = AdaptiveSparkPlanExec(
exchange, adaptivePlan.context, adaptivePlan.preprocessingRules, true)

val broadcastValues = SubqueryBroadcastExec(
name, index, buildKeys, newAdaptivePlan)
DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else {
DynamicPruningExpression(Literal.TrueLiteral)
Expand Down

0 comments on commit 3bc4baf

Please sign in to comment.