From 80bef0d2c22bf91d216784f29839b95f22fb230f Mon Sep 17 00:00:00 2001 From: LantaoJin Date: Wed, 8 Jul 2020 11:33:56 +0800 Subject: [PATCH] add more agg exec --- .../adaptive/OptimizeSkewedJoin.scala | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index e32480574ef11..4b20e3d66dec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -25,7 +25,7 @@ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf @@ -133,13 +133,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { private def canSplitLeftSide(joinType: JoinType, plan: SparkPlan) = { (joinType == Inner || joinType == Cross || joinType == LeftSemi || - joinType == LeftAnti || joinType == LeftOuter) && - plan.find(_.isInstanceOf[HashAggregateExec]).isEmpty + joinType == LeftAnti || joinType == LeftOuter) && !containsAggregateExec(plan) } private def canSplitRightSide(joinType: JoinType, plan: SparkPlan) = { - (joinType == Inner || joinType == Cross || joinType == RightOuter) && - plan.find(_.isInstanceOf[HashAggregateExec]).isEmpty + (joinType == Inner || joinType == Cross || + joinType == RightOuter) && !containsAggregateExec(plan) + } + + private def containsAggregateExec(plan: SparkPlan) = { + plan.find { + case _: HashAggregateExec => true + case _: SortAggregateExec => true + case _: ObjectHashAggregateExec => true + case _ => false + }.isDefined } private def getSizeInfo(medianSize: Long, sizes: Seq[Long]): String = {