From 83ac9a4bbf272028d0c4639cbd1e12022b9ae77a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 8 Apr 2014 00:00:17 -0700 Subject: [PATCH] [SPARK-1331] Added graceful shutdown to Spark Streaming Current version of StreamingContext.stop() directly kills all the data receivers (NetworkReceiver) without waiting for the data already received to be persisted and processed. This PR provides the fix. Now, when the StreamingContext.stop() is called, the following sequence of steps will happen. 1. The driver will send a stop signal to all the active receivers. 2. Each receiver, when it gets a stop signal from the driver, first stop receiving more data, then waits for the thread that persists data blocks to BlockManager to finish persisting all receive data, and finally quits. 3. After all the receivers have stopped, the driver will wait for the Job Generator and Job Scheduler to finish processing all the received data. It also fixes the semantics of StreamingContext.start and stop. It will throw appropriate errors and warnings if stop() is called before start(), stop() is called twice, etc. Author: Tathagata Das Closes #247 from tdas/graceful-shutdown and squashes the following commits: 61c0016 [Tathagata Das] Updated MIMA binary check excludes. ae1d39b [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into graceful-shutdown 6b59cfc [Tathagata Das] Minor changes based on Andrew's comment on PR. d0b8d65 [Tathagata Das] Reduced time taken by graceful shutdown unit test. f55bc67 [Tathagata Das] Fix scalastyle c69b3a7 [Tathagata Das] Updates based on Patrick's comments. c43b8ae [Tathagata Das] Added graceful shutdown to Spark Streaming. --- project/MimaBuild.scala | 24 +-- .../apache/spark/streaming/Checkpoint.scala | 14 +- .../spark/streaming/StreamingContext.scala | 48 +++++- .../api/java/JavaStreamingContext.scala | 12 +- .../dstream/NetworkInputDStream.scala | 151 +++++++++++------ .../dstream/SocketInputDStream.scala | 1 - .../streaming/receivers/ActorReceiver.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 124 ++++++++++---- .../streaming/scheduler/JobScheduler.scala | 56 ++++--- .../scheduler/NetworkInputTracker.scala | 154 ++++++++++-------- .../apache/spark/streaming/util/Clock.scala | 5 +- .../spark/streaming/util/RecurringTimer.scala | 62 +++++-- .../streaming/BasicOperationsSuite.scala | 4 +- .../streaming/StreamingContextSuite.scala | 108 ++++++++++-- .../spark/streaming/TestSuiteBase.scala | 2 +- 15 files changed, 552 insertions(+), 215 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index e7c9c47c960fa..5ea4817bfde18 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -58,17 +58,19 @@ object MimaBuild { SparkBuild.SPARK_VERSION match { case v if v.startsWith("1.0") => Seq( - excludePackage("org.apache.spark.api.java"), - excludePackage("org.apache.spark.streaming.api.java"), - excludePackage("org.apache.spark.mllib") - ) ++ - excludeSparkClass("rdd.ClassTags") ++ - excludeSparkClass("util.XORShiftRandom") ++ - excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - excludeSparkClass("mllib.optimization.SquaredGradient") ++ - excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - excludeSparkClass("mllib.regression.LassoWithSGD") ++ - excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + excludePackage("org.apache.spark.api.java"), + excludePackage("org.apache.spark.streaming.api.java"), + excludePackage("org.apache.spark.mllib") + ) ++ + excludeSparkClass("rdd.ClassTags") ++ + excludeSparkClass("util.XORShiftRandom") ++ + excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + excludeSparkClass("mllib.optimization.SquaredGradient") ++ + excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + excludeSparkClass("mllib.regression.LassoWithSGD") ++ + excludeSparkClass("mllib.regression.LinearRegressionWithSGD") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver#NetworkReceiverActor") case _ => Seq() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index baf80fe2a91b7..93023e8dced57 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -194,19 +194,19 @@ class CheckpointWriter( } } - def stop() { - synchronized { - if (stopped) { - return - } - stopped = true - } + def stop(): Unit = synchronized { + if (stopped) return + executor.shutdown() val startTime = System.currentTimeMillis() val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) + if (!terminated) { + executor.shutdownNow() + } val endTime = System.currentTimeMillis() logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.") + stopped = true } private def fs = synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index e198c69470c1f..a4e236c65ff86 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -158,6 +158,15 @@ class StreamingContext private[streaming] ( private[streaming] val waiter = new ContextWaiter + /** Enumeration to identify current state of the StreamingContext */ + private[streaming] object StreamingContextState extends Enumeration { + type CheckpointState = Value + val Initialized, Started, Stopped = Value + } + + import StreamingContextState._ + private[streaming] var state = Initialized + /** * Return the associated Spark context */ @@ -405,9 +414,18 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. */ - def start() = synchronized { + def start(): Unit = synchronized { + // Throw exception if the context has already been started once + // or if a stopped context is being started again + if (state == Started) { + throw new SparkException("StreamingContext has already been started") + } + if (state == Stopped) { + throw new SparkException("StreamingContext has already been stopped") + } validate() scheduler.start() + state = Started } /** @@ -428,14 +446,38 @@ class StreamingContext private[streaming] ( } /** - * Stop the execution of the streams. + * Stop the execution of the streams immediately (does not wait for all received data + * to be processed). * @param stopSparkContext Stop the associated SparkContext or not + * */ def stop(stopSparkContext: Boolean = true): Unit = synchronized { - scheduler.stop() + stop(stopSparkContext, false) + } + + /** + * Stop the execution of the streams, with option of ensuring all received data + * has been processed. + * @param stopSparkContext Stop the associated SparkContext or not + * @param stopGracefully Stop gracefully by waiting for the processing of all + * received data to be completed + */ + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { + // Warn (but not fail) if context is stopped twice, + // or context is stopped before starting + if (state == Initialized) { + logWarning("StreamingContext has not been started yet") + return + } + if (state == Stopped) { + logWarning("StreamingContext has already been stopped") + return + } // no need to throw an exception as its okay to stop twice + scheduler.stop(stopGracefully) logInfo("StreamingContext stopped successfully") waiter.notifyStop() if (stopSparkContext) sc.stop() + state = Stopped } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index b705d2ec9a58e..c800602d0959b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -509,8 +509,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean): Unit = { - ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + + /** + * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not + * @param stopGracefully Stop gracefully by waiting for the processing of all + * received data to be completed + */ + def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + ssc.stop(stopSparkContext, stopGracefully) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 72ad0bae75bfb..d19a635fe8eca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.rdd.{RDD, BlockRDD} import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver} +import org.apache.spark.util.AkkaUtils /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -69,7 +70,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte // then this returns an empty RDD. This may happen when recovering from a // master failure if (validTime >= graph.startTime) { - val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime) + val blockIds = ssc.scheduler.networkInputTracker.getBlocks(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } else { Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) @@ -79,7 +80,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte private[streaming] sealed trait NetworkReceiverMessage -private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class StopReceiver() extends NetworkReceiverMessage private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) extends NetworkReceiverMessage private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage @@ -90,13 +91,31 @@ private[streaming] case class ReportError(msg: String) extends NetworkReceiverMe */ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging { + /** Local SparkEnv */ lazy protected val env = SparkEnv.get + /** Remote Akka actor for the NetworkInputTracker */ + lazy protected val trackerActor = { + val ip = env.conf.get("spark.driver.host", "localhost") + val port = env.conf.getInt("spark.driver.port", 7077) + val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + env.actorSystem.actorSelection(url) + } + + /** Akka actor for receiving messages from the NetworkInputTracker in the driver */ lazy protected val actor = env.actorSystem.actorOf( Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + /** Timeout for Akka actor messages */ + lazy protected val askTimeout = AkkaUtils.askTimeout(env.conf) + + /** Thread that starts the receiver and stays blocked while data is being received */ lazy protected val receivingThread = Thread.currentThread() + /** Exceptions that occurs while receiving data */ + protected lazy val exceptions = new ArrayBuffer[Exception] + + /** Identifier of the stream this receiver is associated with */ protected var streamId: Int = -1 /** @@ -112,7 +131,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging def getLocationPreference() : Option[String] = None /** - * Starts the receiver. First is accesses all the lazy members to + * Start the receiver. First is accesses all the lazy members to * materialize them. Then it calls the user-defined onStart() method to start * other threads, etc required to receiver the data. */ @@ -124,83 +143,107 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging receivingThread // Call user-defined onStart() + logInfo("Starting receiver") onStart() + + // Wait until interrupt is called on this thread + while(true) Thread.sleep(100000) } catch { case ie: InterruptedException => - logInfo("Receiving thread interrupted") + logInfo("Receiving thread has been interrupted, receiver " + streamId + " stopped") case e: Exception => - stopOnError(e) + logError("Error receiving data in receiver " + streamId, e) + exceptions += e + } + + // Call user-defined onStop() + logInfo("Stopping receiver") + try { + onStop() + } catch { + case e: Exception => + logError("Error stopping receiver " + streamId, e) + exceptions += e + } + + val message = if (exceptions.isEmpty) { + null + } else if (exceptions.size == 1) { + val e = exceptions.head + "Exception in receiver " + streamId + ": " + e.getMessage + "\n" + e.getStackTraceString + } else { + "Multiple exceptions in receiver " + streamId + "(" + exceptions.size + "):\n" + exceptions.zipWithIndex.map { + case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString + }.mkString("\n") } + logInfo("Deregistering receiver " + streamId) + val future = trackerActor.ask(DeregisterReceiver(streamId, message))(askTimeout) + Await.result(future, askTimeout) + logInfo("Deregistered receiver " + streamId) + env.actorSystem.stop(actor) + logInfo("Stopped receiver " + streamId) } /** - * Stops the receiver. First it interrupts the main receiving thread, - * that is, the thread that called receiver.start(). Then it calls the user-defined - * onStop() method to stop other threads and/or do cleanup. + * Stop the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). */ def stop() { + // Stop receiving by interrupting the receiving thread receivingThread.interrupt() - onStop() - // TODO: terminate the actor + logInfo("Interrupted receiving thread " + receivingThread + " for stopping") } /** - * Stops the receiver and reports exception to the tracker. + * Stop the receiver and reports exception to the tracker. * This should be called whenever an exception is to be handled on any thread * of the receiver. */ protected def stopOnError(e: Exception) { logError("Error receiving data", e) + exceptions += e stop() - actor ! ReportError(e.toString) } - /** - * Pushes a block (as an ArrayBuffer filled with data) into the block manager. + * Push a block (as an ArrayBuffer filled with data) into the block manager. */ def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level) - actor ! ReportBlock(blockId, metadata) + trackerActor ! AddBlocks(streamId, Array(blockId), metadata) + logDebug("Pushed block " + blockId) } /** - * Pushes a block (as bytes) into the block manager. + * Push a block (as bytes) into the block manager. */ def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId, metadata) + trackerActor ! AddBlocks(streamId, Array(blockId), metadata) + } + + /** Set the ID of the DStream that this receiver is associated with */ + protected[streaming] def setStreamId(id: Int) { + streamId = id } /** A helper actor that communicates with the NetworkInputTracker */ private class NetworkReceiverActor extends Actor { - logInfo("Attempting to register with tracker") - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorSelection(url) - val timeout = 5.seconds override def preStart() { - val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) + logInfo("Registered receiver " + streamId) + val future = trackerActor.ask(RegisterReceiver(streamId, self))(askTimeout) + Await.result(future, askTimeout) } override def receive() = { - case ReportBlock(blockId, metadata) => - tracker ! AddBlocks(streamId, Array(blockId), metadata) - case ReportError(msg) => - tracker ! DeregisterReceiver(streamId, msg) - case StopReceiver(msg) => + case StopReceiver => + logInfo("Received stop signal") stop() - tracker ! DeregisterReceiver(streamId, msg) } } - protected[streaming] def setStreamId(id: Int) { - streamId = id - } - /** * Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts * them into appropriately named blocks at regular intervals. This class starts two threads, @@ -214,23 +257,26 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging val clock = new SystemClock() val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200) - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer, + "BlockGenerator") val blockStorageLevel = storageLevel val blocksForPushing = new ArrayBlockingQueue[Block](1000) val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } var currentBuffer = new ArrayBuffer[T] + var stopped = false def start() { blockIntervalTimer.start() blockPushingThread.start() - logInfo("Data handler started") + logInfo("Started BlockGenerator") } def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") + blockIntervalTimer.stop(false) + stopped = true + blockPushingThread.join() + logInfo("Stopped BlockGenerator") } def += (obj: T): Unit = synchronized { @@ -248,24 +294,35 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging } } catch { case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") + logInfo("Block updating timer thread was interrupted") case e: Exception => - NetworkReceiver.this.stop() + NetworkReceiver.this.stopOnError(e) } } private def keepPushingBlocks() { - logInfo("Block pushing thread started") + logInfo("Started block pushing thread") try { - while(true) { + while(!stopped) { + Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + case Some(block) => + NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel) + case None => + } + } + // Push out the blocks that are still left + logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") + while (!blocksForPushing.isEmpty) { val block = blocksForPushing.take() NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel) + logInfo("Blocks left to push " + blocksForPushing.size()) } + logInfo("Stopped blocks pushing thread") } catch { case ie: InterruptedException => - logInfo("Block pushing thread interrupted") + logInfo("Block pushing thread was interrupted") case e: Exception => - NetworkReceiver.this.stop() + NetworkReceiver.this.stopOnError(e) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 2cdd13f205313..63d94d1cc670a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -67,7 +67,6 @@ class SocketReceiver[T: ClassTag]( protected def onStop() { blockGenerator.stop() } - } private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index bd78bae8a5c51..44eb2750c6c7a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -174,10 +174,10 @@ private[streaming] class ActorReceiver[T: ClassTag]( blocksGenerator.start() supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + } protected def onStop() = { supervisor ! PoisonPill } - } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index c7306248b1950..92d885c4bc5a5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -39,16 +39,22 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { private val ssc = jobScheduler.ssc private val graph = ssc.graph + val clock = { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") Class.forName(clockClass).newInstance().asInstanceOf[Clock] } + private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventActor ! GenerateJobs(new Time(longTime))) - private lazy val checkpointWriter = - if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) + longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator") + + // This is marked lazy so that this is initialized after checkpoint duration has been set + // in the context and the generator has been started. + private lazy val shouldCheckpoint = ssc.checkpointDuration != null && ssc.checkpointDir != null + + private lazy val checkpointWriter = if (shouldCheckpoint) { + new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } @@ -57,17 +63,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // This not being null means the scheduler has been started and not stopped private var eventActor: ActorRef = null + // last batch whose completion,checkpointing and metadata cleanup has been completed + private var lastProcessedBatch: Time = null + /** Start generation of jobs */ - def start() = synchronized { - if (eventActor != null) { - throw new SparkException("JobGenerator already started") - } + def start(): Unit = synchronized { + if (eventActor != null) return // generator has already been started eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { def receive = { - case event: JobGeneratorEvent => - logDebug("Got event of type " + event.getClass.getName) - processEvent(event) + case event: JobGeneratorEvent => processEvent(event) } }), "JobGenerator") if (ssc.isCheckpointPresent) { @@ -77,30 +82,79 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } - /** Stop generation of jobs */ - def stop() = synchronized { - if (eventActor != null) { - timer.stop() - ssc.env.actorSystem.stop(eventActor) - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("JobGenerator stopped") + /** + * Stop generation of jobs. processReceivedData = true makes this wait until jobs + * of current ongoing time interval has been generated, processed and corresponding + * checkpoints written. + */ + def stop(processReceivedData: Boolean): Unit = synchronized { + if (eventActor == null) return // generator has already been stopped + + if (processReceivedData) { + logInfo("Stopping JobGenerator gracefully") + val timeWhenStopStarted = System.currentTimeMillis() + val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds + val pollTime = 100 + + // To prevent graceful stop to get stuck permanently + def hasTimedOut = { + val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + if (timedOut) logWarning("Timed out while stopping the job generator") + timedOut + } + + // Wait until all the received blocks in the network input tracker has + // been consumed by network input DStreams, and jobs have been generated with them + logInfo("Waiting for all received blocks to be consumed for job generation") + while(!hasTimedOut && jobScheduler.networkInputTracker.hasMoreReceivedBlockIds) { + Thread.sleep(pollTime) + } + logInfo("Waited for all received blocks to be consumed for job generation") + + // Stop generating jobs + val stopTime = timer.stop(false) + graph.stop() + logInfo("Stopped generation timer") + + // Wait for the jobs to complete and checkpoints to be written + def haveAllBatchesBeenProcessed = { + lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime + } + logInfo("Waiting for jobs to be processed and checkpoints to be written") + while (!hasTimedOut && !haveAllBatchesBeenProcessed) { + Thread.sleep(pollTime) + } + logInfo("Waited for jobs to be processed and checkpoints to be written") + } else { + logInfo("Stopping JobGenerator immediately") + // Stop timer and graph immediately, ignore unprocessed data and pending jobs + timer.stop(true) + graph.stop() } + + // Stop the actor and checkpoint writer + if (shouldCheckpoint) checkpointWriter.stop() + ssc.env.actorSystem.stop(eventActor) + logInfo("Stopped JobGenerator") } /** - * On batch completion, clear old metadata and checkpoint computation. + * Callback called when a batch has been completely processed. */ def onBatchCompletion(time: Time) { eventActor ! ClearMetadata(time) } - + + /** + * Callback called when the checkpoint of a batch has been written. + */ def onCheckpointCompletion(time: Time) { eventActor ! ClearCheckpointData(time) } /** Processes all events */ private def processEvent(event: JobGeneratorEvent) { + logDebug("Got event " + event) event match { case GenerateJobs(time) => generateJobs(time) case ClearMetadata(time) => clearMetadata(time) @@ -114,7 +168,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val startTime = new Time(timer.getStartTime()) graph.start(startTime - graph.batchDuration) timer.start(startTime.milliseconds) - logInfo("JobGenerator started at " + startTime) + logInfo("Started JobGenerator at " + startTime) } /** Restarts the generator based on the information in checkpoint */ @@ -152,15 +206,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Restart the timer timer.start(restartTime.milliseconds) - logInfo("JobGenerator restarted at " + restartTime) + logInfo("Restarted JobGenerator at " + restartTime) } /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { SparkEnv.set(ssc.env) Try(graph.generateJobs(time)) match { - case Success(jobs) => jobScheduler.runJobs(time, jobs) - case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + case Success(jobs) => + jobScheduler.runJobs(time, jobs) + case Failure(e) => + jobScheduler.reportError("Error generating jobs for time " + time, e) } eventActor ! DoCheckpoint(time) } @@ -168,20 +224,32 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - eventActor ! DoCheckpoint(time) + + // If checkpointing is enabled, then checkpoint, + // else mark batch to be fully processed + if (shouldCheckpoint) { + eventActor ! DoCheckpoint(time) + } else { + markBatchFullyProcessed(time) + } } /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + markBatchFullyProcessed(time) } /** Perform checkpoint for the give `time`. */ - private def doCheckpoint(time: Time) = synchronized { - if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { + private def doCheckpoint(time: Time) { + if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) } } + + private def markBatchFullyProcessed(time: Time) { + lastProcessedBatch = time + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index de675d3c7fb94..04e0a6a283cfb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -39,7 +39,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private val jobSets = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val executor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() @@ -50,36 +50,54 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private var eventActor: ActorRef = null - def start() = synchronized { - if (eventActor != null) { - throw new SparkException("JobScheduler already started") - } + def start(): Unit = synchronized { + if (eventActor != null) return // scheduler has already been started + logDebug("Starting JobScheduler") eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { def receive = { case event: JobSchedulerEvent => processEvent(event) } }), "JobScheduler") + listenerBus.start() networkInputTracker = new NetworkInputTracker(ssc) networkInputTracker.start() - Thread.sleep(1000) jobGenerator.start() - logInfo("JobScheduler started") + logInfo("Started JobScheduler") } - def stop() = synchronized { - if (eventActor != null) { - jobGenerator.stop() - networkInputTracker.stop() - executor.shutdown() - if (!executor.awaitTermination(2, TimeUnit.SECONDS)) { - executor.shutdownNow() - } - listenerBus.stop() - ssc.env.actorSystem.stop(eventActor) - logInfo("JobScheduler stopped") + def stop(processAllReceivedData: Boolean): Unit = synchronized { + if (eventActor == null) return // scheduler has already been stopped + logDebug("Stopping JobScheduler") + + // First, stop receiving + networkInputTracker.stop() + + // Second, stop generating jobs. If it has to process all received data, + // then this will wait for all the processing through JobScheduler to be over. + jobGenerator.stop(processAllReceivedData) + + // Stop the executor for receiving new jobs + logDebug("Stopping job executor") + jobExecutor.shutdown() + + // Wait for the queued jobs to complete if indicated + val terminated = if (processAllReceivedData) { + jobExecutor.awaitTermination(1, TimeUnit.HOURS) // just a very large period of time + } else { + jobExecutor.awaitTermination(2, TimeUnit.SECONDS) } + if (!terminated) { + jobExecutor.shutdownNow() + } + logDebug("Stopped job executor") + + // Stop everything else + listenerBus.stop() + ssc.env.actorSystem.stop(eventActor) + eventActor = null + logInfo("Stopped JobScheduler") } def runJobs(time: Time, jobs: Seq[Job]) { @@ -88,7 +106,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } else { val jobSet = new JobSet(time, jobs) jobSets.put(time, jobSet) - jobSet.jobs.foreach(job => executor.execute(new JobHandler(job))) + jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) logInfo("Added jobs for time " + time) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index cad68e248ab29..067e804202236 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -17,20 +17,14 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} -import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import org.apache.spark.{SparkException, Logging, SparkEnv} -import org.apache.spark.SparkContext._ - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.concurrent.duration._ +import scala.collection.mutable.{HashMap, Queue, SynchronizedMap} import akka.actor._ -import akka.pattern.ask -import akka.dispatch._ +import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.SparkContext._ import org.apache.spark.storage.BlockId -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.dstream.{NetworkReceiver, StopReceiver} import org.apache.spark.util.AkkaUtils private[streaming] sealed trait NetworkInputTrackerMessage @@ -52,8 +46,8 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { val networkInputStreams = ssc.graph.getNetworkInputStreams() val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() - val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[BlockId]] + val receiverInfo = new HashMap[Int, ActorRef] with SynchronizedMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Queue[BlockId]] with SynchronizedMap[Int, Queue[BlockId]] val timeout = AkkaUtils.askTimeout(ssc.conf) @@ -63,7 +57,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { var currentTime: Time = null /** Start the actor and receiver execution thread. */ - def start() { + def start() = synchronized { if (actor != null) { throw new SparkException("NetworkInputTracker already started") } @@ -77,72 +71,99 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { } /** Stop the receiver execution thread. */ - def stop() { + def stop() = synchronized { if (!networkInputStreams.isEmpty && actor != null) { - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() + // First, stop the receivers + receiverExecutor.stop() + + // Finally, stop the actor ssc.env.actorSystem.stop(actor) + actor = null logInfo("NetworkInputTracker stopped") } } - /** Return all the blocks received from a receiver. */ - def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized { - val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]()) + /** Register a receiver */ + def registerReceiver(streamId: Int, receiverActor: ActorRef, sender: ActorRef) { + if (!networkInputStreamMap.contains(streamId)) { + throw new Exception("Register received for unexpected id " + streamId) } - val result = queue.synchronized { - queue.dequeueAll(x => true) - } - logInfo("Stream " + receiverId + " received " + result.size + " blocks") - result.toArray + receiverInfo += ((streamId, receiverActor)) + logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) + } + + /** Deregister a receiver */ + def deregisterReceiver(streamId: Int, message: String) { + receiverInfo -= streamId + logError("Deregistered receiver for network stream " + streamId + " with message:\n" + message) + } + + /** Get all the received blocks for the given stream. */ + def getBlocks(streamId: Int, time: Time): Array[BlockId] = { + val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId]()) + val result = queue.dequeueAll(x => true).toArray + logInfo("Stream " + streamId + " received " + result.size + " blocks") + result + } + + /** Add new blocks for the given stream */ + def addBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) = { + val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId]) + queue ++= blockIds + networkInputStreamMap(streamId).addMetadata(metadata) + logDebug("Stream " + streamId + " received new blocks: " + blockIds.mkString("[", ", ", "]")) + } + + /** Check if any blocks are left to be processed */ + def hasMoreReceivedBlockIds: Boolean = { + !receivedBlockIds.forall(_._2.isEmpty) } /** Actor to receive messages from the receivers. */ private class NetworkInputTrackerActor extends Actor { def receive = { - case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) - } - receiverInfo += ((streamId, receiverActor)) - logInfo("Registered receiver for network stream " + streamId + " from " - + sender.path.address) + case RegisterReceiver(streamId, receiverActor) => + registerReceiver(streamId, receiverActor, sender) + sender ! true + case AddBlocks(streamId, blockIds, metadata) => + addBlocks(streamId, blockIds, metadata) + case DeregisterReceiver(streamId, message) => + deregisterReceiver(streamId, message) sender ! true - } - case AddBlocks(streamId, blockIds, metadata) => { - val tmp = receivedBlockIds.synchronized { - if (!receivedBlockIds.contains(streamId)) { - receivedBlockIds += ((streamId, new Queue[BlockId])) - } - receivedBlockIds(streamId) - } - tmp.synchronized { - tmp ++= blockIds - } - networkInputStreamMap(streamId).addMetadata(metadata) - } - case DeregisterReceiver(streamId, msg) => { - receiverInfo -= streamId - logError("De-registered receiver for network stream " + streamId - + " with message " + msg) - // TODO: Do something about the corresponding NetworkInputDStream - } } } /** This thread class runs all the receivers on the cluster. */ - class ReceiverExecutor extends Thread { - val env = ssc.env - - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") - } finally { - stopReceivers() + class ReceiverExecutor { + @transient val env = ssc.env + @transient val thread = new Thread() { + override def run() { + try { + SparkEnv.set(env) + startReceivers() + } catch { + case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") + } + } + } + + def start() { + thread.start() + } + + def stop() { + // Send the stop signal to all the receivers + stopReceivers() + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + thread.join(10000) + + // Check if all the receivers have been deregistered or not + if (!receiverInfo.isEmpty) { + logWarning("All of the receivers have not deregistered, " + receiverInfo) + } else { + logInfo("All of the receivers have deregistered successfully") } } @@ -150,7 +171,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { * Get the receivers from the NetworkInputDStreams, distributes them to the * worker nodes as a parallel collection, and runs them. */ - def startReceivers() { + private def startReceivers() { val receivers = networkInputStreams.map(nis => { val rcvr = nis.getReceiver() rcvr.setStreamId(nis.id) @@ -186,13 +207,16 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { } // Distribute the receivers and start them + logInfo("Starting " + receivers.length + " receivers") ssc.sparkContext.runJob(tempRDD, startReceiver) + logInfo("All of the receivers have been terminated") } /** Stops the receivers. */ - def stopReceivers() { + private def stopReceivers() { // Signal the receivers to stop receiverInfo.values.foreach(_ ! StopReceiver) + logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index c3a849d2769a7..c5ef2cc8c390d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -48,14 +48,11 @@ class SystemClock() extends Clock { minPollTime } } - - + while (true) { currentTime = System.currentTimeMillis() waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime } val sleepTime = diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 559c2473851b3..f71938ac55ccb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -17,44 +17,84 @@ package org.apache.spark.streaming.util +import org.apache.spark.Logging + private[streaming] -class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { +class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String) + extends Logging { - private val thread = new Thread("RecurringTimer") { + private val thread = new Thread("RecurringTimer - " + name) { + setDaemon(true) override def run() { loop } } - - private var nextTime = 0L + @volatile private var prevTime = -1L + @volatile private var nextTime = -1L + @volatile private var stopped = false + + /** + * Get the time when this timer will fire if it is started right now. + * The time will be a multiple of this timer's period and more than + * current system time. + */ def getStartTime(): Long = { (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period } + /** + * Get the time when the timer will fire if it is restarted right now. + * This time depends on when the timer was started the first time, and was stopped + * for whatever reason. The time must be a multiple of this timer's period and + * more than current time. + */ def getRestartTime(originalStartTime: Long): Long = { val gap = clock.currentTime - originalStartTime (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime } - def start(startTime: Long): Long = { + /** + * Start at the given start time. + */ + def start(startTime: Long): Long = synchronized { nextTime = startTime thread.start() + logInfo("Started timer for " + name + " at time " + nextTime) nextTime } + /** + * Start at the earliest time it can start based on the period. + */ def start(): Long = { start(getStartTime()) } - def stop() { - thread.interrupt() + /** + * Stop the timer, and return the last time the callback was made. + * interruptTimer = true will interrupt the callback + * if it is in progress (not guaranteed to give correct time in this case). + */ + def stop(interruptTimer: Boolean): Long = synchronized { + if (!stopped) { + stopped = true + if (interruptTimer) thread.interrupt() + thread.join() + logInfo("Stopped timer for " + name + " after time " + prevTime) + } + prevTime } - + + /** + * Repeatedly call the callback every interval. + */ private def loop() { try { - while (true) { + while (!stopped) { clock.waitTillTime(nextTime) callback(nextTime) + prevTime = nextTime nextTime += period + logDebug("Callback for " + name + " called at time " + prevTime) } } catch { case e: InterruptedException => @@ -74,10 +114,10 @@ object RecurringTimer { println("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } - val timer = new RecurringTimer(new SystemClock(), period, onRecur) + val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") timer.start() Thread.sleep(30 * 1000) - timer.stop() + timer.stop(true) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index bcb0c28bf07a0..bb73dbf29b649 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -324,7 +324,7 @@ class BasicOperationsSuite extends TestSuiteBase { val updateStateOperation = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some(values.foldLeft(0)(_ + _) + state.getOrElse(0)) + Some(values.sum + state.getOrElse(0)) } s.map(x => (x, 1)).updateStateByKey[Int](updateFunc) } @@ -359,7 +359,7 @@ class BasicOperationsSuite extends TestSuiteBase { // updateFunc clears a state when a StateObject is seen without new values twice in a row val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { val stateObj = state.getOrElse(new StateObject) - values.foldLeft(0)(_ + _) match { + values.sum match { case 0 => stateObj.expireCounter += 1 // no new values case n => { // has new values, increment and reset expireCounter stateObj.counter += n diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 717da8e00462b..9cc27ef7f03b5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,19 +17,22 @@ package org.apache.spark.streaming -import org.scalatest.{FunSuite, BeforeAndAfter} -import org.scalatest.exceptions.TestFailedDueToTimeoutException +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.{DStream, NetworkReceiver} +import org.apache.spark.util.{MetadataCleaner, Utils} +import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts +import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkException, SparkConf, SparkContext} -import org.apache.spark.util.{Utils, MetadataCleaner} -import org.apache.spark.streaming.dstream.DStream -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { +class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName - val batchDuration = Seconds(1) + val batchDuration = Milliseconds(500) val sparkHome = "someDir" val envPair = "key" -> "value" val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100 @@ -108,19 +111,31 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.cleaner.ttl", ttl.toString) val ssc1 = new StreamingContext(myConf, batchDuration) + addInputStream(ssc1).register + ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl) - ssc = new StreamingContext(null, cp, null) + ssc = new StreamingContext(null, newCp, null) assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl) } - test("start multiple times") { + test("start and stop state check") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register + assert(ssc.state === ssc.StreamingContextState.Initialized) + ssc.start() + assert(ssc.state === ssc.StreamingContextState.Started) + ssc.stop() + assert(ssc.state === ssc.StreamingContextState.Stopped) + } + + test("start multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register ssc.start() intercept[SparkException] { ssc.start() @@ -133,18 +148,61 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { ssc.start() ssc.stop() ssc.stop() - ssc = null } + test("stop before start and start after stop") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + ssc.stop() // stop before start should not throw exception + ssc.start() + ssc.stop() + intercept[SparkException] { + ssc.start() // start after stop should throw exception + } + } + + test("stop only streaming context") { ssc = new StreamingContext(master, appName, batchDuration) sc = ssc.sparkContext addInputStream(ssc).register ssc.start() ssc.stop(false) - ssc = null assert(sc.makeRDD(1 to 100).collect().size === 100) ssc = new StreamingContext(sc, batchDuration) + addInputStream(ssc).register + ssc.start() + ssc.stop() + } + + test("stop gracefully") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("spark.cleaner.ttl", "3600") + sc = new SparkContext(conf) + for (i <- 1 to 4) { + logInfo("==================================") + ssc = new StreamingContext(sc, batchDuration) + var runningCount = 0 + TestReceiver.counter.set(1) + val input = ssc.networkStream(new TestReceiver) + input.count.foreachRDD(rdd => { + val count = rdd.first() + logInfo("Count = " + count) + runningCount += count.toInt + }) + ssc.start() + ssc.awaitTermination(500) + ssc.stop(stopSparkContext = false, stopGracefully = true) + logInfo("Running count = " + runningCount) + logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) + assert(runningCount > 0) + assert( + (TestReceiver.counter.get() == runningCount + 1) || + (TestReceiver.counter.get() == runningCount + 2), + "Received records = " + TestReceiver.counter.get() + ", " + + "processed records = " + runningCount + ) + } } test("awaitTermination") { @@ -199,7 +257,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { test("awaitTermination with error in job generation") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register val exception = intercept[TestException] { ssc.start() @@ -215,4 +272,29 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { } } -class TestException(msg: String) extends Exception(msg) \ No newline at end of file +class TestException(msg: String) extends Exception(msg) + +/** Custom receiver for testing whether all data received by a receiver gets processed or not */ +class TestReceiver extends NetworkReceiver[Int] { + protected lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY) + protected def onStart() { + blockGenerator.start() + logInfo("BlockGenerator started on thread " + receivingThread) + try { + while(true) { + blockGenerator += TestReceiver.counter.getAndIncrement + Thread.sleep(0) + } + } finally { + logInfo("Receiving stopped at count value of " + TestReceiver.counter.get()) + } + } + + protected def onStop() { + blockGenerator.stop() + } +} + +object TestReceiver { + val counter = new AtomicInteger(1) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 201630672ab4c..aa2d5c2fc2454 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -277,7 +277,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") - Thread.sleep(500) // Give some time for the forgetting old RDDs to complete + Thread.sleep(100) // Give some time for the forgetting old RDDs to complete } catch { case e: Exception => {e.printStackTrace(); throw e} } finally {