From 775f8dbded4468692e6ce0bd5515689620bfbbd4 Mon Sep 17 00:00:00 2001 From: Jacek Centkowski Date: Thu, 26 Oct 2023 18:08:59 +0200 Subject: [PATCH] feat: implement `fold` operator The `fold` operation returns combined value retrieved from running function `f` on all source elements in a cumulative manner where result of the previous call is used as an input value to the next e.g.: Source.empty[Int].fold(0)((acc, n) => acc + n) // 0 Source.fromValues(2, 3).fold(5)((acc, n) => acc - n) // 0 Note that in case when `receive()` operation fails then ChannelClosedException.Error exception is thrown. Wheres in case when function `f` throws then this exception is propagated up to the caller. --- .../main/scala/ox/channels/SourceOps.scala | 36 ++++++++++++++ .../scala/ox/channels/SourceOpsFoldTest.scala | 47 +++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 core/src/test/scala/ox/channels/SourceOpsFoldTest.scala diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index fb077f80..43dc5b2b 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -663,6 +663,42 @@ trait SourceOps[+T] { this: Source[T] => * }}} */ def last(): T = lastOption().getOrElse(throw new NoSuchElementException("cannot obtain last element from an empty source")) + + /** Uses `zero` as the current value and applies function `f` on it and a value received from this source. The returned value is used as + * the next current value and `f` is applied again with the value received from a source. The operation is repeated until the source is + * drained. + * + * @param zero + * An initial value to be used as the first argument to function `f` call. + * @param f + * A binary function (a function that takes two arguments) that is applied to the current value and value received from a source. + * @return + * Combined value retrieved from running function `f` on all source elements in a cumulative manner where result of the previous call + * is used as an input value to the next. + * @throws ChannelClosedException.Error + * When receiving an element from this source fails. + * @throws exception + * When function `f` throws an `exception` then it is propagated up to the caller. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * supervised { + * Source.empty[Int].fold(0)((acc, n) => acc + n) // 0 + * Source.fromValues(2, 3).fold(5)((acc, n) => acc - n) // 0 + * } + * }}} + */ + def fold[U](zero: U)(f: (U, T) => U): U = + var current = zero + repeatWhile { + receive() match + case ChannelClosed.Done => false + case e: ChannelClosed.Error => throw e.toThrowable + case t: T @unchecked => current = f(current, t); true + } + current } trait SourceCompanionOps: diff --git a/core/src/test/scala/ox/channels/SourceOpsFoldTest.scala b/core/src/test/scala/ox/channels/SourceOpsFoldTest.scala new file mode 100644 index 00000000..47c5b61a --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsFoldTest.scala @@ -0,0 +1,47 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsFoldTest extends AnyFlatSpec with Matchers { + behavior of "Source.fold" + + it should "throw ChannelClosedException.Error with exception and message that was thrown during retrieval" in supervised { + the[ChannelClosedException.Error] thrownBy { + Source + .failed[Int](new RuntimeException("source is broken")) + .fold(0)((acc, n) => acc + n) + } should have message "java.lang.RuntimeException: source is broken" + } + + it should "throw ChannelClosedException.Error for source failed without exception" in supervised { + the[ChannelClosedException.Error] thrownBy { + Source + .failedWithoutReason[Int]() + .fold(0)((acc, n) => acc + n) + } + } + + it should "throw exception thrown in `f` when `f` throws" in supervised { + the[RuntimeException] thrownBy { + Source + .fromValues(1) + .fold(0)((_, _) => throw new RuntimeException("Function `f` is broken")) + } should have message "Function `f` is broken" + } + + it should "return `zero` value from fold on the empty source" in supervised { + Source.empty[Int].fold(0)((acc, n) => acc + n) shouldBe 0 + } + + it should "return fold on non-empty source" in supervised { + Source.fromValues(1, 2).fold(0)((acc, n) => acc + n) shouldBe 3 + } + + it should "drain the source" in supervised { + val s = Source.fromValues(1) + s.fold(0)((acc, n) => acc + n) shouldBe 1 + s.receive() shouldBe ChannelClosed.Done + } +}