From e71b3553bc2ad1910a900d5a844941558252fdbd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Jul 2020 16:14:51 -0700 Subject: [PATCH] [SPARK-32056][SQL][FOLLOW-UP] Coalesce partitions for repartiotion hint and sql when AQE is enabled As the followup of #28900, this patch extends coalescing partitions to repartitioning using hints and SQL syntax without specifying number of partitions, when AQE is enabled. When repartitionning using hints and SQL syntax, we should follow the shuffling behavior of repartition by expression/range to coalesce partitions when AQE is enabled. Yes. After this change, if users don't specify the number of partitions when repartitioning using `REPARTITION`/`REPARTITION_BY_RANGE` hint or `DISTRIBUTE BY`/`CLUSTER BY`, AQE will coalesce partitions. Unit tests. Closes #28952 from viirya/SPARK-32056-sql. Authored-by: Liang-Chi Hsieh Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/analysis/ResolveHints.scala | 16 ++-- .../catalyst/analysis/ResolveHintsSuite.scala | 4 +- .../spark/sql/execution/SparkSqlParser.scala | 2 +- .../sql/execution/SparkSqlParserSuite.scala | 6 +- .../adaptive/AdaptiveQueryExecSuite.scala | 81 +++++++++++++++---- 5 files changed, 78 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 81de086e78f91..4cbff62e16cc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -183,7 +183,7 @@ object ResolveHints { val hintName = hint.name.toUpperCase(Locale.ROOT) def createRepartitionByExpression( - numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = { + numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = { val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder]) if (sortOrders.nonEmpty) throw new IllegalArgumentException( s"""Invalid partitionExprs specified: $sortOrders @@ -208,11 +208,11 @@ object ResolveHints { throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter") case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(numPartitions: Int, _*) if shuffle => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(_*) if shuffle => - createRepartitionByExpression(conf.numShufflePartitions, param) + createRepartitionByExpression(None, param) } } @@ -224,7 +224,7 @@ object ResolveHints { val hintName = hint.name.toUpperCase(Locale.ROOT) def createRepartitionByExpression( - numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = { + numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = { val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute]) if (invalidParams.nonEmpty) { throw new AnalysisException(s"$hintName Hint parameter should include columns, but " + @@ -239,11 +239,11 @@ object ResolveHints { hint.parameters match { case param @ Seq(IntegerLiteral(numPartitions), _*) => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(numPartitions: Int, _*) => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(_*) => - createRepartitionByExpression(conf.numShufflePartitions, param) + createRepartitionByExpression(None, param) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index d3bd5d07a0932..513f1d001f757 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest { checkAnalysis( UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression( - Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions)) + Seq(AttributeReference("a", IntegerType)()), testRelation, None)) val e = intercept[IllegalArgumentException] { checkAnalysis( @@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest { "REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression( Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), - testRelation, conf.numShufflePartitions)) + testRelation, None)) val errMsg2 = "REPARTITION Hint parameter should include columns, but" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 4814f627e3748..f4567b8c39b86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -740,7 +740,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx: QueryOrganizationContext, expressions: Seq[Expression], query: LogicalPlan): LogicalPlan = { - RepartitionByExpression(expressions, query, conf.numShufflePartitions) + RepartitionByExpression(expressions, query, None) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 343d3c1c13469..79dc516b226bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -209,20 +209,20 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual(s"$baseSql distribute by a, b", RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions)) + None)) assertEqual(s"$baseSql distribute by a sort by b", Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + None))) assertEqual(s"$baseSql cluster by a, b", Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + None))) } test("pipeline concatenation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 9cfff10d59c36..c9992e484e672 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.apache.log4j.Level import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} -import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy} +import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.adaptive.OptimizeLocalShuffleReader.LOCAL_SHUFFLE_READER_DESCRIPTION @@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader)) } + private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = { + // repartition obeys initialPartitionNum when adaptiveExecutionEnabled + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { + case s: ShuffleExchangeExec => s + } + assert(shuffle.size == 1) + assert(shuffle(0).outputPartitioning.numPartitions == numPartition) + } + test("Change merge join to broadcast join") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", @@ -892,14 +903,8 @@ class AdaptiveQueryExecSuite assert(partitionsNum1 < 10) assert(partitionsNum2 < 10) - // repartition obeys initialPartitionNum when adaptiveExecutionEnabled - val plan = df1.queryExecution.executedPlan - assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) - val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { - case s: ShuffleExchangeExec => s - } - assert(shuffle.size == 1) - assert(shuffle(0).outputPartitioning.numPartitions == 10) + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) } else { assert(partitionsNum1 === 10) assert(partitionsNum2 === 10) @@ -933,14 +938,8 @@ class AdaptiveQueryExecSuite assert(partitionsNum1 < 10) assert(partitionsNum2 < 10) - // repartition obeys initialPartitionNum when adaptiveExecutionEnabled - val plan = df1.queryExecution.executedPlan - assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) - val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { - case s: ShuffleExchangeExec => s - } - assert(shuffle.size == 1) - assert(shuffle(0).outputPartitioning.numPartitions == 10) + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) } else { assert(partitionsNum1 === 10) assert(partitionsNum2 === 10) @@ -966,4 +965,52 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") { + Seq(true, false).foreach { enableAQE => + withTempView("test") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + spark.range(10).toDF.createTempView("test") + + val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test") + val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test") + val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id") + val df4 = spark.sql("SELECT * from test CLUSTER BY id") + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + val partitionsNum3 = df3.rdd.collectPartitions().length + val partitionsNum4 = df4.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + assert(partitionsNum3 < 10) + assert(partitionsNum4 < 10) + + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) + checkInitialPartitionNum(df3, 10) + checkInitialPartitionNum(df4, 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) + assert(partitionsNum3 === 10) + assert(partitionsNum4 === 10) + } + + // Don't coalesce partitions if the number of partitions is specified. + val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test") + val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test") + assert(df5.rdd.collectPartitions().length == 10) + assert(df6.rdd.collectPartitions().length == 10) + } + } + } + } }