From 933273b0f9832cb7ddec21def3fa2a7e52eefa62 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 31 Jan 2019 16:20:18 +0800 Subject: [PATCH] [SPARK-26448][SQL][FOLLOWUP] should not normalize grouping expressions for final aggregate ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/23388 . `AggUtils.createAggregate` is not the right place to normalize the grouping expressions, as final aggregate is also created by it. The grouping expressions of final aggregate should be attributes which refer to the grouping expressions in partial aggregate. This PR moves the normalization to the caller side of `AggUtils`. ## How was this patch tested? existing tests Closes #23692 from cloud-fan/follow. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../optimizer/NormalizeFloatingNumbers.scala | 16 ++++++------ .../spark/sql/execution/SparkStrategies.scala | 25 ++++++++++++++++--- .../sql/execution/aggregate/AggUtils.scala | 14 +++-------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 520f24aa22e4c..a5921ebe7751a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -98,8 +98,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { } private[sql] def normalize(expr: Expression): Expression = expr match { - case _ if expr.dataType == FloatType || expr.dataType == DoubleType => - NormalizeNaNAndZero(expr) + case _ if !needNormalize(expr.dataType) => expr + + case a: Alias => + a.withNewChildren(Seq(normalize(a.child))) case CreateNamedStruct(children) => CreateNamedStruct(children.map(normalize)) @@ -113,22 +115,22 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateMap(children) => CreateMap(children.map(normalize)) - case a: Alias if needNormalize(a.dataType) => - a.withNewChildren(Seq(normalize(a.child))) + case _ if expr.dataType == FloatType || expr.dataType == DoubleType => + NormalizeNaNAndZero(expr) - case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) => + case _ if expr.dataType.isInstanceOf[StructType] => val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => normalize(GetStructField(expr, i)) } CreateStruct(fields) - case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) => + case _ if expr.dataType.isInstanceOf[ArrayType] => val ArrayType(et, containsNull) = expr.dataType val lv = NamedLambdaVariable("arg", et, containsNull) val function = normalize(lv) ArrayTransform(expr, LambdaFunction(function, Seq(lv))) - case _ => expr + case _ => throw new IllegalStateException(s"fail to normalize $expr") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b7cc373b2df12..edfa70403ad15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -331,8 +332,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because + // `groupingExpressions` is not extracted during logical phase. + val normalizedGroupingExpressions = namedGroupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } + aggregate.AggUtils.planStreamingAggregation( - namedGroupingExpressions, + normalizedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, stateVersion, @@ -414,16 +424,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Spark user mailing list.") } + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because + // `groupingExpressions` is not extracted during logical phase. + val normalizedGroupingExpressions = groupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } + val aggregateOperator = if (functionsWithDistinct.isEmpty) { aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, + normalizedGroupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { aggregate.AggUtils.planAggregateWithOneDistinct( - groupingExpressions, + normalizedGroupingExpressions, functionsWithDistinct, functionsWithoutDistinct, resultExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 8b7556b0c6c5a..4d762c5ea9f34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -35,20 +35,12 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { - // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because - // `groupingExpressions` is not extracted during logical phase. - val normalizedGroupingExpressions = groupingExpressions.map { e => - NormalizeFloatingNumbers.normalize(e) match { - case n: NamedExpression => n - case other => Alias(other, e.name)(exprId = e.exprId) - } - } val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = normalizedGroupingExpressions, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -61,7 +53,7 @@ object AggUtils { if (objectHashEnabled && useObjectHash) { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = normalizedGroupingExpressions, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -70,7 +62,7 @@ object AggUtils { } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = normalizedGroupingExpressions, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset,