diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index d66a6902b0510..cbef1c7828319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } override def afterEach() { try { resetSparkContext() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } finally { super.afterEach() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index a3ae93810aa3c..d305ce3e698ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} import java.util.Properties import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter /** @@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val converter = unsafeRowConverter(schema) @@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("toUnsafeRow() test helper method") { - // This currently doesnt work because the generic getter throws an exception. + // This currently doesn't work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) @@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-10466: external sorter spilling with unsafe row serializer") { - var sc: SparkContext = null - var outputFile: File = null - val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten - Utils.tryWithSafeFinally { - val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1") - .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.testing.memory", "80000") - - sc = new SparkContext("local", "test", conf) - outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") - // prepare data - val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 10000).iterator.map { i => - (i, converter(Row(i))) - } - val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - - val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( - taskContext, - partitioner = Some(new HashPartitioner(10)), - serializer = new UnsafeRowSerializer(numFields = 1)) - - // Ensure we spilled something and have to merge them later - assert(sorter.numSpills === 0) - sorter.insertAll(data) - assert(sorter.numSpills > 0) + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.testing.memory", "80000") + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() + val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + outputFile.deleteOnExit() + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 10000).iterator.map { i => + (i, converter(Row(i))) + } + val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - // Merging spilled files should not throw assertion error - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) - } { - // Clean up - if (sc != null) { - sc.stop() - } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, + partitioner = Some(new HashPartitioner(10)), + serializer = new UnsafeRowSerializer(numFields = 1)) - // restore the spark env - SparkEnv.set(oldEnv) + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) - if (outputFile != null) { - outputFile.delete() - } - } + // Merging spilled files should not throw assertion error + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } test("SPARK-10403: unsafe row serializer with SortShuffleManager") { val conf = new SparkConf().set("spark.shuffle.manager", "sort") - sc = new SparkContext("local", "test", conf) + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) - .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val rowsRDD = spark.sparkContext.parallelize( + Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)) + ).asInstanceOf[RDD[Product2[Int, InternalRow]]] val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rowsRDD,