From 5cbaaac088346077bcc27229073a5fd3de8eb3a3 Mon Sep 17 00:00:00 2001 From: guowei2 Date: Tue, 19 Aug 2014 10:51:29 +0800 Subject: [PATCH] merge PR 1822 --- .../sql/catalyst/expressions/aggregates.scala | 56 +++++- .../expressions/AggregatesSuite.scala | 86 +++++++++ .../scala/org/apache/spark/sql/SQLConf.scala | 8 + .../spark/sql/execution/Aggregate.scala | 174 +++++++++++++----- .../spark/sql/execution/SparkStrategies.scala | 35 +++- .../org/apache/spark/sql/hive/hiveUdfs.scala | 6 + 6 files changed, 312 insertions(+), 53 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AggregatesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 01947273b6ccc..50e6b40c7db39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -82,6 +82,7 @@ abstract class AggregateFunction override def dataType = base.dataType def update(input: Row): Unit + def merge(input: AggregateFunction): Unit // Do we really need this? override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) @@ -109,11 +110,19 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMin == null) { currentMin = expr.eval(input) - } else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) { + } else if (GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) { currentMin = expr.eval(input) } } + override def merge(input: AggregateFunction): Unit = { + if (currentMin == null) { + currentMin = input.eval(EmptyRow) + } else if (GreaterThan(this, input).eval(EmptyRow) == true) { + currentMin = input.eval(EmptyRow) + } + } + override def eval(input: Row): Any = currentMin } @@ -139,11 +148,19 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMax == null) { currentMax = expr.eval(input) - } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) { + } else if (LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) { currentMax = expr.eval(input) } } + override def merge(input: AggregateFunction): Unit = { + if (currentMax == null) { + currentMax = input.eval(EmptyRow) + } else if (LessThan(this, input).eval(EmptyRow) == true) { + currentMax = input.eval(EmptyRow) + } + } + override def eval(input: Row): Any = currentMax } @@ -292,6 +309,11 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) sum.update(addFunction(evaluatedExpr), input) } } + + override def merge(input: AggregateFunction): Unit = { + count += input.asInstanceOf[AverageFunction].count + sum.update(Add(sum, input.asInstanceOf[AverageFunction].sum),EmptyRow) + } } case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -306,6 +328,10 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag } } + override def merge(input: AggregateFunction): Unit = { + count +=input.eval(EmptyRow).asInstanceOf[Long] + } + override def eval(input: Row): Any = count } @@ -325,6 +351,10 @@ case class ApproxCountDistinctPartitionFunction( } } + override def merge(input: AggregateFunction): Unit = { + hyperLogLog.addAll(input.eval(EmptyRow).asInstanceOf[HyperLogLog]) + } + override def eval(input: Row): Any = hyperLogLog } @@ -342,6 +372,10 @@ case class ApproxCountDistinctMergeFunction( hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) } + override def merge(input: AggregateFunction): Unit = { + hyperLogLog.addAll(input.asInstanceOf[ApproxCountDistinctMergeFunction].hyperLogLog) + } + override def eval(input: Row): Any = hyperLogLog.cardinality() } @@ -358,6 +392,10 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr sum.update(addFunction, input) } + override def merge(input: AggregateFunction): Unit = { + sum.update(Add(this, input),EmptyRow) + } + override def eval(input: Row): Any = sum.eval(null) } @@ -375,6 +413,10 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } + override def merge(input: AggregateFunction): Unit = { + seen ++= input.asInstanceOf[SumDistinctFunction].seen + } + override def eval(input: Row): Any = seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) } @@ -393,6 +435,10 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio } } + override def merge(input: AggregateFunction): Unit = { + seen ++= input.asInstanceOf[CountDistinctFunction].seen + } + override def eval(input: Row): Any = seen.size.toLong } @@ -407,5 +453,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag } } + override def merge(input: AggregateFunction): Unit = { + if (result == null) { + result = input.eval(EmptyRow) + } + } + override def eval(input: Row): Any = result } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AggregatesSuite.scala new file mode 100644 index 0000000000000..f6248c0d92a4d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AggregatesSuite.scala @@ -0,0 +1,86 @@ +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + + +class AggregatesSuite extends FunSuite { + + val testRows = Seq(1, 1, 2, 2, 3, 3, 4, 4).map(x => { + val row = new GenericMutableRow(1) + row(0) = x + row + }) + + val dataType: DataType = IntegerType + val exp = BoundReference(0, dataType, true) + + /** + * ensure whether all merge functions in aggregates have correct output + * according to the output between update and merge + */ + def checkMethod(f: AggregateExpression) = { + val combiner = f.newInstance() + val combiner1 = f.newInstance() + val combiner2 = f.newInstance() + + //merge each row of testRow twice into combiner + testRows.map(combiner.update(_)) + testRows.map(combiner.update(_)) + + //merge each row of testRow into combiner1 + testRows.map(combiner1.update(_)) + + //merge each row of testRow into combiner2 + testRows.map(combiner2.update(_)) + + //merge combiner1 and combiner2 into combiner1 + combiner1.merge(combiner2) + + val r1 = combiner.eval(EmptyRow) + val r2 = combiner1.eval(EmptyRow) + + //check the output between the up two ways + assert(r1 == r2, "test suite failed") + } + + test("max merge test") { + checkMethod(max(exp)) + } + + test("min merge test") { + checkMethod(min(exp)) + } + + test("sum merge test") { + checkMethod(sum(exp)) + } + + test("count merge test") { + checkMethod(count(exp)) + } + + test("avg merge test") { + checkMethod(avg(exp)) + } + + test("sum distinct merge test") { + checkMethod(sumDistinct(exp)) + } + + test("count distinct merge test") { + checkMethod(countDistinct(exp)) + } + + test("first merge test") { + checkMethod(first(exp)) + } + + //this test case seems wrong + //it does not check ApproxCountDistinctPartitionFunction and ApproxCountDistinctMergeFunction + test("approx count distinct merge test") { + checkMethod(approxCountDistinct(exp)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4f2adb006fbc7..e35f4fd455e04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -29,6 +29,7 @@ private[spark] object SQLConf { val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val EXTERNAL_AGGREGATE = "spark.sql.aggregate.external" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" @@ -83,6 +84,13 @@ trait SQLConf { /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt + /** + * When set to true, Spark SQL will use ExternalAggregation. + * Defaults to false will use OnHeapAggregation + */ + private[spark] def externalAggregate: Boolean = + if (getConf(EXTERNAL_AGGREGATE, "false") == "true") true else false + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 463a1d32d7fd7..c9debbbce3fca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -25,25 +25,30 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.SQLContext +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} /** - * :: DeveloperApi :: * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each * group. * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. + * - If true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * - Expressions that are evaluated to determine grouping. + * - Expressions that are computed for each group. + * - The input data source. */ -@DeveloperApi -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { +trait Aggregate{ + + self: SparkPlan => + + /** If true then aggregation is done partially on local data without shuffling to */ + val partial: Boolean + /** Expressions that are evaluated to determine grouping */ + val groupingExpressions: Seq[Expression] + /** Expressions that are computed for each group */ + val aggregateExpressions: Seq[NamedExpression] + /** The input data source */ + val child: SparkPlan override def requiredChildDistribution = if (partial) { @@ -58,7 +63,7 @@ case class Aggregate( // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. - private[this] val childOutput = child.output + protected[this] val childOutput = child.output override def output = aggregateExpressions.map(_.toAttribute) @@ -76,7 +81,7 @@ case class Aggregate( resultAttribute: AttributeReference) /** A list of aggregates that need to be computed for each group. */ - private[this] val computedAggregates = aggregateExpressions.flatMap { agg => + protected[this] val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { case a: AggregateExpression => ComputedAggregate( @@ -87,21 +92,21 @@ case class Aggregate( }.toArray /** The schema of the result of all aggregate evaluations */ - private[this] val computedSchema = computedAggregates.map(_.resultAttribute) + protected[this] val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) + protected[this] def newAggregateBuffer(): CompactBuffer[AggregateFunction] = { + val buffer = new CompactBuffer[AggregateFunction] var i = 0 while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() + buffer += computedAggregates(i).aggregate.newInstance() i += 1 } buffer } /** Named attributes used to substitute grouping attributes into the final result. */ - private[this] val namedGroups = groupingExpressions.map { + protected[this] val namedGroups = groupingExpressions.map { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute } @@ -117,40 +122,52 @@ case class Aggregate( * Substituted version of aggregateExpressions expressions which are used to compute final * output rows given a group and the result of all aggregate computations. */ - private[this] val resultExpressions = aggregateExpressions.map { agg => + protected[this] val resultExpressions = aggregateExpressions.map { agg => agg.transform { case e: Expression if resultMap.contains(e) => resultMap(e) } } - override def execute() = attachTree(this, "execute") { - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - + protected[this] def aggregateNoGrouping() = { + child.execute().mapPartitions { iter => + val buffer = newAggregateBuffer() + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() var i = 0 while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) + buffer(i).update(currentRow) i += 1 } + } + val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) + val aggregateResults = new GenericMutableRow(computedAggregates.length) - Iterator(resultProjection(aggregateResults)) + var i = 0 + while (i < buffer.length) { + aggregateResults(i) = buffer(i).eval(EmptyRow) + i += 1 } + + Iterator(resultProjection(aggregateResults)) + } + } +} + +case class OnHeapAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with Aggregate{ + + override def execute() = attachTree(this, "execute") { + if (groupingExpressions.isEmpty) { + aggregateNoGrouping() } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + val hashTable = new HashMap[Row, CompactBuffer[AggregateFunction]] + val groupingProjection = + new InterpretedMutableProjection(groupingExpressions, childOutput) var currentRow: Row = null while (iter.hasNext) { @@ -184,6 +201,81 @@ case class Aggregate( val currentGroup = currentEntry.getKey val currentBuffer = currentEntry.getValue + var i = 0 + while (i < currentBuffer.length) { + // Evaluating an aggregate buffer returns the result. No row is required since we + // already added all rows in the group using update. + aggregateResults(i) = currentBuffer(i).eval(EmptyRow) + i += 1 + } + resultProjection(joinedRow(aggregateResults, currentGroup)) + } + } + } + } + } + +} + +case class ExternalAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with Aggregate{ + + override def execute() = attachTree(this, "execute") { + if (groupingExpressions.isEmpty) { + aggregateNoGrouping() + } else { + child.execute().mapPartitions { iter => + val groupingProjection = + new InterpretedMutableProjection(groupingExpressions, childOutput) + + val createCombiner = (v: Row) =>{ + val c = newAggregateBuffer() + var i = 0 + while (i < c.length) { + c(i).update(v) + i += 1 + } + c + } + val mergeValue = (c: CompactBuffer[AggregateFunction], v: Row) => { + var i = 0 + while (i < c.length) { + c(i).update(v) + i += 1 + } + c + } + val mergeCombiners = (c1: CompactBuffer[AggregateFunction], c2: CompactBuffer[AggregateFunction]) => { + var i = 0 + while (i < c1.length) { + c1(i).merge(c2(i)) + i += 1 + } + c1 + } + val combiners = new ExternalAppendOnlyMap[Row, Row, CompactBuffer[AggregateFunction]]( + createCombiner, mergeValue, mergeCombiners) + while (iter.hasNext) { + val row = iter.next() + combiners.insert(groupingProjection(row).copy(), row) + } + new Iterator[Row] { + private[this] val externalIter = combiners.iterator + private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) + private[this] val resultProjection = + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) + private[this] val joinedRow = new JoinedRow + + override final def hasNext: Boolean = externalIter.hasNext + override final def next(): Row = { + val currentEntry = externalIter.next() + val currentGroup = currentEntry._1 + val currentBuffer = currentEntry._2 + var i = 0 while (i < currentBuffer.length) { // Evaluating an aggregate buffer returns the result. No row is required since we diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f0c958fdb537f..44e5f6db8353e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -134,15 +134,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingExpressions, partialComputation, child) => - execution.Aggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.Aggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil + + val preAggregate = execution.OnHeapAggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child)) + + if (self.sqlContext.externalAggregate) { + execution.ExternalAggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + preAggregate) :: Nil + } else { + execution.OnHeapAggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + preAggregate) :: Nil + } case _ => Nil } @@ -265,7 +276,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + if (self.sqlContext.externalAggregate) { + execution.ExternalAggregate(partial = false, group, agg, planLater(child)) :: Nil + } else { + execution.OnHeapAggregate(partial = false, group, agg, planLater(child)) :: Nil + } case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index c6497a15efa0c..9ffcca8e4e58b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -347,4 +347,10 @@ private[hive] case class HiveUdafFunction( val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray function.iterate(buffer, inputs) } + + //hiveUdaf does not support external aggregate, for HiveUdafFunction need to spill to disk, + //and all the vals above need Serializable + override def merge(input: AggregateFunction): Unit = { + throw new NotImplementedError(s"HiveUdaf does not support external aggregate") + } }