diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index c2a62240..7bde4707 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -565,6 +565,46 @@ trait SourceOps[+T] { this: Source[T] => case ChannelClosed.Error(r) => throw r.getOrElse(new NoSuchElementException("getting head failed")) case t: T @unchecked => t } + + /** Sends elements to the returned channel limiting the throughput to specific number of elements (evenly spaced) per time unit. Note that + * the element's `receive()` time is included in the resulting throughput. For instance having `throttle(1, 1.second)` and `receive()` + * taking `Xms` means that resulting channel will receive elements every `1s + Xms` time. Throttling is not applied to the empty source. + * + * @param elements + * Number of elements to be emitted. Must be greater than 0. + * @param per + * Per time unit. Must be greater or equal to 1 ms. + * @return + * A source that emits at most `elements` `per` time unit. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * import scala.concurrent.duration.* + * + * scoped { + * Source.empty[Int].throttle(1, 1.second).toList // List() returned without throttling + * Source.fromValues(1, 2).throttle(1, 1.second).toList // List(1, 2) returned after 2 seconds + * } + * }}} + */ + def throttle(elements: Int, per: FiniteDuration)(using Ox, StageCapacity): Source[T] = + require(elements > 0, "elements must be > 0") + require(per.toMillis > 0, "per time must be >= 1 ms") + + val c = StageCapacity.newChannel[T] + val emitEveryMillis = per.toMillis / elements + + forkDaemon { + repeatWhile { + receive() match + case ChannelClosed.Done => c.done(); false + case ChannelClosed.Error(r) => c.error(r); false + case t: T @unchecked => Thread.sleep(emitEveryMillis); c.send(t); true + } + } + c } trait SourceCompanionOps: diff --git a/core/src/test/scala/ox/channels/SourceOpsThrottleTest.scala b/core/src/test/scala/ox/channels/SourceOpsThrottleTest.scala new file mode 100644 index 00000000..6e7d8cd3 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsThrottleTest.scala @@ -0,0 +1,45 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +import scala.concurrent.duration.* + +class SourceOpsThrottleTest extends AnyFlatSpec with Matchers { + behavior of "Source.throttle" + + it should "not throttle the empty source" in supervised { + val s = Source.empty[Int] + val (result, executionTime) = measure { s.throttle(1, 1.second).toList } + result shouldBe List.empty + executionTime.toMillis should be < 1.second.toMillis + } + + it should "throttle to specified elements per time units" in supervised { + val s = Source.fromValues(1, 2) + val (result, executionTime) = measure { s.throttle(1, 50.millis).toList } + result shouldBe List(1, 2) + executionTime.toMillis should (be >= 100L and be <= 150L) + } + + it should "fail to throttle when elements <= 0" in supervised { + val s = Source.empty[Int] + the[IllegalArgumentException] thrownBy { + s.throttle(-1, 50.millis) + } should have message "requirement failed: elements must be > 0" + } + + it should "fail to throttle when per lower than 1ms" in supervised { + val s = Source.empty[Int] + the[IllegalArgumentException] thrownBy { + s.throttle(1, 50.nanos) + } should have message "requirement failed: per time must be >= 1 ms" + } + + private def measure[T](f: => T): (T, Duration) = + val before = System.currentTimeMillis() + val result = f + val after = System.currentTimeMillis(); + (result, (after - before).millis) +}