Skip to content

Commit

Permalink
Various style changes and a first test for the rate controller.
Browse files Browse the repository at this point in the history
  • Loading branch information
dragos committed Jul 22, 2015
1 parent d32ca36 commit 34a389d
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
* A rate estimator configured by the user to compute a dynamic ingestion bound for this stream.
* @see `RateEstimator`
*/
protected [streaming] val rateEstimator = ssc.conf
.getOption("spark.streaming.RateEstimator")
.getOrElse("noop") match {
case _ => new NoopRateEstimator()
}
protected [streaming] val rateEstimator = newEstimator()

/**
* Return the configured estimator, or `noop` if none was specified.
*/
private def newEstimator() =
ssc.conf.get("spark.streaming.RateEstimator", "noop") match {
case "noop" => new NoopRateEstimator()
case estimator => throw new IllegalArgumentException(s"Unknown rate estimator: $estimator")
}


// Keep track of the freshest rate for this stream using the rateEstimator
protected[streaming] val rateController: RateController = new RateController(id, rateEstimator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
* Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
*/
override val rateController: RateController = new RateController(id, rateEstimator) {
override def publish(rate: Long): Unit =
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
}
override def publish(rate: Long): Unit =
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
}

/**
* Gets the receiver object that will be sent to the worker nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
eventLoop.start()

// Estimators receive updates from batch completion
ssc.graph.getInputStreams.map(_.rateController).foreach(ssc.addStreamingListener(_))
ssc.graph.getInputStreams.foreach(is => ssc.addStreamingListener(is.rateController))
listenerBus.start(ssc.sparkContext)
receiverTracker = new ReceiverTracker(ssc)
inputInfoTracker = new InputInfoTracker(ssc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ package org.apache.spark.streaming.scheduler

import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.ThreadUtils

import scala.concurrent.{ExecutionContext, Future}

/**
* :: DeveloperApi ::
* A StreamingListener that receives batch completion updates, and maintains
Expand All @@ -38,32 +38,34 @@ private [streaming] abstract class RateController(val streamUID: Int, rateEstima
protected def publish(rate: Long): Unit

// Used to compute & publish the rate update asynchronously
@transient private val executionContext = ExecutionContext.fromExecutorService(
@transient
implicit private val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update"))

private val rateLimit : AtomicLong = new AtomicLong(-1L)
private val rateLimit: AtomicLong = new AtomicLong(-1L)

// Asynchronous computation of the rate update
/**
* Compute the new rate limit and publish it asynchronously.
*/
private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
Future[Unit] {
val newSpeed = rateEstimator.compute(time, elems, workDelay, waitDelay)
newSpeed foreach { s =>
rateLimit.set(s.toLong)
publish(getLatestRate())
}
} (executionContext)
}

def getLatestRate(): Long = rateLimit.get()

override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted){
val elements = batchCompleted.batchInfo.streamIdToInputInfo
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
val elements = batchCompleted.batchInfo.streamIdToInputInfo

for (
for {
processingEnd <- batchCompleted.batchInfo.processingEndTime;
workDelay <- batchCompleted.batchInfo.processingDelay;
waitDelay <- batchCompleted.batchInfo.schedulingDelay;
elems <- elements.get(streamUID).map(_.numRecords)
) computeAndPublish(processingEnd, elems, workDelay, waitDelay)
} computeAndPublish(processingEnd, elems, workDelay, waitDelay)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,27 @@ private[streaming] trait RateEstimator extends Serializable {
* Computes the number of elements the stream attached to this `RateEstimator`
* should ingest per second, given an update on the size and completion
* times of the latest batch.
*
* @param time The timetamp of the current batch interval that just finished
* @param elements The number of elements that were processed in this batch
* @param processingDelay The time in ms that took for the job to complete
* @param schedulingDelay The time in ms that the job spent in the scheduling queue
*/
def compute(time: Long, elements: Long,
processingDelay: Long, schedulingDelay: Long): Option[Double]
def compute(
time: Long,
elements: Long,
processingDelay: Long,
schedulingDelay: Long): Option[Double]
}

/**
* The trivial rate estimator never sends an update
* The trivial rate estimator never sends an update
*/
private[streaming] class NoopRateEstimator extends RateEstimator {

def compute(time: Long, elements: Long,
processingDelay: Long, schedulingDelay: Long): Option[Double] = None
def compute(
time: Long,
elements: Long,
processingDelay: Long,
schedulingDelay: Long): Option[Double] = None
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.scheduler

import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{StreamingContext, TestOutputStreamWithPartitions, TestSuiteBase, Time}
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.scheduler.rate.RateEstimator

class RateControllerSuite extends TestSuiteBase {

test("rate controller publishes updates") {
val ssc = new StreamingContext(conf, batchDuration)
val dstream = new MockRateLimitDStream(ssc)
val output = new TestOutputStreamWithPartitions(dstream)
output.register()
runStreams(ssc, 1, 1)

eventually(timeout(2.seconds)) {
assert(dstream.publishCalls === 1)
}
}
}

/**
* An InputDStream that counts how often its rate controller `publish` method was called.
*/
private class MockRateLimitDStream(@transient ssc: StreamingContext)
extends InputDStream[Int](ssc) {

@volatile
var publishCalls = 0

private object ConstantEstimator extends RateEstimator {
def compute(
time: Long,
elements: Long,
processingDelay: Long,
schedulingDelay: Long): Option[Double] = {
Some(100.0)
}
}

override val rateController: RateController = new RateController(id, ConstantEstimator) {
override def publish(rate: Long): Unit = {
publishCalls += 1
}
}

def compute(validTime: Time): Option[RDD[Int]] = {
val data = Seq(1)
ssc.scheduler.inputInfoTracker.reportInfo(validTime, StreamInputInfo(id, data.size))
Some(ssc.sc.parallelize(data))
}

def stop(): Unit = {}

def start(): Unit = {}
}

0 comments on commit 34a389d

Please sign in to comment.