Skip to content

Commit

Permalink
[SPARK-32056][SQL] Coalesce partitions for repartition by expressions…
Browse files Browse the repository at this point in the history
… when AQE is enabled

This patch proposes to coalesce partitions for repartition by expressions without specifying number of partitions, when AQE is enabled.

When repartition by some partition expressions, users can specify number of partitions or not. If  the number of partitions is specified, we should not coalesce partitions because it breaks user expectation. But if without specifying number of partitions, AQE should be able to coalesce partitions as other shuffling.

Yes. After this change, if users don't specify the number of partitions when repartitioning data by expressions, AQE will coalesce partitions.

Added unit test.

Closes apache#28900 from viirya/SPARK-32056.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
viirya authored and LorenzoMartini committed May 19, 2021
1 parent f099edb commit f9bc035
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.random.RandomSampler

Expand Down Expand Up @@ -948,16 +949,18 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
}

/**
* This method repartitions data using [[Expression]]s into `numPartitions`, and receives
* This method repartitions data using [[Expression]]s into `optNumPartitions`, and receives
* information about the number of partitions during execution. Used when a specific ordering or
* distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
* `coalesce` and `repartition`.
* `coalesce` and `repartition`. If no `optNumPartitions` is given, by default it partitions data
* into `numShufflePartitions` defined in `SQLConf`, and could be coalesced by AQE.
*/
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Int) extends RepartitionOperation {
optNumPartitions: Option[Int]) extends RepartitionOperation {

val numPartitions = optNumPartitions.getOrElse(SQLConf.get.numShufflePartitions)
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")

val partitioning: Partitioning = {
Expand Down Expand Up @@ -985,6 +988,15 @@ case class RepartitionByExpression(
override def shuffle: Boolean = true
}

object RepartitionByExpression {
def apply(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Int): RepartitionByExpression = {
RepartitionByExpression(partitionExpressions, child, Some(numPartitions))
}
}

/**
* A relation with one row. This is used in "SELECT ..." without a from clause.
*/
Expand Down
54 changes: 33 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2991,17 +2991,9 @@ class Dataset[T] private[sql](
Repartition(numPartitions, shuffle = true, logicalPlan)
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
private def repartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to the new method.
Expand All @@ -3015,6 +3007,20 @@ class Dataset[T] private[sql](
}
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
repartitionByExpression(Some(numPartitions), partitionExprs)
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions, using
* `spark.sql.shuffle.partitions` as number of partitions.
Expand All @@ -3027,7 +3033,20 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = {
repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
repartitionByExpression(None, partitionExprs)
}

private def repartitionByRange(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
withTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
}

/**
Expand All @@ -3049,14 +3068,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
withTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
repartitionByRange(Some(numPartitions), partitionExprs)
}

/**
Expand All @@ -3078,7 +3090,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
repartitionByRange(None, partitionExprs)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case r: logical.RepartitionByExpression =>
val canChangeNumParts = r.optNumPartitions.isEmpty
exchange.ShuffleExchangeExec(
r.partitioning, planLater(r.child), noUserSpecifiedNumPartition = false) :: Nil
r.partitioning, planLater(r.child), canChangeNumParts) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -874,18 +874,81 @@ class AdaptiveQueryExecSuite
}
}

test("SPARK-31220 repartition obeys initialPartitionNum when adaptiveExecutionEnabled") {
test("SPARK-31220, SPARK-32056: repartition by expression with AQE") {
Seq(true, false).foreach { enableAQE =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "6",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") {
val partitionsNum = spark.range(10).repartition($"id").rdd.collectPartitions().length
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {

val df1 = spark.range(10).repartition($"id")
val df2 = spark.range(10).repartition($"id" + 1)

val partitionsNum1 = df1.rdd.collectPartitions().length
val partitionsNum2 = df2.rdd.collectPartitions().length

if (enableAQE) {
assert(partitionsNum === 7)
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
} else {
assert(partitionsNum === 6)
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
}


// Don't coalesce partitions if the number of partitions is specified.
val df3 = spark.range(10).repartition(10, $"id")
val df4 = spark.range(10).repartition(10)
assert(df3.rdd.collectPartitions().length == 10)
assert(df4.rdd.collectPartitions().length == 10)
}
}
}

test("SPARK-31220, SPARK-32056: repartition by range with AQE") {
Seq(true, false).foreach { enableAQE =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {

val df1 = spark.range(10).toDF.repartitionByRange($"id".asc)
val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc)

val partitionsNum1 = df1.rdd.collectPartitions().length
val partitionsNum2 = df2.rdd.collectPartitions().length

if (enableAQE) {
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
}

// Don't coalesce partitions if the number of partitions is specified.
val df3 = spark.range(10).repartitionByRange(10, $"id".asc)
assert(df3.rdd.collectPartitions().length == 10)
}
}
}
Expand Down

0 comments on commit f9bc035

Please sign in to comment.