Skip to content

Commit

Permalink
[SPARK-26450][SQL] Avoid rebuilding map of schema for every column in…
Browse files Browse the repository at this point in the history
… projection

## What changes were proposed in this pull request?

When creating some unsafe projections, Spark rebuilds the map of schema attributes once for each expression in the projection. Some file format readers create one unsafe projection per input file, others create one per task. ProjectExec also creates one unsafe projection per task. As a result, for wide queries on wide tables, Spark might build the map of schema attributes hundreds of thousands of times.

This PR changes two functions to reuse the same AttributeSeq instance when creating BoundReference objects for each expression in the projection. This avoids the repeated rebuilding of the map of schema attributes.

### Benchmarks

The time saved by this PR depends on size of the schema, size of the projection, number of input files (or number of file splits), number of tasks, and file format. I chose a couple of example cases.

In the following tests, I ran the query
```sql
select * from table where id1 = 1
```

Matching rows are about 0.2% of the table.

#### Orc table 6000 columns, 500K rows, 34 input files

baseline | pr | improvement
----|----|----
1.772306 min | 1.487267 min | 16.082943%

#### Orc table 6000 columns, 500K rows, *17* input files

baseline | pr | improvement
----|----|----
 1.656400 min | 1.423550 min | 14.057595%

#### Orc table 60 columns, 50M rows, 34 input files

baseline | pr | improvement
----|----|----
0.299878 min | 0.290339 min | 3.180926%

#### Parquet table 6000 columns, 500K rows, 34 input files

baseline | pr | improvement
----|----|----
1.478306 min | 1.373728 min | 7.074165%

Note: The parquet reader does not create an unsafe projection. However, the filter operation in the query causes the planner to add a ProjectExec, which does create an unsafe projection for each task. So these results have nothing to do with Parquet itself.

#### Parquet table 60 columns, 50M rows, 34 input files

baseline | pr | improvement
----|----|----
0.245006 min | 0.242200 min | 1.145099%

#### CSV table 6000 columns, 500K rows, 34 input files

baseline | pr | improvement
----|----|----
2.390117 min | 2.182778 min | 8.674844%

#### CSV table 60 columns, 50M rows, 34 input files

baseline | pr | improvement
----|----|----
1.520911 min | 1.510211 min | 0.703526%

## How was this patch tested?

SQL unit tests
Python core and SQL test

Closes #23392 from bersprockets/norebuild.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
  • Loading branch information
bersprockets authored and hvanhovell committed Jan 13, 2019
1 parent c01152d commit 09b0548
Show file tree
Hide file tree
Showing 17 changed files with 68 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,13 @@ object BindReferences extends Logging {
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
}

/**
* A helper function to bind given expressions to an input schema.
*/
def bindReferences[A <: Expression](
expressions: Seq[A],
input: AttributeSeq): Seq[A] = {
expressions.map(BindReferences.bindReference(_, input))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp


Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
*/
class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(toBoundExprs(expressions, inputSchema))
this(bindReferences(expressions, inputSchema))

private[this] val buffer = new Array[Any](expressions.size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(expressions, inputSchema))

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
Expand Down Expand Up @@ -99,7 +100,7 @@ object MutableProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}

Expand Down Expand Up @@ -162,7 +163,7 @@ object UnsafeProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}

Expand Down Expand Up @@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp

// MutableProjection is not accessible in Java
Expand All @@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand All @@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

/**
* Creates a code gen ordering for sorting this schema, in ascending order.
Expand Down Expand Up @@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
extends Ordering[InternalRow] with KryoSerializable {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(ordering, inputSchema))

@transient
private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
Expand All @@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

private def createCodeForStruct(
ctx: CodegenContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types._


Expand All @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(ordering, inputSchema))

def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,6 @@ package object expressions {
override def apply(row: InternalRow): InternalRow = row
}

/**
* A helper function to bind given expressions to an input schema.
*/
def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
}

/**
* Helper functions for working with `Seq[Attribute]`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ case class ExpandExec(
// Part 1: declare variables for each column
// If a column has the same value for all output rows, then we also generate its computation
// right after declaration. Otherwise its value is computed in the part 2.
lazy val attributeSeq: AttributeSeq = child.output
val outputColumns = output.indices.map { col =>
val firstExpr = projections.head(col)
if (sameOutput(col)) {
// This column is the same across all output rows. Just generate code for it here.
BindReferences.bindReference(firstExpr, child.output).genCode(ctx)
BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
Expand All @@ -170,7 +171,7 @@ case class ExpandExec(
var updateCode = ""
for (col <- exprs.indices) {
if (!sameOutput(col)) {
val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx)
val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
updateCode +=
s"""
|${ev.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ abstract class AggregationIterator(
val expressionsLength = expressions.length
val functions = new Array[AggregateFunction](expressionsLength)
var i = 0
val inputAttributeSeq: AttributeSeq = inputAttributes
while (i < expressionsLength) {
val func = expressions(i).aggregateFunction
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
Expand All @@ -86,7 +87,7 @@ abstract class AggregationIterator(
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
BindReferences.bindReference(func, inputAttributeSeq)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -199,15 +200,13 @@ case class HashAggregateExec(
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val aggResults = bindReferences(
functions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
}
val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
(resultVars, s"""
|$evaluateAggResults
|${evaluateVariables(resultVars)}
Expand Down Expand Up @@ -264,7 +263,7 @@ case class HashAggregateExec(
}
}
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down Expand Up @@ -456,16 +455,16 @@ case class HashAggregateExec(
val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val aggResults = bindReferences(
declFunctions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateBufferVars
Expand Down Expand Up @@ -494,9 +493,9 @@ case class HashAggregateExec(

ctx.currentVars = keyVars ++ resultBufferVars
val inputAttrs = resultExpressions.map(_.toAttribute)
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateResultBufferVars
Expand All @@ -506,9 +505,9 @@ case class HashAggregateExec(
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
}
val eval = bindReferences[Expression](
resultExpressions,
groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
}
ctx.addNewFunction(funcName,
Expand Down Expand Up @@ -730,9 +729,9 @@ case class HashAggregateExec(
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// create grouping key
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
ctx, bindReferences[Expression](groupingExpressions, child.output))
val fastRowKeys = ctx.generateExpressions(
groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
bindReferences[Expression](groupingExpressions, child.output))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val fastRowBuffer = ctx.freshName("fastAggBuffer")
Expand Down Expand Up @@ -825,7 +824,7 @@ case class HashAggregateExec(

val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand All @@ -849,7 +848,7 @@ case class HashAggregateExec(
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
val exprs = bindReferences[Expression](projectList, child.output)
val resultVars = exprs.map(_.genCode(ctx))
// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
Expand Down Expand Up @@ -145,9 +146,8 @@ object FileFormatWriter extends Logging {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = requiredOrdering
.map(SortOrder(_, Ascending))
.map(BindReferences.bindReference(_, outputSpec.outputColumns))
val orderingExpr = bindReferences(
requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
SortExec(
orderingExpr,
global = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
Expand Down Expand Up @@ -63,9 +64,8 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
.map(BindReferences.bindReference(_, right.output))
val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
Expand Down
Loading

0 comments on commit 09b0548

Please sign in to comment.