Skip to content

Commit

Permalink
[SPARK-49261][SQL] Don't replace literals in aggregate expressions wi…
Browse files Browse the repository at this point in the history
…th group-by expressions

### What changes were proposed in this pull request?

Before this PR, `RewriteDistinctAggregates` could potentially replace literals in the aggregate expressions with output attributes from the `Expand` operator. This can occur when a group-by expression is a literal that happens by chance to match a literal used in an aggregate expression. E.g.:

```
create or replace temp view v1(a, b, c) as values
(1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4);

cache table v1;

select
  round(sum(b), 6) as sum1,
  count(distinct a) as count1,
  count(distinct c) as count2
from (
  select
    6 as gb,
    *
  from v1
)
group by a, gb;
```
In the optimized plan, you can see that the literal 6 in the `round` function invocation has been patched with an output attribute (6#163) from the `Expand` operator:
```
== Optimized Logical Plan ==
'Aggregate [a#123, 6#163], [round(first(sum(__auto_generated_subquery_name.b)#167, true) FILTER (WHERE (gid#162 = 0)), 6#163) AS sum1#114, count(__auto_generated_subquery_name.a#164) FILTER (WHERE (gid#162 = 1)) AS count1#115L, count(__auto_generated_subquery_name.c#165) FILTER (WHERE (gid#162 = 2)) AS count2#116L]
+- Aggregate [a#123, 6#163, __auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, gid#162], [a#123, 6#163, __auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, gid#162, sum(__auto_generated_subquery_name.b#166) AS sum(__auto_generated_subquery_name.b)#167]
   +- Expand [[a#123, 6, null, null, 0, b#124], [a#123, 6, a#123, null, 1, null], [a#123, 6, null, c#125, 2, null]], [a#123, 6#163, __auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, gid#162, __auto_generated_subquery_name.b#166]
      +- InMemoryRelation [a#123, b#124, c#125], StorageLevel(disk, memory, deserialized, 1 replicas)
            +- LocalTableScan [a#6, b#7, c#8]
```
This is because the literal 6 was used in the group-by expressions (referred to as gb in the query, and renamed 6#163 in the `Expand` operator's output attributes).

After this PR, foldable expressions in the aggregate expressions are kept as-is.

### Why are the changes needed?

Some expressions require a foldable argument. In the above example, the `round` function requires a foldable expression as the scale argument. Because the scale argument is patched with an attribute, `RoundBase#checkInputDataTypes` returns an error, which leaves the `Aggregate` operator unresolved:
```
[INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: XX000
org.apache.spark.sql.catalyst.analysis.UnresolvedException: [INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: XX000
	at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:255)
	at org.apache.spark.sql.catalyst.types.DataTypeUtils$.$anonfun$fromAttributes$1(DataTypeUtils.scala:241)
	at scala.collection.immutable.List.map(List.scala:247)
	at scala.collection.immutable.List.map(List.scala:79)
	at org.apache.spark.sql.catalyst.types.DataTypeUtils$.fromAttributes(DataTypeUtils.scala:241)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.schema$lzycompute(QueryPlan.scala:428)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.schema(QueryPlan.scala:428)
	at org.apache.spark.sql.execution.SparkPlan.executeCollectPublic(SparkPlan.scala:474)
        ...
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47876 from bersprockets/group_by_lit_issue.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
(cherry picked from commit 1a0791d)
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
bersprockets authored and dongjoon-hyun committed Sep 12, 2024
1 parent 96eebeb commit 560efed
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
(distinctAggOperatorMap.flatMap(_._2) ++
regularAggOperatorMap.map(e => (e._1, e._3))).toMap

val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable)
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case e: Expression =>
// The same GROUP BY clauses can have different forms (different names for instance) in
// the groupBy and aggregate expressions of an aggregate. This makes a map lookup
// tricky. So we do a linear search for a semantically equal group by expression.
groupByMap
groupByMapNonFoldable
.find(ge => e.semanticEquals(ge._1))
.map(_._2)
.getOrElse(transformations.getOrElse(e, e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.{Literal, Round}
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
Expand Down Expand Up @@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}
}

test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") {
val relation = testRelation2
.select(Literal(6).as("gb"), $"a", $"b", $"c", $"d")
val input = relation
.groupBy($"a", $"gb")(
countDistinct($"b").as("agg1"),
countDistinct($"d").as("agg2"),
Round(sum($"c").as("sum1"), 6)).analyze
val rewriteFold = FoldablePropagation(input)
// without the fix, the below produces an unresolved plan
val rewrite = RewriteDistinctAggregates(rewriteFold)
if (!rewrite.resolved) {
fail(s"Plan is not as expected:\n$rewrite")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,27 @@ class DataFrameAggregateSuite extends QueryTest
})
}
}

test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") {
val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c")
withTempView("v1") {
data.createOrReplaceTempView("v1")
val df =
sql("""SELECT
| ROUND(SUM(b), 6) AS sum1,
| COUNT(DISTINCT a) AS count1,
| COUNT(DISTINCT c) AS count2
|FROM (
| SELECT
| 6 AS gb,
| *
| FROM v1
|)
|GROUP BY a, gb
|""".stripMargin)
checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil)
}
}
}

case class B(c: Option[Double])
Expand Down

0 comments on commit 560efed

Please sign in to comment.