Skip to content

Commit

Permalink
Fixed race condition between offer, take, and close
Browse files Browse the repository at this point in the history
  • Loading branch information
djspiewak committed Oct 1, 2022
1 parent 2fab62e commit 3ff9ee2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 23 deletions.
79 changes: 56 additions & 23 deletions core/shared/src/main/scala/fs2/concurrent/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package fs2
package concurrent

import cats.Applicative
import cats.effect._
import cats.effect.std.Queue
import cats.effect.syntax.all._
Expand Down Expand Up @@ -132,7 +133,7 @@ object Channel {
Queue.bounded[F, A](capacity).flatMap(impl(_))

private[this] def impl[F[_]: Concurrent, A](q: Queue[F, A]): F[Channel[F, A]] =
(Concurrent[F].deferred[Unit], Concurrent[F].ref(0)).mapN { (closedR, leasesR) =>
(Concurrent[F].deferred[Unit], Lease[F]).mapN { (closedR, lease) =>
new Channel[F, A] {

def sendAll: Pipe[F, A, Nothing] =
Expand All @@ -151,15 +152,15 @@ object Channel {
def isClosed: F[Boolean] = closedR.tryGet.map(_.isDefined)

private[this] val leasesDrained: F[Boolean] =
leasesR.get.map(_ <= 0)
lease.isEmpty

private[this] val isQuiesced: F[Boolean] =
isClosed.ifM(leasesDrained, false.pure[F])

def send(a: A): F[Either[Channel.Closed, Unit]] =
isClosed.ifM(
Channel.Closed.asLeft[Unit].pure[F],
(leasesR.update(_ + 1) *> q.offer(a)).guarantee(leasesR.update(_ - 1)).map(Right(_))
lease.permit(q.offer(a)).map(Right(_))
)

def trySend(a: A): F[Either[Channel.Closed, Boolean]] =
Expand All @@ -169,27 +170,21 @@ object Channel {
val takeN: F[Chunk[A]] =
q.tryTakeN(None).flatMap {
case Nil =>
val fallback = leasesDrained.flatMap { b =>
if (b) {
MonadCancel[F].uncancelable { poll =>
poll(Spawn[F].racePair(q.take, closedR.get)).flatMap {
case Left((oca, fiber)) =>
oca.embedNever.flatMap(a => fiber.cancel.as(Chunk.singleton(a)))

case Right((fiber, ocb)) =>
ocb.embedNever.flatMap { _ =>
(fiber.cancel *> fiber.join).flatMap { oca =>
oca.fold(
Chunk.empty[A].pure[F],
_ => Chunk.empty[A].pure[F],
_.map(Chunk.singleton(_))
)
}
}
val fallback = MonadCancel[F].uncancelable { poll =>
poll(Spawn[F].racePair(q.take, closedR.get.both(lease.await))).flatMap {
case Left((oca, fiber)) =>
oca.embedNever.flatMap(a => fiber.cancel.as(Chunk.singleton(a)))

case Right((fiber, ocb)) =>
ocb.embedNever.flatMap { _ =>
(fiber.cancel *> fiber.join).flatMap { oca =>
oca.fold(
Chunk.empty[A].pure[F],
_ => Chunk.empty[A].pure[F],
_.map(Chunk.singleton(_))
)
}
}
}
} else {
q.take.map(Chunk.singleton(_))
}
}

Expand All @@ -206,4 +201,42 @@ object Channel {
def closed: F[Unit] = closedR.get
}
}

// this is like an inverse semaphore and I'm surprised we haven't added it in std yet
private final class Lease[F[_]: Concurrent] private (
leasesR: Ref[F, Int],
latchR: Ref[F, Deferred[F, Unit]]
) {

def permit[A](fa: F[A]): F[A] =
MonadCancel[F].uncancelable { poll =>
val init = leasesR.modify(i => (i + 1, i + 1)).flatMap { leases =>
if (leases == 1)
Concurrent[F].deferred[Unit].flatMap(latchR.set(_))
else
Applicative[F].unit
}

(init *> poll(fa)).guarantee {
leasesR.modify(i => (i - 1, i - 1)).flatMap { leases =>
if (leases == 0)
latchR.get.flatMap(_.complete(())).void
else
Applicative[F].unit
}
}
}

def isEmpty: F[Boolean] = leasesR.get.map(_ == 0)

def await: F[Unit] = latchR.get.flatMap(_.get)
}

private object Lease {
def apply[F[_]: Concurrent]: F[Lease[F]] =
Concurrent[F].deferred[Unit].flatMap { latch =>
latch.complete(()) *> (Concurrent[F].ref(0), Concurrent[F].ref(latch))
.mapN(new Lease[F](_, _))
}
}
}
15 changes: 15 additions & 0 deletions core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,19 @@ class ChannelSuite extends Fs2Suite {
_ <- IO(assert(count == 6)) // we have to overrun the closure to detect it
} yield ()
}

test("sendPull") {
def blackHole(s: Stream[IO, Unit]) =
s.repeatPull(_.uncons.flatMap {
case None => Pull.pure(None)
case Some((hd, tl)) =>
val action = IO.delay(0.until(hd.size).foreach(_ => ()))
Pull.eval(action).as(Some(tl))
})

Channel.bounded[IO, Unit](8).flatMap { channel =>
val action = List.fill(64)(()).traverse_(_ => channel.send(()).void) *> channel.close
action.start *> channel.stream.through(blackHole).compile.drain
}
}
}

0 comments on commit 3ff9ee2

Please sign in to comment.