From 6b667116b3c39d4157566b2cefc4880357a112c5 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 12 Dec 2018 10:59:59 -0800 Subject: [PATCH 01/12] Initial commit --- .../expressions/codegen/GenerateUnsafeProjection.scala | 6 ++++-- .../org/apache/spark/sql/catalyst/expressions/package.scala | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 0ecd0de8d8203..107029e71b95f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -316,8 +316,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { + lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema + in.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) + } def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index bf18e8bcb52df..73e2fb567729e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -89,7 +89,8 @@ package object expressions { * A helper function to bind given expressions to an input schema. */ def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { - exprs.map(BindReferences.bindReference(_, inputSchema)) + lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema + exprs.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) } /** From a9dbbec649438a5079f8d9826ebe40bdecfef93c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 28 Dec 2018 18:53:31 -0800 Subject: [PATCH 02/12] Initial try at helper function --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 7 +++++++ .../catalyst/expressions/codegen/GenerateOrdering.scala | 5 +++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ea8c369ee49ed..e75ea77938fc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -86,4 +86,11 @@ object BindReferences extends Logging { } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } + + def bindReferences[A <: Expression]( + in: Seq[A], + inputSchema: Seq[Attribute]): Seq[A] = { + lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema + in.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 283fd2a6e9383..e84509b5accc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -45,8 +45,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) - protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = - in.map(BindReferences.bindReference(_, inputSchema)) + protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = { + BindReferences.bindReferences(in, inputSchema) + } /** * Creates a code gen ordering for sorting this schema, in ascending order. From 67a943afcbeb4be5942f72a50961656379ba961a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 30 Dec 2018 18:43:14 -0800 Subject: [PATCH 03/12] Use helper function --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 7 ------- .../expressions/codegen/GenerateMutableProjection.scala | 2 +- .../catalyst/expressions/codegen/GenerateOrdering.scala | 2 +- .../expressions/codegen/GenerateSafeProjection.scala | 2 +- .../expressions/codegen/GenerateUnsafeProjection.scala | 3 +-- .../apache/spark/sql/catalyst/expressions/package.scala | 6 ++++-- 6 files changed, 8 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index e75ea77938fc4..ea8c369ee49ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -86,11 +86,4 @@ object BindReferences extends Logging { } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } - - def bindReferences[A <: Expression]( - in: Seq[A], - inputSchema: Seq[Attribute]): Seq[A] = { - lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema - in.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d588e7f081303..823422cbed514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -41,7 +41,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], inputSchema: Seq[Attribute], useSubexprElimination: Boolean): MutableProjection = { - create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) + create(canonicalize(toBoundExprs(expressions, inputSchema)), useSubexprElimination) } def generate(expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index e84509b5accc3..c48ab524f5f3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -46,7 +46,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = { - BindReferences.bindReferences(in, inputSchema) + toBoundExprs(in, inputSchema) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 39778661d1c48..d5c427c946975 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -41,7 +41,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + toBoundExprs(in, inputSchema) private def createCodeForStruct( ctx: CodegenContext, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 107029e71b95f..bb501312242a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -317,8 +317,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { - lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema - in.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) + toBoundExprs(in, inputSchema) } def generate( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 73e2fb567729e..ad13e6b2c9f74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -88,9 +88,11 @@ package object expressions { /** * A helper function to bind given expressions to an input schema. */ - def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { + def toBoundExprs[A <: Expression]( + in: Seq[A], + inputSchema: Seq[Attribute]): Seq[A] = { lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema - exprs.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) + in.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) } /** From 3362744689084478b4fdfa9326f3294f506f1a78 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 30 Dec 2018 22:38:44 -0800 Subject: [PATCH 04/12] More uses of helper function --- .../apache/spark/sql/catalyst/expressions/Projection.scala | 2 +- .../sql/catalyst/expressions/codegen/GenerateOrdering.scala | 5 ++--- .../expressions/codegen/GenerateUnsafeProjection.scala | 3 +-- .../org/apache/spark/sql/catalyst/expressions/ordering.scala | 2 +- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 4 ++-- .../org/apache/spark/sql/execution/joins/HashJoin.scala | 5 ++--- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 3 ++- 7 files changed, 11 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b48f7ba655b2f..32bea22c66a8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructType} */ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(expressions.map(BindReferences.bindReference(_, inputSchema))) + this(toBoundExprs(expressions, inputSchema)) override def initialize(partitionIndex: Int): Unit = { expressions.foreach(_.foreach { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index c48ab524f5f3c..001eca5ae9a4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -45,9 +45,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) - protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = { + protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = toBoundExprs(in, inputSchema) - } /** * Creates a code gen ordering for sorting this schema, in ascending order. @@ -189,7 +188,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[InternalRow] with KryoSerializable { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(ordering.map(BindReferences.bindReference(_, inputSchema))) + this(toBoundExprs(ordering, inputSchema)) @transient private[this] var generatedOrdering = GenerateOrdering.generate(ordering) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index bb501312242a0..1985ff966e188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -316,9 +316,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = toBoundExprs(in, inputSchema) - } def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index e24a3de3cfdbe..764924f90d74e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(ordering.map(BindReferences.bindReference(_, inputSchema))) + this(toBoundExprs(ordering, inputSchema)) def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4827f838fc514..a8521ae3f5ce8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -264,7 +264,7 @@ case class HashAggregateExec( } } ctx.currentVars = bufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val boundUpdateExpr = toBoundExprs(updateExpr, inputAttrs) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { @@ -825,7 +825,7 @@ case class HashAggregateExec( val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1aef5f6864263..af0aec7c8ff7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -63,9 +63,8 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = HashJoin.rewriteKeyExpr(rightKeys) - .map(BindReferences.bindReference(_, right.output)) + val lkeys = toBoundExprs(HashJoin.rewriteKeyExpr(leftKeys), left.output) + val rkeys = toBoundExprs(HashJoin.rewriteKeyExpr(rightKeys), right.output) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d7d3f6d6078b4..e37a85779003e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -393,7 +393,8 @@ case class SortMergeJoinExec( input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row ctx.currentVars = null - keys.map(BindReferences.bindReference(_, input).genCode(ctx)) + val inputAttributeSeq: AttributeSeq = input + keys.map(BindReferences.bindReference(_, inputAttributeSeq).genCode(ctx)) } private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { From 306e0a5424f3bde6b9deadd292d85a4860e66b97 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 31 Dec 2018 10:26:59 -0800 Subject: [PATCH 05/12] Even more uses of helper function --- .../expressions/codegen/GenerateMutableProjection.scala | 2 +- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 823422cbed514..a36f19d4c9091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -35,7 +35,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + toBoundExprs(in, inputSchema) def generate( expressions: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a8521ae3f5ce8..a6e592d1c7049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -849,7 +849,7 @@ case class HashAggregateExec( if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { From a25b59ca756958370dd7ba14d6c1e33dec424ea8 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 1 Jan 2019 19:59:03 -0800 Subject: [PATCH 06/12] Some clean up --- .../expressions/codegen/GenerateMutableProjection.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/package.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index a36f19d4c9091..58f017758af88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -41,7 +41,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], inputSchema: Seq[Attribute], useSubexprElimination: Boolean): MutableProjection = { - create(canonicalize(toBoundExprs(expressions, inputSchema)), useSubexprElimination) + create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) } def generate(expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index ad13e6b2c9f74..69d684cb9b35b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -89,10 +89,10 @@ package object expressions { * A helper function to bind given expressions to an input schema. */ def toBoundExprs[A <: Expression]( - in: Seq[A], + exprs: Seq[A], inputSchema: Seq[Attribute]): Seq[A] = { lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema - in.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) + exprs.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) } /** From b977d3e70ba59f537168a778533a6d1ff9a07a34 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 2 Jan 2019 09:26:57 -0800 Subject: [PATCH 07/12] Review feedback --- .../org/apache/spark/sql/catalyst/expressions/package.scala | 5 ++--- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 69d684cb9b35b..f5799d756f4f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -90,9 +90,8 @@ package object expressions { */ def toBoundExprs[A <: Expression]( exprs: Seq[A], - inputSchema: Seq[Attribute]): Seq[A] = { - lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema - exprs.map(BindReferences.bindReference(_, inputSchemaAttrSeq)) + inputSchema: AttributeSeq): Seq[A] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index e37a85779003e..da341454f8946 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -393,8 +393,7 @@ case class SortMergeJoinExec( input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row ctx.currentVars = null - val inputAttributeSeq: AttributeSeq = input - keys.map(BindReferences.bindReference(_, inputAttributeSeq).genCode(ctx)) + toBoundExprs(keys, input).map(_.genCode(ctx)) } private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { From 1497d3a92ad82eb1b47d7273477c24e23d09d50c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 2 Jan 2019 13:13:59 -0800 Subject: [PATCH 08/12] Fix style issue --- .../org/apache/spark/sql/catalyst/expressions/package.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f5799d756f4f8..c8c47abb4653a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -89,8 +89,8 @@ package object expressions { * A helper function to bind given expressions to an input schema. */ def toBoundExprs[A <: Expression]( - exprs: Seq[A], - inputSchema: AttributeSeq): Seq[A] = { + exprs: Seq[A], + inputSchema: AttributeSeq): Seq[A] = { exprs.map(BindReferences.bindReference(_, inputSchema)) } From 777c5b4c57144efb1a04057da2ecc64078411b70 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 3 Jan 2019 22:42:56 -0800 Subject: [PATCH 09/12] Move helper function to BindReferences --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 9 +++++++++ .../expressions/InterpretedMutableProjection.scala | 3 ++- .../spark/sql/catalyst/expressions/Projection.scala | 9 +++++---- .../expressions/codegen/GenerateMutableProjection.scala | 3 ++- .../catalyst/expressions/codegen/GenerateOrdering.scala | 5 +++-- .../expressions/codegen/GenerateSafeProjection.scala | 3 ++- .../expressions/codegen/GenerateUnsafeProjection.scala | 3 ++- .../apache/spark/sql/catalyst/expressions/ordering.scala | 3 ++- .../apache/spark/sql/catalyst/expressions/package.scala | 9 --------- .../sql/execution/aggregate/HashAggregateExec.scala | 7 ++++--- .../org/apache/spark/sql/execution/joins/HashJoin.scala | 5 +++-- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 3 ++- 12 files changed, 36 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ea8c369ee49ed..7ae5924b20faf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -86,4 +86,13 @@ object BindReferences extends Logging { } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } + + /** + * A helper function to bind given expressions to an input schema. + */ + def bindReferences[A <: Expression]( + expressions: Seq[A], + input: AttributeSeq): Seq[A] = { + expressions.map(BindReferences.bindReference(_, input)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 122a564da61be..5c8aa4e2e9d83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp */ class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(toBoundExprs(expressions, inputSchema)) + this(bindReferences(expressions, inputSchema)) private[this] val buffer = new Array[Any](expressions.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 32bea22c66a8e..eaaf94baac216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType} */ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(toBoundExprs(expressions, inputSchema)) + this(bindReferences(expressions, inputSchema)) override def initialize(partitionIndex: Int): Unit = { expressions.foreach(_.foreach { @@ -99,7 +100,7 @@ object MutableProjection * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } @@ -162,7 +163,7 @@ object UnsafeProjection * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } @@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 58f017758af88..838bd1c679e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp // MutableProjection is not accessible in Java @@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - toBoundExprs(in, inputSchema) + bindReferences(in, inputSchema) def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 001eca5ae9a4a..b66b80ad31dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = - toBoundExprs(in, inputSchema) + bindReferences(in, inputSchema) /** * Creates a code gen ordering for sorting this schema, in ascending order. @@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[InternalRow] with KryoSerializable { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(toBoundExprs(ordering, inputSchema)) + this(bindReferences(ordering, inputSchema)) @transient private[this] var generatedOrdering = GenerateOrdering.generate(ordering) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index d5c427c946975..e285398ba1958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - toBoundExprs(in, inputSchema) + bindReferences(in, inputSchema) private def createCodeForStruct( ctx: CodegenContext, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 1985ff966e188..fb1d8a3c8e739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - toBoundExprs(in, inputSchema) + bindReferences(in, inputSchema) def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 764924f90d74e..c8d667143f452 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(toBoundExprs(ordering, inputSchema)) + this(bindReferences(ordering, inputSchema)) def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index c8c47abb4653a..932c364737249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -85,15 +85,6 @@ package object expressions { override def apply(row: InternalRow): InternalRow = row } - /** - * A helper function to bind given expressions to an input schema. - */ - def toBoundExprs[A <: Expression]( - exprs: Seq[A], - inputSchema: AttributeSeq): Seq[A] = { - exprs.map(BindReferences.bindReference(_, inputSchema)) - } - /** * Helper functions for working with `Seq[Attribute]`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a6e592d1c7049..5cad8a33fbeb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -264,7 +265,7 @@ case class HashAggregateExec( } } ctx.currentVars = bufVars ++ input - val boundUpdateExpr = toBoundExprs(updateExpr, inputAttrs) + val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { @@ -825,7 +826,7 @@ case class HashAggregateExec( val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr) + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { @@ -849,7 +850,7 @@ case class HashAggregateExec( if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr) + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index af0aec7c8ff7b..5ee4c7ffb1911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{RowIterator, SparkPlan} @@ -63,8 +64,8 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = toBoundExprs(HashJoin.rewriteKeyExpr(leftKeys), left.output) - val rkeys = toBoundExprs(HashJoin.rewriteKeyExpr(rightKeys), right.output) + val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output) + val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index da341454f8946..f829f07e80720 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ @@ -393,7 +394,7 @@ case class SortMergeJoinExec( input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row ctx.currentVars = null - toBoundExprs(keys, input).map(_.genCode(ctx)) + bindReferences(keys, input).map(_.genCode(ctx)) } private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { From fbd6787ae45487e4baf722f6be728b6e72553e8f Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 4 Jan 2019 08:20:55 -0800 Subject: [PATCH 10/12] Handle cases missed before --- .../spark/sql/execution/ExpandExec.scala | 3 +- .../aggregate/AggregationIterator.scala | 3 +- .../aggregate/HashAggregateExec.scala | 33 ++++++++----------- .../execution/basicPhysicalOperators.scala | 3 +- .../datasources/FileFormatWriter.scala | 7 ++-- .../window/WindowFunctionFrame.scala | 9 ++--- 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 5b4edf5136e3f..392e5b14eb937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -168,9 +168,10 @@ case class ExpandExec( // Part 2: switch/case statements val cases = projections.zipWithIndex.map { case (exprs, row) => var updateCode = "" + val attributeSeq: AttributeSeq = child.output for (col <- exprs.indices) { if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx) + val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx) updateCode += s""" |${ev.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 98c4a51299958..a1fb23d621d49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -77,6 +77,7 @@ abstract class AggregationIterator( val expressionsLength = expressions.length val functions = new Array[AggregateFunction](expressionsLength) var i = 0 + val inputAttributeSeq: AttributeSeq = inputAttributes while (i < expressionsLength) { val func = expressions(i).aggregateFunction val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { @@ -86,7 +87,7 @@ abstract class AggregationIterator( // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, inputAttributes) + BindReferences.bindReference(func, inputAttributeSeq) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 5cad8a33fbeb0..f1eae9bf8857f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -200,15 +200,12 @@ case class HashAggregateExec( val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results ctx.currentVars = bufVars - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } + val aggResults = bindReferences(functions.map(_.evaluateExpression), + aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) - } + val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx)) (resultVars, s""" |$evaluateAggResults |${evaluateVariables(resultVars)} @@ -457,16 +454,14 @@ case class HashAggregateExec( val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars - val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } + val aggResults = bindReferences(declFunctions.map(_.evaluateExpression), + aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).genCode(ctx) - } + val resultVars = bindReferences[Expression](resultExpressions, + inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars $evaluateBufferVars @@ -495,9 +490,8 @@ case class HashAggregateExec( ctx.currentVars = keyVars ++ resultBufferVars val inputAttrs = resultExpressions.map(_.toAttribute) - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).genCode(ctx) - } + val resultVars = bindReferences[Expression](resultExpressions, + inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars $evaluateResultBufferVars @@ -507,9 +501,8 @@ case class HashAggregateExec( // generate result based on grouping key ctx.INPUT_ROW = keyTerm ctx.currentVars = null - val eval = resultExpressions.map{ e => - BindReferences.bindReference(e, groupingAttributes).genCode(ctx) - } + val eval = bindReferences[Expression](resultExpressions, + groupingAttributes).map(_.genCode(ctx)) consume(ctx, eval) } ctx.addNewFunction(funcName, @@ -731,9 +724,9 @@ case class HashAggregateExec( private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( - ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + ctx, bindReferences[Expression](groupingExpressions, child.output)) val fastRowKeys = ctx.generateExpressions( - groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 09effe087e195..7e87a150c0a4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output)) + val exprs = bindReferences[Expression](projectList, child.output) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 774fe38f5c2e6..b3cca622a99eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} @@ -145,9 +146,9 @@ object FileFormatWriter extends Logging { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = requiredOrdering - .map(SortOrder(_, Ascending)) - .map(BindReferences.bindReference(_, outputSpec.outputColumns)) + val orderingExpr = bindReferences( + requiredOrdering.map(SortOrder(_, Ascending)), + outputSpec.outputColumns) SortExec( orderingExpr, global = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 156002ef58fbe..433204578a9b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -21,6 +21,7 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray @@ -89,9 +90,9 @@ private[window] final class OffsetWindowFunctionFrame( private[this] val projection = { // Collect the expressions and bind them. val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => - BindReferences.bindReference(e.input, inputAttrs) - } + val boundExpressions = bindReferences( + Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map(_.input), + inputAttrs) // Create the projection. newMutableProjection(boundExpressions, Nil).target(target) @@ -100,7 +101,7 @@ private[window] final class OffsetWindowFunctionFrame( /** Create the projection used when the offset row DOES NOT exists. */ private[this] val fillDefaultValue = { // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) + val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true)) val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => if (e.default == null || e.default.foldable && e.default.eval() == null) { // The default value is null. From 686e7b5a4c4f3523970bf41d7d48c8cde66005c9 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sat, 5 Jan 2019 15:29:01 -0800 Subject: [PATCH 11/12] Review feedback --- .../execution/aggregate/HashAggregateExec.scala | 15 ++++++++++----- .../execution/datasources/FileFormatWriter.scala | 3 +-- .../execution/window/WindowFunctionFrame.scala | 5 ++--- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f1eae9bf8857f..28801774418a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -200,7 +200,8 @@ case class HashAggregateExec( val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results ctx.currentVars = bufVars - val aggResults = bindReferences(functions.map(_.evaluateExpression), + val aggResults = bindReferences( + functions.map(_.evaluateExpression), aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions @@ -454,13 +455,15 @@ case class HashAggregateExec( val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars - val aggResults = bindReferences(declFunctions.map(_.evaluateExpression), + val aggResults = bindReferences( + declFunctions.map(_.evaluateExpression), aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes - val resultVars = bindReferences[Expression](resultExpressions, + val resultVars = bindReferences[Expression]( + resultExpressions, inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars @@ -490,7 +493,8 @@ case class HashAggregateExec( ctx.currentVars = keyVars ++ resultBufferVars val inputAttrs = resultExpressions.map(_.toAttribute) - val resultVars = bindReferences[Expression](resultExpressions, + val resultVars = bindReferences[Expression]( + resultExpressions, inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars @@ -501,7 +505,8 @@ case class HashAggregateExec( // generate result based on grouping key ctx.INPUT_ROW = keyTerm ctx.currentVars = null - val eval = bindReferences[Expression](resultExpressions, + val eval = bindReferences[Expression]( + resultExpressions, groupingAttributes).map(_.genCode(ctx)) consume(ctx, eval) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index b3cca622a99eb..260ad97506a85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -147,8 +147,7 @@ object FileFormatWriter extends Logging { // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. val orderingExpr = bindReferences( - requiredOrdering.map(SortOrder(_, Ascending)), - outputSpec.outputColumns) + requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) SortExec( orderingExpr, global = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 433204578a9b8..5bf34558fe493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -90,9 +90,8 @@ private[window] final class OffsetWindowFunctionFrame( private[this] val projection = { // Collect the expressions and bind them. val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = bindReferences( - Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map(_.input), - inputAttrs) + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences( + expressions.toSeq.map(_.input), inputAttrs) // Create the projection. newMutableProjection(boundExpressions, Nil).target(target) From b1afb3a7bb0e1a67ce12c43ac545650e54a37e7c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 13 Jan 2019 10:40:17 -0800 Subject: [PATCH 12/12] Close gap in ExpandExec --- .../scala/org/apache/spark/sql/execution/ExpandExec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 392e5b14eb937..85f49140a4b41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -145,11 +145,12 @@ case class ExpandExec( // Part 1: declare variables for each column // If a column has the same value for all output rows, then we also generate its computation // right after declaration. Otherwise its value is computed in the part 2. + lazy val attributeSeq: AttributeSeq = child.output val outputColumns = output.indices.map { col => val firstExpr = projections.head(col) if (sameOutput(col)) { // This column is the same across all output rows. Just generate code for it here. - BindReferences.bindReference(firstExpr, child.output).genCode(ctx) + BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx) } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") @@ -168,7 +169,6 @@ case class ExpandExec( // Part 2: switch/case statements val cases = projections.zipWithIndex.map { case (exprs, row) => var updateCode = "" - val attributeSeq: AttributeSeq = child.output for (col <- exprs.indices) { if (!sameOutput(col)) { val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)