diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 132a11e439..9a8f32b557 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -26,7 +26,7 @@ import scala.concurrent.TimeoutException import scala.concurrent.duration._ import cats.{Eval => _, _} import cats.data.Ior -import cats.effect.{Concurrent, SyncIO} +import cats.effect.{Concurrent, IO, SyncIO} import cats.effect.kernel._ import cats.effect.kernel.implicits._ import cats.effect.std.{Console, Queue, QueueSink, QueueSource, Semaphore} @@ -2041,10 +2041,15 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, def parEvalMap[F2[x] >: F[x], O2]( maxConcurrent: Int )(f: O => F2[O2])(implicit F: Concurrent[F2]): Stream[F2, O2] = { - def init(ch: Channel[F2, F2[Either[Throwable, O2]]], release: F2[Unit]) = - Deferred[F2, Either[Throwable, O2]].flatTap(value => ch.send(value.get <* release)) - def send(v: Deferred[F2, Either[Throwable, O2]]) = - (el: Either[Throwable, O2]) => v.complete(el).void + + def init(ch: Channel[F2, F2[Outcome[F2, Throwable, O2]]], release: F2[Unit]) = + Deferred[F2, Outcome[F2, Throwable, O2]].flatTap { value => + ch.send(release *> value.get) + } + + def send(v: Deferred[F2, Outcome[F2, Throwable, O2]]) = + (el: Outcome[F2, Throwable, O2]) => v.complete(el).void + parEvalMapAction(maxConcurrent, f)((ch, release) => init(ch, release).map(send)) } @@ -2059,31 +2064,29 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, * res0: Unit = () * }}} */ - def parEvalMapUnordered[F2[x] >: F[ - x - ]: Concurrent, O2]( + def parEvalMapUnordered[F2[x] >: F[x], O2]( maxConcurrent: Int - )(f: O => F2[O2]): Stream[F2, O2] = { + )(f: O => F2[O2])(implicit F: Concurrent[F2]): Stream[F2, O2] = { + val init = ().pure[F2] - def send(ch: Channel[F2, F2[Either[Throwable, O2]]], release: F2[Unit]) = - (el: Either[Throwable, O2]) => ch.send(el.pure[F2]) *> release + + def send(ch: Channel[F2, F2[Outcome[F2, Throwable, O2]]], release: F2[Unit]) = + (el: Outcome[F2, Throwable, O2]) => release <* ch.send(el.pure[F2]) + parEvalMapAction(maxConcurrent, f)((ch, release) => init.as(send(ch, release))) } - private def parEvalMapAction[F2[x] >: F[ - x - ]: Concurrent, O2, T]( + private def parEvalMapAction[F2[x] >: F[x], O2, T]( maxConcurrent: Int, f: O => F2[O2] )( initFork: ( - Channel[F2, F2[Either[Throwable, O2]]], + Channel[F2, F2[Outcome[F2, Throwable, O2]]], F2[Unit] - ) => F2[Either[Throwable, O2] => F2[Unit]] - ): Stream[F2, O2] = + ) => F2[Outcome[F2, Throwable, O2] => F2[Unit]] + )(implicit F: Concurrent[F2]): Stream[F2, O2] = if (maxConcurrent == 1) evalMap(f) else { - val F = Concurrent[F2] assert(maxConcurrent > 0, "maxConcurrent must be > 0, was: " + maxConcurrent) // One is taken by inner stream read. @@ -2091,7 +2094,7 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, val action = ( Semaphore[F2](concurrency.toLong), - Channel.bounded[F2, F2[Either[Throwable, O2]]](concurrency), + Channel.bounded[F2, F2[Outcome[F2, Throwable, O2]]](concurrency), Deferred[F2, Unit], Deferred[F2, Unit] ).mapN { (semaphore, channel, stop, end) => @@ -2108,8 +2111,8 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, Deferred[F2, Unit].flatMap { pushed => val init = initFork(channel, pushed.complete(()).void) poll(init).onCancel(releaseAndCheckCompletion).flatMap { send => - val action = stop.get.race(f(el).attempt.flatMap(send) <* pushed.get) - F.start(poll(action).guarantee(releaseAndCheckCompletion)) + val action = f(el).guaranteeCase(send) *> pushed.get + F.start(stop.get.race(action).guarantee(releaseAndCheckCompletion)) } } } @@ -2123,7 +2126,7 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, case _ => stop.complete(()) *> releaseAndCheckCompletion } - val foreground = channel.stream.evalMap(identity).rethrow + val foreground = channel.stream.evalMap(_.flatMap(_.embed(F.canceled >> F.never))) foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background) } diff --git a/core/shared/src/test/scala/fs2/StreamSuite.scala b/core/shared/src/test/scala/fs2/StreamSuite.scala index 4348c458d3..cb2e2f82d9 100644 --- a/core/shared/src/test/scala/fs2/StreamSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamSuite.scala @@ -1019,12 +1019,8 @@ class StreamSuite extends Fs2Suite { test("should be preserved in parEvalMap") { forAllF { s: Stream[Pure, Int] => - s.zipWithIndex - .covary[IO] - .parEvalMap(Int.MaxValue) { case (i, ind) => IO.sleep((ind % 3).millis).as(i) } - .compile - .toList - .assertEquals(s.toList) + val s2 = s.covary[IO].parEvalMap(Int.MaxValue)(i => IO.sleep(math.abs(i % 3).millis).as(i)) + s2.compile.toList.assertEquals(s.toList) } }