diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala index 2b8cdb72b8d0e..a96e2924a0b44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala @@ -22,12 +22,20 @@ import scala.annotation.tailrec import java.io.OutputStream import java.util.concurrent.TimeUnit._ +import org.apache.spark.Logging + + private[streaming] -class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { - val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) - val CHUNK_SIZE = 8192 - var lastSyncTime = System.nanoTime - var bytesWrittenSinceSync: Long = 0 +class RateLimitedOutputStream(out: OutputStream, desiredBytesPerSec: Int) + extends OutputStream + with Logging { + + require(desiredBytesPerSec > 0) + + private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + private val CHUNK_SIZE = 8192 + private var lastSyncTime = System.nanoTime + private var bytesWrittenSinceSync = 0L override def write(b: Int) { waitToWrite(1) @@ -59,9 +67,9 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu @tailrec private def waitToWrite(numBytes: Int) { val now = System.nanoTime - val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) - val rate = bytesWrittenSinceSync.toDouble / elapsedSecs - if (rate < bytesPerSec) { + val elapsedNanosecs = math.max(now - lastSyncTime, 1) + val rate = bytesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs + if (rate < desiredBytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes if (now > lastSyncTime + SYNC_INTERVAL) { @@ -71,13 +79,14 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu } } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. - // Based on throttler in Kafka - // scalastyle:off - // (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) - // scalastyle:on - val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), - SECONDS) - if (sleepTime > 0) Thread.sleep(sleepTime) + val targetTimeInMillis = bytesWrittenSinceSync * 1000 / desiredBytesPerSec + val elapsedTimeInMillis = elapsedNanosecs / 1000000 + val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis + if (sleepTimeInMillis > 0) { + logTrace("Natural rate is " + rate + " per second but desired rate is " + + desiredBytesPerSec + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") + Thread.sleep(sleepTimeInMillis) + } waitToWrite(numBytes) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index e5bf6d70db5f9..7d18a0fcf7ba8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.util -import org.scalatest.FunSuite import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ +import org.scalatest.FunSuite + class RateLimitedOutputStreamSuite extends FunSuite { private def benchmark[U](f: => U): Long = { @@ -29,12 +30,14 @@ class RateLimitedOutputStreamSuite extends FunSuite { System.nanoTime - start } - ignore("write") { + test("write") { val underlying = new ByteArrayOutputStream val data = "X" * 41000 - val stream = new RateLimitedOutputStream(underlying, 10000) + val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } - assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) - assert(underlying.toString("UTF-8") == data) + + // We accept anywhere from 4.0 to 4.99999 seconds since the value is rounded down. + assert(SECONDS.convert(elapsedNs, NANOSECONDS) === 4) + assert(underlying.toString("UTF-8") === data) } }