diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 1f4ff9c4b184e..c6f6699ae941f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * The Collect aggregate function collects all seen expression values into a list of values. @@ -47,28 +48,56 @@ abstract class Collect extends ImperativeAggregate { override def supportsPartial: Boolean = false - override def aggBufferAttributes: Seq[AttributeReference] = Nil + override def aggBufferAttributes: Seq[AttributeReference] = + AttributeReference(s"groupIndex", IntegerType)() :: Nil override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - override def inputAggBufferAttributes: Seq[AttributeReference] = Nil + override def inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) - protected[this] val buffer: Growable[Any] with Iterable[Any] + private[this] lazy val mapToBuffer = + new mutable.HashMap[Integer, Growable[Any] with Iterable[Any]] - override def initialize(b: MutableRow): Unit = { - buffer.clear() + protected[this] def createBuffer: Growable[Any] with Iterable[Any] + + override def initialize(buffer: MutableRow): Unit = { + buffer.setInt(mutableAggBufferOffset, -1) } - override def update(b: MutableRow, input: InternalRow): Unit = { - buffer += child.eval(input) + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val groupIndex = buffer.getInt(mutableAggBufferOffset) match { + case -1 => + // If not found, create a buffer for a new group + val newIndex = mapToBuffer.size + buffer.setInt(mutableAggBufferOffset, newIndex) + mapToBuffer.put(newIndex, createBuffer) + newIndex + + case index => + index + } + val data = child.eval(input) match { + case struct: UnsafeRow => struct.copy + case array: UnsafeArrayData => array.copy + case map: UnsafeMapData => map.copy + case str: UTF8String => str.clone + case d => d + } + mapToBuffer.get(groupIndex).map(_ += data).getOrElse { + sys.error(s"A group index ${groupIndex} not found in HashMap") + } } override def merge(buffer: MutableRow, input: InternalRow): Unit = { sys.error("Collect cannot be used in partial aggregations.") } - override def eval(input: InternalRow): Any = { - new GenericArrayData(buffer.toArray) + override def eval(buffer: InternalRow): Any = { + val groupIndex = buffer.getInt(mutableAggBufferOffset) + new GenericArrayData(mapToBuffer.get(groupIndex).map(_.toArray).getOrElse { + sys.error(s"A group index ${groupIndex} not found in HashMap") + }) } } @@ -92,7 +121,7 @@ case class CollectList( override def prettyName: String = "collect_list" - override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + override protected[this] def createBuffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty } /** @@ -115,5 +144,5 @@ case class CollectSet( override def prettyName: String = "collect_set" - override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty + override protected[this] def createBuffer: mutable.HashSet[Any] = mutable.HashSet.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c9bf..f967d44fb090c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -35,7 +35,7 @@ object AggUtils { val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - SortAggregateExec( + createAggregate( requiredChildDistributionExpressions = Some(groupingExpressions), groupingExpressions = groupingExpressions, aggregateExpressions = completeAggregateExpressions, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 69a990789bcfd..c95efae7906a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -457,6 +457,36 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("collect functions array") { + val df = Seq((1, 3, 3, 3), (2, 3, 3, 3), (3, 4, 1, 2)) + .toDF("a", "x", "y", "z") + .select($"a", array($"x", $"y", $"z").as("b")) + checkAnswer( + df.select(collect_list($"a"), sort_array(collect_list($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Seq(3, 3, 3), Seq(3, 3, 3), Seq(4, 1, 2)))) + ) + checkAnswer( + df.select(collect_set($"a"), sort_array(collect_set($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Seq(3, 3, 3), Seq(4, 1, 2)))) + ) + } + + test("collect functions map") { + val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1)) + .toDF("a", "x", "y") + .select($"a", map($"x", $"y").as("b")) + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(Map(3 -> 0), Map(3 -> 0), Map(4 -> 1)))) + ) + // TODO: We need to implement `UnsafeMapData#hashCode` and `UnsafeMapData#equals` for getting + // a set of input data. + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(Map(3 -> 0), Map(3 -> 0), Map(4 -> 1)))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c8a0f7134d5dd..98ce4c972dc79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1339,6 +1339,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("non-partial collect_set/collect_list aggregate test") { + checkAnswer( + testData4.groupBy("key").agg(collect_set($"value")), + Row(1, Array("1")) :: Row(2, Array("2")) :: Row(3, Array("3")) :: Nil + ) + checkAnswer( + testData4.groupBy("key").agg(collect_list($"value")), + Row(1, Array("1", "1")) :: Row(2, Array("2", "2")) :: Row(3, Array("3", "3")) :: Nil + ) + } + test("fix case sensitivity of partition by") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { withTempPath { path => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e52152..e762a3f54e883 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -72,6 +72,18 @@ private[sql] trait SQLTestData { self => df } + protected lazy val testData4: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData(1, "1") :: + TestData(1, "1") :: + TestData(2, "2") :: + TestData(2, "2") :: + TestData(3, "3") :: + TestData(3, "3") :: Nil, 2).toDF() + df.createOrReplaceTempView("testData4") + df + } + protected lazy val negativeData: DataFrame = { val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()