Skip to content

Commit

Permalink
Added a couple of tests for the full scenario from driver to receivers,
Browse files Browse the repository at this point in the history
with several rate updates.
  • Loading branch information
dragos committed Jul 23, 2015
1 parent b425d32 commit e57c66b
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
* Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
*/
override protected[streaming] val rateController: Option[RateController] =
RateEstimator.makeEstimator(ssc.conf).map { estimator =>
new RateController(id, estimator) {
override def publish(rate: Long): Unit =
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
}
}
RateEstimator.makeEstimator(ssc.conf).map { new ReceiverRateController(id, _) }

/**
* Gets the receiver object that will be sent to the worker nodes
Expand Down Expand Up @@ -122,4 +117,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
}
Some(blockRDD)
}

/**
* A RateController that sends the new rate to receivers, via the receiver tracker.
*/
private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator)
extends RateController(id, estimator) {
override def publish(rate: Long): Unit =
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@

package org.apache.spark.streaming.scheduler

import scala.collection.mutable
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{StreamingContext, TestOutputStreamWithPartitions, TestSuiteBase, Time}
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming._
import org.apache.spark.streaming.scheduler.rate.RateEstimator



class RateControllerSuite extends TestSuiteBase {

override def actuallyWait: Boolean = true

test("rate controller publishes updates") {
val ssc = new StreamingContext(conf, batchDuration)
val dstream = new MockRateLimitDStream(ssc)
val dstream = new MockRateLimitDStream(ssc, Seq(Seq(1)), 1)
val output = new TestOutputStreamWithPartitions(dstream)
output.register()
runStreams(ssc, 1, 1)
Expand All @@ -39,41 +45,98 @@ class RateControllerSuite extends TestSuiteBase {
assert(dstream.publishCalls === 1)
}
}

test("receiver rate controller updates reach receivers") {
val ssc = new StreamingContext(conf, batchDuration)

val dstream = new RateLimitInputDStream(ssc) {
override val rateController =
Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
}
SingletonDummyReceiver.reset()

val output = new TestOutputStreamWithPartitions(dstream)
output.register()
runStreams(ssc, 2, 2)

eventually(timeout(5.seconds)) {
assert(dstream.getCurrentRateLimit === Some(200))
}
}

test("multiple rate controller updates reach receivers") {
val ssc = new StreamingContext(conf, batchDuration)
val rates = Seq(100L, 200L, 300L)

val dstream = new RateLimitInputDStream(ssc) {
override val rateController =
Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*)))
}
SingletonDummyReceiver.reset()

val output = new TestOutputStreamWithPartitions(dstream)
output.register()

val observedRates = mutable.HashSet.empty[Long]

@volatile var done = false
runInBackground {
while (!done) {
try {
dstream.getCurrentRateLimit.foreach(observedRates += _)
} catch {
case NonFatal(_) => () // don't stop if the executor wasn't installed yet
}
Thread.sleep(20)
}
}
runStreams(ssc, 4, 4)
done = true

// Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver
observedRates should contain theSameElementsAs (rates :+ Long.MaxValue)
}

private def runInBackground(f: => Unit): Unit = {
new Thread {
override def run(): Unit = {
f
}
}.start()
}
}

/**
* An InputDStream that counts how often its rate controller `publish` method was called.
*/
private class MockRateLimitDStream(@transient ssc: StreamingContext)
extends InputDStream[Int](ssc) {
private class MockRateLimitDStream[T: ClassTag](
@transient ssc: StreamingContext,
input: Seq[Seq[T]],
numPartitions: Int) extends TestInputStream[T](ssc, input, numPartitions) {

@volatile
var publishCalls = 0

private object ConstantEstimator extends RateEstimator {
def compute(
time: Long,
elements: Long,
processingDelay: Long,
schedulingDelay: Long): Option[Double] = {
Some(100.0)
}
}

override val rateController: Option[RateController] =
Some(new RateController(id, ConstantEstimator) {
Some(new RateController(id, new ConstantEstimator(100.0)) {
override def publish(rate: Long): Unit = {
publishCalls += 1
}
})
}

def compute(validTime: Time): Option[RDD[Int]] = {
val data = Seq(1)
ssc.scheduler.inputInfoTracker.reportInfo(validTime, StreamInputInfo(id, data.size))
Some(ssc.sc.parallelize(data))
}
private class ConstantEstimator(rates: Double*) extends RateEstimator {
private var idx: Int = 0

def stop(): Unit = {}
private def nextRate(): Double = {
val rate = rates(idx)
idx = (idx + 1) % rates.size
rate
}

def start(): Unit = {}
def compute(
time: Long,
elements: Long,
processingDelay: Long,
schedulingDelay: Long): Option[Double] = Some(nextRate())
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
package org.apache.spark.streaming.scheduler

import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import org.apache.spark.streaming._

import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receiver._
import org.apache.spark.util.Utils
import org.apache.spark.streaming.dstream.InputDStream
import scala.reflect.ClassTag
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisor}


/** Testsuite for receiver scheduling */
class ReceiverTrackerSuite extends TestSuiteBase {
Expand Down Expand Up @@ -129,12 +128,20 @@ private class RateLimitInputDStream(@transient ssc_ : StreamingContext)
}

/**
* A Receiver as an object so we can read its rate limit.
* A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when
* reusing this receiver, otherwise a non-null `executor_` field will prevent it from being
* serialized when receivers are installed on executors.
*
* @note It's necessary to be a top-level object, or else serialization would create another
* one on the executor side and we won't be able to read its rate limit.
*/
private object SingletonDummyReceiver extends DummyReceiver
private object SingletonDummyReceiver extends DummyReceiver {

/** Reset the object to be usable in another test. */
def reset(): Unit = {
executor_ = null
}
}

/**
* Dummy receiver implementation
Expand Down

0 comments on commit e57c66b

Please sign in to comment.