Skip to content

Commit

Permalink
Coalesce partitions for repartition by key when AQE is enabled.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 23, 2020
1 parent 338efee commit 0a9223f
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.random.RandomSampler

Expand Down Expand Up @@ -953,16 +954,18 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
}

/**
* This method repartitions data using [[Expression]]s into `numPartitions`, and receives
* This method repartitions data using [[Expression]]s into `optNumPartitions`, and receives
* information about the number of partitions during execution. Used when a specific ordering or
* distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
* `coalesce` and `repartition`.
* `coalesce` and `repartition`. If no `optNumPartitions` is given, by default it partitions data
* into `numShufflePartitions` defined in `SQLConf`, and could be coalesced by AQE.
*/
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Int) extends RepartitionOperation {
optNumPartitions: Option[Int]) extends RepartitionOperation {

val numPartitions = optNumPartitions.getOrElse(SQLConf.get.numShufflePartitions)
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")

val partitioning: Partitioning = {
Expand Down Expand Up @@ -990,6 +993,15 @@ case class RepartitionByExpression(
override def shuffle: Boolean = true
}

object RepartitionByExpression {
def apply(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Int): RepartitionByExpression = {
RepartitionByExpression(partitionExpressions, child, Some(numPartitions))
}
}

/**
* A relation with one row. This is used in "SELECT ..." without a from clause.
*/
Expand Down
54 changes: 33 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2991,17 +2991,9 @@ class Dataset[T] private[sql](
Repartition(numPartitions, shuffle = true, logicalPlan)
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
private def repartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Column*): Dataset[T] = {
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to the new method.
Expand All @@ -3015,6 +3007,20 @@ class Dataset[T] private[sql](
}
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
repartitionByExpression(Some(numPartitions), partitionExprs: _*)
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions, using
* `spark.sql.shuffle.partitions` as number of partitions.
Expand All @@ -3027,7 +3033,20 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = {
repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
repartitionByExpression(None, partitionExprs: _*)
}

private def repartitionByRange(
numPartitions: Option[Int],
partitionExprs: Column*): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
withTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
}

/**
Expand All @@ -3049,14 +3068,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
withTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
repartitionByRange(Some(numPartitions), partitionExprs: _*)
}

/**
Expand All @@ -3078,7 +3090,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
repartitionByRange(None, partitionExprs: _*)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case r: logical.RepartitionByExpression =>
val canChangeNumParts = r.optNumPartitions.isEmpty
exchange.ShuffleExchangeExec(
r.partitioning, planLater(r.child), canChangeNumPartitions = false) :: Nil
r.partitioning, planLater(r.child), canChangeNumPartitions = canChangeNumParts) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1026,15 +1026,48 @@ class AdaptiveQueryExecSuite
Seq(true, false).foreach { enableAQE =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "6",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") {
val partitionsNum = spark.range(10).repartition($"id").rdd.collectPartitions().length
val df = spark.range(10).repartition($"id")
val partitionsNum = df.rdd.collectPartitions().length
if (enableAQE) {
assert(partitionsNum === 7)
assert(partitionsNum < 6)

val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 7)
} else {
assert(partitionsNum === 6)
}
}
}
}

test("SPARK-32056 coalesce partitions for repartition by expressions when AQE is enabled") {
Seq(true, false).foreach { enableAQE =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "50",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
val partitionsNum1 = (1 to 10).toDF.repartition($"value")
.rdd.collectPartitions().length

val partitionsNum2 = (1 to 10).toDF.repartitionByRange($"value".asc)
.rdd.collectPartitions().length
if (enableAQE) {
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
}
}
}
}
}

0 comments on commit 0a9223f

Please sign in to comment.