From fb04b3a0cf47fb465ac6f0d88b8c6d12764a931b Mon Sep 17 00:00:00 2001 From: HollandDM Date: Tue, 10 Dec 2024 23:33:35 +0700 Subject: [PATCH] [kyo-reactive-stream]: fix feedbacks --- .../main/scala/kyo/interop/flow/package.scala | 52 ++- .../reactive-streams/StreamPublisher.scala | 86 ++--- .../reactive-streams/StreamSubscriber.scala | 352 +++++++++++------- .../reactive-streams/StreamSubscription.scala | 91 +++-- .../interop/reactive-streams/package.scala | 54 ++- .../reactive-streams/CancellationTest.scala | 25 +- .../PublisherToSubscriberTest.scala | 154 +++++--- .../StreamSubscriberTest.scala | 38 +- 8 files changed, 552 insertions(+), 300 deletions(-) diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala index 2be446b82..bcab7b7bf 100644 --- a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala @@ -3,31 +3,71 @@ package kyo.interop import java.util.concurrent.Flow.* import kyo.* import kyo.interop.reactivestreams +import kyo.interop.reactivestreams.StreamSubscriber.EmitStrategy import kyo.kernel.Boundary import org.reactivestreams.FlowAdapters +import scala.annotation.nowarn package object flow: inline def fromPublisher[T]( publisher: Publisher[T], - bufferSize: Int + bufferSize: Int, + emitStrategy: EmitStrategy = EmitStrategy.Eager )( using Frame, - Tag[T] - ): Stream[T, Async] < IO = reactivestreams.fromPublisher(FlowAdapters.toPublisher(publisher), bufferSize) + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] + ): Stream[T, Async] < IO = reactivestreams.fromPublisher(FlowAdapters.toPublisher(publisher), bufferSize, emitStrategy) - def subscribeToStream[T, Ctx]( + @nowarn("msg=anonymous") + inline def subscribeToStream[T, Ctx]( stream: Stream[T, Ctx], subscriber: Subscriber[? >: T] )( using - Boundary[Ctx, IO], Frame, - Tag[T] + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] ): Subscription < (Resource & IO & Ctx) = reactivestreams.subscribeToStream(stream, FlowAdapters.toSubscriber(subscriber)).map { subscription => new Subscription: override def request(n: Long): Unit = subscription.request(n) override def cancel(): Unit = subscription.cancel() } + + inline def streamToPublisher[T, Ctx]( + stream: Stream[T, Ctx] + )( + using + Frame, + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] + ): Publisher[T] < (Resource & IO & Ctx) = reactivestreams.streamToPublisher(stream).map { publisher => + FlowAdapters.toFlowPublisher(publisher) + } + + object StreamReactiveStreamsExtensions: + extension [T, Ctx](stream: Stream[T, Ctx]) + inline def subscribe( + subscriber: Subscriber[? >: T] + )( + using + Frame, + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] + ): Subscription < (Resource & IO & Ctx) = + subscribeToStream(stream, subscriber) + + inline def toPublisher( + using + Frame, + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] + ): Publisher[T] < (Resource & IO & Ctx) = + streamToPublisher(stream) + end extension + end StreamReactiveStreamsExtensions + + export StreamReactiveStreamsExtensions.* end flow diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala index d34205364..312afb1de 100644 --- a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala @@ -7,8 +7,9 @@ import kyo.kernel.ContextEffect.Isolated import kyo.kernel.Safepoint import kyo.scheduler.IOTask import org.reactivestreams.* +import scala.annotation.nowarn -abstract class StreamPublisher[V, Ctx] private ( +abstract private[kyo] class StreamPublisher[V, Ctx]( stream: Stream[V, Ctx] ) extends Publisher[V]: @@ -24,16 +25,18 @@ abstract class StreamPublisher[V, Ctx] private ( end StreamPublisher object StreamPublisher: + def apply[V, Ctx]( stream: Stream[V, Ctx], capacity: Int = Int.MaxValue )( using - boundary: Boundary[Ctx, IO], - frame: Frame, - tag: Tag[V] + Boundary[Ctx, IO & Abort[Nothing]], + Frame, + Tag[Emit[Chunk[V]]], + Tag[Poll[Chunk[V]]] ): StreamPublisher[V, Ctx] < (Resource & IO & Ctx) = - inline def interruptPanic = Result.Panic(Fiber.Interrupted(frame)) + inline def interruptPanic = Result.Panic(Fiber.Interrupted(scala.compiletime.summonInline[Frame])) def discardSubscriber(subscriber: Subscriber[? >: V]): Unit = subscriber.onSubscribe(new Subscription: @@ -45,80 +48,67 @@ object StreamPublisher: def consumeChannel( channel: Channel[Subscriber[? >: V]], - supervisorPromise: Fiber.Promise[Nothing, Unit] + supervisor: Fiber.Promise[Nothing, Unit] ): Unit < (Async & Ctx) = Loop(()) { _ => channel.closed.map { - if _ then - Loop.done - else - val result = Abort.run[Closed] { + case true => Loop.done + case false => + Abort.run[Closed] { for subscriber <- channel.take subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber)) fiber <- subscription.subscribe.andThen(subscription.consume) - _ <- supervisorPromise.onComplete(_ => discard(fiber.interrupt(interruptPanic))) + _ <- supervisor.onComplete(_ => discard(fiber.interrupt(interruptPanic))) yield () - } - result.map { + }.map { case Result.Success(_) => Loop.continue(()) case _ => Loop.done } } } - IO.Unsafe { - for - channel <- - Resource.acquireRelease(Channel.init[Subscriber[? >: V]](capacity))( - _.close.map(_.foreach(_.foreach(discardSubscriber(_)))) - ) - publisher <- IO.Unsafe { - new StreamPublisher[V, Ctx](stream): - override protected def bind( - subscriber: Subscriber[? >: V] - ): Unit = - channel.unsafe.offer(subscriber) match - case Result.Success(true) => () - case _ => discardSubscriber(subscriber) - } - supervisorPromise <- Fiber.Promise.init[Nothing, Unit] - _ <- Resource.acquireRelease(boundary((trace, context) => - Fiber.fromTask(IOTask(consumeChannel(channel, supervisorPromise), trace, context)) - ))( - _.interrupt.map(discard(_)) + for + channel <- + Resource.acquireRelease(Channel.init[Subscriber[? >: V]](capacity))( + _.close.map(_.foreach(_.foreach(discardSubscriber(_)))) ) - yield publisher - end for - } + publisher <- IO.Unsafe { + new StreamPublisher[V, Ctx](stream): + override protected def bind( + subscriber: Subscriber[? >: V] + ): Unit = + channel.unsafe.offer(subscriber) match + case Result.Success(true) => () + case _ => discardSubscriber(subscriber) + } + supervisor <- Resource.acquireRelease(Fiber.Promise.init[Nothing, Unit])(_.interrupt.map(discard(_))) + _ <- Resource.acquireRelease(Async._run(consumeChannel(channel, supervisor)))(_.interrupt.map(discard(_))) + yield publisher + end for end apply object Unsafe: - def apply[V, Ctx]( + @nowarn("msg=anonymous") + inline def apply[V, Ctx]( stream: Stream[V, Ctx], subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit )( using - allowance: AllowUnsafe, - boundary: Boundary[Ctx, IO], - frame: Frame, - tag: Tag[V] + AllowUnsafe, + Frame, + Tag[Emit[Chunk[V]]], + Tag[Poll[Chunk[V]]] ): StreamPublisher[V, Ctx] = new StreamPublisher[V, Ctx](stream): override protected def bind( subscriber: Subscriber[? >: V] ): Unit = - discard(StreamSubscription.Unsafe.subscribe( + discard(StreamSubscription.Unsafe._subscribe( stream, subscriber )( subscribeCallback - )( - using - allowance, - boundary, - frame, - tag )) end Unsafe end StreamPublisher diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala index 297e093b6..9c8a546e5 100644 --- a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala @@ -3,13 +3,13 @@ package kyo.interop.reactivestreams import StreamSubscriber.* import kyo.* import org.reactivestreams.* +import scala.annotation.tailrec -final class StreamSubscriber[V] private ( - bufferSize: Int +final private[kyo] class StreamSubscriber[V]( + bufferSize: Int, + strategy: EmitStrategy )( - using - allowance: AllowUnsafe, - frame: Frame + using AllowUnsafe ) extends Subscriber[V]: private enum UpstreamState derives CanEqual: @@ -22,167 +22,235 @@ final class StreamSubscriber[V] private ( UpstreamState.Uninitialized -> Maybe.empty[Fiber.Promise.Unsafe[Nothing, Unit]] ) - private def throwIfNull[A](b: A): Unit = if isNull(b) then throw new NullPointerException() + private inline def throwIfNull[A](b: A): Unit = if isNull(b) then throw new NullPointerException() override def onSubscribe(subscription: Subscription): Unit = throwIfNull(subscription) - var sideEffect: () => Unit = () => () - state.update { - case (UpstreamState.Uninitialized, maybePromise) => - // Notify if someone wait - sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) - UpstreamState.WaitForRequest(subscription, Chunk.empty, 0) -> Absent - case other => - // wrong state, cancel incoming subscription - sideEffect = () => subscription.cancel() - other - } - sideEffect() + @tailrec def handleSubscribe(): Unit = + val curState = state.get() + curState match + case (UpstreamState.Uninitialized, maybePromise) => + val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty, 0) -> Absent + if state.compareAndSet(curState, nextState) then + maybePromise.foreach(_.completeDiscard(Result.success(()))) + else + handleSubscribe() + end if + case other => + if state.compareAndSet(curState, other) then + subscription.cancel() + else + handleSubscribe() + end match + end handleSubscribe + handleSubscribe() end onSubscribe override def onNext(item: V): Unit = throwIfNull(item) - var sideEffect: () => Unit = () => () - state.update { - case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => - sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) - UpstreamState.WaitForRequest(subscription, items.append(item), remaining - 1) -> Absent - case other => - sideEffect = () => () - other - } - sideEffect() + @tailrec def handleNext(): Unit = + val curState = state.get() + curState match + case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => + if (strategy == EmitStrategy.Eager) || (strategy == EmitStrategy.Buffer && remaining == 1) then + val nextState = UpstreamState.WaitForRequest(subscription, items.append(item), remaining - 1) -> Absent + if state.compareAndSet(curState, nextState) then + maybePromise.foreach(_.completeDiscard(Result.success(()))) + else + handleNext() + end if + else + val nextState = UpstreamState.WaitForRequest(subscription, items.append(item), remaining - 1) -> maybePromise + if !state.compareAndSet(curState, nextState) then handleNext() + case other => + if !state.compareAndSet(curState, other) then handleNext() + end match + end handleNext + handleNext() end onNext override def onError(throwable: Throwable): Unit = throwIfNull(throwable) - var sideEffect: () => Unit = () => () - state.update { - case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => - sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) - UpstreamState.Finished(Maybe(throwable), items) -> Absent - case other => - sideEffect = () => () - other - } - sideEffect() + @tailrec def handleError(): Unit = + val curState = state.get() + curState match + case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => + val nextState = UpstreamState.Finished(Maybe(throwable), items) -> Absent + if state.compareAndSet(curState, nextState) then + maybePromise.foreach(_.completeDiscard(Result.success(()))) + else + handleError() + end if + case other => + if !state.compareAndSet(curState, other) then handleError() + end match + end handleError + handleError() end onError override def onComplete(): Unit = - var sideEffect: () => Unit = () => () - state.update { - case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => - sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) - UpstreamState.Finished(Absent, items) -> Absent - case other => - sideEffect = () => () - other - } - sideEffect() + @tailrec def handleComplete(): Unit = + val curState = state.get() + curState match + case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => + val nextState = UpstreamState.Finished(Absent, items) -> Absent + if state.compareAndSet(curState, nextState) then + maybePromise.foreach(_.completeDiscard(Result.success(()))) + else + handleComplete() + end if + case other => + if !state.compareAndSet(curState, other) then handleComplete() + end match + end handleComplete + handleComplete() end onComplete - private[interop] def await: Boolean < Async = - var sideEffect: () => (Boolean < Async) = () => IO(false) - state.update { - case (UpstreamState.Uninitialized, Absent) => - val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() - sideEffect = () => promise.safe.use(_ => false) - UpstreamState.Uninitialized -> Present(promise) - case s @ (UpstreamState.Uninitialized, Present(promise)) => - sideEffect = () => promise.safe.use(_ => false) - s - case s @ (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => - if items.isEmpty then - if remaining == 0 then - sideEffect = () => IO(true) - UpstreamState.WaitForRequest(subscription, items, remaining) -> Absent - else - val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() - sideEffect = () => promise.safe.use(_ => false) - UpstreamState.WaitForRequest(subscription, items, remaining) -> Present(promise) - else - sideEffect = () => IO(false) - s - case s @ (UpstreamState.Finished(_, _), maybePromise) => - sideEffect = () => - maybePromise match - case Present(promise) => promise.safe.use(_ => false) - case Absent => IO(false) - s - case other => - sideEffect = () => IO(false) - other - } - sideEffect() + private[interop] def await(using Frame): Boolean < Async = + @tailrec def handleAwait: Boolean < Async = + val curState = state.get() + curState match + case (UpstreamState.Uninitialized, Absent) => + val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() + val nextState = UpstreamState.Uninitialized -> Present(promise) + if state.compareAndSet(curState, nextState) then + promise.safe.use(_ => false) + else + handleAwait + end if + case s @ (UpstreamState.Uninitialized, Present(promise)) => + if state.compareAndSet(curState, s) then + promise.safe.use(_ => false) + else + handleAwait + case s @ (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => + if items.isEmpty then + if remaining == 0 then + val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty[V], 0) -> Absent + if state.compareAndSet(curState, nextState) then + IO(true) + else + handleAwait + end if + else + val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() + val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty[V], remaining) -> Present(promise) + if state.compareAndSet(curState, nextState) then + promise.safe.use(_ => false) + else + handleAwait + end if + else + if state.compareAndSet(curState, s) then + IO(false) + else + handleAwait + case other => + if state.compareAndSet(curState, other) then + IO(false) + else + handleAwait + end match + end handleAwait + IO(handleAwait) end await - private[interop] def request: Long < IO = - var sideEffect: () => Long < IO = () => IO(0L) - state.update { - case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => - sideEffect = () => IO(subscription.request(bufferSize)).andThen(bufferSize.toLong) - UpstreamState.WaitForRequest(subscription, items, remaining + bufferSize) -> maybePromise - case other => - sideEffect = () => IO(0L) - other - } - sideEffect() + private[interop] def request(using Frame): Long < IO = + @tailrec def handleRequest: Long < IO = + val curState = state.get() + curState match + case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => + val nextState = UpstreamState.WaitForRequest(subscription, items, remaining + bufferSize) -> maybePromise + if state.compareAndSet(curState, nextState) then + IO(subscription.request(bufferSize)).andThen(bufferSize.toLong) + else + handleRequest + end if + case other => + if state.compareAndSet(curState, other) then + IO(0L) + else + handleRequest + end match + end handleRequest + IO(handleRequest) end request - private[interop] def poll: Result[Throwable | SubscriberDone, Chunk[V]] < IO = - var sideEffect: () => (Result[Throwable | SubscriberDone, Chunk[V]] < IO) = () => IO(Result.success(Chunk.empty)) - state.update { - case (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => - sideEffect = () => IO(Result.success(items)) - UpstreamState.WaitForRequest(subscription, Chunk.empty, remaining) -> Absent - case s @ (UpstreamState.Finished(reason, leftOver), Absent) => - if leftOver.isEmpty then - sideEffect = () => - IO { - reason match - case Present(error) => Result.fail(error) - case Absent => Result.fail(SubscriberDone) - - } - s - else - sideEffect = () => IO(Result.success(leftOver)) - UpstreamState.Finished(reason, Chunk.empty) -> Absent - case other => - sideEffect = () => IO(Result.success(Chunk.empty)) - other - } - sideEffect() + private[interop] def poll(using Frame): Result[Throwable | SubscriberDone, Chunk[V]] < IO = + @tailrec def handlePoll: Result[Throwable | SubscriberDone, Chunk[V]] < IO = + val curState = state.get() + curState match + case (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => + val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty, remaining) -> Absent + if state.compareAndSet(curState, nextState) then + IO(Result.success(items)) + else + handlePoll + end if + case s @ (UpstreamState.Finished(reason, leftOver), Absent) => + if leftOver.isEmpty then + if state.compareAndSet(curState, s) then + IO { + reason match + case Present(error) => Result.fail(error) + case Absent => Result.fail(SubscriberDone) + } + else + handlePoll + else + val nextState = UpstreamState.Finished(reason, Chunk.empty) -> Absent + if state.compareAndSet(curState, nextState) then + IO(Result.success(leftOver)) + else + handlePoll + end if + case other => + if state.compareAndSet(curState, other) then + IO(Result.success(Chunk.empty)) + else + handlePoll + end match + end handlePoll + IO(handlePoll) end poll - private[interop] def interupt: Unit < IO = - var sideEffect: () => (Unit < IO) = () => IO.unit - state.update { - case (UpstreamState.Uninitialized, maybePromise) => - // Notify if someone wait - sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) - UpstreamState.Finished(Absent, Chunk.empty) -> Absent - case (UpstreamState.WaitForRequest(subscription, _, _), Absent) => - sideEffect = () => IO(subscription.cancel()) - UpstreamState.Finished(Absent, Chunk.empty) -> Absent - case other => - sideEffect = () => IO.unit - other - } - sideEffect() + private[interop] def interupt(using Frame): Unit < IO = + @tailrec def handleInterupt: Unit < IO = + val curState = state.get() + curState match + case (UpstreamState.Uninitialized, maybePromise) => + val nextState = UpstreamState.Finished(Absent, Chunk.empty) -> Absent + if state.compareAndSet(curState, nextState) then + IO(maybePromise.foreach(_.completeDiscard(Result.success(())))) + else + handleInterupt + end if + case (UpstreamState.WaitForRequest(subscription, _, _), Absent) => + val nextState = UpstreamState.Finished(Absent, Chunk.empty) -> Absent + if state.compareAndSet(curState, nextState) then + IO(subscription.cancel()) + else + handleInterupt + end if + case other => + if state.compareAndSet(curState, other) then + IO.unit + else + handleInterupt + end match + end handleInterupt + IO(handleInterupt) end interupt - private[interop] def emit(ack: Ack)(using Tag[V]): Ack < (Emit[Chunk[V]] & Async) = + private[interop] def emit(using Frame, Tag[Emit[Chunk[V]]]): Ack < (Emit[Chunk[V]] & Async) = Emit.andMap(Chunk.empty) { ack => Loop(ack) { case Ack.Stop => interupt.andThen(Loop.done(Ack.Stop)) case Ack.Continue(_) => await .map { - if _ then - request.andThen(Ack.Continue()) - else - poll.map { + case true => request.andThen(Ack.Continue()) + case false => poll.map { case Result.Success(nextChunk) => Emit(nextChunk) case Result.Error(e: Throwable) => Abort.panic(e) case _ => Ack.Stop @@ -192,7 +260,7 @@ final class StreamSubscriber[V] private ( } } - def stream(using Tag[V]): Stream[V, Async] = Stream(Emit.andMap(Chunk.empty)(emit)) + def stream(using Frame, Tag[Emit[Chunk[V]]]): Stream[V, Async] = Stream(emit) end StreamSubscriber @@ -201,9 +269,15 @@ object StreamSubscriber: abstract private[reactivestreams] class SubscriberDone private[reactivestreams] case object SubscriberDone extends SubscriberDone + enum EmitStrategy derives CanEqual: + case Eager // Emit value to downstream stream as soon as the subscriber receives one + case Buffer // Subscriber buffers received values and emit them only when reaching bufferSize + end EmitStrategy + def apply[V]( - bufferSize: Int + bufferSize: Int, + strategy: EmitStrategy = EmitStrategy.Eager )( using Frame - ): StreamSubscriber[V] < IO = IO.Unsafe(new StreamSubscriber(bufferSize)) + ): StreamSubscriber[V] < IO = IO.Unsafe(new StreamSubscriber(bufferSize, strategy)) end StreamSubscriber diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala index 98c926bf9..d19df6cdc 100644 --- a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala @@ -10,18 +10,16 @@ import kyo.scheduler.IOPromise import kyo.scheduler.IOTask import org.reactivestreams.* -final class StreamSubscription[V, Ctx] private[reactivestreams] ( +final private[kyo] class StreamSubscription[V, Ctx]( private val stream: Stream[V, Ctx], subscriber: Subscriber[? >: V] )( using - allowance: AllowUnsafe, - boundary: Boundary[Ctx, IO], - tag: Tag[V], - frame: Frame + AllowUnsafe, + Frame ) extends Subscription: - private val requestChannel = Channel.Unsafe.init[Long](Int.MaxValue, Access.SingleProducerSingleConsumer) + private val requestChannel = Channel.Unsafe.init[Long](Int.MaxValue) override def request(n: Long): Unit = if n <= 0 then subscriber.onError(new IllegalArgumentException("non-positive subscription request")) @@ -33,10 +31,10 @@ final class StreamSubscription[V, Ctx] private[reactivestreams] ( discard(requestChannel.close()) end cancel - private[reactivestreams] inline def subscribe: Unit < IO = IO(subscriber.onSubscribe(this)) + private[interop] inline def subscribe(using Frame): Unit < IO = IO(subscriber.onSubscribe(this)) - private[reactivestreams] def poll: StreamFinishState < (Async & Poll[Chunk[V]]) = - def loopPoll(requesting: Long): (Chunk[V] | StreamFinishState) < (IO & Poll[Chunk[V]]) = + private[interop] def poll(using Tag[Poll[Chunk[V]]], Frame): StreamFinishState < (Async & Poll[Chunk[V]]) = + inline def loopPoll(requesting: Long): (Chunk[V] | StreamFinishState) < (IO & Poll[Chunk[V]]) = Loop(requesting) { requesting => Poll.one[Chunk[V]](Ack.Continue()).map { case Present(values) => @@ -85,40 +83,52 @@ final class StreamSubscription[V, Ctx] private[reactivestreams] ( } end poll - private[reactivestreams] def consume( + private[interop] def consume( using - emitTag: Tag[Emit[Chunk[V]]], - pollTag: Tag[Poll[Chunk[V]]], - frame: Frame, - safepoint: Safepoint + Tag[Emit[Chunk[V]]], + Tag[Poll[Chunk[V]]], + Frame, + Boundary[Ctx, IO & Abort[Nothing]] ): Fiber[Nothing, StreamFinishState] < (IO & Ctx) = - boundary { (trace, context) => - val fiber = Fiber.fromTask(IOTask(Poll.run(stream.emit)(poll).map(_._2), safepoint.copyTrace(trace), context)) - fiber.unsafe.onComplete { - case Result.Success(StreamFinishState.StreamComplete) => subscriber.onComplete() - case Result.Panic(e) => subscriber.onError(e) - case _ => () + Async + ._run[Nothing, StreamFinishState, Ctx](Poll.run(stream.emit)(poll).map(_._2)) + .map { fiber => + fiber.onComplete { + case Result.Success(StreamFinishState.StreamComplete) => IO(subscriber.onComplete()) + case Result.Panic(e) => IO(subscriber.onError(e)) + case _ => IO.unit + }.andThen(fiber) } - fiber - } end consume end StreamSubscription object StreamSubscription: - private[reactivestreams] enum StreamFinishState derives CanEqual: + private[interop] enum StreamFinishState derives CanEqual: case StreamComplete, StreamCanceled end StreamFinishState - def subscribe[V, Ctx]( + inline def subscribe[V, Ctx]( + stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] + )( + using + Frame, + Tag[Emit[Chunk[V]]], + Tag[Poll[Chunk[V]]] + ): StreamSubscription[V, Ctx] < (IO & Ctx & Resource) = + _subscribe(stream, subscriber) + + private[kyo] inline def _subscribe[V, Ctx]( stream: Stream[V, Ctx], subscriber: Subscriber[? >: V] )( using - boundary: Boundary[Ctx, IO], - frame: Frame, - tag: Tag[V] + Frame, + Boundary[Ctx, IO & Abort[Nothing]], + Tag[Emit[Chunk[V]]], + Tag[Poll[Chunk[V]]] ): StreamSubscription[V, Ctx] < (IO & Ctx & Resource) = for subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber)) @@ -127,22 +137,37 @@ object StreamSubscription: yield subscription object Unsafe: - def subscribe[V, Ctx]( + inline def subscribe[V, Ctx]( + stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] + )( + subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit + )( + using + AllowUnsafe, + Frame, + Tag[Emit[Chunk[V]]], + Tag[Poll[Chunk[V]]] + ): StreamSubscription[V, Ctx] = + _subscribe(stream, subscriber)(subscribeCallback) + + private[kyo] inline def _subscribe[V, Ctx]( stream: Stream[V, Ctx], subscriber: Subscriber[? >: V] )( subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit )( using - allowance: AllowUnsafe, - boundary: Boundary[Ctx, IO], - frame: Frame, - tag: Tag[V] + AllowUnsafe, + Boundary[Ctx, IO & Abort[Nothing]], + Frame, + Tag[Emit[Chunk[V]]], + Tag[Poll[Chunk[V]]] ): StreamSubscription[V, Ctx] = val subscription = new StreamSubscription[V, Ctx](stream, subscriber) subscribeCallback(subscription.subscribe.andThen(subscription.consume)) subscription - end subscribe + end _subscribe end Unsafe end StreamSubscription diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala index 4463bd4b1..f5700bfaa 100644 --- a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala @@ -2,32 +2,68 @@ package kyo.interop import kyo.* import kyo.interop.reactivestreams.* +import kyo.interop.reactivestreams.StreamSubscriber.EmitStrategy import kyo.kernel.Boundary import org.reactivestreams.* import org.reactivestreams.FlowAdapters package object reactivestreams: - def fromPublisher[T]( + inline def fromPublisher[T]( publisher: Publisher[T], - bufferSize: Int + bufferSize: Int, + emitStrategy: EmitStrategy = EmitStrategy.Eager )( using Frame, - Tag[T] + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] ): Stream[T, Async] < IO = for - subscriber <- StreamSubscriber[T](bufferSize) + subscriber <- StreamSubscriber[T](bufferSize, emitStrategy) _ <- IO(publisher.subscribe(subscriber)) yield subscriber.stream - def subscribeToStream[T, Ctx]( + inline def subscribeToStream[T, Ctx]( stream: Stream[T, Ctx], subscriber: Subscriber[? >: T] )( using - boundary: Boundary[Ctx, IO], - frame: Frame, - tag: Tag[T] + Frame, + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] ): Subscription < (Resource & IO & Ctx) = - StreamSubscription.subscribe(stream, subscriber)(using boundary, frame, tag) + StreamSubscription.subscribe(stream, subscriber) + + inline def streamToPublisher[T, Ctx]( + stream: Stream[T, Ctx] + )( + using + Frame, + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] + ): Publisher[T] < (Resource & IO & Ctx) = StreamPublisher[T, Ctx](stream) + + object StreamReactiveStreamsExtensions: + extension [T, Ctx](stream: Stream[T, Ctx]) + inline def subscribe( + subscriber: Subscriber[? >: T] + )( + using + Frame, + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] + ): Subscription < (Resource & IO & Ctx) = + subscribeToStream(stream, subscriber) + + inline def toPublisher( + using + Frame, + Tag[Emit[Chunk[T]]], + Tag[Poll[Chunk[T]]] + ): Publisher[T] < (Resource & IO & Ctx) = + streamToPublisher(stream) + end extension + end StreamReactiveStreamsExtensions + + export StreamReactiveStreamsExtensions.* end reactivestreams diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala index 62b9959cd..2c844e0f1 100644 --- a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala @@ -6,31 +6,28 @@ import org.reactivestreams.Subscriber import org.reactivestreams.Subscription final class CancellationTest extends Test: - final class Sub[A](b: AtomicBoolean.Unsafe) extends Subscriber[A]: + final class Sub[A](b: AtomicBoolean) extends Subscriber[A]: import AllowUnsafe.embrace.danger - def onNext(t: A) = b.set(true) - def onComplete() = b.set(true) - def onError(e: Throwable) = b.set(true) + def onNext(t: A) = IO.Unsafe.run(b.set(true)).eval + def onComplete() = IO.Unsafe.run(b.set(true)).eval + def onError(e: Throwable) = IO.Unsafe.run(b.set(true)).eval def onSubscribe(s: Subscription) = () end Sub val stream: Stream[Int, Any] = Stream.range(0, 5, 1) - val attempts = 1000 + val attempts = 100 - def testStreamSubscription(clue: String)(program: Subscription => Unit): Unit < IO = + def testStreamSubscription(clue: String)(program: Subscription => Unit): Unit < (IO & Resource) = Loop(attempts) { index => if index <= 0 then Loop.done else - IO.Unsafe { - val flag = AtomicBoolean.Unsafe.init(false) - StreamSubscription.Unsafe.subscribe(stream, new Sub(flag)) { fiber => - discard(IO.Unsafe.run(fiber).eval) - } - }.map { subscription => - program(subscription) - }.andThen(Loop.continue(index - 1)) + for + flag <- AtomicBoolean.init(false) + subscription <- StreamSubscription.subscribe(stream, new Sub(flag)) + _ <- IO(program(subscription)) + yield Loop.continue(index - 1) end if } diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala index 8c10d8050..e1b418620 100644 --- a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala @@ -3,6 +3,7 @@ package kyo.interop.reactivestreams import kyo.* import kyo.Duration import kyo.interop.reactivestreams.* +import kyo.interop.reactivestreams.StreamSubscriber.EmitStrategy final class PublisherToSubscriberTest extends Test: import PublisherToSubscriberTest.* @@ -17,59 +18,118 @@ final class PublisherToSubscriberTest extends Test: .andThen(int) }) - "should have the same output as input" in runJVM { - for - stream <- randomStream - publisher <- StreamPublisher[Int, Async](stream) - subscriber <- StreamSubscriber[Int](BufferSize) - _ = publisher.subscribe(subscriber) - (isSame, _) <- subscriber.stream - .runFold(true -> 0) { case ((acc, expected), cur) => - Random - .use(_.nextInt(10)) - .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) - .andThen((acc && (expected == cur)) -> (expected + 1)) + "eager" - { + "should have the same output as input" in runJVM { + for + stream <- randomStream + publisher <- stream.toPublisher + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + (isSame, _) <- subscriber.stream + .runFold(true -> 0) { case ((acc, expected), cur) => + Random + .use(_.nextInt(10)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen((acc && (expected == cur)) -> (expected + 1)) + } + yield assert(isSame) + } + + "should propagate errors downstream" in runJVM { + for + inputStream <- IO { + Stream.range(0, 10, 1, 1).map { int => + if int < 5 then + Async.sleep(Duration.fromUnits(10, Duration.Units.Millis)).andThen(int) + else + Abort.panic(TestError) + } } - yield assert(isSame) + publisher <- inputStream.toPublisher + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + result <- Abort.run[Throwable](subscriber.stream.runDiscard) + yield result match + case Result.Error(e: Throwable) => assert(e == TestError) + case _ => assert(false) + end for + } + + "should cancel upstream if downstream completes" in runJVM { + def emit(ack: Ack, cur: Int, stopPromise: Fiber.Promise[Nothing, Unit]): Ack < (Emit[Chunk[Int]] & IO) = + ack match + case Ack.Stop => stopPromise.completeDiscard(Result.success(())).andThen(Ack.Stop) + case Ack.Continue(_) => Emit.andMap(Chunk(cur))(emit(_, cur + 1, stopPromise)) + end emit + + for + stopPromise <- Fiber.Promise.init[Nothing, Unit] + stream <- IO(Stream(Emit.andMap(Chunk.empty[Int])(emit(_, 0, stopPromise)))) + publisher <- stream.toPublisher + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + _ <- subscriber.stream.take(10).runDiscard + _ <- stopPromise.get + yield assert(true) + end for + } } - "should propagate errors downstream" in runJVM { - for - inputStream <- IO { - Stream.range(0, 10, 1, 1).map { int => - if int < 5 then - Async.sleep(Duration.fromUnits(10, Duration.Units.Millis)).andThen(int) - else - Abort.panic(TestError) + "buffer" - { + "should have the same output as input" in runJVM { + for + stream <- randomStream + publisher <- stream.toPublisher + subscriber <- StreamSubscriber[Int](BufferSize, EmitStrategy.Buffer) + _ = publisher.subscribe(subscriber) + (isSame, _) <- subscriber.stream + .runFold(true -> 0) { case ((acc, expected), cur) => + Random + .use(_.nextInt(10)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen((acc && (expected == cur)) -> (expected + 1)) + } + yield assert(isSame) + } + + "should propagate errors downstream" in runJVM { + for + inputStream <- IO { + Stream.range(0, 10, 1, 1).map { int => + if int < 5 then + Async.sleep(Duration.fromUnits(10, Duration.Units.Millis)).andThen(int) + else + Abort.panic(TestError) + } } - } - publisher <- StreamPublisher[Int, Async](inputStream) - subscriber <- StreamSubscriber[Int](BufferSize) - _ = publisher.subscribe(subscriber) - result <- Abort.run[Throwable](subscriber.stream.runDiscard) - yield result match - case Result.Error(e: Throwable) => assert(e == TestError) - case _ => assert(false) - end for - } + publisher <- inputStream.toPublisher + subscriber <- StreamSubscriber[Int](BufferSize, EmitStrategy.Buffer) + _ = publisher.subscribe(subscriber) + result <- Abort.run[Throwable](subscriber.stream.runDiscard) + yield result match + case Result.Error(e: Throwable) => assert(e == TestError) + case _ => assert(false) + end for + } - "should cancel upstream if downstream completes" in runJVM { - def emit(ack: Ack, cur: Int, stopPromise: Fiber.Promise[Nothing, Unit]): Ack < (Emit[Chunk[Int]] & IO) = - ack match - case Ack.Stop => stopPromise.completeDiscard(Result.success(())).andThen(Ack.Stop) - case Ack.Continue(_) => Emit.andMap(Chunk(cur))(emit(_, cur + 1, stopPromise)) - end emit + "should cancel upstream if downstream completes" in runJVM { + def emit(ack: Ack, cur: Int, stopPromise: Fiber.Promise[Nothing, Unit]): Ack < (Emit[Chunk[Int]] & IO) = + ack match + case Ack.Stop => stopPromise.completeDiscard(Result.success(())).andThen(Ack.Stop) + case Ack.Continue(_) => Emit.andMap(Chunk(cur))(emit(_, cur + 1, stopPromise)) + end emit - for - stopPromise <- Fiber.Promise.init[Nothing, Unit] - stream <- IO(Stream(Emit.andMap(Chunk.empty[Int])(emit(_, 0, stopPromise)))) - publisher <- StreamPublisher[Int, IO](stream) - subscriber <- StreamSubscriber[Int](BufferSize) - _ = publisher.subscribe(subscriber) - _ <- subscriber.stream.take(10).runDiscard - _ <- stopPromise.get - yield assert(true) - end for + for + stopPromise <- Fiber.Promise.init[Nothing, Unit] + stream <- IO(Stream(Emit.andMap(Chunk.empty[Int])(emit(_, 0, stopPromise)))) + publisher <- stream.toPublisher + subscriber <- StreamSubscriber[Int](BufferSize, EmitStrategy.Buffer) + _ = publisher.subscribe(subscriber) + _ <- subscriber.stream.take(10).runDiscard + _ <- stopPromise.get + yield assert(true) + end for + } } end PublisherToSubscriberTest diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala index f2ce89ca7..cf57ac318 100644 --- a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala @@ -14,7 +14,7 @@ import org.reactivestreams.tck.SubscriberWhiteboxVerification.WhiteboxSubscriber import org.reactivestreams.tck.TestEnvironment import org.scalatestplus.testng.* -class StreamSubscriberTest extends SubscriberWhiteboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: +class EagerStreamSubscriberTest extends SubscriberWhiteboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: import AllowUnsafe.embrace.danger private val counter = new AtomicInteger() @@ -26,7 +26,21 @@ class StreamSubscriberTest extends SubscriberWhiteboxVerification[Int](new TestE }.eval def createElement(i: Int): Int = counter.getAndIncrement -end StreamSubscriberTest +end EagerStreamSubscriberTest + +class BufferStreamSubscriberTest extends SubscriberWhiteboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + private val counter = new AtomicInteger() + + def createSubscriber( + p: SubscriberWhiteboxVerification.WhiteboxSubscriberProbe[Int] + ): Subscriber[Int] = + IO.Unsafe.run { + StreamSubscriber[Int](bufferSize = 16, EmitStrategy.Buffer).map(s => new WhiteboxSubscriber(s, p)) + }.eval + + def createElement(i: Int): Int = counter.getAndIncrement +end BufferStreamSubscriberTest final class WhiteboxSubscriber[V]( sub: StreamSubscriber[V], @@ -74,7 +88,7 @@ final class WhiteboxSubscriber[V]( end onNext end WhiteboxSubscriber -final class SubscriberBlackboxSpec extends SubscriberBlackboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: +final class EagerSubscriberBlackboxSpec extends SubscriberBlackboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: import AllowUnsafe.embrace.danger private val counter = new AtomicInteger() @@ -88,4 +102,20 @@ final class SubscriberBlackboxSpec extends SubscriberBlackboxVerification[Int](n discard(IO.Unsafe.run(computation).eval) def createElement(i: Int): Int = counter.incrementAndGet() -end SubscriberBlackboxSpec +end EagerSubscriberBlackboxSpec + +final class BufferSubscriberBlackboxSpec extends SubscriberBlackboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + private val counter = new AtomicInteger() + + def createSubscriber(): StreamSubscriber[Int] = + IO.Unsafe.run { + StreamSubscriber[Int](bufferSize = 16, EmitStrategy.Buffer) + }.eval + + override def triggerRequest(s: Subscriber[? >: Int]): Unit = + val computation: Long < IO = s.asInstanceOf[StreamSubscriber[Int]].request + discard(IO.Unsafe.run(computation).eval) + + def createElement(i: Int): Int = counter.incrementAndGet() +end BufferSubscriberBlackboxSpec