Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp implementation of parEvalMap, fixing numerous issues and increasing performance #2673

Merged
merged 8 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class ParEvalMapBenchmark {
def evalMap(): Unit =
execute(getStream.evalMap(_ => dummyLoad))

@Benchmark
def parEvalMap10(): Unit =
execute(getStream.parEvalMap(10)(_ => dummyLoad))

@Benchmark
def parEvalMapUnordered10(): Unit =
execute(getStream.parEvalMapUnordered(10)(_ => dummyLoad))
Expand Down
140 changes: 54 additions & 86 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import scala.annotation.{nowarn, tailrec}
import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import cats.{Eval => _, _}
import cats.data.{Ior, NonEmptyList}
import cats.effect.{Concurrent, SyncIO}
import cats.data.Ior
import cats.effect.{Concurrent, IO, SyncIO}
import cats.effect.kernel._
import cats.effect.kernel.implicits._
import cats.effect.std.{Console, Queue, QueueSink, QueueSource, Semaphore}
Expand Down Expand Up @@ -2054,40 +2054,18 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
*/
def parEvalMap[F2[x] >: F[x], O2](
maxConcurrent: Int
)(f: O => F2[O2])(implicit F: Concurrent[F2]): Stream[F2, O2] =
if (maxConcurrent === 1) evalMap(f)
else {
val fstream: F2[Stream[F2, O2]] = for {
chan <- Channel.bounded[F2, F2[Either[Throwable, O2]]](maxConcurrent)
chanReadDone <- F.deferred[Unit]
} yield {
def forkOnElem(o: O): F2[Stream[F2, Unit]] =
for {
value <- F.deferred[Either[Throwable, O2]]
send = chan.send(value.get).as {
Stream.eval(f(o).attempt.flatMap(value.complete(_).void))
}
eit <- chanReadDone.get.race(send)
} yield eit match {
case Left(()) => Stream.empty
case Right(stream) => stream
}
)(f: O => F2[O2])(implicit F: Concurrent[F2]): Stream[F2, O2] = {

val background = this
.evalMap(forkOnElem)
.parJoin(maxConcurrent)
.onFinalize(chanReadDone.get.race(chan.close).void)
def init(ch: Channel[F2, F2[Either[Throwable, O2]]], release: F2[Unit]) =
Deferred[F2, Either[Throwable, O2]].flatTap { value =>
ch.send(release *> value.get)
}

val foreground =
chan.stream
.evalMap(identity)
.rethrow
.onFinalize(chanReadDone.complete(()).void)
def send(v: Deferred[F2, Either[Throwable, O2]]) =
(el: Either[Throwable, O2]) => v.complete(el).void

foreground.concurrently(background)
}
Stream.eval(fstream).flatten
}
parEvalMapAction(maxConcurrent, f)((ch, release) => init(ch, release).map(send))
}

/** Like [[Stream#evalMap]], but will evaluate effects in parallel, emitting the results
* downstream. The number of concurrent effects is limited by the `maxConcurrent` parameter.
Expand All @@ -2100,11 +2078,27 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
* res0: Unit = ()
* }}}
*/
def parEvalMapUnordered[F2[x] >: F[
x
]: Concurrent, O2](
def parEvalMapUnordered[F2[x] >: F[x], O2](
maxConcurrent: Int
)(f: O => F2[O2]): Stream[F2, O2] =
)(f: O => F2[O2])(implicit F: Concurrent[F2]): Stream[F2, O2] = {

val init = ().pure[F2]

def send(ch: Channel[F2, F2[Either[Throwable, O2]]], release: F2[Unit]) =
(el: Either[Throwable, O2]) => release <* ch.send(el.pure[F2])

parEvalMapAction(maxConcurrent, f)((ch, release) => init.as(send(ch, release)))
}

private def parEvalMapAction[F2[x] >: F[x], O2, T](
maxConcurrent: Int,
f: O => F2[O2]
)(
initFork: (
Channel[F2, F2[Either[Throwable, O2]]],
F2[Unit]
) => F2[Either[Throwable, O2] => F2[Unit]]
)(implicit F: Concurrent[F2]): Stream[F2, O2] =
if (maxConcurrent == 1) evalMap(f)
else {
assert(maxConcurrent > 0, "maxConcurrent must be > 0, was: " + maxConcurrent)
Expand All @@ -2114,66 +2108,40 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
val action =
(
Semaphore[F2](concurrency.toLong),
Channel.bounded[F2, O2](concurrency),
Ref[F2].of(none[Either[NonEmptyList[Throwable], Unit]]),
Channel.bounded[F2, F2[Either[Throwable, O2]]](concurrency),
Deferred[F2, Unit],
Deferred[F2, Unit]
).mapN { (semaphore, channel, result, stopReading) =>
).mapN { (semaphore, channel, stop, end) =>
val releaseAndCheckCompletion =
semaphore.release *>
semaphore.available
.product(result.get)
.flatMap {
case (`concurrency`, Some(_)) => channel.close.void
case _ => ().pure[F2]
}

val succeed =
result.update {
case None => ().asRight.some
case other => other
}

val cancelled = stopReading.complete(()) *> succeed

def failed(ex: Throwable) =
stopReading.complete(()) *>
result.update {
case Some(Left(nel)) => nel.prepend(ex).asLeft.some
case _ => NonEmptyList.one(ex).asLeft.some
}

val completeStream =
Stream.force {
result.get.map {
case Some(Left(nel)) => Stream.raiseError[F2](CompositeFailure.fromNel(nel))
case _ => Stream.empty
semaphore.available.flatMap {
case `concurrency` => channel.close *> end.complete(()).void
case _ => ().pure[F2]
}
}

def forkOnElem(el: O): F2[Unit] =
semaphore.acquire *>
f(el).attempt
.race(stopReading.get)
.flatMap {
case Left(Left(ex)) => failed(ex)
case Left(Right(a)) => channel.send(a).void
case Right(_) => ().pure[F2]
F.uncancelable { poll =>
poll(semaphore.acquire) <*
Deferred[F2, Unit].flatMap { pushed =>
val init = initFork(channel, pushed.complete(()).void)
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
val action = f(el).attempt.flatMap(send) *> pushed.get
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
}
}
.guarantee(releaseAndCheckCompletion)
.start
.void
}

val background =
Stream.exec(semaphore.acquire) ++
interruptWhen(stopReading.get.map(_.asRight[Throwable]))
interruptWhen(stop.get.map(_.asRight[Throwable]))
.foreach(forkOnElem)
.onFinalizeCase {
case ExitCase.Succeeded => succeed *> releaseAndCheckCompletion
case ExitCase.Errored(ex) => failed(ex) *> releaseAndCheckCompletion
case ExitCase.Canceled => cancelled *> releaseAndCheckCompletion
case ExitCase.Succeeded => releaseAndCheckCompletion
case _ => stop.complete(()) *> releaseAndCheckCompletion
}

channel.stream.concurrently(background) ++ completeStream
val foreground = channel.stream.evalMap(_.rethrow)
foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background)
}

Stream.force(action)
Expand Down Expand Up @@ -4487,17 +4455,17 @@ object Stream extends StreamLowPriority {
*
* As a quick example, let's write a timed pull which emits the
* string "late!" whenever a chunk of the stream is not emitted
* within 150 milliseconds:
* within 450 milliseconds:
*
* @example {{{
* scala> import cats.effect.IO
* scala> import cats.effect.unsafe.implicits.global
* scala> import scala.concurrent.duration._
* scala> val s = (Stream("elem") ++ Stream.sleep_[IO](200.millis)).repeat.take(3)
* scala> val s = (Stream("elem") ++ Stream.sleep_[IO](600.millis)).repeat.take(3)
* scala> s.pull
* | .timed { timedPull =>
* | def go(timedPull: Pull.Timed[IO, String]): Pull[IO, String, Unit] =
* | timedPull.timeout(150.millis) >> // starts new timeout and stops the previous one
* | timedPull.timeout(450.millis) >> // starts new timeout and stops the previous one
* | timedPull.uncons.flatMap {
* | case Some((Right(elems), next)) => Pull.output(elems) >> go(next)
* | case Some((Left(_), next)) => Pull.output1("late!") >> go(next)
Expand Down
Loading