Skip to content

Commit

Permalink
[SPARK-40382][SQL] Group distinct aggregate expressions by semantical…
Browse files Browse the repository at this point in the history
…ly equivalent children in `RewriteDistinctAggregates`

### What changes were proposed in this pull request?

In `RewriteDistinctAggregates`, when grouping aggregate expressions by function children, treat children that are semantically equivalent as the same.

### Why are the changes needed?

This PR will reduce the number of projections in the Expand operator when there are multiple distinct aggregations with superficially different children. In some cases, it will eliminate the need for an Expand operator.

Example: In the following query, the Expand operator creates 3\*n rows (where n is the number of incoming rows) because it has a projection for each of function children `b + 1`, `1 + b` and `c`.

```
create or replace temp view v1 as
select * from values
(1, 2, 3.0),
(1, 3, 4.0),
(2, 4, 2.5),
(2, 3, 1.0)
v1(a, b, c);

select
  a,
  count(distinct b + 1),
  avg(distinct 1 + b) filter (where c > 0),
  sum(c)
from
  v1
group by a;
```
The Expand operator has three projections (each producing a row for each incoming row):
```
[a#87, null, null, 0, null, UnscaledValue(c#89)], <== projection #1 (for regular aggregation)
[a#87, (b#88 + 1), null, 1, null, null],          <== projection #2 (for distinct aggregation of b + 1)
[a#87, null, (1 + b#88), 2, (c#89 > 0.0), null]], <== projection #3 (for distinct aggregation of 1 + b)
```
In reality, the Expand only needs one projection for `1 + b` and `b + 1`, because they are semantically equivalent.

With the proposed change, the Expand operator's projections look like this:
```
[a#67, null, 0, null, UnscaledValue(c#69)],  <== projection #1 (for regular aggregations)
[a#67, (b#68 + 1), 1, (c#69 > 0.0), null]],  <== projection #2 (for distinct aggregation on b + 1 and 1 + b)
```
With one less projection, Expand produces 2\*n rows instead of 3\*n rows, but still produces the correct result.

In the case where all distinct aggregates have semantically equivalent children, the Expand operator is not needed at all.

Benchmark code in the JIRA (SPARK-40382).

Before the PR:
```
distinct aggregates:                      Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
all semantically equivalent                       14721          14859         195          5.7         175.5       1.0X
some semantically equivalent                      14569          14572           5          5.8         173.7       1.0X
none semantically equivalent                      14408          14488         113          5.8         171.8       1.0X
```
After the PR:
```
distinct aggregates:                      Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
all semantically equivalent                        3658           3692          49         22.9          43.6       1.0X
some semantically equivalent                       9124           9214         127          9.2         108.8       0.4X
none semantically equivalent                      14601          14777         250          5.7         174.1       0.3X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New unit tests.

Closes #37825 from bersprockets/rewritedistinct_issue.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
bersprockets authored and cloud-fan committed Oct 13, 2022
1 parent 9bc8c06 commit 6e0ef86
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
val unfoldableChildren = ExpressionSet(e.aggregateFunction.children.filter(!_.foldable))
if (unfoldableChildren.nonEmpty) {
// Only expand the unfoldable children
unfoldableChildren
Expand All @@ -231,7 +231,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// count(distinct 1) will be explained to count(1) after the rewrite function.
// Generally, the distinct aggregateFunction should not run
// foldable TypeCheck for the first child.
e.aggregateFunction.children.take(1).toSet
ExpressionSet(e.aggregateFunction.children.take(1))
}
}

Expand All @@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
val distinctAggChildAttrMap = distinctAggChildren.map { e =>
e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)()
}
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
// Setup all the filters in distinct aggregate.
val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect {
Expand Down Expand Up @@ -292,7 +294,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
af
} else {
patchAggregateFunctionChildren(af) { x =>
distinctAggChildAttrLookup.get(x)
distinctAggChildAttrLookup.get(x.canonicalized)
}
}
val newCondition = if (condition.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,37 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
.analyze
checkRewrite(RewriteDistinctAggregates(input))
}

test("SPARK-40382: eliminate multiple distinct groups due to superficial differences") {
val input = testRelation
.groupBy($"a")(
countDistinct($"b" + $"c").as("agg1"),
countDistinct($"c" + $"b").as("agg2"),
max($"c").as("agg3"))
.analyze

val rewrite = RewriteDistinctAggregates(input)
rewrite match {
case Aggregate(_, _, LocalRelation(_, _, _)) =>
case _ => fail(s"Plan is not as expected:\n$rewrite")
}
}

test("SPARK-40382: reduce multiple distinct groups due to superficial differences") {
val input = testRelation
.groupBy($"a")(
countDistinct($"b" + $"c" + $"d").as("agg1"),
countDistinct($"d" + $"c" + $"b").as("agg2"),
countDistinct($"b" + $"c").as("agg3"),
countDistinct($"c" + $"b").as("agg4"),
max($"c").as("agg5"))
.analyze

val rewrite = RewriteDistinctAggregates(input)
rewrite match {
case Aggregate(_, _, Aggregate(_, _, e: Expand)) =>
assert(e.projections.size == 3)
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(
_.aggregateFunction.children.filterNot(_.foldable).toSet).distinct.length > 1) {
val distinctAggChildSets = functionsWithDistinct.map { ae =>
ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable))
}.distinct
if (distinctAggChildSets.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets. Our `RewriteDistinctAggregates` should take care this case.
throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,17 @@ object AggUtils {
}

// 3. Create an Aggregate operator for partial aggregation (for distinct)
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes)
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized),
distinctAttributes)
val rewrittenDistinctFunctions = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) =>
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
aggregateFunction.transformDown {
case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) =>
distinctColumnAttributeLookup(e.canonicalized)
}.asInstanceOf[AggregateFunction]
case agg =>
throw new IllegalArgumentException(
"Non-distinct aggregate is found in functionsWithDistinct " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,40 @@ class DataFrameAggregateSuite extends QueryTest
val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id")
checkAnswer(df, Row(2, 3, 1))
}

test("SPARK-40382: Distinct aggregation expression grouping by semantic equivalence") {
Seq(
(1, 1, 3),
(1, 2, 3),
(1, 2, 3),
(2, 1, 1),
(2, 2, 5)
).toDF("k", "c1", "c2").createOrReplaceTempView("df")

// all distinct aggregation children are semantically equivalent
val res1 = sql(
"""select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1)
|from df
|group by k
|""".stripMargin)
checkAnswer(res1, Row(1, 5, 2.5, 2) :: Row(2, 5, 2.5, 2) :: Nil)

// some distinct aggregation children are semantically equivalent
val res2 = sql(
"""select k, sum(distinct c1 + 2), avg(distinct 2 + c1), count(distinct c2)
|from df
|group by k
|""".stripMargin)
checkAnswer(res2, Row(1, 7, 3.5, 1) :: Row(2, 7, 3.5, 2) :: Nil)

// no distinct aggregation children are semantically equivalent
val res3 = sql(
"""select k, sum(distinct c1 + 2), avg(distinct 3 + c1), count(distinct c2)
|from df
|group by k
|""".stripMargin)
checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil)
}
}

case class B(c: Option[Double])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
// 2 distinct columns with different order
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
assertNoExpand(query3.queryExecution.executedPlan)

// SPARK-40382: 1 distinct expression with cosmetic differences
val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i")
assertNoExpand(query4.queryExecution.executedPlan)
}
}

Expand Down

0 comments on commit 6e0ef86

Please sign in to comment.