Skip to content

Commit

Permalink
[SPARK-23711][SQL] Add fallback generator for UnsafeProjection
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add fallback logic for `UnsafeProjection`. In production we can try to create unsafe projection using codegen implementation. Once any compile error happens, it fallbacks to interpreted implementation.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #21106 from viirya/SPARK-23711.
  • Loading branch information
viirya authored and cloud-fan committed May 23, 2018
1 parent 00c13cf commit a40ffc6
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 88 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.InternalCompilerException

import org.apache.spark.TaskContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

/**
* Catches compile error during code generation.
*/
object CodegenError {
def unapply(throwable: Throwable): Option[Exception] = throwable match {
case e: InternalCompilerException => Some(e)
case e: CompileException => Some(e)
case _ => None
}
}

/**
* Defines values for `SQLConf` config of fallback mode. Use for test only.
*/
object CodegenObjectFactoryMode extends Enumeration {
val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value
}

/**
* A codegen object generator which creates objects with codegen path first. Once any compile
* error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config
* `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior.
*/
abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] {

def createObject(in: IN): OUT = {
// We are allowed to choose codegen-only or no-codegen modes if under tests.
val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE)
val fallbackMode = CodegenObjectFactoryMode.withName(config)

fallbackMode match {
case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting =>
createCodeGeneratedObject(in)
case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting =>
createInterpretedObject(in)
case _ =>
try {
createCodeGeneratedObject(in)
} catch {
case CodegenError(_) => createInterpretedObject(in)
}
}
}

protected def createCodeGeneratedObject(in: IN): OUT
protected def createInterpretedObject(in: IN): OUT
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
/**
* Helper functions for creating an [[InterpretedUnsafeProjection]].
*/
object InterpretedUnsafeProjection extends UnsafeProjectionCreator {

object InterpretedUnsafeProjection {
/**
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.transform {
case s: Stateful => s.freshCopy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection {
override def apply(row: InternalRow): UnsafeRow
}

trait UnsafeProjectionCreator {
/**
* The factory object for `UnsafeProjection`.
*/
object UnsafeProjection
extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] {

override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(in)
}

override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
InterpretedUnsafeProjection.createProjection(in)
}

protected def toBoundExprs(
exprs: Seq[Expression],
inputSchema: Seq[Attribute]): Seq[Expression] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
}

protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
}

/**
* Returns an UnsafeProjection for given StructType.
*
Expand All @@ -129,10 +154,7 @@ trait UnsafeProjectionCreator {
* Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
val unsafeExprs = exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
createProjection(unsafeExprs)
createObject(toUnsafeExprs(exprs))
}

def create(expr: Expression): UnsafeProjection = create(Seq(expr))
Expand All @@ -142,34 +164,24 @@ trait UnsafeProjectionCreator {
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}

/**
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
protected def createProjection(exprs: Seq[Expression]): UnsafeProjection
}

object UnsafeProjection extends UnsafeProjectionCreator {

override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
create(toBoundExprs(exprs, inputSchema))
}

/**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
* TODO: refactor the plumbing and clean this up.
* The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
* when fallbacking to interpreted execution, it is not supported.
*/
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema))
try {
GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled)
} catch {
case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -703,6 +704,17 @@ object SQLConf {
.intConf
.createWithDefault(100)

val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode")
.doc("This config determines the fallback behavior of several codegen generators " +
"during tests. `FALLBACK` means trying codegen first and then fallbacking to " +
"interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " +
"`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " +
"this config works only for tests.")
.internal()
.stringConf
.checkValues(CodegenObjectFactoryMode.values.map(_.toString))
.createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString)

val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback")
.internal()
.doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, LongType}

class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase {

test("UnsafeProjection with codegen factory mode") {
val input = Seq(LongType, IntegerType)
.zipWithIndex.map(x => BoundReference(x._2, x._1, true))

val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
val obj = UnsafeProjection.createObject(input)
assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection"))
}

val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
val obj = UnsafeProjection.createObject(input)
assert(obj.isInstanceOf[InterpretedUnsafeProjection])
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils
/**
* A few helper functions for expression evaluation testing. Mixin this trait to use them.
*/
trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase {
self: SparkFunSuite =>

protected def create_row(values: Any*): InternalRow = {
Expand Down Expand Up @@ -205,39 +206,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
expression: Expression,
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection)
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection)
}

protected def checkEvaluationWithUnsafeProjection(
expression: Expression,
expected: Any,
inputRow: InternalRow,
factory: UnsafeProjectionCreator): Unit = {
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"

if (expected == null) {
if (!unsafeRow.isNullAt(0)) {
val expectedRow = InternalRow(expected, expected)
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
}
} else {
val lit = InternalRow(expected, expected)
val expectedRow =
factory.create(Array(expression.dataType, expression.dataType)).apply(lit)
if (unsafeRow != expectedRow) {
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN)
for (fallbackMode <- modes) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"

if (expected == null) {
if (!unsafeRow.isNullAt(0)) {
val expectedRow = InternalRow(expected, expected)
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
}
} else {
val lit = InternalRow(expected, expected)
val expectedRow =
UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
if (unsafeRow != expectedRow) {
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
}
}
}
}
}

protected def evaluateWithUnsafeProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow,
factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = {
inputRow: InternalRow = EmptyRow): InternalRow = {
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
// some expression is reusing variable names across different instances.
// This behavior is tested in ExpressionEvalHelperSuite.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val structExpected = new GenericArrayData(
Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
checkEvaluationWithUnsafeProjection(
structEncoder.serializer.head,
structExpected,
structInputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
structEncoder.serializer.head, structExpected, structInputRow)

// test UnsafeArray-backed data
val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
val arrayExpected = new GenericArrayData(
Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
checkEvaluationWithUnsafeProjection(
arrayEncoder.serializer.head,
arrayExpected,
arrayInputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

// test UnsafeMap-backed data
val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
Expand All @@ -109,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new GenericArrayData(Array(3, 4)),
new GenericArrayData(Array(300, 400)))))
checkEvaluationWithUnsafeProjection(
mapEncoder.serializer.head,
mapExpected,
mapInputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
mapEncoder.serializer.head, mapExpected, mapInputRow)
}

test("SPARK-23582: StaticInvoke should support interpreted execution") {
Expand Down Expand Up @@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluationWithUnsafeProjection(
expr,
expected,
inputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
inputRow)
}
checkEvaluationWithOptimization(expr, expected, inputRow)
}
Expand Down
Loading

0 comments on commit a40ffc6

Please sign in to comment.