Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26450][SQL] Avoid rebuilding map of schema for every column in projection #23392

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,13 @@ object BindReferences extends Logging {
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
}

/**
* A helper function to bind given expressions to an input schema.
*/
def bindReferences[A <: Expression](
expressions: Seq[A],
input: AttributeSeq): Seq[A] = {
expressions.map(BindReferences.bindReference(_, input))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp


Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
*/
class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(toBoundExprs(expressions, inputSchema))
this(bindReferences(expressions, inputSchema))

private[this] val buffer = new Array[Any](expressions.size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(expressions, inputSchema))

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
Expand Down Expand Up @@ -99,7 +100,7 @@ object MutableProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}

Expand Down Expand Up @@ -162,7 +163,7 @@ object UnsafeProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}

Expand Down Expand Up @@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp

// MutableProjection is not accessible in Java
Expand All @@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand All @@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

/**
* Creates a code gen ordering for sorting this schema, in ascending order.
Expand Down Expand Up @@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
extends Ordering[InternalRow] with KryoSerializable {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(ordering, inputSchema))

@transient
private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
Expand All @@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

private def createCodeForStruct(
ctx: CodegenContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types._


Expand All @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(ordering, inputSchema))

def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,6 @@ package object expressions {
override def apply(row: InternalRow): InternalRow = row
}

/**
* A helper function to bind given expressions to an input schema.
*/
def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
}

/**
* Helper functions for working with `Seq[Attribute]`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ case class ExpandExec(
// Part 1: declare variables for each column
// If a column has the same value for all output rows, then we also generate its computation
// right after declaration. Otherwise its value is computed in the part 2.
lazy val attributeSeq: AttributeSeq = child.output
val outputColumns = output.indices.map { col =>
val firstExpr = projections.head(col)
if (sameOutput(col)) {
// This column is the same across all output rows. Just generate code for it here.
BindReferences.bindReference(firstExpr, child.output).genCode(ctx)
BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
Expand All @@ -170,7 +171,7 @@ case class ExpandExec(
var updateCode = ""
for (col <- exprs.indices) {
if (!sameOutput(col)) {
val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx)
val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
updateCode +=
s"""
|${ev.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ abstract class AggregationIterator(
val expressionsLength = expressions.length
val functions = new Array[AggregateFunction](expressionsLength)
var i = 0
val inputAttributeSeq: AttributeSeq = inputAttributes
while (i < expressionsLength) {
val func = expressions(i).aggregateFunction
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
Expand All @@ -86,7 +87,7 @@ abstract class AggregationIterator(
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
BindReferences.bindReference(func, inputAttributeSeq)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -199,15 +200,13 @@ case class HashAggregateExec(
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val aggResults = bindReferences(
functions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
}
val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
(resultVars, s"""
|$evaluateAggResults
|${evaluateVariables(resultVars)}
Expand Down Expand Up @@ -264,7 +263,7 @@ case class HashAggregateExec(
}
}
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down Expand Up @@ -456,16 +455,16 @@ case class HashAggregateExec(
val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val aggResults = bindReferences(
declFunctions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateBufferVars
Expand Down Expand Up @@ -494,9 +493,9 @@ case class HashAggregateExec(

ctx.currentVars = keyVars ++ resultBufferVars
val inputAttrs = resultExpressions.map(_.toAttribute)
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateResultBufferVars
Expand All @@ -506,9 +505,9 @@ case class HashAggregateExec(
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
}
val eval = bindReferences[Expression](
resultExpressions,
groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
}
ctx.addNewFunction(funcName,
Expand Down Expand Up @@ -730,9 +729,9 @@ case class HashAggregateExec(
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// create grouping key
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
ctx, bindReferences[Expression](groupingExpressions, child.output))
val fastRowKeys = ctx.generateExpressions(
groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
bindReferences[Expression](groupingExpressions, child.output))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val fastRowBuffer = ctx.freshName("fastAggBuffer")
Expand Down Expand Up @@ -825,7 +824,7 @@ case class HashAggregateExec(

val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand All @@ -849,7 +848,7 @@ case class HashAggregateExec(
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
val exprs = bindReferences[Expression](projectList, child.output)
val resultVars = exprs.map(_.genCode(ctx))
// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
Expand Down Expand Up @@ -145,9 +146,8 @@ object FileFormatWriter extends Logging {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = requiredOrdering
.map(SortOrder(_, Ascending))
.map(BindReferences.bindReference(_, outputSpec.outputColumns))
val orderingExpr = bindReferences(
requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
SortExec(
orderingExpr,
global = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
Expand Down Expand Up @@ -63,9 +64,8 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
.map(BindReferences.bindReference(_, right.output))
val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
Expand Down
Loading