diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala index 51f6ca374909b..b2e8d79ef02cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala @@ -17,37 +17,73 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedLambdaVariable} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.types.{ArrayType, MapType, StructType} import org.apache.spark.util.ArrayImplicits.SparkArrayOps /** - * Adds MapSort to group expressions containing map columns, as the key/value paris need to be + * Adds [[MapSort]] to group expressions containing map columns, as the key/value paris need to be * in the correct order before grouping: - * SELECT COUNT(*) FROM TABLE GROUP BY map_column => - * SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column) + * + * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column => + * SELECT _groupingmapsort as map_column, COUNT(*) FROM ( + * SELECT map_sort(map_column) as _groupingmapsort FROM TABLE + * ) GROUP BY _groupingmapsort */ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsPattern(AGGREGATE), ruleId) { - case a @ Aggregate(groupingExpr, _, _) => - val newGrouping = groupingExpr.map { expr => - if (!expr.exists(_.isInstanceOf[MapSort]) - && expr.dataType.existsRecursively(_.isInstanceOf[MapType])) { - insertMapSortRecursively(expr) - } else { - expr + private def shouldAddMapSort(expr: Expression): Boolean = { + expr.dataType.existsRecursively(_.isInstanceOf[MapType]) + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.containsPattern(TreePattern.AGGREGATE)) { + return plan + } + val shouldRewrite = plan.exists { + case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true + case _ => false + } + if (!shouldRewrite) { + return plan + } + + plan transformUpWithNewOutput { + case agg @ Aggregate(groupingExprs, aggregateExpressions, child) + if agg.groupingExpressions.exists(shouldAddMapSort) => + val exprToMapSort = new mutable.HashMap[Expression, NamedExpression] + val newGroupingKeys = groupingExprs.map { expr => + val inserted = insertMapSortRecursively(expr) + if (expr.ne(inserted)) { + exprToMapSort.getOrElseUpdate( + expr.canonicalized, + Alias(inserted, "_groupingmapsort")() + ).toAttribute + } else { + expr + } } - } - a.copy(groupingExpressions = newGrouping) + val newAggregateExprs = aggregateExpressions.map { + case named if exprToMapSort.contains(named.canonicalized) => + // If we replace the top-level named expr, then should add back the original name + exprToMapSort(named.canonicalized).toAttribute.withName(named.name) + case other => + other.transformUp { + case e => exprToMapSort.get(e.canonicalized).map(_.toAttribute).getOrElse(e) + }.asInstanceOf[NamedExpression] + } + val newChild = Project(child.output ++ exprToMapSort.values, child) + val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild) + newAgg -> agg.output.zip(newAgg.output) + } } - /* - Inserts MapSort recursively taking into account when - it is nested inside a struct or array. + /** + * Inserts MapSort recursively taking into account when it is nested inside a struct or array. */ private def insertMapSortRecursively(e: Expression): Expression = { e.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 647812ff80e78..a2a26924885c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -150,7 +150,7 @@ abstract class Optimizer(catalogManager: CatalogManager) } val batches = ( - Batch("Finish Analysis", Once, FinishAnalysis) :: + Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) :: // We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression // may produce `With` expressions that need to be rewritten. Batch("Rewrite With expression", Once, RewriteWithExpression) :: @@ -246,8 +246,6 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveRedundantAliases, RemoveNoopOperators) :+ - Batch("InsertMapSortInGroupingExpressions", Once, - InsertMapSortInGroupingExpressions) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) @@ -297,6 +295,10 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceExpressions, RewriteNonCorrelatedExists, PullOutGroupingExpressions, + // Put `InsertMapSortInGroupingExpressions` after `PullOutGroupingExpressions`, + // so the grouping keys can only be attribute and literal which makes + // `InsertMapSortInGroupingExpressions` easy to insert `MapSort`. + InsertMapSortInGroupingExpressions, ComputeCurrentTime, ReplaceCurrentLike(catalogManager), SpecialDatetimeValues, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index d36ce37406063..8be7aac7bebf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -127,7 +127,6 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" :: "org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" :: "org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" :: - "org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" :: "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 90ac4f351ff4e..66b1883b91d5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2162,8 +2162,9 @@ class DataFrameAggregateSuite extends QueryTest ) } - private def assertAggregateOnDataframe(df: DataFrame, - expected: Int, aggregateColumn: String): Unit = { + private def assertAggregateOnDataframe( + df: => DataFrame, + expected: Int): Unit = { val configurations = Seq( Seq.empty[(String, String)], // hash aggregate is used by default Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", @@ -2175,32 +2176,64 @@ class DataFrameAggregateSuite extends QueryTest Seq("spark.sql.test.forceApplySortAggregate" -> "true") ) - for (conf <- configurations) { - withSQLConf(conf: _*) { - assert(createAggregate(df).count() == expected) + // Make tests faster + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") { + for (conf <- configurations) { + withSQLConf(conf: _*) { + assert(df.count() == expected, df.queryExecution.simpleString) + } } } - - def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*")) } test("SPARK-47430 Support GROUP BY MapType") { - val numRows = 50 - - val dfSameInt = (0 until numRows) - .map(_ => Tuple1(Map(1 -> 1))) - .toDF("m0") - assertAggregateOnDataframe(dfSameInt, 1, "m0") - - val dfSameFloat = (0 until numRows) - .map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 ))) - .toDF("m0") - assertAggregateOnDataframe(dfSameFloat, 1, "m0") - - val dfDifferent = (0 until numRows) - .map(i => Tuple1(Map(i -> i))) - .toDF("m0") - assertAggregateOnDataframe(dfDifferent, numRows, "m0") + def genMapData(dataType: String): String = { + s""" + |case when id % 4 == 0 then map() + |when id % 4 == 1 then map(cast(0 as $dataType), cast(0 as $dataType)) + |when id % 4 == 2 then map(cast(0 as $dataType), cast(0 as $dataType), + | cast(1 as $dataType), cast(1 as $dataType)) + |else map(cast(1 as $dataType), cast(1 as $dataType), + | cast(0 as $dataType), cast(0 as $dataType)) + |end + |""".stripMargin + } + Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt => + withTempView("v") { + spark.range(20) + .selectExpr( + s"cast(1 as $dt) as c1", + s"${genMapData(dt)} as c2", + "map(c1, null) as c3", + s"cast(null as map<$dt, $dt>) as c4") + .createOrReplaceTempView("v") + + assertAggregateOnDataframe( + spark.sql("SELECT count(*) FROM v GROUP BY c2"), + 3) + assertAggregateOnDataframe( + spark.sql("SELECT c2, count(*) FROM v GROUP BY c2"), + 3) + assertAggregateOnDataframe( + spark.sql("SELECT c1, c2, count(*) FROM v GROUP BY c1, c2"), + 3) + assertAggregateOnDataframe( + spark.sql("SELECT map(c1, c1) FROM v GROUP BY map(c1, c1)"), + 1) + assertAggregateOnDataframe( + spark.sql("SELECT map(c1, c1), count(*) FROM v GROUP BY map(c1, c1)"), + 1) + assertAggregateOnDataframe( + spark.sql("SELECT c3, count(*) FROM v GROUP BY c3"), + 1) + assertAggregateOnDataframe( + spark.sql("SELECT c4, count(*) FROM v GROUP BY c4"), + 1) + assertAggregateOnDataframe( + spark.sql("SELECT c1, c2, c3, c4, count(*) FROM v GROUP BY c1, c2, c3, c4"), + 3) + } + } } test("SPARK-46536 Support GROUP BY CalendarIntervalType") { @@ -2209,12 +2242,16 @@ class DataFrameAggregateSuite extends QueryTest val dfSame = (0 until numRows) .map(_ => Tuple1(new CalendarInterval(1, 2, 3))) .toDF("c0") - assertAggregateOnDataframe(dfSame, 1, "c0") + .groupBy($"c0") + .count() + assertAggregateOnDataframe(dfSame, 1) val dfDifferent = (0 until numRows) .map(i => Tuple1(new CalendarInterval(i, i, i))) .toDF("c0") - assertAggregateOnDataframe(dfDifferent, numRows, "c0") + .groupBy($"c0") + .count() + assertAggregateOnDataframe(dfDifferent, numRows) } test("SPARK-46779: Group by subquery with a cached relation") {