Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13749][SQL] Faster pivot implementation for many distinct values with two phase aggregation #11583

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,38 +309,64 @@ class Analyzer(

object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved)
| !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
val singleAgg = aggregates.size == 1
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(expr: Expression) = {
If(EqualTo(pivotColumn, value), expr, Literal(null))
def outputName(value: Literal, aggregate: Expression): String = {
if (singleAgg) value.toString else value + "_" + aggregate.sql
}
if (pivotValues.length >= 10
&& aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it better if we remove pivotValues.length and just keep aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))? Also, which data types does the new code path not support?

// Since evaluating |pivotValues| if statements for each input row can get slow this is an
// alternate plan that instead uses two steps of aggregation.
val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
val namedPivotCol = pivotColumn match {
case n: NamedExpression => n
case _ => Alias(pivotColumn, "__pivot_col")()
}
val bigGroup = groupByExprs :+ namedPivotCol
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow))
val pivotAggs = namedAggExps.map { a =>
Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues)
.toAggregateExpression()
, "__pivot_" + a.sql)()
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
// Assumption is the aggregate function ignores nulls. This is true for all current
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg)
val pivotAggAttribute = pivotAggs.map(_.toAttribute)
val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This map is not needed anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, added a check for !p.groupByExprs.forall(_.resolved) to the guard case.

Project(groupByExprs ++ pivotOutputs, secondAgg)
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we will decide which branch to use based on the datatypes, do we still have enough test coverage for this else branch?

val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(expr: Expression) = {
If(EqualTo(pivotColumn, value), expr, Literal(null))
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
// Assumption is the aggregate function ignores nulls. This is true for all current
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
Alias(filteredAggregate, outputName(value, aggregate))()
}
val name = if (singleAgg) value.toString else value + "_" + aggregate.sql
Alias(filteredAggregate, name)()
}
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
}
val newGroupByExprs = groupByExprs.map {
case UnresolvedAlias(e, _) => e
case e => e
}
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import scala.collection.immutable.HashMap

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

object PivotFirst {

def supportsDataType(dataType: DataType): Boolean = {
try {
updateFunction(dataType)
true
} catch {
case _: UnsupportedOperationException => false
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is better to avoid of try/catch to determine if a data type is supported.


// Currently UnsafeRow does not support the generic update method (throws
// UnsupportedOperationException), so we need to explicitly support each DataType.
private def updateFunction(dataType: DataType): (MutableRow, Int, Any) => Unit = dataType match {
case DoubleType =>
(row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double])
case IntegerType =>
(row, offset, value) => row.setInt(offset, value.asInstanceOf[Int])
case LongType =>
(row, offset, value) => row.setLong(offset, value.asInstanceOf[Long])
case FloatType =>
(row, offset, value) => row.setFloat(offset, value.asInstanceOf[Float])
case BooleanType =>
(row, offset, value) => row.setBoolean(offset, value.asInstanceOf[Boolean])
case ShortType =>
(row, offset, value) => row.setShort(offset, value.asInstanceOf[Short])
case ByteType =>
(row, offset, value) => row.setByte(offset, value.asInstanceOf[Byte])
case d: DecimalType =>
(row, offset, value) => row.setDecimal(offset, value.asInstanceOf[Decimal], d.precision)
case _ => throw new UnsupportedOperationException(
s"Unsupported datatype ($dataType) used in PivotFirst, this is a bug."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there is any data type that works with existing pivot but will not work with this new version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, i see. If we have an unsupported data type, we will fall back to the previous code path.

)
}
}

case class PivotFirst(pivotColumn: Expression,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have scala doc to explain this function.

Also, for the format, we can use

case class PivotFirst(
  pivotColumn: Expression,
  valueColumn: Expression,
  ...)

valueColumn: Expression,
pivotColumnValues: Seq[Any],
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends ImperativeAggregate {

val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*)

val valueDataType = valueColumn.dataType

val indexSize = pivotIndex.size

private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType)

override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = {
val pivotColValue = pivotColumn.eval(inputRow)
if (pivotColValue != null) {
val index = pivotIndex.getOrElse(pivotColValue, -1)
if (index >= 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment to explain when index will be -1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, for two different inputRows, we should not get the same index, right?

val value = valueColumn.eval(inputRow)
if (value != null) {
updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value)
}
}
}
}

override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = {
for (i <- 0 until indexSize) {
if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) {
val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType)
updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value)
}
}
}

override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match {
case d: DecimalType =>
for (i <- 0 until indexSize) {
mutableAggBuffer.setDecimal(mutableAggBufferOffset + i, null, d.precision)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment to explain why we need a special care for DecimalType.

case _ =>
for (i <- 0 until indexSize) {
mutableAggBuffer.setNullAt(mutableAggBufferOffset + i)
}
}

override def eval(input: InternalRow): Any = {
val result = new Array[Any](indexSize)
for (i <- 0 until indexSize) {
result(i) = input.get(mutableAggBufferOffset + i, valueDataType)
}
new GenericArrayData(result)
}

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)


override lazy val aggBufferAttributes: Seq[AttributeReference] =
pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we avoid of using lazy val for aggBufferAttributes, aggBufferSchema, and inputAggBufferAttributes?


override lazy val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we use inputTypes to ask the analyzer to do type casting. So, if there is a value column that has an invalid data type, the analyzer will complain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by this, but no casting is needed.


override val nullable: Boolean = false

override val dataType: DataType = ArrayType(valueDataType)

override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I feel it will be better for readers if we can put inputTypes, nullable, dataType, and children at the beginning o the class body. )

}

Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,31 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._

class DataFramePivotSuite extends QueryTest with SharedSQLContext{
import testImplicits._

test("pivot courses with literals") {
test("pivot courses") {
checkAnswer(
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
}

test("pivot year with literals") {
test("pivot year") {
checkAnswer(
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot courses with literals and multiple aggregations") {
test("pivot courses with multiple aggregations") {
checkAnswer(
courseSales.groupBy($"year")
.pivot("course", Seq("dotNET", "Java"))
Expand Down Expand Up @@ -94,4 +96,88 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
}

// Tests for optimized pivot (with PivotFirst) below

test("optimized pivot planned") {
val df = courseSales.groupBy("year")
// pivot with extra columns to trigger optimization
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
.agg(sum($"earnings"))
val queryExecution = sqlContext.executePlan(df.queryExecution.logical)
assert(queryExecution.simpleString.contains("pivotfirst"))
}


test("optimized pivot courses with literals") {
checkAnswer(
courseSales.groupBy("year")
// pivot with extra columns to trigger optimization
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
.agg(sum($"earnings"))
.select("year", "dotNET", "Java"),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
}

test("optimized pivot year with literals") {
checkAnswer(
courseSales.groupBy($"course")
// pivot with extra columns to trigger optimization
.pivot("year", Seq(2012, 2013) ++ (1 to 10))
.agg(sum($"earnings"))
.select("course", "2012", "2013"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("optimized pivot year with string values (cast)") {
checkAnswer(
courseSales.groupBy("course")
// pivot with extra columns to trigger optimization
.pivot("year", Seq("2012", "2013") ++ (1 to 10).map(_.toString))
.sum("earnings")
.select("course", "2012", "2013"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("optimized pivot DecimalType") {
val df = courseSales.select($"course", $"year", $"earnings".cast(DecimalType(10, 2)))
.groupBy("year")
// pivot with extra columns to trigger optimization
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
.agg(sum($"earnings"))
.select("year", "dotNET", "Java")

assertResult(IntegerType)(df.schema("year").dataType)
assertResult(DecimalType(20, 2))(df.schema("Java").dataType)
assertResult(DecimalType(20, 2))(df.schema("dotNET").dataType)

checkAnswer(df, Row(2012, BigDecimal(1500000, 2), BigDecimal(2000000, 2)) ::
Row(2013, BigDecimal(4800000, 2), BigDecimal(3000000, 2)) :: Nil)
}

test("PivotFirst supported datatypes") {
val supportedDataTypes: Seq[DataType] = DoubleType :: IntegerType :: LongType :: FloatType ::
BooleanType :: ShortType :: ByteType :: Nil
for (datatype <- supportedDataTypes) {
assertResult(true)(PivotFirst.supportsDataType(datatype))
}
assertResult(true)(PivotFirst.supportsDataType(DecimalType(10, 1)))
assertResult(false)(PivotFirst.supportsDataType(null))
assertResult(false)(PivotFirst.supportsDataType(ArrayType(IntegerType)))
}

test("optimized pivot with multiple aggregations") {
checkAnswer(
courseSales.groupBy($"year")
// pivot with extra columns to trigger optimization
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
.agg(sum($"earnings"), avg($"earnings")),
Row(Seq(2012, 15000.0, 7500.0, 20000.0, 20000.0) ++ Seq.fill(20)(null): _*) ::
Row(Seq(2013, 48000.0, 48000.0, 30000.0, 30000.0) ++ Seq.fill(20)(null): _*) :: Nil
)
}

}