From 0950e9a99f26dc80243f87eb914b76b28d89d1be Mon Sep 17 00:00:00 2001 From: LantaoJin Date: Thu, 16 Jul 2020 13:31:25 +0800 Subject: [PATCH] fix ut --- .../spark/sql/execution/adaptive/OptimizeSkewedJoin.scala | 4 +--- .../spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 7c9560545dbfd..b6bb48ae9cc38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -199,9 +199,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { * 3 tasks separately. */ def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { - case smj @ SortMergeJoinExec(_, _, joinType, _, - s1 @ SortExec(_, _, _, _), - s2 @ SortExec(_, _, _, _), _) + case smj @ SortMergeJoinExec(_, _, joinType, _, s1: SortExec, s2: SortExec, _) if supportedJoinTypes.contains(joinType) => // find the shuffleStage from the plan tree val leftOpt = findShuffleStage(s1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 08c9cdb0b7a8a..ed7fe7a84b297 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -660,9 +660,9 @@ class AdaptiveQueryExecSuite checkSkewJoin( "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2", true) - // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" optimization + // After patched SPARK-32201, this query won't introduce additional shuffle anymore. checkSkewJoin( - "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 GROUP BY key1", false) + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 GROUP BY key1", true) } } }