Skip to content

Commit

Permalink
[SPARK-26448][SQL][FOLLOWUP] should not normalize grouping expression…
Browse files Browse the repository at this point in the history
…s for final aggregate

## What changes were proposed in this pull request?

A followup of apache#23388 .

`AggUtils.createAggregate` is not the right place to normalize the grouping expressions, as final aggregate is also created by it. The grouping expressions of final aggregate should be attributes which refer to the grouping expressions in partial aggregate.

This PR moves the normalization to the caller side of `AggUtils`.

## How was this patch tested?

existing tests

Closes apache#23692 from cloud-fan/follow.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan authored and jackylee-ch committed Feb 18, 2019
1 parent ce754a5 commit 933273b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}

private[sql] def normalize(expr: Expression): Expression = expr match {
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
NormalizeNaNAndZero(expr)
case _ if !needNormalize(expr.dataType) => expr

case a: Alias =>
a.withNewChildren(Seq(normalize(a.child)))

case CreateNamedStruct(children) =>
CreateNamedStruct(children.map(normalize))
Expand All @@ -113,22 +115,22 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case CreateMap(children) =>
CreateMap(children.map(normalize))

case a: Alias if needNormalize(a.dataType) =>
a.withNewChildren(Seq(normalize(a.child)))
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
NormalizeNaNAndZero(expr)

case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) =>
case _ if expr.dataType.isInstanceOf[StructType] =>
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
normalize(GetStructField(expr, i))
}
CreateStruct(fields)

case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) =>
case _ if expr.dataType.isInstanceOf[ArrayType] =>
val ArrayType(et, containsNull) = expr.dataType
val lv = NamedLambdaVariable("arg", et, containsNull)
val function = normalize(lv)
ArrayTransform(expr, LambdaFunction(function, Seq(lv)))

case _ => expr
case _ => throw new IllegalStateException(s"fail to normalize $expr")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -331,8 +332,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)

// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase.
val normalizedGroupingExpressions = namedGroupingExpressions.map { e =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
case other => Alias(other, e.name)(exprId = e.exprId)
}
}

aggregate.AggUtils.planStreamingAggregation(
namedGroupingExpressions,
normalizedGroupingExpressions,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
Expand Down Expand Up @@ -414,16 +424,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"Spark user mailing list.")
}

// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase.
val normalizedGroupingExpressions = groupingExpressions.map { e =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
case other => Alias(other, e.name)(exprId = e.exprId)
}
}

val aggregateOperator =
if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
normalizedGroupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
normalizedGroupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
resultExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,12 @@ object AggUtils {
initialInputBufferOffset: Int = 0,
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase.
val normalizedGroupingExpressions = groupingExpressions.map { e =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
case other => Alias(other, e.name)(exprId = e.exprId)
}
}
val useHash = HashAggregateExec.supportsAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
if (useHash) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = normalizedGroupingExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
Expand All @@ -61,7 +53,7 @@ object AggUtils {
if (objectHashEnabled && useObjectHash) {
ObjectHashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = normalizedGroupingExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
Expand All @@ -70,7 +62,7 @@ object AggUtils {
} else {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = normalizedGroupingExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
Expand Down

0 comments on commit 933273b

Please sign in to comment.