From 028bde678f1f039b765bd8a76978eb9fb05fc77a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Apr 2014 12:23:39 -0700 Subject: [PATCH] Further refactored receiver to allow restarting of a receiver. --- .../streaming/mqtt/MQTTInputDStream.scala | 11 +- .../twitter/TwitterInputDStream.scala | 5 +- .../dstream/SocketInputDStream.scala | 46 ++++-- .../streaming/receiver/NetworkReceiver.scala | 88 ++++++++++-- .../receiver/NetworkReceiverExecutor.scala | 134 +++++++++++++----- .../NetworkReceiverExecutorImpl.scala | 19 ++- .../scheduler/NetworkInputTracker.scala | 3 +- .../streaming/NetworkReceiverSuite.scala | 81 +++++++---- 8 files changed, 287 insertions(+), 100 deletions(-) diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index c9c85f0a88f13..2f2380237e572 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -69,15 +69,17 @@ class MQTTReceiver( storageLevel: StorageLevel ) extends NetworkReceiver[String](storageLevel) { - def onStop() { } + def onStop() { + + } def onStart() { // Set up persistence for messages - val peristance: MqttClientPersistence = new MemoryPersistence() + val persistence = new MemoryPersistence() // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance - val client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) + val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) // Connect to MqttBroker client.connect() @@ -97,8 +99,7 @@ class MQTTReceiver( } override def connectionLost(arg0: Throwable) { - reportError("Connection lost ", arg0) - stop() + restart("Connection lost ", arg0) } } diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 372b4c269a634..980dbc30eaf75 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -77,12 +77,11 @@ class TwitterReceiver( def onScrubGeo(l: Long, l1: Long) {} def onStallWarning(stallWarning: StallWarning) {} def onException(e: Exception) { - reportError("Error receiving tweets", e) - stop() + restart("Error receiving tweets", e) } }) - val query: FilterQuery = new FilterQuery + val query = new FilterQuery if (filters.size > 0) { query.track(filters.toArray) twitterStream.filter(query) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 731cb84cd45ad..81152eb4c0586 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.util.NextIterator import scala.reflect.ClassTag import java.io._ -import java.net.Socket +import java.net.{UnknownHostException, Socket} import org.apache.spark.Logging import org.apache.spark.streaming.receiver.NetworkReceiver @@ -51,19 +51,49 @@ class SocketReceiver[T: ClassTag]( ) extends NetworkReceiver[T](storageLevel) with Logging { var socket: Socket = null + var receivingThread: Thread = null def onStart() { - logInfo("Connecting to " + host + ":" + port) - socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - val iterator = bytesToObjects(socket.getInputStream()) - while(!isStopped && iterator.hasNext) { - store(iterator.next) + receivingThread = new Thread("Socket Receiver") { + override def run() { + connect() + receive() + } } + receivingThread.start() } def onStop() { - if (socket != null) socket.close() + if (socket != null) { + socket.close() + } + socket = null + if (receivingThread != null) { + receivingThread.join() + } + } + + def connect() { + try { + logInfo("Connecting to " + host + ":" + port) + socket = new Socket(host, port) + } catch { + case e: Exception => + restart("Could not connect to " + host + ":" + port, e) + } + } + + def receive() { + try { + logInfo("Connected to " + host + ":" + port) + val iterator = bytesToObjects(socket.getInputStream()) + while(!isStopped && iterator.hasNext) { + store(iterator.next) + } + } catch { + case e: Exception => + restart("Error receiving data from socket", e) + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala index 50c5648daef60..be3590810277f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala @@ -33,10 +33,14 @@ import org.apache.spark.storage.StorageLevel * class MyReceiver(storageLevel) extends NetworkReceiver[String](storageLevel) { * def onStart() { * // Setup stuff (start threads, open sockets, etc.) to start receiving data. - * // Call store(...) to store received data into Spark's memory. - * // Optionally, wait for other threads to complete or watch for exceptions. - * // Call reportError(...) if there is an error that you cannot ignore and need - * // the receiver to be terminated. + * // Must start new thread to receive data, as onStart() must be non-blocking. + * + * // Call store(...) in those threads to store received data into Spark's memory. + * + * // Call stop(...), restart() or reportError(...) on any thread based on how + * // different errors should be handled. + * + * // See corresponding method documentation for more details. * } * * def onStop() { @@ -47,17 +51,24 @@ import org.apache.spark.storage.StorageLevel abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serializable { /** - * This method is called by the system when the receiver is started to start receiving data. - * All threads and resources set up in this method must be cleaned up in onStop(). - * If there are exceptions on other threads such that the receiver must be terminated, - * then you must call reportError(exception). However, the thread that called onStart() must - * never catch and ignore InterruptedException (it can catch and rethrow). + * This method is called by the system when the receiver is started. This function + * must initialize all resources (threads, buffers, etc.) necessary for receiving data. + * This function must be non-blocking, so receiving the data must occur on a different + * thread. Received data can be stored with Spark by calling `store(data)`. + * + * If there are errors in threads started here, then following options can be done + * (i) `reportError(...)` can be called to report the error to the driver. + * The receiving of data will continue uninterrupted. + * (ii) `stop(...)` can be called to stop receiving data. This will call `onStop()` to + * clear up all resources allocated (threads, buffers, etc.) during `onStart()`. + * (iii) `restart(...)` can be called to restart the receiver. This will call `onStop()` + * immediately, and then `onStart()` after a delay. */ def onStart() /** - * This method is called by the system when the receiver is stopped to stop receiving data. - * All threads and resources setup in onStart() must be cleaned up in this method. + * This method is called by the system when the receiver is stopped. All resources + * (threads, buffers, etc.) setup in `onStart()` must be cleaned up in this method. */ def onStop() @@ -95,6 +106,7 @@ abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serial def store(dataIterator: Iterator[T], metadata: Any) { executor.pushIterator(dataIterator, Some(metadata), None) } + /** Store the bytes of received data into Spark's memory. */ def store(bytes: ByteBuffer) { executor.pushBytes(bytes, None, None) @@ -107,24 +119,70 @@ abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serial def store(bytes: ByteBuffer, metadata: Any = null) { executor.pushBytes(bytes, Some(metadata), None) } + /** Report exceptions in receiving data. */ def reportError(message: String, throwable: Throwable) { executor.reportError(message, throwable) } - /** Stop the receiver. */ - def stop() { - executor.stop() + /** + * Restart the receiver. This will call `onStop()` immediately and return. + * Asynchronously, after a delay, `onStart()` will be called. + * The `message` will be reported to the driver. + * The delay is defined by the Spark configuration + * `spark.streaming.receiverRestartDelay`. + */ + def restart(message: String) { + executor.restartReceiver(message) + } + + /** + * Restart the receiver. This will call `onStop()` immediately and return. + * Asynchronously, after a delay, `onStart()` will be called. + * The `message` and `exception` will be reported to the driver. + * The delay is defined by the Spark configuration + * `spark.streaming.receiverRestartDelay`. + */ + def restart(message: String, exception: Throwable) { + executor.restartReceiver(message, exception) + } + + /** + * Restart the receiver. This will call `onStop()` immediately and return. + * Asynchronously, after the given delay, `onStart()` will be called. + */ + def restart(message: String, throwable: Throwable, millisecond: Int) { + executor.restartReceiver(message, throwable, millisecond) + } + + /** Stop the receiver completely. */ + def stop(message: String) { + executor.stop(message) + } + + /** Stop the receiver completely due to an exception */ + def stop(message: String, exception: Throwable) { + executor.stop(message, exception) + } + + def isStarted(): Boolean = { + executor.isReceiverStarted() } /** Check if receiver has been marked for stopping. */ def isStopped(): Boolean = { - executor.isStopped + !executor.isReceiverStarted() } /** Get unique identifier of this receiver. */ def receiverId = id + /* + * ================= + * Private methods + * ================= + */ + /** Identifier of the stream this receiver is associated with. */ private var id: Int = -1 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala index 01b9283568dcf..a22d93e6a04be 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala @@ -24,6 +24,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import java.util.concurrent.CountDownLatch +import scala.concurrent._ +import ExecutionContext.Implicits.global /** * Abstract class that is responsible for executing a NetworkReceiver in the worker. @@ -31,20 +33,37 @@ import java.util.concurrent.CountDownLatch */ private[streaming] abstract class NetworkReceiverExecutor( receiver: NetworkReceiver[_], - conf: SparkConf = new SparkConf() + conf: SparkConf ) extends Logging { + + /** Enumeration to identify current state of the StreamingContext */ + object NetworkReceiverState extends Enumeration { + type CheckpointState = Value + val Initialized, Started, Stopped = Value + } + import NetworkReceiverState._ + + // Attach the executor to the receiver receiver.attachExecutor(this) /** Receiver id */ protected val receiverId = receiver.receiverId - /** Thread that starts the receiver and stays blocked while data is being received. */ - @volatile protected var executionThread: Option[Thread] = None + /** Message associated with the stopping of the receiver */ + protected var stopMessage = "" + + /** Exception associated with the stopping of the receiver */ + protected var stopException: Throwable = null /** Has the receiver been marked for stop. */ - //@volatile private var stopped = false - val stopLatch = new CountDownLatch(1) + private val stopLatch = new CountDownLatch(1) + + /** Time between a receiver is stopped */ + private val restartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) + + /** State of the receiver */ + private[streaming] var receiverState = Initialized /** Push a single data item to backend data store. */ def pushSingle(data: Any) @@ -73,50 +92,99 @@ private[streaming] abstract class NetworkReceiverExecutor( /** Report errors. */ def reportError(message: String, throwable: Throwable) + /** Start the executor */ + def start() { + startReceiver() + } + /** - * Run the receiver. The thread that calls this is supposed to stay blocked - * in this function until the stop() is called or there is an exception + * Mark the executor and the receiver for stopping */ - def run() { - // Remember this thread as the receiving thread - executionThread = Some(Thread.currentThread()) + def stop(message: String, exception: Throwable = null) { + stopMessage = message + stopException = exception + stopReceiver() + stopLatch.countDown() + if (exception != null) { + logError("Stopped executor: " + message, exception) + } else { + logWarning("Stopped executor: " + message) + } + } + /** Start receiver */ + def startReceiver(): Unit = synchronized { try { - // Call user-defined onStart() - logInfo("Calling onStart") - receiver.onStart() - // Wait until interrupt is called on this thread - awaitStop() - logInfo("Outside latch") + logInfo("Starting receiver") + stopMessage = "" + stopException = null + onReceiverStart() + receiverState = Started } catch { - case ie: InterruptedException => - logInfo("Receiving thread has been interrupted, receiver " + receiverId + " stopped") case t: Throwable => - reportError("Error receiving data in receiver " + receiverId, t) + stop("Error starting receiver " + receiverId, t) } + } - // Call user-defined onStop() + /** Stop receiver */ + def stopReceiver(): Unit = synchronized { try { - logInfo("Calling onStop") - receiver.onStop() + receiverState = Stopped + onReceiverStop() } catch { - case t: Throwable => - reportError("Error stopping receiver " + receiverId, t) + case t: Throwable => + stop("Error stopping receiver " + receiverId, t) } } - /** - * Mark the executor and the receiver as stopped - */ - def stop() { - // Mark for stop - stopLatch.countDown() - logInfo("Marked for stop " + stopLatch.getCount) + /** Restart receiver with delay */ + def restartReceiver(message: String, throwable: Throwable = null) { + val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) + restartReceiver(message, throwable, defaultRestartDelay) + } + + /** Restart receiver with delay */ + def restartReceiver(message: String, exception: Throwable, delay: Int) { + logWarning("Restarting receiver with delay " + delay + " ms: " + message, exception) + reportError(message, exception) + stopReceiver() + future { + logDebug("Sleeping for " + delay) + Thread.sleep(delay) + logDebug("Starting receiver again") + startReceiver() + logInfo("Receiver started again") + } + } + + /** Called when the receiver needs to be started */ + protected def onReceiverStart(): Unit = synchronized { + // Call user-defined onStart() + logInfo("Calling receiver onStart") + receiver.onStart() + logInfo("Called receiver onStart") + } + + /** Called when the receiver needs to be stopped */ + protected def onReceiverStop(): Unit = synchronized { + // Call user-defined onStop() + logInfo("Calling receiver onStop") + receiver.onStop() + logInfo("Called receiver onStop") } /** Check if receiver has been marked for stopping */ - def isStopped() = (stopLatch.getCount == 0L) + def isReceiverStarted() = { + logDebug("state = " + receiverState) + receiverState == Started + } /** Wait the thread until the executor is stopped */ - def awaitStop() = stopLatch.await() + def awaitStop() { + stopLatch.await() + logInfo("Waiting for executor stop is over") + if (stopException != null) { + throw new Exception(stopMessage, stopException) + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala index dcdd14637e3d7..7796d2a64bf8a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala @@ -45,7 +45,7 @@ import org.apache.spark.streaming.scheduler.RegisterReceiver private[streaming] class NetworkReceiverExecutorImpl( receiver: NetworkReceiver[_], env: SparkEnv - ) extends NetworkReceiverExecutor(receiver) with Logging { + ) extends NetworkReceiverExecutor(receiver, env.conf) with Logging { private val blockManager = env.blockManager @@ -76,7 +76,7 @@ private[streaming] class NetworkReceiverExecutorImpl( override def receive() = { case StopReceiver => logInfo("Received stop signal") - stop() + stop("Stopped by driver") } }), "NetworkReceiver-" + receiverId + "-" + System.currentTimeMillis()) @@ -153,16 +153,13 @@ private[streaming] class NetworkReceiverExecutorImpl( exceptions += new Exception(message, throwable) } - /** - * Starts the receiver. First is accesses all the lazy members to - * materialize them. Then it calls the user-defined onStart() method to start - * other threads, etc. required to receive the data. - */ - override def run() { - // Starting the block generator + override def onReceiverStart() { blockGenerator.start() - super.run() - // Stopping BlockGenerator + super.onReceiverStart() + } + + override def onReceiverStop() { + super.onReceiverStop() blockGenerator.stop() reportStop() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index c80defb23f071..8f9fca9365759 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -232,7 +232,8 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { } val receiver = iterator.next() val executor = new NetworkReceiverExecutorImpl(receiver, SparkEnv.get) - executor.run() + executor.start() + executor.awaitStop() } // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index 5e0a9d7238ac9..f29ea065f8767 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -23,24 +23,28 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, NetworkReceiver, NetworkReceiverExecutor} +import org.apache.spark.streaming.receiver.NetworkReceiverExecutor /** Testsuite for testing the network receiver behavior */ class NetworkReceiverSuite extends FunSuite with Timeouts { test("network receiver life cycle") { + val receiver = new FakeReceiver val executor = new FakeReceiverExecutor(receiver) + assert(executor.isAllEmpty) + // Thread that runs the executor val executingThread = new Thread() { override def run() { - println("Running receiver") - executor.run() - println("Finished receiver") + executor.start() + executor.awaitStop() } } @@ -54,11 +58,15 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { } } - // Verify that onStart was called, and onStop wasn't called - assert(receiver.started) + // Verify that receiver was started + assert(receiver.onStartCalled) + assert(executor.isReceiverStarted) + assert(receiver.isStarted) + assert(!receiver.isStopped()) assert(receiver.otherThread.isAlive) - assert(!receiver.stopped) - assert(executor.isAllEmpty) + eventually(timeout(100 millis), interval(10 millis)) { + assert(receiver.receiving) + } // Verify whether the data stored by the receiver was sent to the executor val byteBuffer = ByteBuffer.allocate(100) @@ -83,15 +91,28 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { assert(executor.errors.size === 1) assert(executor.errors.head.eq(exception)) + // Verify restarting actually stops and starts the receiver + receiver.restart("restarting", null, 100) + assert(receiver.isStopped) + assert(receiver.onStopCalled) + eventually(timeout(1000 millis), interval(100 millis)) { + assert(receiver.onStartCalled) + assert(executor.isReceiverStarted) + assert(receiver.isStarted) + assert(!receiver.isStopped) + assert(receiver.receiving) + } + // Verify that stopping actually stops the thread failAfter(100 millis) { - receiver.stop() - executingThread.join() + receiver.stop("test") + assert(receiver.isStopped) assert(!receiver.otherThread.isAlive) - } - // Verify that onStop was called - assert(receiver.stopped) + // The thread that started the executor should complete + // as stop() stops everything + executingThread.join() + } } test("block generator") { @@ -125,24 +146,35 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. */ class FakeReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) { - var started = false - var stopped = false - val otherThread = new Thread() { - override def run() { - while(!stopped) { - Thread.sleep(10) - } - } - } + var otherThread: Thread = null + var receiving = false + var onStartCalled = false + var onStopCalled = false def onStart() { + otherThread = new Thread() { + override def run() { + receiving = true + while(!isStopped()) { + Thread.sleep(10) + } + } + } + onStartCalled = true otherThread.start() - started = true + } + def onStop() { - stopped = true + onStopCalled = true otherThread.join() } + + def reset() { + receiving = false + onStartCalled = false + onStopCalled = false + } } /** @@ -150,7 +182,8 @@ class FakeReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) { * Instead of storing the data in the BlockManager, it stores all the data in a local buffer * that can used for verifying that the data has been forwarded correctly. */ -class FakeReceiverExecutor(receiver: FakeReceiver) extends NetworkReceiverExecutor(receiver) { +class FakeReceiverExecutor(receiver: FakeReceiver) + extends NetworkReceiverExecutor(receiver, new SparkConf()) { val singles = new ArrayBuffer[Any] val byteBuffers = new ArrayBuffer[ByteBuffer] val iterators = new ArrayBuffer[Iterator[_]]