Skip to content

Commit

Permalink
add the InsertDynamicPruningFilters rule to plan the DPP filters when…
Browse files Browse the repository at this point in the history
… enable AQE
  • Loading branch information
JkSelf committed Mar 23, 2005
1 parent 89c74bb commit 0d78a62
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec

case class SubqueryAdaptiveBroadcastExec(
name: String,
index: Int,
buildKeys: Seq[Expression],
child: SparkPlan,
exchange: BroadcastExchangeExec) extends BaseSubqueryExec with UnaryExecNode {

protected override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException(
"SubqueryAdaptiveBroadcastExec does not support the execute() code path.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.{HashedRelation, HashJoin, LongHashedRelation}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.ThreadUtils
Expand All @@ -43,8 +42,7 @@ case class SubqueryBroadcastExec(
name: String,
index: Int,
buildKeys: Seq[Expression],
child: SparkPlan,
logicalPlan: Option[LogicalPlan] = None) extends BaseSubqueryExec with UnaryExecNode {
child: SparkPlan) extends BaseSubqueryExec with UnaryExecNode {

// `SubqueryBroadcastExec` is only used with `InSubqueryExec`. No one would reference this output,
// so the exprId doesn't matter here. But it's important to correctly report the output length, so
Expand Down Expand Up @@ -120,4 +118,4 @@ case class SubqueryBroadcastExec(
object SubqueryBroadcastExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("dynamicpruning", 16))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression, Literal}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
Expand Down Expand Up @@ -88,7 +88,8 @@ case class AdaptiveSparkPlanExec(
RemoveRedundantProjects,
EnsureRequirements,
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan
DisableUnnecessaryBucketedScan,
InsertDynamicPruningFilters(context.stageCache)
) ++ context.session.sessionState.queryStagePrepRules

// A list of physical optimizer rules to be applied to a new stage before its execution. These
Expand Down Expand Up @@ -139,6 +140,12 @@ case class AdaptiveSparkPlanExec(

private var currentStageId = 0

def stageId: Int = currentStageId

def setStageId(newStageId: Int): Unit = {
currentStageId = newStageId
}

/**
* Return type for `createQueryStages`
* @param newPlan the new plan with created query stages.
Expand Down Expand Up @@ -555,7 +562,6 @@ case class AdaptiveSparkPlanExec(
setTempTagRecursive(physicalNode.get, logicalNode)
// Replace the corresponding logical node with LogicalQueryStage
val newLogicalNode = LogicalQueryStage(logicalNode, physicalNode.get)
context.dppStageCache.getOrElseUpdate(logicalNode.canonicalized, physicalNode.get)
val newLogicalPlan = logicalPlan.transformDown {
case p if p.eq(logicalNode) => newLogicalNode
}
Expand All @@ -571,27 +577,6 @@ case class AdaptiveSparkPlanExec(
logicalPlan
}

private def insertDPPFilter(plan: SparkPlan): SparkPlan = {
plan transformAllExpressions {
case DynamicPruningExpression(InSubqueryExec(
value, broadcastValue: SubqueryBroadcastExec, exprId, _)) =>

val stage = context.dppStageCache.get(broadcastValue.logicalPlan.get.canonicalized)

if (conf.exchangeReuseEnabled && stage.nonEmpty) {
val name = s"dynamicpruning#${exprId.id}"
val bqs = stage.get.asInstanceOf[BroadcastQueryStageExec]
val newStage = reuseQueryStage(bqs, bqs.broadcast)
val broadcastValues =
SubqueryBroadcastExec(name, broadcastValue.index, broadcastValue.buildKeys, newStage)

DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else {
DynamicPruningExpression(Literal.TrueLiteral)
}
}
}

/**
* Re-optimize and run physical planning on the current logical plan based on the latest stats.
*/
Expand All @@ -603,8 +588,7 @@ case class AdaptiveSparkPlanExec(
sparkPlan,
preprocessingRules ++ queryStagePreparationRules,
Some((planChangeLogger, "AQE Replanning")))
val finalPlan = insertDPPFilter(newPlan)
(finalPlan, optimized)
(newPlan, optimized)
}

/**
Expand Down Expand Up @@ -734,9 +718,6 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
*/
val stageCache: TrieMap[SparkPlan, QueryStageExec] =
new TrieMap[SparkPlan, QueryStageExec]()

val dppStageCache: TrieMap[LogicalPlan, SparkPlan] =
new TrieMap[LogicalPlan, SparkPlan]()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ package org.apache.spark.sql.execution.adaptive
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, BindReferences, ListQuery, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange}
import org.apache.spark.sql.execution.joins.{HashedRelationBroadcastMode, HashJoin}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -132,12 +133,15 @@ case class InsertAdaptiveSparkPlan(
if !subqueryMap.contains(exprId.id) =>
val executedPlan = compileSubquery(buildPlan)
verifyAdaptivePlan(executedPlan, buildPlan)

val adaptivePlan = executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
val packedKeys = BindReferences.bindReferences(
HashJoin.rewriteKeyExpr(buildKeys), adaptivePlan.executedPlan.output)
val mode = HashedRelationBroadcastMode(packedKeys)
// plan a broadcast exchange of the build side of the join
val exchange = BroadcastExchangeExec(mode, adaptivePlan.executedPlan)
val name = s"dynamicpruning#${exprId.id}"
// place the broadcast adaptor for reusing the broadcast results on the probe side
val broadcastValues =
SubqueryBroadcastExec(name, broadcastKeyIndex, buildKeys, executedPlan, Some(buildPlan))

val broadcastValues = SubqueryAdaptiveBroadcastExec(
name, broadcastKeyIndex, buildKeys, adaptivePlan, exchange)
subqueryMap.put(exprId.id, broadcastValues)
case _ =>
}))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import scala.collection.concurrent.TrieMap

import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Literal}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._

case class InsertDynamicPruningFilters(
stageCache: TrieMap[SparkPlan, QueryStageExec]) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (!conf.dynamicPartitionPruningEnabled) {
return plan
}

plan transformAllExpressions {
case DynamicPruningExpression(InSubqueryExec(
value, SubqueryAdaptiveBroadcastExec(
name, index, buildKeys, adaptivePlan: AdaptiveSparkPlanExec, exchange), exprId, _)) =>

val existingStage = stageCache.get(exchange.canonicalized)
if (existingStage.nonEmpty && conf.exchangeReuseEnabled) {
val name = s"dynamicpruning#${exprId.id}"

val reuseQueryStage = existingStage.get.newReuseInstance(
adaptivePlan.stageId, exchange.output)
adaptivePlan.setStageId(adaptivePlan.stageId + 1)

// Set the logical link for the reuse query stage.
val link = exchange.getTagValue(AdaptiveSparkPlanExec.TEMP_LOGICAL_PLAN_TAG).orElse(
exchange.logicalLink.orElse(exchange.collectFirst {
case p if p.getTagValue(AdaptiveSparkPlanExec.TEMP_LOGICAL_PLAN_TAG).isDefined =>
p.getTagValue(AdaptiveSparkPlanExec.TEMP_LOGICAL_PLAN_TAG).get
case p if p.logicalLink.isDefined => p.logicalLink.get
}))
assert(link.isDefined)
reuseQueryStage.setLogicalLink(link.get)

val broadcastValues =
SubqueryBroadcastExec(name, index, buildKeys, reuseQueryStage)
DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else {
DynamicPruningExpression(Literal.TrueLiteral)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,7 @@ class DynamicPartitionPruningSuiteAEOn extends DynamicPartitionPruningSuiteBase
|SELECT f.date_id, f.store_id FROM fact_sk f
|JOIN dim_store s ON f.store_id = s.store_id AND s.country = 'NL'
""".stripMargin)
// df.show()
checkPartitionPruningPredicateWithAQE(df, false, true)

checkAnswer(df, Row(1000, 1) :: Row(1010, 2) :: Row(1020, 2) :: Nil)
Expand Down

0 comments on commit 0d78a62

Please sign in to comment.