Skip to content

Commit

Permalink
Merge pull request #2626 from nikiforo/par-eval-performance
Browse files Browse the repository at this point in the history
parEvalMapUnordered performance
  • Loading branch information
mpilquist authored Sep 23, 2021
2 parents 1de1e8b + e693eed commit 2ac85a2
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 5 deletions.
49 changes: 49 additions & 0 deletions benchmark/src/main/scala/fs2/benchmark/ParEvalMapBenchmark.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2013 Functional Streams for Scala
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2
package benchmark

import cats.effect.IO
import cats.effect.unsafe.implicits.global
import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, State}

@State(Scope.Thread)
class ParEvalMapBenchmark {
@Param(Array("100", "10000"))
var size: Int = _

@Param(Array("10", "100"))
var chunkSize: Int = _

private def dummyLoad = IO.delay(())

@Benchmark
def evalMap(): Unit =
execute(getStream.evalMap(_ => dummyLoad))

@Benchmark
def parEvalMapUnordered10(): Unit =
execute(getStream.parEvalMapUnordered(10)(_ => dummyLoad))

private def getStream: Stream[IO, Unit] = Stream.constant((), chunkSize).take(size).covary[IO]
private def execute(s: Stream[IO, Unit]): Unit = s.compile.drain.unsafeRunSync()
}
6 changes: 6 additions & 0 deletions core/shared/src/main/scala/fs2/CompositeFailure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ object CompositeFailure {
}
}

def fromNel(errors: NonEmptyList[Throwable]): Throwable =
errors match {
case NonEmptyList(hd, Nil) => hd
case NonEmptyList(first, second :: rest) => apply(first, second, rest)
}

def fromList(errors: List[Throwable]): Option[Throwable] =
errors match {
case Nil => None
Expand Down
80 changes: 75 additions & 5 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ package fs2
import scala.annotation.{nowarn, tailrec}
import scala.concurrent.TimeoutException
import scala.concurrent.duration._

import cats.{Eval => _, _}
import cats.data.Ior
import cats.data.{Ior, NonEmptyList}
import cats.effect.{Concurrent, SyncIO}
import cats.effect.kernel._
import cats.effect.kernel.implicits._
import cats.effect.std.{Console, Queue, QueueSink, QueueSource, Semaphore}
import cats.effect.Resource.ExitCase
import cats.syntax.all._

import fs2.compat._
import fs2.concurrent._
import fs2.internal._
Expand Down Expand Up @@ -1757,7 +1755,7 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
]: Concurrent, O2](
maxConcurrent: Int
)(f: O => F2[O2]): Stream[F2, O2] =
map(o => Stream.eval(f(o))).parJoin(maxConcurrent)
parEvalMapUnordered[F2, O2](maxConcurrent)(f)

/** Applies the specified pure function to each chunk in this stream.
*
Expand Down Expand Up @@ -2069,7 +2067,79 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
]: Concurrent, O2](
maxConcurrent: Int
)(f: O => F2[O2]): Stream[F2, O2] =
map(o => Stream.eval(f(o))).parJoin(maxConcurrent)
if (maxConcurrent == 1) evalMap(f)
else {
assert(maxConcurrent > 0, "maxConcurrent must be > 0, was: " + maxConcurrent)

// One is taken by inner stream read.
val concurrency = if (maxConcurrent == Int.MaxValue) Int.MaxValue else maxConcurrent + 1
val action =
(
Semaphore[F2](concurrency.toLong),
Channel.bounded[F2, O2](concurrency),
Ref[F2].of(none[Either[NonEmptyList[Throwable], Unit]]),
Deferred[F2, Unit]
).mapN { (semaphore, channel, result, stopReading) =>
val releaseAndCheckCompletion =
semaphore.release *>
semaphore.available
.product(result.get)
.flatMap {
case (`concurrency`, Some(_)) => channel.close.void
case _ => ().pure[F2]
}

val succeed =
result.update {
case None => ().asRight.some
case other => other
}

val cancelled = stopReading.complete(()) *> succeed

def failed(ex: Throwable) =
stopReading.complete(()) *>
result.update {
case Some(Left(nel)) => nel.prepend(ex).asLeft.some
case _ => NonEmptyList.one(ex).asLeft.some
}

val completeStream =
Stream.force {
result.get.map {
case Some(Left(nel)) => Stream.raiseError[F2](CompositeFailure.fromNel(nel))
case _ => Stream.empty
}
}

def forkOnElem(el: O): F2[Unit] =
semaphore.acquire *>
f(el).attempt
.race(stopReading.get)
.flatMap {
case Left(Left(ex)) => failed(ex)
case Left(Right(a)) => channel.send(a).void
case Right(_) => ().pure[F2]
}
.guarantee(releaseAndCheckCompletion)
.start
.void

val background =
Stream.exec(semaphore.acquire) ++
interruptWhen(stopReading.get.map(_.asRight[Throwable]))
.foreach(forkOnElem)
.onFinalizeCase {
case ExitCase.Succeeded => succeed *> releaseAndCheckCompletion
case ExitCase.Errored(ex) => failed(ex) *> releaseAndCheckCompletion
case ExitCase.Canceled => cancelled *> releaseAndCheckCompletion
}

channel.stream.concurrently(background) ++ completeStream
}

Stream.force(action)
}

/** Concurrent zip.
*
Expand Down

0 comments on commit 2ac85a2

Please sign in to comment.