From db136d360e54e13f1d7071a0428964a202cf7e31 Mon Sep 17 00:00:00 2001 From: Simeon Simeonov Date: Tue, 20 Nov 2018 21:29:56 +0100 Subject: [PATCH] [SPARK-26084][SQL] Fixes unresolved AggregateExpression.references exception ## What changes were proposed in this pull request? This PR fixes an exception in `AggregateExpression.references` called on unresolved expressions. It implements the solution proposed in [SPARK-26084](https://issues.apache.org/jira/browse/SPARK-26084), a minor refactoring that removes the unnecessary dependence on `AttributeSet.toSeq`, which requires expression IDs and, therefore, can only execute successfully for resolved expressions. The refactored implementation is both simpler and faster, eliminating the conversion of a `Set` to a `Seq` and back to `Set`. ## How was this patch tested? Added a new test based on the failing case in [SPARK-26084](https://issues.apache.org/jira/browse/SPARK-26084). hvanhovell Closes #23075 from ssimeonov/ss_SPARK-26084. Authored-by: Simeon Simeonov Signed-off-by: Herman van Hovell --- .../expressions/aggregate/interfaces.scala | 8 ++--- .../aggregate/AggregateExpressionSuite.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e1d16a2cd38b0..56c2ee6b53fe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -128,12 +128,10 @@ case class AggregateExpression( override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { - val childReferences = mode match { - case Partial | Complete => aggregateFunction.references.toSeq - case PartialMerge | Final => aggregateFunction.aggBufferAttributes + mode match { + case Partial | Complete => aggregateFunction.references + case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) } - - AttributeSet(childReferences) } override def toString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala new file mode 100644 index 0000000000000..8e9c9972071ad --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeSet} + +class AggregateExpressionSuite extends SparkFunSuite { + + test("test references from unresolved aggregate functions") { + val x = UnresolvedAttribute("x") + val y = UnresolvedAttribute("y") + val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete, isDistinct = false).references + val expected = AttributeSet(x :: y :: Nil) + assert(expected == actual, s"Expected: $expected. Actual: $actual") + } + +}