Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2873] [SQL] using ExternalAppendOnlyMap to resolve OOM when aggregating #2029

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ abstract class AggregateFunction
override def dataType = base.dataType

def update(input: Row): Unit
def merge(input: AggregateFunction): Unit

// Do we really need this?
override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
Expand Down Expand Up @@ -109,11 +110,19 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def update(input: Row): Unit = {
if (currentMin == null) {
currentMin = expr.eval(input)
} else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
} else if (GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
currentMin = expr.eval(input)
}
}

override def merge(input: AggregateFunction): Unit = {
if (currentMin == null) {
currentMin = input.eval(EmptyRow)
} else if (GreaterThan(this, input).eval(EmptyRow) == true) {
currentMin = input.eval(EmptyRow)
}
}

override def eval(input: Row): Any = currentMin
}

Expand All @@ -139,11 +148,19 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def update(input: Row): Unit = {
if (currentMax == null) {
currentMax = expr.eval(input)
} else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
} else if (LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
currentMax = expr.eval(input)
}
}

override def merge(input: AggregateFunction): Unit = {
if (currentMax == null) {
currentMax = input.eval(EmptyRow)
} else if (LessThan(this, input).eval(EmptyRow) == true) {
currentMax = input.eval(EmptyRow)
}
}

override def eval(input: Row): Any = currentMax
}

Expand Down Expand Up @@ -292,6 +309,11 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
sum.update(addFunction(evaluatedExpr), input)
}
}

override def merge(input: AggregateFunction): Unit = {
count += input.asInstanceOf[AverageFunction].count
sum.update(Add(sum, input.asInstanceOf[AverageFunction].sum),EmptyRow)
}
}

case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
Expand All @@ -306,6 +328,10 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
}
}

override def merge(input: AggregateFunction): Unit = {
count +=input.eval(EmptyRow).asInstanceOf[Long]
}

override def eval(input: Row): Any = count
}

Expand All @@ -325,6 +351,10 @@ case class ApproxCountDistinctPartitionFunction(
}
}

override def merge(input: AggregateFunction): Unit = {
hyperLogLog.addAll(input.eval(EmptyRow).asInstanceOf[HyperLogLog])
}

override def eval(input: Row): Any = hyperLogLog
}

Expand All @@ -342,6 +372,10 @@ case class ApproxCountDistinctMergeFunction(
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
}

override def merge(input: AggregateFunction): Unit = {
hyperLogLog.addAll(input.asInstanceOf[ApproxCountDistinctMergeFunction].hyperLogLog)
}

override def eval(input: Row): Any = hyperLogLog.cardinality()
}

Expand All @@ -358,6 +392,10 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
sum.update(addFunction, input)
}

override def merge(input: AggregateFunction): Unit = {
sum.update(Add(this, input),EmptyRow)
}

override def eval(input: Row): Any = sum.eval(null)
}

Expand All @@ -375,6 +413,10 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}

override def merge(input: AggregateFunction): Unit = {
seen ++= input.asInstanceOf[SumDistinctFunction].seen
}

override def eval(input: Row): Any =
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
}
Expand All @@ -393,6 +435,10 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio
}
}

override def merge(input: AggregateFunction): Unit = {
seen ++= input.asInstanceOf[CountDistinctFunction].seen
}

override def eval(input: Row): Any = seen.size.toLong
}

Expand All @@ -407,5 +453,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
}
}

override def merge(input: AggregateFunction): Unit = {
if (result == null) {
result = input.eval(EmptyRow)
}
}

override def eval(input: Row): Any = result
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.dsl.expressions._


class AggregatesSuite extends FunSuite {

val testRows = Seq(1, 1, 2, 2, 3, 3, 4, 4).map(x => {
val row = new GenericMutableRow(1)
row(0) = x
row
})

val dataType: DataType = IntegerType
val exp = BoundReference(0, dataType, true)

/**
* ensure whether all merge functions in aggregates have correct output
* according to the output between update and merge
*/
def checkMethod(f: AggregateExpression) = {
val combiner = f.newInstance()
val combiner1 = f.newInstance()
val combiner2 = f.newInstance()

//merge each row of testRow twice into combiner
testRows.map(combiner.update(_))
testRows.map(combiner.update(_))

//merge each row of testRow into combiner1
testRows.map(combiner1.update(_))

//merge each row of testRow into combiner2
testRows.map(combiner2.update(_))

//merge combiner1 and combiner2 into combiner1
combiner1.merge(combiner2)

val r1 = combiner.eval(EmptyRow)
val r2 = combiner1.eval(EmptyRow)

//check the output between the up two ways
assert(r1 == r2, "test suite failed")
}

test("max merge test") {
checkMethod(max(exp))
}

test("min merge test") {
checkMethod(min(exp))
}

test("sum merge test") {
checkMethod(sum(exp))
}

test("count merge test") {
checkMethod(count(exp))
}

test("avg merge test") {
checkMethod(avg(exp))
}

test("sum distinct merge test") {
checkMethod(sumDistinct(exp))
}

test("count distinct merge test") {
checkMethod(countDistinct(exp))
}

test("first merge test") {
checkMethod(first(exp))
}

//this test case seems wrong
//it does not check ApproxCountDistinctPartitionFunction and ApproxCountDistinctMergeFunction
test("approx count distinct merge test") {
checkMethod(approxCountDistinct(exp))
}

}
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ private[spark] object SQLConf {
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val EXTERNAL_AGGREGATE = "spark.sql.aggregate.external"
val CODEGEN_ENABLED = "spark.sql.codegen"
val DIALECT = "spark.sql.dialect"
val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
Expand Down Expand Up @@ -83,6 +84,13 @@ trait SQLConf {
/** Number of partitions to use for shuffle operators. */
private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt

/**
* When set to true, Spark SQL will use ExternalAggregation.
* Defaults to false will use OnHeapAggregation
*/
private[spark] def externalAggregate: Boolean =
if (getConf(EXTERNAL_AGGREGATE, "false") == "true") true else false

/**
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
Expand Down
Loading