diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 65edeeffb837a..7cccf74003431 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -58,12 +58,26 @@ import org.apache.spark.util._ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. + * * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ - class SparkContext(config: SparkConf) extends Logging { + // The call site where this SparkContext was constructed. + private val creationSite: CallSite = Utils.getCallSite() + + // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active + private val allowMultipleContexts: Boolean = + config.getBoolean("spark.driver.allowMultipleContexts", false) + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having started construction. + // NOTE: this must be placed at the beginning of the SparkContext constructor. + SparkContext.markPartiallyConstructed(this, allowMultipleContexts) + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It // contains a map from hostname to a list of input format splits on the host. @@ -1166,27 +1180,30 @@ class SparkContext(config: SparkConf) extends Logging { /** Shut down the SparkContext. */ def stop() { - postApplicationEnd() - ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - env.metricsSystem.report() - metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) - cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - listenerBus.stop() - eventLogger.foreach(_.stop()) - logInfo("Successfully stopped SparkContext") - } else { - logInfo("SparkContext already stopped") + SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + postApplicationEnd() + ui.foreach(_.stop()) + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { + env.metricsSystem.report() + metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) + cleaner.foreach(_.stop()) + dagSchedulerCopy.stop() + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + listenerBus.stop() + eventLogger.foreach(_.stop()) + logInfo("Successfully stopped SparkContext") + SparkContext.clearActiveContext() + } else { + logInfo("SparkContext already stopped") + } } } @@ -1475,6 +1492,11 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def cleanup(cleanupTime: Long) { persistentRdds.clearOldValues(cleanupTime) } + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having finished construction. + // NOTE: this must be placed at the end of the SparkContext constructor. + SparkContext.setActiveContext(this, allowMultipleContexts) } /** @@ -1483,6 +1505,107 @@ class SparkContext(config: SparkConf) extends Logging { */ object SparkContext extends Logging { + /** + * Lock that guards access to global variables that track SparkContext construction. + */ + private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() + + /** + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var activeContext: Option[SparkContext] = None + + /** + * Points to a partially-constructed SparkContext if some thread is in the SparkContext + * constructor, or `None` if no SparkContext is being constructed. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var contextBeingConstructed: Option[SparkContext] = None + + /** + * Called to ensure that no other SparkContext is running in this JVM. + * + * Throws an exception if a running context is detected and logs a warning if another thread is + * constructing a SparkContext. This warning is necessary because the current locking scheme + * prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private def assertNoOtherContextIsRunning( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + contextBeingConstructed.foreach { otherContext => + if (otherContext ne sc) { // checks for reference equality + // Since otherContext might point to a partially-constructed context, guard against + // its creationSite field being null: + val otherContextCreationSite = + Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location") + val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" + + " constructor). This may indicate an error, since only one SparkContext may be" + + " running in this JVM (see SPARK-2243)." + + s" The other SparkContext was created at:\n$otherContextCreationSite" + logWarning(warnMsg) + } + + activeContext.foreach { ctx => + val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" + val exception = new SparkException(errMsg) + if (allowMultipleContexts) { + logWarning("Multiple running SparkContexts detected in the same JVM!", exception) + } else { + throw exception + } + } + } + } + } + + /** + * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is + * running. Throws an exception if a running context is detected and logs a warning if another + * thread is constructing a SparkContext. This warning is necessary because the current locking + * scheme prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private[spark] def markPartiallyConstructed( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = Some(sc) + } + } + + /** + * Called at the end of the SparkContext constructor to ensure that no other SparkContext has + * raced with this constructor and started. + */ + private[spark] def setActiveContext( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = None + activeContext = Some(sc) + } + } + + /** + * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's + * also called in unit tests to prevent a flood of warnings from test suites that don't / can't + * properly clean up their SparkContexts. + */ + private[spark] def clearActiveContext(): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + activeContext = None + } + } + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d50ed32ca085c..6a6d9bf6857d3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -42,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. + * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround with Closeable { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 4b27477790212..ce804f94f3267 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -37,20 +37,24 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") intercept[SparkException] { new SparkContext(conf) } SparkEnv.get.stop() // cleanup the created environment + SparkContext.clearActiveContext() // Only min val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1") intercept[SparkException] { new SparkContext(conf1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Only max val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2") intercept[SparkException] { new SparkContext(conf2) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 31edad1c56c73..9e454ddcc52a6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -21,9 +21,62 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable -class SparkContextSuite extends FunSuite { - //Regression test for SPARK-3121 +class SparkContextSuite extends FunSuite with LocalSparkContext { + + /** Allows system properties to be changed in tests */ + private def withSystemProperty[T](property: String, value: String)(block: => T): T = { + val originalValue = System.getProperty(property) + try { + System.setProperty(property, value) + block + } finally { + if (originalValue == null) { + System.clearProperty(property) + } else { + System.setProperty(property, originalValue) + } + } + } + + test("Only one SparkContext may be active at a time") { + // Regression test for SPARK-4180 + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + // A SparkContext is already running, so we shouldn't be able to create a second one + intercept[SparkException] { new SparkContext(conf) } + // After stopping the running context, we should be able to create a new one + resetSparkContext() + sc = new SparkContext(conf) + } + } + + test("Can still construct a new SparkContext after failing to construct a previous one") { + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + // This is an invalid configuration (no app name or master URL) + intercept[SparkException] { + new SparkContext(new SparkConf()) + } + // Even though those earlier calls failed, we should still be able to create a new context + sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test")) + } + } + + test("Check for multiple SparkContexts can be disabled via undocumented debug option") { + withSystemProperty("spark.driver.allowMultipleContexts", "true") { + var secondSparkContext: SparkContext = null + try { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + secondSparkContext = new SparkContext(conf) + } finally { + Option(secondSparkContext).foreach(_.stop()) + } + } + } + test("BytesWritable implicit conversion is correct") { + // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() val inputArray = (1 to 10).map(_.toByte).toArray bytesWritable.set(inputArray, 0, 10) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 9de2f914b8b4c..49f319ba775e5 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/ how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object that contains information about your application. +Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one. + {% highlight scala %} val conf = new SparkConf().setAppName(appName).setMaster(master) new SparkContext(conf) diff --git a/pom.xml b/pom.xml index 639ea22a1fda3..cc7bce175778f 100644 --- a/pom.xml +++ b/pom.xml @@ -978,6 +978,7 @@ 1 false ${test_classpath} + true diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c96a6c49545c1..1697b6d4f2d43 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -377,6 +377,7 @@ object TestSettings { javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", javaOptions in Test += "-Dspark.ui.enabled=false", + javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 30a359677cc74..86b96785d7b87 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -470,32 +470,31 @@ class BasicOperationsSuite extends TestSuiteBase { } test("slice") { - val ssc = new StreamingContext(conf, Seconds(1)) - val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) - val stream = new TestInputStream[Int](ssc, input, 2) - stream.foreachRDD(_ => {}) // Dummy output stream - ssc.start() - Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { - stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet - } + withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + stream.foreachRDD(_ => {}) // Dummy output stream + ssc.start() + Thread.sleep(2000) + def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet + } - assert(getInputFromSlice(0, 1000) == Set(1)) - assert(getInputFromSlice(0, 2000) == Set(1, 2)) - assert(getInputFromSlice(1000, 2000) == Set(1, 2)) - assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) - ssc.stop() - Thread.sleep(1000) + assert(getInputFromSlice(0, 1000) == Set(1)) + assert(getInputFromSlice(0, 2000) == Set(1, 2)) + assert(getInputFromSlice(1000, 2000) == Set(1, 2)) + assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) + } } - test("slice - has not been initialized") { - val ssc = new StreamingContext(conf, Seconds(1)) - val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) - val stream = new TestInputStream[Int](ssc, input, 2) - val thrown = intercept[SparkException] { - stream.slice(new Time(0), new Time(1000)) + withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + val thrown = intercept[SparkException] { + stream.slice(new Time(0), new Time(1000)) + } + assert(thrown.getMessage.contains("has not been initialized")) } - assert(thrown.getMessage.contains("has not been initialized")) } val cleanupTestInput = (0 until 10).map(x => Seq(x, x + 1)).toSeq @@ -555,73 +554,72 @@ class BasicOperationsSuite extends TestSuiteBase { test("rdd cleanup - input blocks and persisted RDDs") { // Actually receive data over through receiver to create BlockRDDs - // Start the server - val testServer = new TestServer() - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val mappedStream = networkStream.map(_ + ".").persist() - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(mappedStream, outputBuffer) - - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5, 6) + withTestServer(new TestServer()) { testServer => + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + testServer.start() + // Set up the streaming context and input streams + val networkStream = + ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val mappedStream = networkStream.map(_ + ".").persist() + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(mappedStream, outputBuffer) + + outputStream.register() + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5, 6) + + val blockRdds = new mutable.HashMap[Time, BlockRDD[_]] + val persistentRddIds = new mutable.HashMap[Time, Int] + + def collectRddInfo() { // get all RDD info required for verification + networkStream.generatedRDDs.foreach { case (time, rdd) => + blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]] + } + mappedStream.generatedRDDs.foreach { case (time, rdd) => + persistentRddIds(time) = rdd.id + } + } - val blockRdds = new mutable.HashMap[Time, BlockRDD[_]] - val persistentRddIds = new mutable.HashMap[Time, Int] + Thread.sleep(200) + for (i <- 0 until input.size) { + testServer.send(input(i).toString + "\n") + Thread.sleep(200) + clock.addToTime(batchDuration.milliseconds) + collectRddInfo() + } - def collectRddInfo() { // get all RDD info required for verification - networkStream.generatedRDDs.foreach { case (time, rdd) => - blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]] - } - mappedStream.generatedRDDs.foreach { case (time, rdd) => - persistentRddIds(time) = rdd.id + Thread.sleep(200) + collectRddInfo() + logInfo("Stopping server") + testServer.stop() + + // verify data has been received + assert(outputBuffer.size > 0) + assert(blockRdds.size > 0) + assert(persistentRddIds.size > 0) + + import Time._ + + val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max) + val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min) + val latestBlockRdd = blockRdds(blockRdds.keySet.max) + val earliestBlockRdd = blockRdds(blockRdds.keySet.min) + // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted + assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId)) + assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId)) + + // verify that the latest input blocks are present but the earliest blocks have been removed + assert(latestBlockRdd.isValid) + assert(latestBlockRdd.collect != null) + assert(!earliestBlockRdd.isValid) + earliestBlockRdd.blockIds.foreach { blockId => + assert(!ssc.sparkContext.env.blockManager.master.contains(blockId)) + } } } - - Thread.sleep(200) - for (i <- 0 until input.size) { - testServer.send(input(i).toString + "\n") - Thread.sleep(200) - clock.addToTime(batchDuration.milliseconds) - collectRddInfo() - } - - Thread.sleep(200) - collectRddInfo() - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - - // verify data has been received - assert(outputBuffer.size > 0) - assert(blockRdds.size > 0) - assert(persistentRddIds.size > 0) - - import Time._ - - val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max) - val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min) - val latestBlockRdd = blockRdds(blockRdds.keySet.max) - val earliestBlockRdd = blockRdds(blockRdds.keySet.min) - // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted - assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId)) - assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId)) - - // verify that the latest input blocks are present but the earliest blocks have been removed - assert(latestBlockRdd.isValid) - assert(latestBlockRdd.collect != null) - assert(!earliestBlockRdd.isValid) - earliestBlockRdd.blockIds.foreach { blockId => - assert(!ssc.sparkContext.env.blockManager.master.contains(blockId)) - } - ssc.stop() } /** Test cleanup of RDDs in DStream metadata */ @@ -635,13 +633,15 @@ class BasicOperationsSuite extends TestSuiteBase { // Setup the stream computation assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second, check cleanup tests") - val ssc = setupStreams(cleanupTestInput, operation) - val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] - if (rememberDuration != null) ssc.remember(rememberDuration) - val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.time === Seconds(10).milliseconds) - assert(output.size === numExpectedOutput) - operatedStream + withStreamingContext(setupStreams(cleanupTestInput, operation)) { ssc => + val operatedStream = + ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] + if (rememberDuration != null) ssc.remember(rememberDuration) + val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + assert(clock.time === Seconds(10).milliseconds) + assert(output.size === numExpectedOutput) + operatedStream + } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 2154c24abda3a..52972f63c6c5c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -163,6 +163,40 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { before(beforeFunction) after(afterFunction) + /** + * Run a block of code with the given StreamingContext and automatically + * stop the context when the block completes or when an exception is thrown. + */ + def withStreamingContext[R](ssc: StreamingContext)(block: StreamingContext => R): R = { + try { + block(ssc) + } finally { + try { + ssc.stop(stopSparkContext = true) + } catch { + case e: Exception => + logError("Error stopping StreamingContext", e) + } + } + } + + /** + * Run a block of code with the given TestServer and automatically + * stop the server when the block completes or when an exception is thrown. + */ + def withTestServer[R](testServer: TestServer)(block: TestServer => R): R = { + try { + block(testServer) + } finally { + try { + testServer.stop() + } catch { + case e: Exception => + logError("Error stopping TestServer", e) + } + } + } + /** * Set up required DStreams to test the DStream operation using the two sequences * of input collections. @@ -282,10 +316,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") Thread.sleep(100) // Give some time for the forgetting old RDDs to complete - } catch { - case e: Exception => {e.printStackTrace(); throw e} } finally { - ssc.stop() + ssc.stop(stopSparkContext = true) } output } @@ -351,9 +383,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { useSet: Boolean ) { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, numBatches_, expectedOutput.size) - verifyOutput[V](output, expectedOutput, useSet) + withStreamingContext(setupStreams[U, V](input, operation)) { ssc => + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) + } } /** @@ -389,8 +422,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { useSet: Boolean ) { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V, W](input1, input2, operation) - val output = runStreams[W](ssc, numBatches_, expectedOutput.size) - verifyOutput[W](output, expectedOutput, useSet) + withStreamingContext(setupStreams[U, V, W](input1, input2, operation)) { ssc => + val output = runStreams[W](ssc, numBatches_, expectedOutput.size) + verifyOutput[W](output, expectedOutput, useSet) + } } }