Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-39915][SQL] Dataset.repartition(N) may not create N partitions Non-AQE part #37706

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ package object dsl {
def repartition(num: Integer): LogicalPlan =
Repartition(num, shuffle = true, logicalPlan)

def repartition(): LogicalPlan =
RepartitionByExpression(Seq.empty, logicalPlan, None)

def distribute(exprs: Expression*)(n: Int): LogicalPlan =
RepartitionByExpression(exprs, logicalPlan, numPartitions = n)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL}

/**
Expand All @@ -44,6 +45,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
// This tag is used to mark a repartition as a root repartition which is user-specified
private[sql] val ROOT_REPARTITION = TreeNodeTag[Unit]("ROOT_REPARTITION")

protected def isEmpty(plan: LogicalPlan): Boolean = plan match {
case p: LocalRelation => p.data.isEmpty
case _ => false
Expand Down Expand Up @@ -137,8 +141,13 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
case _: GlobalLimit if !p.isStreaming => empty(p)
case _: LocalLimit if !p.isStreaming => empty(p)
case _: Offset => empty(p)
case _: Repartition => empty(p)
case _: RepartitionByExpression => empty(p)
case _: RepartitionOperation =>
if (p.getTagValue(ROOT_REPARTITION).isEmpty) {
empty(p)
} else {
p.unsetTagValue(ROOT_REPARTITION)
p
}
case _: RebalancePartitions => empty(p)
// An aggregate with non-empty group expression will return one output row per group when the
// input to the aggregate is not empty. If the input to the aggregate is empty then all groups
Expand All @@ -162,13 +171,40 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
case _ => p
}
}

protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
case _: Repartition => true
case r: RepartitionByExpression
if r.optNumPartitions.isDefined || r.partitionExpressions.nonEmpty => true
case _ => false
}

protected def applyInternal(plan: LogicalPlan): LogicalPlan

/**
* Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so this rule can
* skip optimize it.
*/
private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: we can skip this earlier with something like if (!plan.containsPattern(REPARTITION))

case p: Project => p.mapChildren(addTagForRootRepartition)
case f: Filter => f.mapChildren(addTagForRootRepartition)
case r if userSpecifiedRepartition(r) =>
r.setTagValue(ROOT_REPARTITION, ())
r
case _ => plan
}

override def apply(plan: LogicalPlan): LogicalPlan = {
val planWithTag = addTagForRootRepartition(plan)
applyInternal(planWithTag)
}
}

/**
* This rule runs in the normal optimizer
*/
object PropagateEmptyRelation extends PropagateEmptyRelationBase {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
commonApplyFunc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,42 @@ class PropagateEmptyRelationSuite extends PlanTest {
.fromExternalRows(Seq($"a".int, $"b".int, $"window".long.withNullability(false)), Nil)
comparePlans(Optimize.execute(originalQuery.analyze), expected.analyze)
}

test("Propagate empty relation with repartition") {
val emptyRelation = LocalRelation($"a".int, $"b".int)
comparePlans(Optimize.execute(
emptyRelation.repartition(1).sortBy($"a".asc).analyze
), emptyRelation.analyze)

comparePlans(Optimize.execute(
emptyRelation.distribute($"a")(1).sortBy($"a".asc).analyze
), emptyRelation.analyze)

comparePlans(Optimize.execute(
emptyRelation.repartition().analyze
), emptyRelation.analyze)

comparePlans(Optimize.execute(
emptyRelation.repartition(1).sortBy($"a".asc).repartition().analyze
), emptyRelation.analyze)
}

test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
val emptyRelation = LocalRelation($"a".int, $"b".int)
val p1 = emptyRelation.repartition(1).analyze
comparePlans(Optimize.execute(p1), p1)

val p2 = emptyRelation.repartition(1).select($"a").analyze
comparePlans(Optimize.execute(p2), p2)

val p3 = emptyRelation.repartition(1).where($"a" > rand(1)).analyze
comparePlans(Optimize.execute(p3), p3)

val p4 = emptyRelation.repartition(1).where($"a" > rand(1)).select($"a").analyze
comparePlans(Optimize.execute(p4), p4)

val p5 = emptyRelation.sortBy("$a".asc).repartition().limit(1).repartition(1).analyze
val expected5 = emptyRelation.repartition(1).analyze
comparePlans(Optimize.execute(p5), expected5)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
empty(j)
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
// LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
// `PropagateEmptyRelationBase.commonApplyFunc`
// LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonApplyFunc`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3419,6 +3419,13 @@ class DataFrameSuite extends QueryTest
Row(java.sql.Date.valueOf("2020-02-01"), java.sql.Date.valueOf("2020-02-01")) ::
Row(java.sql.Date.valueOf("2020-01-01"), java.sql.Date.valueOf("2020-01-02")) :: Nil)
}

test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val df = spark.sql("select * from values(1) where 1 < rand()").repartition(2)
assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
}
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down