diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5b5a3fe648602..c3078cd4ad35f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -271,7 +271,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Get the attached executor. */ - private def executor = { + private[streaming] def executor = { assert(executor_ != null, "Executor has not been attached to this receiver") executor_ } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index eeb14ca3a49e9..944d893b9bbf7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -58,6 +58,9 @@ private[streaming] abstract class ReceiverSupervisor( /** Time between a receiver is stopped and started again */ private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) + /** The current maximum rate limit for this receiver. */ + private[streaming] def getCurrentRateLimit: Option[Int] = None + /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 6e819460b1b23..edb0fc3718fc7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -100,6 +100,9 @@ private[streaming] class ReceiverSupervisorImpl( } }, streamId, env.conf) + override private[streaming] def getCurrentRateLimit: Option[Int] = + Some(blockGenerator.currentRateLimit.get) + /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { blockGenerator.addData(data) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 0d58a7b54412f..d0ac371db9aad 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -537,4 +537,19 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { verifyOutput[W](output, expectedOutput, useSet) } } + + /** + * Wait until `cond` becomes true, or timeout ms have passed. This method checks the condition + * every 100ms, so it won't wait more than 100ms more than necessary. + * + * @param cond A boolean that should become `true` + * @param timemout How many millis to wait before giving up + */ + def waitUntil(cond: => Boolean, timeout: Int): Unit = { + val start = System.currentTimeMillis() + val end = start + timeout + while ((System.currentTimeMillis() < end) && !cond) { + Thread.sleep(100) + } + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index a6e783861dbe6..9da851b5e6c1e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -22,6 +22,9 @@ import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.receiver._ import org.apache.spark.util.Utils +import org.apache.spark.streaming.dstream.InputDStream +import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.ReceiverInputDStream /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { @@ -72,15 +75,46 @@ class ReceiverTrackerSuite extends TestSuiteBase { assert(locations(0).length === 1) assert(locations(3).length === 1) } + + test("Receiver tracker - propagates rate limit") { + val newRateLimit = 100 + val ids = new TestReceiverInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + waitUntil(TestDummyReceiver.started, 5000) + tracker.sendRateUpdate(ids.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + waitUntil(ids.getRateLimit.get == newRateLimit, 1000) + assert(ids.getRateLimit.get === newRateLimit) + } +} + +/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ +private class TestReceiverInputDStream(@transient ssc_ : StreamingContext) + extends ReceiverInputDStream[Int](ssc_) { + + override def getReceiver(): DummyReceiver = TestDummyReceiver + + def getRateLimit: Option[Int] = + TestDummyReceiver.executor.getCurrentRateLimit } +/** + * We need the receiver to be an object, otherwise serialization will create another one + * and we won't be able to read its rate limit. + */ +private object TestDummyReceiver extends DummyReceiver + /** * Dummy receiver implementation */ private class DummyReceiver(host: Option[String] = None) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + var started = false + def onStart() { + started = true } def onStop() {