From a40ffc656d62372da85e0fa932b67207839e7fde Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 22:40:52 +0800 Subject: [PATCH] [SPARK-23711][SQL] Add fallback generator for UnsafeProjection ## What changes were proposed in this pull request? Add fallback logic for `UnsafeProjection`. In production we can try to create unsafe projection using codegen implementation. Once any compile error happens, it fallbacks to interpreted implementation. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21106 from viirya/SPARK-23711. --- ...CodeGeneratorWithInterpretedFallback.scala | 73 +++++++++++++++++++ .../InterpretedUnsafeProjection.scala | 5 +- .../sql/catalyst/expressions/Projection.scala | 60 +++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 12 +++ ...eneratorWithInterpretedFallbackSuite.scala | 43 +++++++++++ .../expressions/ExpressionEvalHelper.scala | 52 ++++++------- .../expressions/ObjectExpressionsSuite.scala | 18 +---- .../expressions/UnsafeRowConverterSuite.scala | 52 ++++++++----- .../UnsafeKVExternalSorterSuite.scala | 3 +- 9 files changed, 230 insertions(+), 88 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala new file mode 100644 index 0000000000000..fb25e781e72e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.InternalCompilerException + +import org.apache.spark.TaskContext +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Catches compile error during code generation. + */ +object CodegenError { + def unapply(throwable: Throwable): Option[Exception] = throwable match { + case e: InternalCompilerException => Some(e) + case e: CompileException => Some(e) + case _ => None + } +} + +/** + * Defines values for `SQLConf` config of fallback mode. Use for test only. + */ +object CodegenObjectFactoryMode extends Enumeration { + val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value +} + +/** + * A codegen object generator which creates objects with codegen path first. Once any compile + * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config + * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. + */ +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { + + def createObject(in: IN): OUT = { + // We are allowed to choose codegen-only or no-codegen modes if under tests. + val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE) + val fallbackMode = CodegenObjectFactoryMode.withName(config) + + fallbackMode match { + case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting => + createCodeGeneratedObject(in) + case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting => + createInterpretedObject(in) + case _ => + try { + createCodeGeneratedObject(in) + } catch { + case CodegenError(_) => createInterpretedObject(in) + } + } + } + + protected def createCodeGeneratedObject(in: IN): OUT + protected def createInterpretedObject(in: IN): OUT +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6d69d69b1c802..55a5bd380859e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** * Helper functions for creating an [[InterpretedUnsafeProjection]]. */ -object InterpretedUnsafeProjection extends UnsafeProjectionCreator { - +object InterpretedUnsafeProjection { /** * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. */ - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + def createProjection(exprs: Seq[Expression]): UnsafeProjection = { // We need to make sure that we do not reuse stateful expressions. val cleanedExpressions = exprs.map(_.transform { case s: Stateful => s.freshCopy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 3cd73682188bc..6493f09100577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -trait UnsafeProjectionCreator { +/** + * The factory object for `UnsafeProjection`. + */ +object UnsafeProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + + protected def toBoundExprs( + exprs: Seq[Expression], + inputSchema: Seq[Attribute]): Seq[Expression] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) + } + + protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { + exprs.map(_ transform { + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + } + /** * Returns an UnsafeProjection for given StructType. * @@ -129,10 +154,7 @@ trait UnsafeProjectionCreator { * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - val unsafeExprs = exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - createProjection(unsafeExprs) + createObject(toUnsafeExprs(exprs)) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -142,34 +164,24 @@ trait UnsafeProjectionCreator { * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(exprs.map(BindReferences.bindReference(_, inputSchema))) - } - - /** - * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. - */ - protected def createProjection(exprs: Seq[Expression]): UnsafeProjection -} - -object UnsafeProjection extends UnsafeProjectionCreator { - - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + create(toBoundExprs(exprs, inputSchema)) } /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. + * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example, + * when fallbacking to interpreted execution, it is not supported. */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - val e = exprs.map(BindReferences.bindReference(_, inputSchema)) - .map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema)) + try { + GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) + } catch { + case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d0478d6ad250b..fb98df587129b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.Utils @@ -703,6 +704,17 @@ object SQLConf { .intConf .createWithDefault(100) + val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .doc("This config determines the fallback behavior of several codegen generators " + + "during tests. `FALLBACK` means trying codegen first and then fallbacking to " + + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + + "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " + + "this config works only for tests.") + .internal() + .stringConf + .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala new file mode 100644 index 0000000000000..531ca9a87370a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType} + +class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { + + test("UnsafeProjection with codegen factory mode") { + val input = Seq(LongType, IntegerType) + .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val obj = UnsafeProjection.createObject(input) + assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection")) + } + + val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val obj = UnsafeProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c2a44e0d33b18..14bfa212b5496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -205,39 +206,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) - } - - protected def checkEvaluationWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow, - factory: UnsafeProjectionCreator): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - - if (expected == null) { - if (!unsafeRow.isNullAt(0)) { - val expectedRow = InternalRow(expected, expected) - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") - } - } else { - val lit = InternalRow(expected, expected) - val expectedRow = - factory.create(Array(expression.dataType, expression.dataType)).apply(lit) - if (unsafeRow != expectedRow) { - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected, expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected, expected) + val expectedRow = + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } } } } protected def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow, - factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { + inputRow: InternalRow = EmptyRow): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 77ca640f2e0bd..20d568c44258f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -81,10 +81,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) checkEvaluationWithUnsafeProjection( - structEncoder.serializer.head, - structExpected, - structInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + structEncoder.serializer.head, structExpected, structInputRow) // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] @@ -92,10 +89,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) checkEvaluationWithUnsafeProjection( - arrayEncoder.serializer.head, - arrayExpected, - arrayInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -109,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) checkEvaluationWithUnsafeProjection( - mapEncoder.serializer.head, - mapExpected, - mapInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + mapEncoder.serializer.head, mapExpected, mapInputRow) } test("SPARK-23582: StaticInvoke should support interpreted execution") { @@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluationWithUnsafeProjection( expr, expected, - inputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + inputRow) } checkEvaluationWithOptimization(expr, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c07da122cd7b8..5a646d9a850ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,25 +24,30 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testWithFactory( - name: String)( - f: UnsafeProjectionCreator => Unit): Unit = { - test(name) { - f(UnsafeProjection) - f(InterpretedUnsafeProjection) + private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } } } - testWithFactory("basic conversion with only primitive types") { factory => + testBothCodegenAndInterpreted("basic conversion with only primitive types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) @@ -79,7 +84,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - testWithFactory("basic conversion with primitive, string and binary types") { factory => + testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = factory.create(fieldTypes) @@ -98,7 +104,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => + testBothCodegenAndInterpreted( + "basic conversion with primitive, string, date and timestamp types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = factory.create(fieldTypes) @@ -127,7 +135,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - testWithFactory("null handling") { factory => + testBothCodegenAndInterpreted("null handling") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -248,7 +257,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testWithFactory("NaN canonicalization") { factory => + testBothCodegenAndInterpreted("NaN canonicalization") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -263,7 +273,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - testWithFactory("basic conversion with struct type") { factory => + testBothCodegenAndInterpreted("basic conversion with struct type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) @@ -325,7 +336,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - testWithFactory("basic conversion with array type") { factory => + testBothCodegenAndInterpreted("basic conversion with array type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) @@ -355,7 +367,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - testWithFactory("basic conversion with map type") { factory => + testBothCodegenAndInterpreted("basic conversion with map type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) @@ -401,7 +414,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - testWithFactory("basic conversion with struct and array") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and array") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) @@ -440,7 +454,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with struct and map") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) @@ -486,7 +501,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with array and map") { factory => + testBothCodegenAndInterpreted("basic conversion with array and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index bf588d3bb7841..c882a9dd2148c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -231,7 +231,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` // which has duplicated keys and the number of entries exceeds its capacity. try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + TaskContext.setTaskContext(context) new UnsafeKVExternalSorter( schema, schema,