Skip to content

Commit

Permalink
[SPARK-17187][SQL][BRANCH-2.0] Supports using arbitrary Java object a…
Browse files Browse the repository at this point in the history
…s internal aggregation buffer object

## What changes were proposed in this pull request?

(This PR cherry-picks PR apache#14753 to Databricks Spark branch-2.0.)

This PR introduces an abstract class `TypedImperativeAggregate` so that an aggregation function of TypedImperativeAggregate can use  **arbitrary** user-defined Java object as intermediate aggregation buffer object.

**This has advantages like:**
1. It now can support larger category of aggregation functions. For example, it will be much easier to implement aggregation function `percentile_approx`, which has a complex aggregation buffer definition.
2. It can be used to avoid doing serialization/de-serialization for every call of `update` or `merge` when converting domain specific aggregation object to internal Spark-Sql storage format.
3. It is easier to integrate with other existing monoid libraries like algebird, and supports more aggregation functions with high performance.

Please see `org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMaxAggregate` to find an example of how to defined a `TypedImperativeAggregate` aggregation function.
Please see Java doc of `TypedImperativeAggregate` and Jira ticket SPARK-17187 for more information.

## How was this patch tested?

Unit tests.

Author: Sean Zhong <seanzhongdatabricks.com>
Author: Yin Huai <yhuaidatabricks.com>

Closes apache#14753 from clockfly/object_aggregation_buffer_try_2.

Author: Sean Zhong <seanzhong@databricks.com>

Closes apache#71 from liancheng/typed-imperative-agg-db-2.0.
  • Loading branch information
clockfly authored and liancheng committed Sep 1, 2016
1 parent 2c0f5de commit f003e0c
Show file tree
Hide file tree
Showing 3 changed files with 456 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,144 @@ abstract class DeclarativeAggregate
def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
}

/**
* Aggregation function which allows **arbitrary** user-defined java object to be used as internal
* aggregation buffer object.
*
* {{{
* aggregation buffer for normal aggregation function `avg`
* |
* v
* +--------------+---------------+-----------------------------------+
* | sum1 (Long) | count1 (Long) | generic user-defined java objects |
* +--------------+---------------+-----------------------------------+
* ^
* |
* Aggregation buffer object for `TypedImperativeAggregate` aggregation function
* }}}
*
* Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side):
*
* Stage 1: Partial aggregate at Mapper side:
*
* 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
* buffer object.
* 2. Upon each input row, the framework calls
* `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
* 3. After processing all rows of current group (group by key), the framework will serialize
* aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
* to disk if needed.
* 4. The framework moves on to next group, until all groups have been processed.
*
* Shuffling exchange data to Reducer tasks...
*
* Stage 2: Final mode aggregate at Reducer side:
*
* 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
* buffer object (type T) for merging.
* 2. For each aggregation output of Stage 1, The framework de-serializes the storage
* format (Array[Byte]) and produces one input aggregation object (type T).
* 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
* to merge the input aggregation object into aggregation buffer object.
* 4. After processing all input aggregation objects of current group (group by key), the framework
* calls method `eval(buffer: T)` to generate the final output for this group.
* 5. The framework moves on to next group, until all groups have been processed.
*
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
* buffer's storage format, which is not supported by hash based aggregation. Hash based
* aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
* fixed length and can be mutated in place in UnsafeRow)
*/
abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {

/**
* Creates an empty aggregation buffer object. This is called before processing each key group
* (group by key).
*
* @return an aggregation buffer object
*/
def createAggregationBuffer(): T

/**
* In-place updates the aggregation buffer object with an input row. buffer = buffer + input.
* This is typically called when doing Partial or Complete mode aggregation.
*
* @param buffer The aggregation buffer object.
* @param input an input row
*/
def update(buffer: T, input: InternalRow): Unit

/**
* Merges an input aggregation object into aggregation buffer object. buffer = buffer + input.
* This is typically called when doing PartialMerge or Final mode aggregation.
*
* @param buffer the aggregation buffer object used to store the aggregation result.
* @param input an input aggregation object. Input aggregation object can be produced by
* de-serializing the partial aggregate's output from Mapper side.
*/
def merge(buffer: T, input: T): Unit

/**
* Generates the final aggregation result value for current key group with the aggregation buffer
* object.
*
* @param buffer aggregation buffer object.
* @return The aggregation result of current key group
*/
def eval(buffer: T): Any

/** Serializes the aggregation buffer object T to Array[Byte] */
def serialize(buffer: T): Array[Byte]

/** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
def deserialize(storageFormat: Array[Byte]): T

final override def initialize(buffer: MutableRow): Unit = {
val bufferObject = createAggregationBuffer()
buffer.update(mutableAggBufferOffset, bufferObject)
}

final override def update(buffer: MutableRow, input: InternalRow): Unit = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
update(bufferObject, input)
}

final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
// The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
merge(bufferObject, inputObject)
}

final override def eval(buffer: InternalRow): Any = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
eval(bufferObject)
}

private[this] val anyObjectType = ObjectType(classOf[AnyRef])
private def getField[U](input: InternalRow, fieldIndex: Int): U = {
input.get(fieldIndex, anyObjectType).asInstanceOf[U]
}

final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
// Underlying storage type for the aggregation buffer object
Seq(AttributeReference("buf", BinaryType)())
}

final override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

/**
* In-place replaces the aggregation buffer object stored at buffer's index
* `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format
* (BinaryType).
*/
final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
val bufferObject = getField[T](buffer, mutableAggBufferOffset)
buffer(mutableAggBufferOffset) = serialize(bufferObject)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,22 @@ abstract class AggregationIterator(
val resultProjection = UnsafeProjection.create(
groupingAttributes ++ bufferAttributes,
groupingAttributes ++ bufferAttributes)

// TypedImperativeAggregate stores generic object in aggregation buffer, and requires
// calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info.
val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = {
aggregateFunctions.collect {
case (ag: TypedImperativeAggregate[_]) => ag
}
}

(currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
// Serializes the generic object stored in aggregation buffer
var i = 0
while (i < typedImperativeAggregates.length) {
typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer)
i += 1
}
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
} else {
Expand Down
Loading

0 comments on commit f003e0c

Please sign in to comment.