Skip to content

Commit

Permalink
Further refactored receiver to allow restarting of a receiver.
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Apr 16, 2014
1 parent 43f5290 commit 028bde6
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,17 @@ class MQTTReceiver(
storageLevel: StorageLevel
) extends NetworkReceiver[String](storageLevel) {

def onStop() { }
def onStop() {

}

def onStart() {

// Set up persistence for messages
val peristance: MqttClientPersistence = new MemoryPersistence()
val persistence = new MemoryPersistence()

// Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
val client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence)

// Connect to MqttBroker
client.connect()
Expand All @@ -97,8 +99,7 @@ class MQTTReceiver(
}

override def connectionLost(arg0: Throwable) {
reportError("Connection lost ", arg0)
stop()
restart("Connection lost ", arg0)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,11 @@ class TwitterReceiver(
def onScrubGeo(l: Long, l1: Long) {}
def onStallWarning(stallWarning: StallWarning) {}
def onException(e: Exception) {
reportError("Error receiving tweets", e)
stop()
restart("Error receiving tweets", e)
}
})

val query: FilterQuery = new FilterQuery
val query = new FilterQuery
if (filters.size > 0) {
query.track(filters.toArray)
twitterStream.filter(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.util.NextIterator
import scala.reflect.ClassTag

import java.io._
import java.net.Socket
import java.net.{UnknownHostException, Socket}
import org.apache.spark.Logging
import org.apache.spark.streaming.receiver.NetworkReceiver

Expand All @@ -51,19 +51,49 @@ class SocketReceiver[T: ClassTag](
) extends NetworkReceiver[T](storageLevel) with Logging {

var socket: Socket = null
var receivingThread: Thread = null

def onStart() {
logInfo("Connecting to " + host + ":" + port)
socket = new Socket(host, port)
logInfo("Connected to " + host + ":" + port)
val iterator = bytesToObjects(socket.getInputStream())
while(!isStopped && iterator.hasNext) {
store(iterator.next)
receivingThread = new Thread("Socket Receiver") {
override def run() {
connect()
receive()
}
}
receivingThread.start()
}

def onStop() {
if (socket != null) socket.close()
if (socket != null) {
socket.close()
}
socket = null
if (receivingThread != null) {
receivingThread.join()
}
}

def connect() {
try {
logInfo("Connecting to " + host + ":" + port)
socket = new Socket(host, port)
} catch {
case e: Exception =>
restart("Could not connect to " + host + ":" + port, e)
}
}

def receive() {
try {
logInfo("Connected to " + host + ":" + port)
val iterator = bytesToObjects(socket.getInputStream())
while(!isStopped && iterator.hasNext) {
store(iterator.next)
}
} catch {
case e: Exception =>
restart("Error receiving data from socket", e)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ import org.apache.spark.storage.StorageLevel
* class MyReceiver(storageLevel) extends NetworkReceiver[String](storageLevel) {
* def onStart() {
* // Setup stuff (start threads, open sockets, etc.) to start receiving data.
* // Call store(...) to store received data into Spark's memory.
* // Optionally, wait for other threads to complete or watch for exceptions.
* // Call reportError(...) if there is an error that you cannot ignore and need
* // the receiver to be terminated.
* // Must start new thread to receive data, as onStart() must be non-blocking.
*
* // Call store(...) in those threads to store received data into Spark's memory.
*
* // Call stop(...), restart() or reportError(...) on any thread based on how
* // different errors should be handled.
*
* // See corresponding method documentation for more details.
* }
*
* def onStop() {
Expand All @@ -47,17 +51,24 @@ import org.apache.spark.storage.StorageLevel
abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serializable {

/**
* This method is called by the system when the receiver is started to start receiving data.
* All threads and resources set up in this method must be cleaned up in onStop().
* If there are exceptions on other threads such that the receiver must be terminated,
* then you must call reportError(exception). However, the thread that called onStart() must
* never catch and ignore InterruptedException (it can catch and rethrow).
* This method is called by the system when the receiver is started. This function
* must initialize all resources (threads, buffers, etc.) necessary for receiving data.
* This function must be non-blocking, so receiving the data must occur on a different
* thread. Received data can be stored with Spark by calling `store(data)`.
*
* If there are errors in threads started here, then following options can be done
* (i) `reportError(...)` can be called to report the error to the driver.
* The receiving of data will continue uninterrupted.
* (ii) `stop(...)` can be called to stop receiving data. This will call `onStop()` to
* clear up all resources allocated (threads, buffers, etc.) during `onStart()`.
* (iii) `restart(...)` can be called to restart the receiver. This will call `onStop()`
* immediately, and then `onStart()` after a delay.
*/
def onStart()

/**
* This method is called by the system when the receiver is stopped to stop receiving data.
* All threads and resources setup in onStart() must be cleaned up in this method.
* This method is called by the system when the receiver is stopped. All resources
* (threads, buffers, etc.) setup in `onStart()` must be cleaned up in this method.
*/
def onStop()

Expand Down Expand Up @@ -95,6 +106,7 @@ abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serial
def store(dataIterator: Iterator[T], metadata: Any) {
executor.pushIterator(dataIterator, Some(metadata), None)
}

/** Store the bytes of received data into Spark's memory. */
def store(bytes: ByteBuffer) {
executor.pushBytes(bytes, None, None)
Expand All @@ -107,24 +119,70 @@ abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serial
def store(bytes: ByteBuffer, metadata: Any = null) {
executor.pushBytes(bytes, Some(metadata), None)
}

/** Report exceptions in receiving data. */
def reportError(message: String, throwable: Throwable) {
executor.reportError(message, throwable)
}

/** Stop the receiver. */
def stop() {
executor.stop()
/**
* Restart the receiver. This will call `onStop()` immediately and return.
* Asynchronously, after a delay, `onStart()` will be called.
* The `message` will be reported to the driver.
* The delay is defined by the Spark configuration
* `spark.streaming.receiverRestartDelay`.
*/
def restart(message: String) {
executor.restartReceiver(message)
}

/**
* Restart the receiver. This will call `onStop()` immediately and return.
* Asynchronously, after a delay, `onStart()` will be called.
* The `message` and `exception` will be reported to the driver.
* The delay is defined by the Spark configuration
* `spark.streaming.receiverRestartDelay`.
*/
def restart(message: String, exception: Throwable) {
executor.restartReceiver(message, exception)
}

/**
* Restart the receiver. This will call `onStop()` immediately and return.
* Asynchronously, after the given delay, `onStart()` will be called.
*/
def restart(message: String, throwable: Throwable, millisecond: Int) {
executor.restartReceiver(message, throwable, millisecond)
}

/** Stop the receiver completely. */
def stop(message: String) {
executor.stop(message)
}

/** Stop the receiver completely due to an exception */
def stop(message: String, exception: Throwable) {
executor.stop(message, exception)
}

def isStarted(): Boolean = {
executor.isReceiverStarted()
}

/** Check if receiver has been marked for stopping. */
def isStopped(): Boolean = {
executor.isStopped
!executor.isReceiverStarted()
}

/** Get unique identifier of this receiver. */
def receiverId = id

/*
* =================
* Private methods
* =================
*/

/** Identifier of the stream this receiver is associated with. */
private var id: Int = -1

Expand Down
Loading

0 comments on commit 028bde6

Please sign in to comment.