From 3223e95bc259c80f082f32702f8835bb754c6117 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Apr 2014 23:21:08 -0700 Subject: [PATCH] Refactored the code that runs the NetworkReceiver into further classes and traits to make them more testable. --- .../streaming/flume/FlumeInputDStream.scala | 1 + .../streaming/kafka/KafkaInputDStream.scala | 1 + .../streaming/mqtt/MQTTInputDStream.scala | 5 +- .../twitter/TwitterInputDStream.scala | 6 +- .../spark/streaming/StreamingContext.scala | 1 + .../dstream/NetworkInputDStream.scala | 405 +----------------- .../dstream/PluggableInputDStream.scala | 1 + .../streaming/dstream/RawInputDStream.scala | 1 + .../dstream/SocketInputDStream.scala | 1 + .../streaming/receiver/BlockGenerator.scala | 147 +++++++ .../streaming/receiver/NetworkReceiver.scala | 151 +++++++ .../receiver/NetworkReceiverExecutor.scala | 132 ++++++ .../NetworkReceiverExecutorImpl.scala | 188 ++++++++ .../receiver/NetworkReceiverMessage.scala | 23 + .../streaming/receivers/ActorReceiver.scala | 26 +- .../scheduler/NetworkInputTracker.scala | 24 +- .../spark/streaming/InputStreamsSuite.scala | 2 +- .../streaming/NetworkReceiverSuite.scala | 186 ++++++++ .../streaming/StreamingContextSuite.scala | 4 +- 19 files changed, 871 insertions(+), 434 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverMessage.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 4b2373473c7cc..cbcebb812cfca 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -35,6 +35,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.Logging +import org.apache.spark.streaming.receiver.NetworkReceiver private[streaming] class FlumeInputDStream[T: ClassTag]( diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 7c10c4a0d6a16..d685a3b7f737c 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -33,6 +33,7 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.receiver.NetworkReceiver /** * Input stream that pulls messages from a Kafka Broker. diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 5f8d1463dc46b..2896e42019fe2 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -39,6 +39,7 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.receiver.NetworkReceiver /** * Input stream that subscribe messages from a Mqtt Broker. @@ -96,8 +97,8 @@ class MQTTReceiver( } override def connectionLost(arg0: Throwable) { - store("Connection lost " + arg0) - stopOnError(new Exception(arg0)) + reportError("Connection lost ", arg0) + stop() } } diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 30cf3bd1a8efe..59957c05c9f76 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -26,6 +26,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream._ import org.apache.spark.storage.StorageLevel import org.apache.spark.Logging +import org.apache.spark.streaming.receiver.NetworkReceiver /* A stream of Twitter statuses, potentially filtered by one or more keywords. * @@ -75,7 +76,10 @@ class TwitterReceiver( def onTrackLimitationNotice(i: Int) {} def onScrubGeo(l: Long, l1: Long) {} def onStallWarning(stallWarning: StallWarning) {} - def onException(e: Exception) { stopOnError(e) } + def onException(e: Exception) { + reportError("Error receiving tweets", e) + stop() + } }) val query: FilterQuery = new FilterQuery diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index e198c69470c1f..42a70ead7e40f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -40,6 +40,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receivers._ import org.apache.spark.streaming.scheduler._ import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.receiver.NetworkReceiver /** * Main entry point for Spark Streaming functionality. It provides methods used to create diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 77cf5ee4cc075..423c2ae72d691 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -17,31 +17,20 @@ package org.apache.spark.streaming.dstream -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} -import scala.concurrent.Await +import scala.Array import scala.reflect.ClassTag -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} - -import akka.actor.{Props, Actor} -import akka.pattern.ask - -import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} +import org.apache.spark.rdd.{BlockRDD, RDD} +import org.apache.spark.storage.BlockId import org.apache.spark.streaming._ -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.rdd.{RDD, BlockRDD} -import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} -import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver} -import org.apache.spark.util.AkkaUtils +import org.apache.spark.streaming.receiver.NetworkReceiver /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] * that has to start a receiver on worker nodes to receive external data. * Specific implementations of NetworkInputDStream must * define the getReceiver() function that gets the receiver object of type - * [[org.apache.spark.streaming.dstream.NetworkReceiver]] that will be sent + * [[org.apache.spark.streaming.receiver.NetworkReceiver]] that will be sent * to the workers to receive data. * @param ssc_ Streaming context that will execute this input stream * @tparam T Class type of the object of this stream @@ -79,390 +68,6 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte } -private[streaming] sealed trait NetworkReceiverMessage -private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage -private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) - extends NetworkReceiverMessage -private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage - -/** - * Abstract class of a receiver that can be run on worker nodes to receive external data. A - * custom receiver can be defined by defining the functions onStart() and onStop(). onStart() - * should define the setup steps necessary to start receiving data, - * and onStop() should define the cleanup steps necessary to stop receiving data. A custom - * receiver would look something like this. - * - * class MyReceiver(storageLevel) extends NetworkReceiver[String](storageLevel) { - * def onStart() { - * // Setup stuff (start threads, open sockets, etc.) to start receiving data. - * // Call store(...) to store received data into Spark's memory. - * // Optionally, wait for other threads to complete or watch for exceptions. - * // Call stopOnError(...) if there is an error that you cannot ignore and need - * // the receiver to be terminated. - * } - * - * def onStop() { - * // Cleanup stuff (stop threads, close sockets, etc.) to stop receiving data. - * } - * } - */ -abstract class NetworkReceiver[T: ClassTag](val storageLevel: StorageLevel) - extends Serializable { - - /** - * This method is called by the system when the receiver is started to start receiving data. - * All threads and resources set up in this method must be cleaned up in onStop(). - * If there are exceptions on other threads such that the receiver must be terminated, - * then you must call stopOnError(exception). However, the thread that called onStart() must - * never catch and ignore InterruptedException (it can catch and rethrow). - */ - def onStart() - - /** - * This method is called by the system when the receiver is stopped to stop receiving data. - * All threads and resources setup in onStart() must be cleaned up in this method. - */ - def onStop() - - /** Override this to specify a preferred location (hostname). */ - def preferredLocation : Option[String] = None - - /** Store a single item of received data to Spark's memory/ */ - def store(dataItem: T) { - handler.pushSingle(dataItem) - } - - /** Store a sequence of received data block into Spark's memory. */ - def store(dataBuffer: ArrayBuffer[T]) { - handler.pushArrayBuffer(dataBuffer) - } - - /** Store a sequence of received data block into Spark's memory. */ - def store(dataIterator: Iterator[T]) { - handler.pushIterator(dataIterator) - } - - /** Store the bytes of received data block into Spark's memory. */ - def store(bytes: ByteBuffer) { - handler.pushBytes(bytes) - } - - /** Stop the receiver. */ - def stop() { - handler.stop() - } - - /** Stop the receiver when an error occurred. */ - def stopOnError(e: Exception) { - handler.stop(e) - } - - /** Check if receiver has been marked for stopping */ - def isStopped: Boolean = { - handler.isStopped - } - /** Get unique identifier of this receiver. */ - def receiverId = id - - /** Identifier of the stream this receiver is associated with. */ - private var id: Int = -1 - - /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - private[streaming] lazy val handler = new NetworkReceiverHandler(this) - - /** Set the ID of the DStream that this receiver is associated with */ - private[streaming] def setReceiverId(id_ : Int) { - id = id_ - } -} - - -private[streaming] class NetworkReceiverHandler(receiver: NetworkReceiver[_]) extends Logging { - - val env = SparkEnv.get - val receiverId = receiver.receiverId - val storageLevel = receiver.storageLevel - - /** Remote Akka actor for the NetworkInputTracker */ - private val trackerActor = { - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - env.actorSystem.actorSelection(url) - } - - /** Timeout for Akka actor messages */ - private val askTimeout = AkkaUtils.askTimeout(env.conf) - - /** Akka actor for receiving messages from the NetworkInputTracker in the driver */ - private val actor = env.actorSystem.actorOf( - Props(new Actor { - override def preStart() { - logInfo("Registered receiver " + receiverId) - val future = trackerActor.ask(RegisterReceiver(receiverId, self))(askTimeout) - Await.result(future, askTimeout) - } - - override def receive() = { - case StopReceiver => - logInfo("Received stop signal") - stop() - } - }), "NetworkReceiver-" + receiverId) - - /** Divides received data records into data blocks for pushing in BlockManager */ - private val blockGenerator = new BlockGenerator(this) - - /** Exceptions that occurs while receiving data */ - private val exceptions = new ArrayBuffer[Exception] with SynchronizedBuffer[Exception] - - /** Unique block ids if one wants to add blocks directly */ - private val newBlockId = new AtomicLong(System.currentTimeMillis()) - - /** Thread that starts the receiver and stays blocked while data is being received */ - private var receivingThread: Option[Thread] = None - - /** Has the receiver been marked for stop */ - private var stopped = false - - /** - * Starts the receiver. First is accesses all the lazy members to - * materialize them. Then it calls the user-defined onStart() method to start - * other threads, etc. required to receive the data. - */ - def run() { - // Remember this thread as the receiving thread - receivingThread = Some(Thread.currentThread()) - - // Starting the block generator - blockGenerator.start() - - try { - // Call user-defined onStart() - logInfo("Calling onStart") - receiver.onStart() - - // Wait until interrupt is called on this thread - while(true) Thread.sleep(100000) - } catch { - case ie: InterruptedException => - logInfo("Receiving thread has been interrupted, receiver " + receiverId + " stopped") - case e: Exception => - logError("Error receiving data in receiver " + receiverId, e) - exceptions += e - } - - // Call user-defined onStop() - try { - logInfo("Calling onStop") - receiver.onStop() - } catch { - case e: Exception => - logError("Error stopping receiver " + receiverId, e) - exceptions += e - } - // Stopping BlockGenerator - blockGenerator.stop() - val message = if (exceptions.isEmpty) { - null - } else if (exceptions.size == 1) { - val e = exceptions.head - "Exception in receiver " + receiverId + ": " + e.getMessage + "\n" + e.getStackTraceString - } else { - "Multiple exceptions in receiver " + receiverId + "(" + exceptions.size + "):\n" - exceptions.zipWithIndex.map { - case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString - }.mkString("\n") - } - logInfo("Deregistering receiver " + receiverId) - val future = trackerActor.ask(DeregisterReceiver(receiverId, message))(askTimeout) - Await.result(future, askTimeout) - logInfo("Deregistered receiver " + receiverId) - env.actorSystem.stop(actor) - logInfo("Stopped receiver " + receiverId) - } - - - /** Push a single record of received data into block generator. */ - def pushSingle(data: Any) { - blockGenerator += data - } - - /** Push a block of received data into block manager. */ - def pushArrayBuffer( - arrayBuffer: ArrayBuffer[_], - blockId: StreamBlockId = nextBlockId, - metadata: Any = null - ) { - logDebug("Pushing block " + blockId) - val time = System.currentTimeMillis - env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], storageLevel, true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - trackerActor ! AddBlocks(receiverId, Array(blockId), null) - logDebug("Reported block " + blockId) - } - - /** - * Push a received data into Spark as . Call this method from the data receiving - * thread to submit - * a block of data. - */ - def pushIterator( - iterator: Iterator[_], - blockId: StreamBlockId = nextBlockId, - metadata: Any = null - ) { - env.blockManager.put(blockId, iterator, storageLevel, true) - trackerActor ! AddBlocks(receiverId, Array(blockId), null) - logInfo("Pushed block " + blockId) - } - - - /** - * Push a block (as bytes) into the block manager. - */ - def pushBytes( - bytes: ByteBuffer, - blockId: StreamBlockId = nextBlockId, - metadata: Any = null - ) { - env.blockManager.putBytes(blockId, bytes, storageLevel, true) - trackerActor ! AddBlocks(receiverId, Array(blockId), null) - logInfo("Pushed block " + blockId) - } - - /** - * Stop receiving data. - */ - def stop(e: Exception = null) { - // Mark has stopped - stopped = true - logInfo("Marked as stop") - - // Store the exception if any - if (e != null) { - logError("Error receiving data", e) - exceptions += e - } - - if (receivingThread.isDefined) { - // Wait for the receiving thread to finish on its own - receivingThread.get.join(env.conf.getLong("spark.streaming.receiverStopTimeout", 2000)) - - // Stop receiving by interrupting the receiving thread - receivingThread.get.interrupt() - logInfo("Interrupted receiving thread of receiver " + receiverId + " for stopping") - } - } - - /** Check if receiver has been marked for stopping. */ - def isStopped = stopped - - private def nextBlockId = StreamBlockId(receiverId, newBlockId.getAndIncrement) -} - -/** - * Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts them into - * appropriately named blocks at regular intervals. This class starts two threads, - * one to periodically start a new batch and prepare the previous batch of as a block, - * the other to push the blocks into the block manager. - */ -private[streaming] class BlockGenerator(handler: NetworkReceiverHandler) extends Logging { - - private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any], metadata: Any = null) - - private val env = handler.env - private val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200) - private val blockIntervalTimer = - new RecurringTimer(new SystemClock(), blockInterval, updateCurrentBuffer) - private val blocksForPushing = new ArrayBlockingQueue[Block](10) - private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - private var currentBuffer = new ArrayBuffer[Any] - private var stopped = false - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Started BlockGenerator") - } - - def stop() { - // Stop generating blocks - blockIntervalTimer.stop() - - // Mark as stopped - synchronized { stopped = true } - - // Wait for all blocks to be pushed - logDebug("Waiting for block pushing thread to terminate") - blockPushingThread.join() - logInfo("Stopped BlockGenerator") - } - - def += (obj: Any): Unit = synchronized { - currentBuffer += obj - } - - private def isStopped = synchronized { stopped } - - private def updateCurrentBuffer(time: Long): Unit = synchronized { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[Any] - if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(handler.receiverId, time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer) - blocksForPushing.add(newBlock) - logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) - } - } catch { - case ie: InterruptedException => - logInfo("Block updating timer thread was interrupted") - case e: Exception => - handler.stop(e) - } - } - - private def keepPushingBlocks() { - logInfo("Started block pushing thread") - - def pushNextBlock() { - Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { - case Some(block) => - handler.pushArrayBuffer(block.buffer, block.id, block.metadata) - logInfo("Pushed block "+ block.id) - case None => - } - } - - try { - while(!isStopped) { - Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { - case Some(block) => - handler.pushArrayBuffer(block.buffer, block.id, block.metadata) - logInfo("Pushed block "+ block.id) - case None => - } - } - // Push out the blocks that are still left - logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") - while (!blocksForPushing.isEmpty) { - logDebug("Getting block ") - val block = blocksForPushing.take() - logDebug("Got block") - handler.pushArrayBuffer(block.buffer, block.id, block.metadata) - logInfo("Blocks left to push " + blocksForPushing.size()) - } - logInfo("Stopped block pushing thread") - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread was interrupted") - case e: Exception => - handler.stop(e) - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 6f9477020a459..0438b83a4d05e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.streaming.StreamingContext import scala.reflect.ClassTag +import org.apache.spark.streaming.receiver.NetworkReceiver private[streaming] class PluggableInputDStream[T: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index b920dae60cd66..55a689285bc45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -28,6 +28,7 @@ import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, SocketChannel} import java.io.EOFException import java.util.concurrent.ArrayBlockingQueue +import org.apache.spark.streaming.receiver.NetworkReceiver /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 53ead3d22f736..701e4920ec9cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -26,6 +26,7 @@ import scala.reflect.ClassTag import java.io._ import java.net.Socket import org.apache.spark.Logging +import org.apache.spark.streaming.receiver.NetworkReceiver private[streaming] class SocketInputDStream[T: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala new file mode 100644 index 0000000000000..5157e20927533 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} + +/** Listener object for BlockGenerator events */ +private[streaming] trait BlockGeneratorListener { + /** Called when a new block needs to be pushed */ + def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) + /** Called when an error has occurred in BlockGenerator */ + def onError(message: String, throwable: Throwable) +} + +/** + * Generates batches of objects received by a + * [[org.apache.spark.streaming.receiver.NetworkReceiver]] and puts them into appropriately + * named blocks at regular intervals. This class starts two threads, + * one to periodically start a new batch and prepare the previous batch of as a block, + * the other to push the blocks into the block manager. + */ +private[streaming] class BlockGenerator( + listener: BlockGeneratorListener, + receiverId: Int, + conf: SparkConf + ) extends Logging { + + private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) + + private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200) + private val blockIntervalTimer = + new RecurringTimer(new SystemClock(), blockInterval, updateCurrentBuffer) + private val blocksForPushing = new ArrayBlockingQueue[Block](10) + private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + private var currentBuffer = new ArrayBuffer[Any] + private var stopped = false + + /** Start block generating and pushing threads. */ + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Started BlockGenerator") + } + + /** Stop all threads. */ + def stop() { + // Stop generating blocks + blockIntervalTimer.stop() + + // Mark as stopped + synchronized { stopped = true } + + // Wait for all blocks to be pushed + logDebug("Waiting for block pushing thread to terminate") + blockPushingThread.join() + logInfo("Stopped BlockGenerator") + } + + /** + * Push a single data item into the buffer. All received data items + * will be periodically coallesced into blocks and pushed into BlockManager. + */ + def += (data: Any): Unit = synchronized { + currentBuffer += data + } + + /** Change the buffer to which single records are added to. */ + private def updateCurrentBuffer(time: Long): Unit = synchronized { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[Any] + if (newBlockBuffer.size > 0) { + val blockId = StreamBlockId(receiverId, time - blockInterval) + val newBlock = new Block(blockId, newBlockBuffer) + blocksForPushing.add(newBlock) + logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) + } + } catch { + case ie: InterruptedException => + logInfo("Block updating timer thread was interrupted") + case t: Throwable => + reportError("Error in block updating thread", t) + } + } + + /** Keep pushing blocks to the BlockManager. */ + private def keepPushingBlocks() { + logInfo("Started block pushing thread") + + try { + while(!isStopped) { + Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + case Some(block) => pushBlock(block) + case None => + } + } + // Push out the blocks that are still left + logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") + while (!blocksForPushing.isEmpty) { + logDebug("Getting block ") + val block = blocksForPushing.take() + pushBlock(block) + logInfo("Blocks left to push " + blocksForPushing.size()) + } + logInfo("Stopped block pushing thread") + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread was interrupted") + case t: Throwable => + reportError("Error in block pushing thread", t) + } + } + + private def reportError(message: String, t: Throwable) { + logError(message, t) + listener.onError(message, t) + } + + private def pushBlock(block: Block) { + listener.onPushBlock(block.id, block.buffer) + logInfo("Pushed block " + block.id) + } + + private def isStopped = synchronized { stopped } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala new file mode 100644 index 0000000000000..50c5648daef60 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiver.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.storage.StorageLevel + +/** + * Abstract class of a receiver that can be run on worker nodes to receive external data. A + * custom receiver can be defined by defining the functions onStart() and onStop(). onStart() + * should define the setup steps necessary to start receiving data, + * and onStop() should define the cleanup steps necessary to stop receiving data. A custom + * receiver would look something like this. + * + * class MyReceiver(storageLevel) extends NetworkReceiver[String](storageLevel) { + * def onStart() { + * // Setup stuff (start threads, open sockets, etc.) to start receiving data. + * // Call store(...) to store received data into Spark's memory. + * // Optionally, wait for other threads to complete or watch for exceptions. + * // Call reportError(...) if there is an error that you cannot ignore and need + * // the receiver to be terminated. + * } + * + * def onStop() { + * // Cleanup stuff (stop threads, close sockets, etc.) to stop receiving data. + * } + * } + */ +abstract class NetworkReceiver[T](val storageLevel: StorageLevel) extends Serializable { + + /** + * This method is called by the system when the receiver is started to start receiving data. + * All threads and resources set up in this method must be cleaned up in onStop(). + * If there are exceptions on other threads such that the receiver must be terminated, + * then you must call reportError(exception). However, the thread that called onStart() must + * never catch and ignore InterruptedException (it can catch and rethrow). + */ + def onStart() + + /** + * This method is called by the system when the receiver is stopped to stop receiving data. + * All threads and resources setup in onStart() must be cleaned up in this method. + */ + def onStop() + + /** Override this to specify a preferred location (hostname). */ + def preferredLocation : Option[String] = None + + /** Store a single item of received data to Spark's memory. */ + def store(dataItem: T) { + executor.pushSingle(dataItem) + } + + /** Store a sequence of received data into Spark's memory. */ + def store(dataBuffer: ArrayBuffer[T]) { + executor.pushArrayBuffer(dataBuffer, None, None) + } + + /** + * Store a sequence of received data into Spark's memory. + * The metadata will be associated with this block of data + * for being used in the corresponding InputDStream. + */ + def store(dataBuffer: ArrayBuffer[T], metadata: Any) { + executor.pushArrayBuffer(dataBuffer, Some(metadata), None) + } + /** Store a sequence of received data into Spark's memory. */ + def store(dataIterator: Iterator[T]) { + executor.pushIterator(dataIterator, None, None) + } + + /** + * Store a sequence of received data into Spark's memory. + * The metadata will be associated with this block of data + * for being used in the corresponding InputDStream. + */ + def store(dataIterator: Iterator[T], metadata: Any) { + executor.pushIterator(dataIterator, Some(metadata), None) + } + /** Store the bytes of received data into Spark's memory. */ + def store(bytes: ByteBuffer) { + executor.pushBytes(bytes, None, None) + } + + /** Store the bytes of received data into Spark's memory. + * The metadata will be associated with this block of data + * for being used in the corresponding InputDStream. + */ + def store(bytes: ByteBuffer, metadata: Any = null) { + executor.pushBytes(bytes, Some(metadata), None) + } + /** Report exceptions in receiving data. */ + def reportError(message: String, throwable: Throwable) { + executor.reportError(message, throwable) + } + + /** Stop the receiver. */ + def stop() { + executor.stop() + } + + /** Check if receiver has been marked for stopping. */ + def isStopped(): Boolean = { + executor.isStopped + } + + /** Get unique identifier of this receiver. */ + def receiverId = id + + /** Identifier of the stream this receiver is associated with. */ + private var id: Int = -1 + + /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ + private[streaming] var executor_ : NetworkReceiverExecutor = null + + /** Set the ID of the DStream that this receiver is associated with. */ + private[streaming] def setReceiverId(id_ : Int) { + id = id_ + } + + /** Attach Network Receiver executor to this receiver. */ + private[streaming] def attachExecutor(exec: NetworkReceiverExecutor) { + assert(executor_ == null) + executor_ = exec + } + + /** Get the attached executor. */ + private def executor = { + assert(executor_ != null, "Executor has not been attached to this receiver") + executor_ + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala new file mode 100644 index 0000000000000..77c53112493c9 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.StreamBlockId + +/** + * Abstract class that is responsible for executing a NetworkReceiver in the worker. + * It provides all the necessary interfaces for handling the data received by the receiver. + */ +private[streaming] abstract class NetworkReceiverExecutor( + receiver: NetworkReceiver[_], + conf: SparkConf = new SparkConf() + ) extends Logging { + + receiver.attachExecutor(this) + + /** Receiver id */ + protected val receiverId = receiver.receiverId + + /** Thread that starts the receiver and stays blocked while data is being received. */ + @volatile protected var receivingThread: Option[Thread] = None + + /** Has the receiver been marked for stop. */ + @volatile private var stopped = false + + /** Push a single data item to backend data store. */ + def pushSingle(data: Any) + + /** Push a byte buffer to backend data store. */ + def pushBytes( + bytes: ByteBuffer, + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) + + /** Push an iterator of objects as a block to backend data store. */ + def pushIterator( + iterator: Iterator[_], + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) + + /** Push an ArrayBuffer of object as a block to back data store. */ + def pushArrayBuffer( + arrayBuffer: ArrayBuffer[_], + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) + + /** Report errors. */ + def reportError(message: String, throwable: Throwable) + + /** + * Run the receiver. The thread that calls this is supposed to stay blocked + * in this function until the stop() is called or there is an exception + */ + def run() { + // Remember this thread as the receiving thread + receivingThread = Some(Thread.currentThread()) + + try { + // Call user-defined onStart() + logInfo("Calling onStart") + receiver.onStart() + + // Wait until interrupt is called on this thread + while(true) { + Thread.sleep(100) + } + } catch { + case ie: InterruptedException => + logInfo("Receiving thread has been interrupted, receiver " + receiverId + " stopped") + case t: Throwable => + reportError("Error receiving data in receiver " + receiverId, t) + } + + // Call user-defined onStop() + try { + logInfo("Calling onStop") + receiver.onStop() + } catch { + case t: Throwable => + reportError("Error stopping receiver " + receiverId, t) + } + } + + /** + * Stop receiving data. + */ + def stop() { + // Mark has stopped + + if (receivingThread.isDefined) { + // Interrupt the thread + receivingThread.get.interrupt() + + // Wait for the receiving thread to finish on its own + receivingThread.get.join(conf.getLong("spark.streaming.receiverStopTimeout", 2000)) + + // Stop receiving by interrupting the receiving thread + receivingThread.get.interrupt() + logInfo("Interrupted receiving thread of receiver " + receiverId + " for stopping") + } + + stopped = true + logInfo("Marked as stop") + } + + /** Check if receiver has been marked for stopping. */ + def isStopped = stopped +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala new file mode 100644 index 0000000000000..5ac28405462f4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.concurrent.Await + +import akka.actor.{Actor, Props} +import akka.pattern.ask + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.scheduler.{AddBlocks, DeregisterReceiver, RegisterReceiver} +import org.apache.spark.util.AkkaUtils + +/** + * Concrete implementation of [[org.apache.spark.streaming.receiver.NetworkReceiverExecutor]] + * which provides all the necessary functionality for handling the data received by + * the receiver. Specifically, it creates a [[org.apache.spark.streaming.receiver.BlockGenerator]] + * object that is used to divide the received data stream into blocks of data. + */ +private[streaming] class NetworkReceiverExecutorImpl( + receiver: NetworkReceiver[_], + env: SparkEnv + ) extends NetworkReceiverExecutor(receiver) with Logging { + + private val blockManager = env.blockManager + + private val storageLevel = receiver.storageLevel + + /** Remote Akka actor for the NetworkInputTracker */ + private val trackerActor = { + val ip = env.conf.get("spark.driver.host", "localhost") + val port = env.conf.getInt("spark.driver.port", 7077) + val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + env.actorSystem.actorSelection(url) + } + + /** Timeout for Akka actor messages */ + private val askTimeout = AkkaUtils.askTimeout(env.conf) + + /** Akka actor for receiving messages from the NetworkInputTracker in the driver */ + private val actor = env.actorSystem.actorOf( + Props(new Actor { + override def preStart() { + logInfo("Registered receiver " + receiverId) + val future = trackerActor.ask(RegisterReceiver(receiverId, self))(askTimeout) + Await.result(future, askTimeout) + } + + override def receive() = { + case StopReceiver => + logInfo("Received stop signal") + stop() + } + }), "NetworkReceiver-" + receiverId) + + /** Unique block ids if one wants to add blocks directly */ + private val newBlockId = new AtomicLong(System.currentTimeMillis()) + + /** Divides received data records into data blocks for pushing in BlockManager. */ + private val blockGenerator = new BlockGenerator(new BlockGeneratorListener { + def onError(message: String, throwable: Throwable) { + reportError(message, throwable) + } + + def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { + pushArrayBuffer(arrayBuffer, None, Some(blockId)) + } + }, receiverId, env.conf) + + /** Exceptions that occurs while receiving data */ + val exceptions = new ArrayBuffer[Exception] with SynchronizedBuffer[Exception] + + /** Push a single record of received data into block generator. */ + def pushSingle(data: Any) { + blockGenerator += (data) + } + + /** Push a block of received data into block generator. */ + def pushArrayBuffer( + arrayBuffer: ArrayBuffer[_], + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) { + val blockId = optionalBlockId.getOrElse(nextBlockId) + val time = System.currentTimeMillis + blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], + storageLevel, tellMaster = true) + logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") + reportPushedBlock(blockId, optionalMetadata) + } + + /** Push a block of received data into block generator. */ + def pushIterator( + iterator: Iterator[_], + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) { + val blockId = optionalBlockId.getOrElse(nextBlockId) + val time = System.currentTimeMillis + blockManager.put(blockId, iterator, storageLevel, tellMaster = true) + logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") + reportPushedBlock(blockId, optionalMetadata) + } + + /** Push a block (as bytes) into the block generator. */ + def pushBytes( + bytes: ByteBuffer, + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) { + val blockId = optionalBlockId.getOrElse(nextBlockId) + val time = System.currentTimeMillis + blockManager.putBytes(blockId, bytes, storageLevel, tellMaster = true) + logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") + reportPushedBlock(blockId, optionalMetadata) + } + + /** Report pushed block */ + def reportPushedBlock(blockId: StreamBlockId, optionalMetadata: Option[Any]) { + trackerActor ! AddBlocks(receiverId, Array(blockId), optionalMetadata.orNull) + logDebug("Reported block " + blockId) + } + + /** Add exceptions to a list */ + def reportError(message: String, throwable: Throwable) { + exceptions += new Exception(message, throwable) + } + + /** + * Starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc. required to receive the data. + */ + override def run() { + // Starting the block generator + blockGenerator.start() + + super.run() + + // Stopping BlockGenerator + blockGenerator.stop() + reportStop() + } + + /** Report to the NetworkInputTracker that the receiver has stopped */ + private def reportStop() { + val message = if (exceptions.isEmpty) { + null + } else if (exceptions.size == 1) { + val e = exceptions.head + "Exception in receiver " + receiverId + ": " + e.getMessage + "\n" + e.getStackTraceString + } else { + "Multiple exceptions in receiver " + receiverId + "(" + exceptions.size + "):\n" + exceptions.zipWithIndex.map { + case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString + }.mkString("\n") + } + logInfo("Deregistering receiver " + receiverId) + val future = trackerActor.ask(DeregisterReceiver(receiverId, message))(askTimeout) + Await.result(future, askTimeout) + logInfo("Deregistered receiver " + receiverId) + env.actorSystem.stop(actor) + logInfo("Stopped receiver " + receiverId) + } + + /** Generate new block ID */ + private def nextBlockId = StreamBlockId(receiverId, newBlockId.getAndIncrement) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverMessage.scala new file mode 100644 index 0000000000000..6ab3ca6ea5fa6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverMessage.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +/** Messages sent to the NetworkReceiver. */ +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] object StopReceiver extends NetworkReceiverMessage + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index da07878cc3070..66c736e114372 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -17,21 +17,17 @@ package org.apache.spark.streaming.receivers -import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy } -import akka.actor.{ actorRef2Scala, ActorRef } -import akka.actor.{ PossiblyHarmful, OneForOneStrategy } -import akka.actor.SupervisorStrategy._ +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration._ import scala.reflect.ClassTag -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.dstream.NetworkReceiver - -import java.util.concurrent.atomic.AtomicInteger +import akka.actor.{Actor, OneForOneStrategy, PoisonPill, PossiblyHarmful, Props, SupervisorStrategy, actorRef2Scala} +import akka.actor.SupervisorStrategy._ -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver.NetworkReceiver /** A helper with set of defaults for supervisor strategy */ object ReceiverSupervisorStrategy { @@ -117,11 +113,11 @@ private[streaming] case class Data[T: ClassTag](data: T) * }}} */ private[streaming] class ActorReceiver[T: ClassTag]( - props: Props, - name: String, - storageLevel: StorageLevel, - receiverSupervisorStrategy: SupervisorStrategy) - extends NetworkReceiver[T](storageLevel) with Logging { + props: Props, + name: String, + storageLevel: StorageLevel, + receiverSupervisorStrategy: SupervisorStrategy + ) extends NetworkReceiver[T](storageLevel) with Logging { protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), "Supervisor" + receiverId) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index 6ac54cf7be29e..cb0021143381b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -17,20 +17,15 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} -import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import org.apache.spark.{SparkException, Logging, SparkEnv} -import org.apache.spark.SparkContext._ - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.concurrent.duration._ +import scala.collection.mutable.{HashMap, Queue} import akka.actor._ -import akka.pattern.ask -import akka.dispatch._ + +import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.SparkContext._ import org.apache.spark.storage.BlockId -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.receiver.{NetworkReceiver, NetworkReceiverExecutorImpl, StopReceiver} import org.apache.spark.util.AkkaUtils private[streaming] sealed trait NetworkInputTrackerMessage @@ -173,9 +168,12 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { if (!iterator.hasNext) { - throw new Exception("Could not start receiver as details not found.") + throw new SparkException( + "Could not start receiver as NetworkReceiver object not found.") } - iterator.next().handler.run() + val receiver = iterator.next() + val executor = new NetworkReceiverExecutorImpl(receiver, SparkEnv.get) + executor.run() } // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index e29685bc91fb6..4ae23184d7c80 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -36,10 +36,10 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.NetworkReceiver import org.apache.spark.streaming.receivers.Receiver import org.apache.spark.streaming.util.ManualClock import org.apache.spark.util.Utils +import org.apache.spark.streaming.receiver.NetworkReceiver class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala new file mode 100644 index 0000000000000..4c3ac00cf36b0 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, NetworkReceiver, NetworkReceiverExecutor} + +class NetworkReceiverSuite extends FunSuite with Timeouts { + + test("network receiver with fake executor") { + val receiver = new MockReceiver + val executor = new MockReceiverExecutor(receiver) + + val receivingThread = new Thread() { + override def run() { + println("Running receiver") + executor.run() + println("Finished receiver") + } + } + receivingThread.start() + + // Verify that NetworkReceiver.run() blocks + intercept[Exception] { + failAfter(200 millis) { + receivingThread.join() + } + } + + // Verify that onStart was called, and onStop wasn't called + assert(receiver.started) + assert(!receiver.stopped) + assert(executor.isAllEmpty) + + // Verify whether the data stored by the receiver was + // sent to the executor + val byteBuffer = ByteBuffer.allocate(100) + val arrayBuffer = new ArrayBuffer[Int]() + val iterator = arrayBuffer.iterator + receiver.store(1) + receiver.store(byteBuffer) + receiver.store(arrayBuffer) + receiver.store(iterator) + assert(executor.singles.size === 1) + assert(executor.singles.head === 1) + assert(executor.byteBuffers.size === 1) + assert(executor.byteBuffers.head.eq(byteBuffer)) + assert(executor.iterators.size === 1) + assert(executor.iterators.head.eq(iterator)) + assert(executor.arrayBuffers.size === 1) + assert(executor.arrayBuffers.head.eq(arrayBuffer)) + + // Verify whether the exceptions reported by the receiver + // was sent to the executor + val exception = new Exception + receiver.reportError("Error", exception) + assert(executor.errors.size === 1) + assert(executor.errors.head.eq(exception)) + + // Verify that stopping actually stops the thread + failAfter(500 millis) { + receiver.stop() + receivingThread.join() + } + + // Verify that onStop was called + assert(receiver.stopped) + } + + test("block generator") { + val blockGeneratorListener = new MockBlockGeneratorListener + val blockInterval = 200 + val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString) + val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) + val expectedBlocks = 5 + val waitTime = expectedBlocks * blockInterval + (blockInterval / 2) + val generatedData = new ArrayBuffer[Int] + + // Generate blocks + val startTime = System.currentTimeMillis() + blockGenerator.start() + var count = 0 + while(System.currentTimeMillis - startTime < waitTime) { + blockGenerator += count + generatedData += count + count += 1 + Thread.sleep(10) + } + blockGenerator.stop() + + val recordedData = blockGeneratorListener.arrayBuffers.flatten + assert(blockGeneratorListener.arrayBuffers.size > 0) + assert(recordedData.size <= count) + //assert(generatedData.toList === recordedData.toList) + } +} + +class MockReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) { + var started = false + var stopped = false + def onStart() { started = true } + def onStop() { stopped = true } +} + +class MockReceiverExecutor(receiver: MockReceiver) extends NetworkReceiverExecutor(receiver) { + val singles = new ArrayBuffer[Any] + val byteBuffers = new ArrayBuffer[ByteBuffer] + val iterators = new ArrayBuffer[Iterator[_]] + val arrayBuffers = new ArrayBuffer[ArrayBuffer[_]] + val errors = new ArrayBuffer[Throwable] + + def isAllEmpty = { + singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && + arrayBuffers.isEmpty && errors.isEmpty + } + + def pushSingle(data: Any) { + singles += data + } + + def pushBytes( + bytes: ByteBuffer, + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) { + byteBuffers += bytes + } + + def pushIterator( + iterator: Iterator[_], + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) { + iterators += iterator + } + + def pushArrayBuffer( + arrayBuffer: ArrayBuffer[_], + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) { + arrayBuffers += arrayBuffer + } + + def reportError(message: String, throwable: Throwable) { + errors += throwable + } +} + +class MockBlockGeneratorListener extends BlockGeneratorListener { + val arrayBuffers = new ArrayBuffer[ArrayBuffer[Int]] + val errors = new ArrayBuffer[Throwable] + + def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { + val bufferOfInts = arrayBuffer.map(_.asInstanceOf[Int]) + arrayBuffers += bufferOfInts + } + + def onError(message: String, throwable: Throwable) { + errors += throwable + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4d8c82d78ba40..6e16bbfb4a109 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -160,12 +160,12 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { ssc.start() } - // test whether waitForStop() exits after give amount of time + // test whether awaitTermination() exits after give amount of time failAfter(1000 millis) { ssc.awaitTermination(500) } - // test whether waitForStop() does not exit if not time is given + // test whether awaitTermination() does not exit if not time is given val exception = intercept[Exception] { failAfter(1000 millis) { ssc.awaitTermination()