Skip to content

Commit

Permalink
feat: implement fold operator
Browse files Browse the repository at this point in the history
The `fold` operation returns combined value retrieved from running function
`f` on all source elements in a cumulative manner where result of the previous
call is used as an input value to the next e.g.:

  Source.empty[Int].fold(0)((acc, n) => acc + n)       // 0
  Source.fromValues(2, 3).fold(5)((acc, n) => acc - n) // 0

Note that in case when `receive()` operation fails then
ChannelClosedException.Error exception is thrown. Wheres in case when
function `f` throws then this exception is propagated up to the caller.
  • Loading branch information
geminicaprograms committed Oct 30, 2023
1 parent 856823e commit ba005a8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
36 changes: 36 additions & 0 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,42 @@ trait SourceOps[+T] { this: Source[T] =>
* }}}
*/
def last(): T = lastOption().getOrElse(throw new NoSuchElementException("cannot obtain last element from an empty source"))

/** Uses `zero` as the current value and applies function `f` on it and a value received from this source. The returned value is used as
* the next current value and `f` is applied again with the value received from a source. The operation is repeated until the source is
* drained.
*
* @param zero
* An initial value to be used as the first argument to function `f` call.
* @param f
* A binary function (a function that takes two arguments) that is applied to the current value and value received from a source.
* @return
* Combined value retrieved from running function `f` on all source elements in a cumulative manner where result of the previous call
* is used as an input value to the next.
* @throws ChannelClosedException.Error
* When receiving an element from this source fails.
* @throws exception
* When function `f` throws an `exception` then it is propagated up to the caller.
* @example
* {{{
* import ox.*
* import ox.channels.Source
*
* supervised {
* Source.empty[Int].fold(0)((acc, n) => acc + n) // 0
* Source.fromValues(2, 3).fold(5)((acc, n) => acc - n) // 0
* }
* }}}
*/
def fold[U](zero: U)(f: (U, T) => U): U =
var current = zero
repeatWhile {
receive() match
case ChannelClosed.Done => false
case e: ChannelClosed.Error => throw e.toThrowable
case t: T @unchecked => current = f(current, t); true
}
current
}

trait SourceCompanionOps:
Expand Down
47 changes: 47 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsFoldTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsFoldTest extends AnyFlatSpec with Matchers {
behavior of "Source.fold"

it should "throw ChannelClosedException.Error with exception and message that was thrown during retrieval" in supervised {
the[ChannelClosedException.Error] thrownBy {
Source
.failed[Int](new RuntimeException("source is broken"))
.fold(0)((acc, n) => acc + n)
} should have message "java.lang.RuntimeException: source is broken"
}

it should "throw ChannelClosedException.Error for source failed without exception" in supervised {
the[ChannelClosedException.Error] thrownBy {
Source
.failedWithoutReason[Int]()
.fold(0)((acc, n) => acc + n)
}
}

it should "throw exception thrown in `f` when `f` throws" in supervised {
the[RuntimeException] thrownBy {
Source
.fromValues(1)
.fold(0)((_, _) => throw new RuntimeException("Function `f` is broken"))
} should have message "Function `f` is broken"
}

it should "return `zero` value from fold on the empty source" in supervised {
Source.empty[Int].fold(0)((acc, n) => acc + n) shouldBe 0
}

it should "return fold on non-empty source" in supervised {
Source.fromValues(1, 2).fold(0)((acc, n) => acc + n) shouldBe 3
}

it should "drain the source" in supervised {
val s = Source.fromValues(1)
s.fold(0)((acc, n) => acc + n) shouldBe 1
s.receive() shouldBe ChannelClosed.Done
}
}

0 comments on commit ba005a8

Please sign in to comment.