Skip to content

Commit

Permalink
Merge pull request #24 from softwaremill/feat_throttle
Browse files Browse the repository at this point in the history
feat: implement `throttle` function
  • Loading branch information
adamw authored Oct 30, 2023
2 parents e28a268 + 13f8b50 commit 931591f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
40 changes: 40 additions & 0 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,46 @@ trait SourceOps[+T] { this: Source[T] =>
case ChannelClosed.Error(r) => throw r.getOrElse(new NoSuchElementException("getting head failed"))
case t: T @unchecked => t
}

/** Sends elements to the returned channel limiting the throughput to specific number of elements (evenly spaced) per time unit. Note that
* the element's `receive()` time is included in the resulting throughput. For instance having `throttle(1, 1.second)` and `receive()`
* taking `Xms` means that resulting channel will receive elements every `1s + Xms` time. Throttling is not applied to the empty source.
*
* @param elements
* Number of elements to be emitted. Must be greater than 0.
* @param per
* Per time unit. Must be greater or equal to 1 ms.
* @return
* A source that emits at most `elements` `per` time unit.
* @example
* {{{
* import ox.*
* import ox.channels.Source
*
* import scala.concurrent.duration.*
*
* scoped {
* Source.empty[Int].throttle(1, 1.second).toList // List() returned without throttling
* Source.fromValues(1, 2).throttle(1, 1.second).toList // List(1, 2) returned after 2 seconds
* }
* }}}
*/
def throttle(elements: Int, per: FiniteDuration)(using Ox, StageCapacity): Source[T] =
require(elements > 0, "elements must be > 0")
require(per.toMillis > 0, "per time must be >= 1 ms")

val c = StageCapacity.newChannel[T]
val emitEveryMillis = per.toMillis / elements

forkDaemon {
repeatWhile {
receive() match
case ChannelClosed.Done => c.done(); false
case ChannelClosed.Error(r) => c.error(r); false
case t: T @unchecked => Thread.sleep(emitEveryMillis); c.send(t); true
}
}
c
}

trait SourceCompanionOps:
Expand Down
45 changes: 45 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsThrottleTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

import scala.concurrent.duration.*

class SourceOpsThrottleTest extends AnyFlatSpec with Matchers {
behavior of "Source.throttle"

it should "not throttle the empty source" in supervised {
val s = Source.empty[Int]
val (result, executionTime) = measure { s.throttle(1, 1.second).toList }
result shouldBe List.empty
executionTime.toMillis should be < 1.second.toMillis
}

it should "throttle to specified elements per time units" in supervised {
val s = Source.fromValues(1, 2)
val (result, executionTime) = measure { s.throttle(1, 50.millis).toList }
result shouldBe List(1, 2)
executionTime.toMillis should (be >= 100L and be <= 150L)
}

it should "fail to throttle when elements <= 0" in supervised {
val s = Source.empty[Int]
the[IllegalArgumentException] thrownBy {
s.throttle(-1, 50.millis)
} should have message "requirement failed: elements must be > 0"
}

it should "fail to throttle when per lower than 1ms" in supervised {
val s = Source.empty[Int]
the[IllegalArgumentException] thrownBy {
s.throttle(1, 50.nanos)
} should have message "requirement failed: per time must be >= 1 ms"
}

private def measure[T](f: => T): (T, Duration) =
val before = System.currentTimeMillis()
val result = f
val after = System.currentTimeMillis();
(result, (after - before).millis)
}

0 comments on commit 931591f

Please sign in to comment.