diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 44cdff10aca45..c699e92cf0190 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule @@ -34,7 +35,7 @@ import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ -object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { @@ -95,22 +96,27 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => + if filters.isEmpty && CollapseProject.canCollapseExpressions( + resultExpressions, project, alwaysInline = true) => sHolder.builder match { case r: SupportsPushDownAggregates => + val aliasMap = getAliasMap(project) + val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap)) + val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap)) + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - groupingExpressions, sHolder.relation.output) + actualGroupExprs, sHolder.relation.output) val translatedAggregates = DataSourceStrategy.translateAggregation( normalizedAggregates, normalizedGroupingExpressions) val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { if (translatedAggregates.isEmpty || r.supportCompletePushDown(translatedAggregates.get) || translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { - (resultExpressions, aggregates, translatedAggregates) + (actualResultExprs, aggregates, translatedAggregates) } else { // scalastyle:off // The data source doesn't support the complete push-down of this aggregation. @@ -127,7 +133,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] // +- ScanOperation[...] // scalastyle:on - val newResultExpressions = resultExpressions.map { expr => + val newResultExpressions = actualResultExprs.map { expr => expr.transform { case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) @@ -206,7 +212,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = resultExpressions.map { expr => + val projectExpressions = finalResultExpressions.map { expr => // TODO At present, only push down group by attribute is supported. // In future, more attribute conversion is extended here. e.g. GetStructField expr.transform { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index 47740c5274616..26dfe1a50971f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -184,7 +184,7 @@ trait FileSourceAggregatePushDownSuite } } - test("aggregate over alias not push down") { + test("aggregate over alias push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) withDataSourceTable(data, "t") { @@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: []" // aggregate alias not pushed down + "PushedAggregation: [MIN(_1)]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(-2))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index d6f098f1d5189..31fdb022b625f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -974,15 +974,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } - test("scan with aggregate push-down: aggregate over alias NOT push down") { + test("scan with aggregate push-down: aggregate over alias push down") { val cols = Seq("a", "b", "c", "d", "e") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") - checkAggregateRemoved(df2, false) + checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.aggregation.isEmpty) + case relation: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.nonEmpty) } } checkAnswer(df2, Seq(Row(53000.00))) @@ -1228,4 +1232,76 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin) checkAnswer(df, Seq.empty[Row]) } + + test("scan with aggregate push-down: complete push-down aggregate with alias") { + val df = spark.table("h2.test.employee") + .select($"DEPT", $"SALARY".as("mySalary")) + .groupBy($"DEPT") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expectedPlanFragment) + } + checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + + val df2 = spark.table("h2.test.employee") + .select($"DEPT".as("myDept"), $"SALARY".as("mySalary")) + .groupBy($"myDept") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + } + checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + } + + test("scan with aggregate push-down: partial push-down aggregate with alias") { + val df = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME", $"SALARY".as("mySalary")) + .groupBy($"NAME") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df, expectedPlanFragment) + } + checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + + val df2 = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME".as("myName"), $"SALARY".as("mySalary")) + .groupBy($"myName") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2, false) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + } + checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + } }