diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala index a9d9832b8..297e093b6 100644 --- a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala @@ -2,7 +2,6 @@ package kyo.interop.reactivestreams import StreamSubscriber.* import kyo.* -import kyo.Emit.Ack import org.reactivestreams.* final class StreamSubscriber[V] private ( @@ -174,19 +173,24 @@ final class StreamSubscriber[V] private ( end interupt private[interop] def emit(ack: Ack)(using Tag[V]): Ack < (Emit[Chunk[V]] & Async) = - ack match - case Ack.Stop => interupt.andThen(Ack.Stop) - case Ack.Continue(_) => - await.map { - if _ then - request.andThen(Ack.Continue()) - else - poll.map { - case Result.Success(nextChunk) => Emit(nextChunk) - case Result.Error(e: Throwable) => Abort.panic(e) - case _ => Ack.Stop + Emit.andMap(Chunk.empty) { ack => + Loop(ack) { + case Ack.Stop => interupt.andThen(Loop.done(Ack.Stop)) + case Ack.Continue(_) => + await + .map { + if _ then + request.andThen(Ack.Continue()) + else + poll.map { + case Result.Success(nextChunk) => Emit(nextChunk) + case Result.Error(e: Throwable) => Abort.panic(e) + case _ => Ack.Stop + } } - }.map(emit) + .map(Loop.continue(_)) + } + } def stream(using Tag[V]): Stream[V, Async] = Stream(Emit.andMap(Chunk.empty)(emit)) diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala index c3605aa23..98c926bf9 100644 --- a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala @@ -2,8 +2,7 @@ package kyo.interop.reactivestreams import StreamSubscription.* import kyo.* -import kyo.Emit.Ack -import kyo.interop.reactivestreams.StreamSubscription.StreamFinishState +import kyo.interop.reactivestreams.* import kyo.kernel.ArrowEffect import kyo.kernel.Boundary import kyo.kernel.Safepoint @@ -22,185 +21,79 @@ final class StreamSubscription[V, Ctx] private[reactivestreams] ( frame: Frame ) extends Subscription: - private enum DownstreamState derives CanEqual: - case Uninitialized extends DownstreamState - case Requesting( - requested: Long, - maybePromise: Maybe[(Long, Fiber.Promise.Unsafe[Unit, Long])] - ) extends DownstreamState - case Finished extends DownstreamState - end DownstreamState - - private val state = AtomicRef.Unsafe.init(DownstreamState.Uninitialized) - - private def offer(n: Long)(using Frame): Result[Unit, Long] < Async = - var sideEffect: () => Result[Unit, Long] < Async = () => null.asInstanceOf - state.update { - case DownstreamState.Requesting(0L, Absent) => - // No one requested, accumulate offerring and wait - val promise = Fiber.Promise.Unsafe.init[Unit, Long]() - sideEffect = () => promise.safe.getResult - DownstreamState.Requesting(0L, Present(n -> promise)) - case DownstreamState.Requesting(requested, Absent) => - // Someone requested, we offer right away - val accepted = Math.min(requested, n) - val nextRequested = requested - accepted - sideEffect = () => IO(Result.success(accepted)) - DownstreamState.Requesting(nextRequested, Absent) - case DownstreamState.Finished => - // Downstream cancelled - sideEffect = () => IO(Result.fail(())) - DownstreamState.Finished - case other => - sideEffect = () => IO(Result.success(0L)) - other - } - sideEffect() - end offer + private val requestChannel = Channel.Unsafe.init[Long](Int.MaxValue, Access.SingleProducerSingleConsumer) override def request(n: Long): Unit = if n <= 0 then subscriber.onError(new IllegalArgumentException("non-positive subscription request")) - var sideEffect: () => Unit = () => () - state.update { - case DownstreamState.Requesting(0L, Present(offered -> promise)) => - val accepted = Math.min(offered, n) - val nextRequested = n - accepted - val nextOfferred = offered - accepted - sideEffect = () => promise.completeDiscard(Result.success(accepted)) - DownstreamState.Requesting(nextRequested, Absent) - case DownstreamState.Requesting(requested, Absent) => - val nextRequested = Math.min(Long.MaxValue - requested, n) + requested - sideEffect = () => () - DownstreamState.Requesting(nextRequested, Absent) - case other => - sideEffect = () => () - other - } - sideEffect() + discard(requestChannel.offer(n)) end request override def cancel(): Unit = - given Frame = Frame.internal - var sideEffect: () => Unit = () => () - state.update { - case DownstreamState.Requesting(_, Present(_ -> promise)) => - sideEffect = () => promise.completeDiscard(Result.fail(())) - DownstreamState.Finished - case other => - sideEffect = () => () - DownstreamState.Finished - } - sideEffect() + given Frame = Frame.internal + discard(requestChannel.close()) end cancel - def subscribe: Unit < (IO & Ctx) = - var sideEffect: () => Unit < (IO & Ctx) = () => IO.unit - state.update { - case DownstreamState.Uninitialized => - sideEffect = () => - IO { - subscriber.onSubscribe(this) - } - DownstreamState.Requesting(0L, Absent) - case other => - sideEffect = () => IO.unit - other + private[reactivestreams] inline def subscribe: Unit < IO = IO(subscriber.onSubscribe(this)) + + private[reactivestreams] def poll: StreamFinishState < (Async & Poll[Chunk[V]]) = + def loopPoll(requesting: Long): (Chunk[V] | StreamFinishState) < (IO & Poll[Chunk[V]]) = + Loop(requesting) { requesting => + Poll.one[Chunk[V]](Ack.Continue()).map { + case Present(values) => + if values.size <= requesting then + IO(values.foreach(subscriber.onNext(_))) + .andThen(Loop.continue(requesting - values.size)) + else + IO(values.take(requesting.intValue).foreach(subscriber.onNext(_))) + .andThen(Loop.done[Long, Chunk[V] | StreamFinishState]( + values.drop(requesting.intValue) + )) + case Absent => + IO(Loop.done[Long, Chunk[V] | StreamFinishState](StreamFinishState.StreamComplete)) + } + } + + Loop[Chunk[V] | StreamFinishState, StreamFinishState, Async & Poll[Chunk[V]]](Chunk.empty[V]) { + case leftOver: Chunk[V] => + for + requestingResult <- (requestChannel.poll(): @unchecked) match + case Result.Success(Present(requesting)) => IO(Result.Success(requesting)) + case Result.Success(Absent) => requestChannel.takeFiber().safe.getResult + case error: Result.Error[Closed] => IO(error) + outcome <- requestingResult match + case Result.Success(requesting) => + if requesting < leftOver.size then + IO(leftOver.take(requesting.intValue).foreach(subscriber.onNext(_))) + .andThen(Loop.continue[StreamFinishState | Chunk[V], StreamFinishState, Async & Poll[Chunk[V]]]( + leftOver.drop(requesting.intValue) + )) + else + IO(leftOver.foreach(subscriber.onNext(_))) + .andThen(loopPoll(requesting - leftOver.size)) + .map(Loop.continue[StreamFinishState | Chunk[V], StreamFinishState, Async & Poll[Chunk[V]]](_)) + case Result.Fail(_) => + IO(Loop.continue[StreamFinishState | Chunk[V], StreamFinishState, Async & Poll[Chunk[V]]]( + StreamFinishState.StreamCanceled + )) + case Result.Panic(exception) => IO(throw exception).andThen(Loop.continue[ + StreamFinishState | Chunk[V], + StreamFinishState, + Async & Poll[Chunk[V]] + ](StreamFinishState.StreamCanceled)) + yield outcome + case state: StreamFinishState => Loop.done(state) } - sideEffect() - end subscribe + end poll private[reactivestreams] def consume( using - tag: Tag[Emit[Chunk[V]]], + emitTag: Tag[Emit[Chunk[V]]], + pollTag: Tag[Poll[Chunk[V]]], frame: Frame, safepoint: Safepoint ): Fiber[Nothing, StreamFinishState] < (IO & Ctx) = - def consumeStream: StreamFinishState < (Abort[Nothing] & Async & Ctx) = - ArrowEffect.handleState(tag, 0: (Long | StreamFinishState), stream.emit.unit)( - handle = - [C] => - (input, state, cont) => - // Handle the input chunk - if input.nonEmpty then - // Input chunk contains values that we need to feed the subscriber - Loop[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState), Abort[Nothing] & Async & Ctx]( - input, - state - ) { - (curChunk, curState) => - curState match - case leftOver: Long => - if curChunk.isEmpty then - // We finish the current chunk, go next - Loop.done[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState)]( - Ack.Continue() -> leftOver - ) - else - if leftOver > 0 then - // Some requests left from last loop, feed them - val taken = Math.min(curChunk.size, leftOver) - val nextLeftOver = leftOver - taken - curChunk.take(taken.toInt).foreach { value => - subscriber.onNext(value) - } - // Loop the rest - Loop.continue(curChunk.drop(taken.toInt), nextLeftOver) - else - for - // We signal that we can `offer` "curChunk.size" elements - // then we wait until subscriber picks up that offer - acceptedResult <- offer(curChunk.size) - outcome = acceptedResult match - // Subscriber requests "accepted" elements - case Result.Success(accepted) => - val taken = Math.min(curChunk.size, accepted) - val nextLeftOver = accepted - taken - curChunk.take(taken.toInt).foreach { value => - subscriber.onNext(value) - } - // Loop the rest - IO(Loop.continue(curChunk.drop(taken.toInt), nextLeftOver)) - case Result.Error(e) => - e match - case t: Throwable => - Abort.panic(t) - .andThen(Loop.done[ - Chunk[V], - Long | StreamFinishState, - (Ack, Long | StreamFinishState) - ]( - Ack.Stop -> StreamFinishState.StreamCanceled - )) - case _: Unit => - IO(Loop.done[ - Chunk[V], - Long | StreamFinishState, - (Ack, Long | StreamFinishState) - ]( - Ack.Stop -> StreamFinishState.StreamCanceled - )) - yield outcome - end for - end if - case finishState: StreamFinishState => - Loop.done[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState)]( - Ack.Stop -> finishState - ) - }.map { case (ack, state) => - state -> cont(ack) - } - else - // The input chunk is empty, we go next - state -> cont(Ack.Continue()) - , - done = (state, _) => - state match - case _: Long => StreamFinishState.StreamComplete - case finishState: StreamFinishState => finishState - ) - boundary { (trace, context) => - val fiber = Fiber.fromTask(IOTask(consumeStream, safepoint.copyTrace(trace), context)) + val fiber = Fiber.fromTask(IOTask(Poll.run(stream.emit)(poll).map(_._2), safepoint.copyTrace(trace), context)) fiber.unsafe.onComplete { case Result.Success(StreamFinishState.StreamComplete) => subscriber.onComplete() case Result.Panic(e) => subscriber.onError(e) diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala index c12ec394e..8c10d8050 100644 --- a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala @@ -2,7 +2,6 @@ package kyo.interop.reactivestreams import kyo.* import kyo.Duration -import kyo.Emit.Ack import kyo.interop.reactivestreams.* final class PublisherToSubscriberTest extends Test: diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala index be0322e62..07c207080 100644 --- a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala @@ -1,7 +1,6 @@ package kyo.interop.reactivestreams import kyo.* -import kyo.Emit.Ack import kyo.interop.reactivestreams.StreamPublisher import org.reactivestreams.tck.PublisherVerification import org.reactivestreams.tck.TestEnvironment