Skip to content

Commit

Permalink
[SPARK-47430][SQL] Rework group by map type to fix bind reference exc…
Browse files Browse the repository at this point in the history
…eption

### What changes were proposed in this pull request?

This pr reworks the group by map type to fix issues:
- Can not bind reference excpetion at runtume since the attribute was wrapped by `MapSort` and we didi not transform the plan with new output
- The add `MapSort` rule should be put before `PullOutGroupingExpressions` to avoid complex expr existing in grouping keys

### Why are the changes needed?

To fix issues.

for example:
```
select map(1, id) from range(10) group by map(1, id);

[INTERNAL_ERROR] Couldn't find _groupingexpression#18 in [mapsort(_groupingexpression#18)apache#19] SQLSTATE: XX000
org.apache.spark.SparkException: [INTERNAL_ERROR] Couldn't find _groupingexpression#18 in [mapsort(_groupingexpression#18)apache#19] SQLSTATE: XX000
	at org.apache.spark.SparkException$.internalError(SparkException.scala:92)
	at org.apache.spark.SparkException$.internalError(SparkException.scala:96)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:81)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:74)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:470)
```

### Does this PR introduce _any_ user-facing change?

no, not released

### How was this patch tested?

improve the tests to add more cases

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#47545 from ulysses-you/maptype.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: youxiduo <youxiduo@corp.netease.com>
  • Loading branch information
ulysses-you authored and IvanK-db committed Sep 19, 2024
1 parent 42d9730 commit 359792f
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,73 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps

/**
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be
* Adds [[MapSort]] to group expressions containing map columns, as the key/value paris need to be
* in the correct order before grouping:
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
*
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
* ) GROUP BY _groupingmapsort
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case a @ Aggregate(groupingExpr, _, _) =>
val newGrouping = groupingExpr.map { expr =>
if (!expr.exists(_.isInstanceOf[MapSort])
&& expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
insertMapSortRecursively(expr)
} else {
expr
private def shouldAddMapSort(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.containsPattern(TreePattern.AGGREGATE)) {
return plan
}
val shouldRewrite = plan.exists {
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
case _ => false
}
if (!shouldRewrite) {
return plan
}

plan transformUpWithNewOutput {
case agg @ Aggregate(groupingExprs, aggregateExpressions, child)
if agg.groupingExpressions.exists(shouldAddMapSort) =>
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
val newGroupingKeys = groupingExprs.map { expr =>
val inserted = insertMapSortRecursively(expr)
if (expr.ne(inserted)) {
exprToMapSort.getOrElseUpdate(
expr.canonicalized,
Alias(inserted, "_groupingmapsort")()
).toAttribute
} else {
expr
}
}
}
a.copy(groupingExpressions = newGrouping)
val newAggregateExprs = aggregateExpressions.map {
case named if exprToMapSort.contains(named.canonicalized) =>
// If we replace the top-level named expr, then should add back the original name
exprToMapSort(named.canonicalized).toAttribute.withName(named.name)
case other =>
other.transformUp {
case e => exprToMapSort.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
}.asInstanceOf[NamedExpression]
}
val newChild = Project(child.output ++ exprToMapSort.values, child)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
newAgg -> agg.output.zip(newAgg.output)
}
}

/*
Inserts MapSort recursively taking into account when
it is nested inside a struct or array.
/**
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
*/
private def insertMapSortRecursively(e: Expression): Expression = {
e.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
}

val batches = (
Batch("Finish Analysis", Once, FinishAnalysis) ::
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", Once, RewriteWithExpression) ::
Expand Down Expand Up @@ -246,8 +246,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveRedundantAliases,
RemoveNoopOperators) :+
Batch("InsertMapSortInGroupingExpressions", Once,
InsertMapSortInGroupingExpressions) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
Expand Down Expand Up @@ -297,6 +295,10 @@ abstract class Optimizer(catalogManager: CatalogManager)
ReplaceExpressions,
RewriteNonCorrelatedExists,
PullOutGroupingExpressions,
// Put `InsertMapSortInGroupingExpressions` after `PullOutGroupingExpressions`,
// so the grouping keys can only be attribute and literal which makes
// `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
InsertMapSortInGroupingExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2162,8 +2162,9 @@ class DataFrameAggregateSuite extends QueryTest
)
}

private def assertAggregateOnDataframe(df: DataFrame,
expected: Int, aggregateColumn: String): Unit = {
private def assertAggregateOnDataframe(
df: => DataFrame,
expected: Int): Unit = {
val configurations = Seq(
Seq.empty[(String, String)], // hash aggregate is used by default
Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN",
Expand All @@ -2175,32 +2176,64 @@ class DataFrameAggregateSuite extends QueryTest
Seq("spark.sql.test.forceApplySortAggregate" -> "true")
)

for (conf <- configurations) {
withSQLConf(conf: _*) {
assert(createAggregate(df).count() == expected)
// Make tests faster
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
for (conf <- configurations) {
withSQLConf(conf: _*) {
assert(df.count() == expected, df.queryExecution.simpleString)
}
}
}

def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*"))
}

test("SPARK-47430 Support GROUP BY MapType") {
val numRows = 50

val dfSameInt = (0 until numRows)
.map(_ => Tuple1(Map(1 -> 1)))
.toDF("m0")
assertAggregateOnDataframe(dfSameInt, 1, "m0")

val dfSameFloat = (0 until numRows)
.map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 )))
.toDF("m0")
assertAggregateOnDataframe(dfSameFloat, 1, "m0")

val dfDifferent = (0 until numRows)
.map(i => Tuple1(Map(i -> i)))
.toDF("m0")
assertAggregateOnDataframe(dfDifferent, numRows, "m0")
def genMapData(dataType: String): String = {
s"""
|case when id % 4 == 0 then map()
|when id % 4 == 1 then map(cast(0 as $dataType), cast(0 as $dataType))
|when id % 4 == 2 then map(cast(0 as $dataType), cast(0 as $dataType),
| cast(1 as $dataType), cast(1 as $dataType))
|else map(cast(1 as $dataType), cast(1 as $dataType),
| cast(0 as $dataType), cast(0 as $dataType))
|end
|""".stripMargin
}
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
withTempView("v") {
spark.range(20)
.selectExpr(
s"cast(1 as $dt) as c1",
s"${genMapData(dt)} as c2",
"map(c1, null) as c3",
s"cast(null as map<$dt, $dt>) as c4")
.createOrReplaceTempView("v")

assertAggregateOnDataframe(
spark.sql("SELECT count(*) FROM v GROUP BY c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT c2, count(*) FROM v GROUP BY c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT c1, c2, count(*) FROM v GROUP BY c1, c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT map(c1, c1) FROM v GROUP BY map(c1, c1)"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT map(c1, c1), count(*) FROM v GROUP BY map(c1, c1)"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c3, count(*) FROM v GROUP BY c3"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c4, count(*) FROM v GROUP BY c4"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c1, c2, c3, c4, count(*) FROM v GROUP BY c1, c2, c3, c4"),
3)
}
}
}

test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
Expand All @@ -2209,12 +2242,16 @@ class DataFrameAggregateSuite extends QueryTest
val dfSame = (0 until numRows)
.map(_ => Tuple1(new CalendarInterval(1, 2, 3)))
.toDF("c0")
assertAggregateOnDataframe(dfSame, 1, "c0")
.groupBy($"c0")
.count()
assertAggregateOnDataframe(dfSame, 1)

val dfDifferent = (0 until numRows)
.map(i => Tuple1(new CalendarInterval(i, i, i)))
.toDF("c0")
assertAggregateOnDataframe(dfDifferent, numRows, "c0")
.groupBy($"c0")
.count()
assertAggregateOnDataframe(dfDifferent, numRows)
}

test("SPARK-46779: Group by subquery with a cached relation") {
Expand Down

0 comments on commit 359792f

Please sign in to comment.