Skip to content

Commit

Permalink
par-eval-v3 - remove unneeded cancellation handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nikiforo committed Nov 21, 2021
1 parent 4cb3ad4 commit 921460c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 50 deletions.
24 changes: 12 additions & 12 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2042,13 +2042,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
maxConcurrent: Int
)(f: O => F2[O2])(implicit F: Concurrent[F2]): Stream[F2, O2] = {

def init(ch: Channel[F2, F2[Outcome[F2, Throwable, O2]]], release: F2[Unit]) =
Deferred[F2, Outcome[F2, Throwable, O2]].flatTap { value =>
def init(ch: Channel[F2, F2[Either[Throwable, O2]]], release: F2[Unit]) =
Deferred[F2, Either[Throwable, O2]].flatTap { value =>
ch.send(release *> value.get)
}

def send(v: Deferred[F2, Outcome[F2, Throwable, O2]]) =
(el: Outcome[F2, Throwable, O2]) => v.complete(el).void
def send(v: Deferred[F2, Either[Throwable, O2]]) =
(el: Either[Throwable, O2]) => v.complete(el).void

parEvalMapAction(maxConcurrent, f)((ch, release) => init(ch, release).map(send))
}
Expand All @@ -2070,8 +2070,8 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,

val init = ().pure[F2]

def send(ch: Channel[F2, F2[Outcome[F2, Throwable, O2]]], release: F2[Unit]) =
(el: Outcome[F2, Throwable, O2]) => release <* ch.send(el.pure[F2])
def send(ch: Channel[F2, F2[Either[Throwable, O2]]], release: F2[Unit]) =
(el: Either[Throwable, O2]) => release <* ch.send(el.pure[F2])

parEvalMapAction(maxConcurrent, f)((ch, release) => init.as(send(ch, release)))
}
Expand All @@ -2081,9 +2081,9 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
f: O => F2[O2]
)(
initFork: (
Channel[F2, F2[Outcome[F2, Throwable, O2]]],
Channel[F2, F2[Either[Throwable, O2]]],
F2[Unit]
) => F2[Outcome[F2, Throwable, O2] => F2[Unit]]
) => F2[Either[Throwable, O2] => F2[Unit]]
)(implicit F: Concurrent[F2]): Stream[F2, O2] =
if (maxConcurrent == 1) evalMap(f)
else {
Expand All @@ -2094,7 +2094,7 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
val action =
(
Semaphore[F2](concurrency.toLong),
Channel.bounded[F2, F2[Outcome[F2, Throwable, O2]]](concurrency),
Channel.bounded[F2, F2[Either[Throwable, O2]]](concurrency),
Deferred[F2, Unit],
Deferred[F2, Unit]
).mapN { (semaphore, channel, stop, end) =>
Expand All @@ -2111,8 +2111,8 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
Deferred[F2, Unit].flatMap { pushed =>
val init = initFork(channel, pushed.complete(()).void)
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
val action = f(el).guaranteeCase(send) *> pushed.get
F.start(stop.get.race(action).guarantee(releaseAndCheckCompletion))
val action = f(el).attempt.flatMap(send) *> pushed.get
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
}
}
}
Expand All @@ -2126,7 +2126,7 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
case _ => stop.complete(()) *> releaseAndCheckCompletion
}

val foreground = channel.stream.evalMap(_.flatMap(_.embed(F.canceled >> F.never)))
val foreground = channel.stream.evalMap(_.rethrow)
foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background)
}

Expand Down
97 changes: 59 additions & 38 deletions core/shared/src/test/scala/fs2/StreamSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -987,11 +987,7 @@ class StreamSuite extends Fs2Suite {
}

private implicit class verifyOps[T](val action: IO[T]) {
def assertNotCompletes(): IO[Unit] =
IO.race(IO.sleep(1.second), action).assertEquals(Left(()))

def assertCompletes(expected: T): IO[Unit] =
IO.race(IO.sleep(1.second), action).assertEquals(Right(expected))
def assertNotCompletes(): IO[Unit] = IO.race(IO.sleep(1.second), action).assertEquals(Left(()))
}

val u: IO[Unit] = ().pure[IO]
Expand All @@ -1005,7 +1001,7 @@ class StreamSuite extends Fs2Suite {

test("can exceed maxConcurrent in parEvalMapUnordered") {
val action = run(_.parEvalMapUnordered(2)(identity))
action.assertCompletes(Right(()))
action.assertEquals(Right(()))
}

def run(pipe: Pipe[IO, IO[Unit], Unit]): IO[Either[Unit, Unit]] =
Expand Down Expand Up @@ -1046,7 +1042,7 @@ class StreamSuite extends Fs2Suite {
val parallel = math.abs(p % 20) + 2
val requested = math.min(length, parallel)
val action = runWithLatch(length, requested, _.parEvalMapUnordered(parallel)(identity))
action.assertCompletes(())
action.assertEquals(())
}
}

Expand All @@ -1056,7 +1052,7 @@ class StreamSuite extends Fs2Suite {
val parallel = math.abs(p % 20) + 2
val requested = math.min(length, parallel)
val action = runWithLatch(length, requested, _.parEvalMap(parallel)(identity))
action.assertCompletes(())
action.assertEquals(())
}
}

Expand Down Expand Up @@ -1090,25 +1086,26 @@ class StreamSuite extends Fs2Suite {
test("parEvalMapUnordered") {
forAllF { (i: Int) =>
val amount = math.abs(i % 10) + 1
CountDownLatch[IO](amount).flatMap { latch =>
val stream = Stream(latch.release *> latch.await *> ex).repeatN(amount).covary[IO]
stream
.parEvalMapUnordered(amount)(identity)
.compile
.drain
.intercept[RuntimeException]
.void
}
CountDownLatch[IO](amount)
.flatMap { latch =>
val stream = Stream(latch.release *> latch.await *> ex).repeatN(amount).covary[IO]
stream.parEvalMapUnordered(amount)(identity).compile.drain
}
.intercept[RuntimeException]
.void
}
}

test("parEvalMap") {
forAllF { (i: Int) =>
val amount = math.abs(i % 10) + 1
CountDownLatch[IO](amount).flatMap { latch =>
val stream = Stream(latch.release *> latch.await *> ex).repeatN(amount).covary[IO]
stream.parEvalMap(amount)(identity).compile.drain.intercept[RuntimeException].void
}
CountDownLatch[IO](amount)
.flatMap { latch =>
val stream = Stream(latch.release *> latch.await *> ex).repeatN(amount).covary[IO]
stream.parEvalMap(amount)(identity).compile.drain
}
.intercept[RuntimeException]
.void
}
}
}
Expand All @@ -1123,11 +1120,34 @@ class StreamSuite extends Fs2Suite {
}

def check(pipe: Pipe[IO, IO[Unit], Unit]) =
Deferred[IO, Unit].flatMap { d =>
val simple = Stream(u, (d.get *> ex).uncancelable).covary[IO]
val stream = simple.through(pipe).take(1).productL(Stream.eval(d.complete(()).void))
stream.compile.toList.assertEquals(List(()))
}
IO.deferred[Unit]
.flatMap { d =>
val simple = Stream(u, (d.get *> ex).uncancelable).covary[IO]
val stream = simple.through(pipe).take(1).productL(Stream.eval(d.complete(()).void))
stream.compile.toList
}
.assertEquals(List(()))
}

group("cancels running computations when error raised") {

test("parEvalMapUnordered") {
check(_.parEvalMap(Int.MaxValue)(identity))
}

test("parEvalMap") {
check(_.parEvalMap(Int.MaxValue)(identity))
}

def check(pipe: Pipe[IO, IO[Unit], Unit]) =
(CountDownLatch[IO](2), IO.deferred[Unit])
.mapN { (latch, d) =>
val w = latch.release *> latch.await
val s = Stream(w *> ex, w *> IO.never.onCancel(d.complete(()).void)).covary[IO]
IO.race(pipe(s).compile.drain, d.get)
}
.flatten
.assertEquals(Right(()))
}

group("cancels unneeded") {
Expand All @@ -1136,17 +1156,19 @@ class StreamSuite extends Fs2Suite {
check(_.parEvalMapUnordered(2)(identity))
}

test("parEvalMapUnordered") {
test("parEvalMap") {
check(_.parEvalMap(2)(identity))
}

def check(pipe: Pipe[IO, IO[Unit], Unit]) =
Deferred[IO, Unit].flatMap { d =>
val cancelled = IO.never.onCancel(d.complete(()).void)
val stream = Stream(u, cancelled).covary[IO]
val action = stream.through(pipe).take(1).compile.drain
action *> d.get.assertCompletes(())
}
IO.deferred[Unit]
.flatMap { d =>
val cancelled = IO.never.onCancel(d.complete(()).void)
val stream = Stream(u, cancelled).covary[IO]
val action = stream.through(pipe).take(1).compile.drain
action *> d.get
}
.assertEquals(())
}

group("waits for uncancellable completion") {
Expand All @@ -1162,19 +1184,18 @@ class StreamSuite extends Fs2Suite {
val uncancMsg = "uncancellable"
val onFin2Msg = "onFin2"

Ref[IO]
.of(List.empty[String])
IO.ref(Vector.empty[String])
.flatMap { ref =>
val io = ref.update(uncancMsg :: _).void
val onFin2 = ref.update(onFin2Msg :: _)
val io = ref.update(_ :+ uncancMsg).void
val onFin2 = ref.update(_ :+ onFin2Msg)
CountDownLatch[IO](2).flatMap { latch =>
val w = latch.release *> latch.await
val stream = Stream(w *> u, (w *> io).uncancelable).covary[IO]
val action = stream.through(pipe).take(1).compile.drain <* onFin2
action *> ref.get
}
}
.assertEquals(List(onFin2Msg, uncancMsg))
.assertEquals(Vector(uncancMsg, onFin2Msg))
}
}
}

0 comments on commit 921460c

Please sign in to comment.