Skip to content

Commit

Permalink
feat: implement throttle function
Browse files Browse the repository at this point in the history
Sends elements to the returned channel limiting the throughput to specific
number of elements (evenly spaced) per time unit. Note that the element's
`receive()` time is included in the resulting throughput. For instance having
`throttle(1, 1.second)` and `receive()` taking `Xms` means that resulting
channel will receive elements every `1s + Xms` time.
Throttling is not applied to the empty source.

Examples:

  Source.empty[Int].throttle(1, 1.second).toList       // List() returned without throttling
  Source.fromValues(1, 2).throttle(1, 1.second).toList // List(1, 2) returned after 2 seconds

Note that implementation relies on `Thread.sleep` that is according to
[1] project Loom compatible.

[1] https://softwaremill.com/what-is-blocking-in-loom/
  • Loading branch information
geminicaprograms committed Oct 26, 2023
1 parent 7d7897a commit dfd7e87
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
40 changes: 40 additions & 0 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,46 @@ trait SourceOps[+T] { this: Source[T] =>
}
}
c

/** Sends elements to the returned channel limiting the throughput to specific number of elements (evenly spaced) per time unit. Note that
* the element's `receive()` time is included in the resulting throughput. For instance having `throttle(1, 1.second)` and `receive()`
* taking `Xms` means that resulting channel will receive elements every `1s + Xms` time. Throttling is not applied to the empty source.
*
* @param elements
* Number of elements to be emitted. Must be greater than 0.
* @param per
* Per time unit. Must be greater or equal to 1 ms.
* @return
* A source that emits at most `elements` `per` time unit.
* @example
* {{{
* import ox.*
* import ox.channels.Source
*
* import scala.concurrent.duration.*
*
* scoped {
* Source.empty[Int].throttle(1, 1.second).toList // List() returned without throttling
* Source.fromValues(1, 2).throttle(1, 1.second).toList // List(1, 2) returned after 2 seconds
* }
* }}}
*/
def throttle(elements: Int, per: FiniteDuration)(using Ox, StageCapacity): Source[T] =
require(elements > 0, "elements must be > 0")
require(per.toMillis > 0, "per time must be >= 1 ms")

val c = StageCapacity.newChannel[T]
val emitEveryMillis = per.toMillis / elements

forkDaemon {
repeatWhile {
receive() match
case ChannelClosed.Done => c.done(); false
case ChannelClosed.Error(r) => c.error(r); false
case t: T @unchecked => Thread.sleep(emitEveryMillis); c.send(t); true
}
}
c
}

trait SourceCompanionOps:
Expand Down
46 changes: 46 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsThrottleTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

import java.util.concurrent.TimeUnit
import scala.concurrent.duration.*

class SourceOpsThrottleTest extends AnyFlatSpec with Matchers {
behavior of "Source.throttle"

it should "not throttle the empty source" in supervised {
val s = Source.empty[Int]
val (result, executionTime) = measure { s.throttle(1, 1.second).toList }
result shouldBe List.empty
executionTime.toMillis should be < 1.second.toMillis
}

it should "throttle to specified elements per time units" in supervised {
val s = Source.fromValues(1, 2)
val (result, executionTime) = measure { s.throttle(1, 50.millis).toList }
result shouldBe List(1, 2)
executionTime.toMillis should (be >= 100L and be <= 150L)
}

it should "fail to throttle when elements <= 0" in supervised {
val s = Source.empty[Int]
the[IllegalArgumentException] thrownBy {
s.throttle(-1, 50.millis)
} should have message "requirement failed: elements must be > 0"
}

it should "fail to throttle when per lower than 1ms" in supervised {
val s = Source.empty[Int]
the[IllegalArgumentException] thrownBy {
s.throttle(1, 50.nanos)
} should have message "requirement failed: per time must be >= 1 ms"
}

private def measure[T](f: => T): (T, Duration) =
val before = System.currentTimeMillis()
val result = f
val after = System.currentTimeMillis();
(result, FiniteDuration.apply(after - before, TimeUnit.MILLISECONDS))
}

0 comments on commit dfd7e87

Please sign in to comment.