diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 81de086e78f91..4cbff62e16cc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -183,7 +183,7 @@ object ResolveHints { val hintName = hint.name.toUpperCase(Locale.ROOT) def createRepartitionByExpression( - numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = { + numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = { val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder]) if (sortOrders.nonEmpty) throw new IllegalArgumentException( s"""Invalid partitionExprs specified: $sortOrders @@ -208,11 +208,11 @@ object ResolveHints { throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter") case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(numPartitions: Int, _*) if shuffle => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(_*) if shuffle => - createRepartitionByExpression(conf.numShufflePartitions, param) + createRepartitionByExpression(None, param) } } @@ -224,7 +224,7 @@ object ResolveHints { val hintName = hint.name.toUpperCase(Locale.ROOT) def createRepartitionByExpression( - numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = { + numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = { val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute]) if (invalidParams.nonEmpty) { throw new AnalysisException(s"$hintName Hint parameter should include columns, but " + @@ -239,11 +239,11 @@ object ResolveHints { hint.parameters match { case param @ Seq(IntegerLiteral(numPartitions), _*) => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(numPartitions: Int, _*) => - createRepartitionByExpression(numPartitions, param.tail) + createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(_*) => - createRepartitionByExpression(conf.numShufflePartitions, param) + createRepartitionByExpression(None, param) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index d3bd5d07a0932..513f1d001f757 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest { checkAnalysis( UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression( - Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions)) + Seq(AttributeReference("a", IntegerType)()), testRelation, None)) val e = intercept[IllegalArgumentException] { checkAnalysis( @@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest { "REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")), RepartitionByExpression( Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), - testRelation, conf.numShufflePartitions)) + testRelation, None)) val errMsg2 = "REPARTITION Hint parameter should include columns, but" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 4814f627e3748..f4567b8c39b86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -740,7 +740,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx: QueryOrganizationContext, expressions: Seq[Expression], query: LogicalPlan): LogicalPlan = { - RepartitionByExpression(expressions, query, conf.numShufflePartitions) + RepartitionByExpression(expressions, query, None) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 343d3c1c13469..79dc516b226bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -209,20 +209,20 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual(s"$baseSql distribute by a, b", RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions)) + None)) assertEqual(s"$baseSql distribute by a sort by b", Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + None))) assertEqual(s"$baseSql cluster by a, b", Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + None))) } test("pipeline concatenation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 9cfff10d59c36..c9992e484e672 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.apache.log4j.Level import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} -import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy} +import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.adaptive.OptimizeLocalShuffleReader.LOCAL_SHUFFLE_READER_DESCRIPTION @@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader)) } + private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = { + // repartition obeys initialPartitionNum when adaptiveExecutionEnabled + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { + case s: ShuffleExchangeExec => s + } + assert(shuffle.size == 1) + assert(shuffle(0).outputPartitioning.numPartitions == numPartition) + } + test("Change merge join to broadcast join") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", @@ -892,14 +903,8 @@ class AdaptiveQueryExecSuite assert(partitionsNum1 < 10) assert(partitionsNum2 < 10) - // repartition obeys initialPartitionNum when adaptiveExecutionEnabled - val plan = df1.queryExecution.executedPlan - assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) - val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { - case s: ShuffleExchangeExec => s - } - assert(shuffle.size == 1) - assert(shuffle(0).outputPartitioning.numPartitions == 10) + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) } else { assert(partitionsNum1 === 10) assert(partitionsNum2 === 10) @@ -933,14 +938,8 @@ class AdaptiveQueryExecSuite assert(partitionsNum1 < 10) assert(partitionsNum2 < 10) - // repartition obeys initialPartitionNum when adaptiveExecutionEnabled - val plan = df1.queryExecution.executedPlan - assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) - val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { - case s: ShuffleExchangeExec => s - } - assert(shuffle.size == 1) - assert(shuffle(0).outputPartitioning.numPartitions == 10) + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) } else { assert(partitionsNum1 === 10) assert(partitionsNum2 === 10) @@ -966,4 +965,52 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") { + Seq(true, false).foreach { enableAQE => + withTempView("test") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + spark.range(10).toDF.createTempView("test") + + val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test") + val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test") + val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id") + val df4 = spark.sql("SELECT * from test CLUSTER BY id") + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + val partitionsNum3 = df3.rdd.collectPartitions().length + val partitionsNum4 = df4.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + assert(partitionsNum3 < 10) + assert(partitionsNum4 < 10) + + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) + checkInitialPartitionNum(df3, 10) + checkInitialPartitionNum(df4, 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) + assert(partitionsNum3 === 10) + assert(partitionsNum4 === 10) + } + + // Don't coalesce partitions if the number of partitions is specified. + val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test") + val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test") + assert(df5.rdd.collectPartitions().length == 10) + assert(df6.rdd.collectPartitions().length == 10) + } + } + } + } }