diff --git a/.gitignore b/.gitignore index ca2708f55..51a929e0d 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ jmh-result.json *.jfr *.json *.gpg +test-output \ No newline at end of file diff --git a/build.sbt b/build.sbt index 64bb08a71..86aa83d8c 100644 --- a/build.sbt +++ b/build.sbt @@ -96,6 +96,7 @@ lazy val kyoJVM = project `kyo-stats-registry`.jvm, `kyo-stats-otel`.jvm, `kyo-cache`.jvm, + `kyo-reactive-streams`.jvm, `kyo-sttp`.jvm, `kyo-tapir`.jvm, `kyo-caliban`.jvm, @@ -306,6 +307,22 @@ lazy val `kyo-cache` = ) .jvmSettings(mimaCheck(false)) +lazy val `kyo-reactive-streams` = + crossProject(JVMPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Full) + .in(file("kyo-reactive-streams")) + .dependsOn(`kyo-core`) + .settings( + `kyo-settings`, + libraryDependencies ++= Seq( + "org.reactivestreams" % "reactive-streams" % "1.0.4", + "org.reactivestreams" % "reactive-streams-tck" % "1.0.4" % Test, + "org.scalatestplus" %% "testng-7-5" % "3.2.17.0" % Test + ) + ) + .jvmSettings(mimaCheck(false)) + lazy val `kyo-sttp` = crossProject(JSPlatform, JVMPlatform, NativePlatform) .withoutSuffixFor(JVMPlatform) diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala new file mode 100644 index 000000000..2be446b82 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala @@ -0,0 +1,33 @@ +package kyo.interop + +import java.util.concurrent.Flow.* +import kyo.* +import kyo.interop.reactivestreams +import kyo.kernel.Boundary +import org.reactivestreams.FlowAdapters + +package object flow: + inline def fromPublisher[T]( + publisher: Publisher[T], + bufferSize: Int + )( + using + Frame, + Tag[T] + ): Stream[T, Async] < IO = reactivestreams.fromPublisher(FlowAdapters.toPublisher(publisher), bufferSize) + + def subscribeToStream[T, Ctx]( + stream: Stream[T, Ctx], + subscriber: Subscriber[? >: T] + )( + using + Boundary[Ctx, IO], + Frame, + Tag[T] + ): Subscription < (Resource & IO & Ctx) = + reactivestreams.subscribeToStream(stream, FlowAdapters.toSubscriber(subscriber)).map { subscription => + new Subscription: + override def request(n: Long): Unit = subscription.request(n) + override def cancel(): Unit = subscription.cancel() + } +end flow diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala new file mode 100644 index 000000000..d34205364 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala @@ -0,0 +1,124 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.interop.reactivestreams.StreamSubscription.StreamFinishState +import kyo.kernel.Boundary +import kyo.kernel.ContextEffect.Isolated +import kyo.kernel.Safepoint +import kyo.scheduler.IOTask +import org.reactivestreams.* + +abstract class StreamPublisher[V, Ctx] private ( + stream: Stream[V, Ctx] +) extends Publisher[V]: + + protected def bind(subscriber: Subscriber[? >: V]): Unit + + override def subscribe(subscriber: Subscriber[? >: V]): Unit = + if isNull(subscriber) then + throw new NullPointerException("Subscriber must not be null.") + else + bind(subscriber) + end subscribe + +end StreamPublisher + +object StreamPublisher: + def apply[V, Ctx]( + stream: Stream[V, Ctx], + capacity: Int = Int.MaxValue + )( + using + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamPublisher[V, Ctx] < (Resource & IO & Ctx) = + inline def interruptPanic = Result.Panic(Fiber.Interrupted(frame)) + + def discardSubscriber(subscriber: Subscriber[? >: V]): Unit = + subscriber.onSubscribe(new Subscription: + override def request(n: Long): Unit = () + override def cancel(): Unit = () + ) + subscriber.onComplete() + end discardSubscriber + + def consumeChannel( + channel: Channel[Subscriber[? >: V]], + supervisorPromise: Fiber.Promise[Nothing, Unit] + ): Unit < (Async & Ctx) = + Loop(()) { _ => + channel.closed.map { + if _ then + Loop.done + else + val result = Abort.run[Closed] { + for + subscriber <- channel.take + subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber)) + fiber <- subscription.subscribe.andThen(subscription.consume) + _ <- supervisorPromise.onComplete(_ => discard(fiber.interrupt(interruptPanic))) + yield () + } + result.map { + case Result.Success(_) => Loop.continue(()) + case _ => Loop.done + } + } + } + + IO.Unsafe { + for + channel <- + Resource.acquireRelease(Channel.init[Subscriber[? >: V]](capacity))( + _.close.map(_.foreach(_.foreach(discardSubscriber(_)))) + ) + publisher <- IO.Unsafe { + new StreamPublisher[V, Ctx](stream): + override protected def bind( + subscriber: Subscriber[? >: V] + ): Unit = + channel.unsafe.offer(subscriber) match + case Result.Success(true) => () + case _ => discardSubscriber(subscriber) + } + supervisorPromise <- Fiber.Promise.init[Nothing, Unit] + _ <- Resource.acquireRelease(boundary((trace, context) => + Fiber.fromTask(IOTask(consumeChannel(channel, supervisorPromise), trace, context)) + ))( + _.interrupt.map(discard(_)) + ) + yield publisher + end for + } + end apply + + object Unsafe: + def apply[V, Ctx]( + stream: Stream[V, Ctx], + subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit + )( + using + allowance: AllowUnsafe, + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamPublisher[V, Ctx] = + new StreamPublisher[V, Ctx](stream): + override protected def bind( + subscriber: Subscriber[? >: V] + ): Unit = + discard(StreamSubscription.Unsafe.subscribe( + stream, + subscriber + )( + subscribeCallback + )( + using + allowance, + boundary, + frame, + tag + )) + end Unsafe +end StreamPublisher diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala new file mode 100644 index 000000000..a9d9832b8 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala @@ -0,0 +1,205 @@ +package kyo.interop.reactivestreams + +import StreamSubscriber.* +import kyo.* +import kyo.Emit.Ack +import org.reactivestreams.* + +final class StreamSubscriber[V] private ( + bufferSize: Int +)( + using + allowance: AllowUnsafe, + frame: Frame +) extends Subscriber[V]: + + private enum UpstreamState derives CanEqual: + case Uninitialized extends UpstreamState + case WaitForRequest(subscription: Subscription, items: Chunk[V], remaining: Int) extends UpstreamState + case Finished(reason: Maybe[Throwable], leftOver: Chunk[V]) extends UpstreamState + end UpstreamState + + private val state = AtomicRef.Unsafe.init( + UpstreamState.Uninitialized -> Maybe.empty[Fiber.Promise.Unsafe[Nothing, Unit]] + ) + + private def throwIfNull[A](b: A): Unit = if isNull(b) then throw new NullPointerException() + + override def onSubscribe(subscription: Subscription): Unit = + throwIfNull(subscription) + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.Uninitialized, maybePromise) => + // Notify if someone wait + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.WaitForRequest(subscription, Chunk.empty, 0) -> Absent + case other => + // wrong state, cancel incoming subscription + sideEffect = () => subscription.cancel() + other + } + sideEffect() + end onSubscribe + + override def onNext(item: V): Unit = + throwIfNull(item) + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.WaitForRequest(subscription, items.append(item), remaining - 1) -> Absent + case other => + sideEffect = () => () + other + } + sideEffect() + end onNext + + override def onError(throwable: Throwable): Unit = + throwIfNull(throwable) + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.Finished(Maybe(throwable), items) -> Absent + case other => + sideEffect = () => () + other + } + sideEffect() + end onError + + override def onComplete(): Unit = + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.Finished(Absent, items) -> Absent + case other => + sideEffect = () => () + other + } + sideEffect() + end onComplete + + private[interop] def await: Boolean < Async = + var sideEffect: () => (Boolean < Async) = () => IO(false) + state.update { + case (UpstreamState.Uninitialized, Absent) => + val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() + sideEffect = () => promise.safe.use(_ => false) + UpstreamState.Uninitialized -> Present(promise) + case s @ (UpstreamState.Uninitialized, Present(promise)) => + sideEffect = () => promise.safe.use(_ => false) + s + case s @ (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => + if items.isEmpty then + if remaining == 0 then + sideEffect = () => IO(true) + UpstreamState.WaitForRequest(subscription, items, remaining) -> Absent + else + val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() + sideEffect = () => promise.safe.use(_ => false) + UpstreamState.WaitForRequest(subscription, items, remaining) -> Present(promise) + else + sideEffect = () => IO(false) + s + case s @ (UpstreamState.Finished(_, _), maybePromise) => + sideEffect = () => + maybePromise match + case Present(promise) => promise.safe.use(_ => false) + case Absent => IO(false) + s + case other => + sideEffect = () => IO(false) + other + } + sideEffect() + end await + + private[interop] def request: Long < IO = + var sideEffect: () => Long < IO = () => IO(0L) + state.update { + case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => + sideEffect = () => IO(subscription.request(bufferSize)).andThen(bufferSize.toLong) + UpstreamState.WaitForRequest(subscription, items, remaining + bufferSize) -> maybePromise + case other => + sideEffect = () => IO(0L) + other + } + sideEffect() + end request + + private[interop] def poll: Result[Throwable | SubscriberDone, Chunk[V]] < IO = + var sideEffect: () => (Result[Throwable | SubscriberDone, Chunk[V]] < IO) = () => IO(Result.success(Chunk.empty)) + state.update { + case (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => + sideEffect = () => IO(Result.success(items)) + UpstreamState.WaitForRequest(subscription, Chunk.empty, remaining) -> Absent + case s @ (UpstreamState.Finished(reason, leftOver), Absent) => + if leftOver.isEmpty then + sideEffect = () => + IO { + reason match + case Present(error) => Result.fail(error) + case Absent => Result.fail(SubscriberDone) + + } + s + else + sideEffect = () => IO(Result.success(leftOver)) + UpstreamState.Finished(reason, Chunk.empty) -> Absent + case other => + sideEffect = () => IO(Result.success(Chunk.empty)) + other + } + sideEffect() + end poll + + private[interop] def interupt: Unit < IO = + var sideEffect: () => (Unit < IO) = () => IO.unit + state.update { + case (UpstreamState.Uninitialized, maybePromise) => + // Notify if someone wait + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.Finished(Absent, Chunk.empty) -> Absent + case (UpstreamState.WaitForRequest(subscription, _, _), Absent) => + sideEffect = () => IO(subscription.cancel()) + UpstreamState.Finished(Absent, Chunk.empty) -> Absent + case other => + sideEffect = () => IO.unit + other + } + sideEffect() + end interupt + + private[interop] def emit(ack: Ack)(using Tag[V]): Ack < (Emit[Chunk[V]] & Async) = + ack match + case Ack.Stop => interupt.andThen(Ack.Stop) + case Ack.Continue(_) => + await.map { + if _ then + request.andThen(Ack.Continue()) + else + poll.map { + case Result.Success(nextChunk) => Emit(nextChunk) + case Result.Error(e: Throwable) => Abort.panic(e) + case _ => Ack.Stop + } + }.map(emit) + + def stream(using Tag[V]): Stream[V, Async] = Stream(Emit.andMap(Chunk.empty)(emit)) + +end StreamSubscriber + +object StreamSubscriber: + + abstract private[reactivestreams] class SubscriberDone + private[reactivestreams] case object SubscriberDone extends SubscriberDone + + def apply[V]( + bufferSize: Int + )( + using Frame + ): StreamSubscriber[V] < IO = IO.Unsafe(new StreamSubscriber(bufferSize)) +end StreamSubscriber diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala new file mode 100644 index 000000000..c3605aa23 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala @@ -0,0 +1,255 @@ +package kyo.interop.reactivestreams + +import StreamSubscription.* +import kyo.* +import kyo.Emit.Ack +import kyo.interop.reactivestreams.StreamSubscription.StreamFinishState +import kyo.kernel.ArrowEffect +import kyo.kernel.Boundary +import kyo.kernel.Safepoint +import kyo.scheduler.IOPromise +import kyo.scheduler.IOTask +import org.reactivestreams.* + +final class StreamSubscription[V, Ctx] private[reactivestreams] ( + private val stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] +)( + using + allowance: AllowUnsafe, + boundary: Boundary[Ctx, IO], + tag: Tag[V], + frame: Frame +) extends Subscription: + + private enum DownstreamState derives CanEqual: + case Uninitialized extends DownstreamState + case Requesting( + requested: Long, + maybePromise: Maybe[(Long, Fiber.Promise.Unsafe[Unit, Long])] + ) extends DownstreamState + case Finished extends DownstreamState + end DownstreamState + + private val state = AtomicRef.Unsafe.init(DownstreamState.Uninitialized) + + private def offer(n: Long)(using Frame): Result[Unit, Long] < Async = + var sideEffect: () => Result[Unit, Long] < Async = () => null.asInstanceOf + state.update { + case DownstreamState.Requesting(0L, Absent) => + // No one requested, accumulate offerring and wait + val promise = Fiber.Promise.Unsafe.init[Unit, Long]() + sideEffect = () => promise.safe.getResult + DownstreamState.Requesting(0L, Present(n -> promise)) + case DownstreamState.Requesting(requested, Absent) => + // Someone requested, we offer right away + val accepted = Math.min(requested, n) + val nextRequested = requested - accepted + sideEffect = () => IO(Result.success(accepted)) + DownstreamState.Requesting(nextRequested, Absent) + case DownstreamState.Finished => + // Downstream cancelled + sideEffect = () => IO(Result.fail(())) + DownstreamState.Finished + case other => + sideEffect = () => IO(Result.success(0L)) + other + } + sideEffect() + end offer + + override def request(n: Long): Unit = + if n <= 0 then subscriber.onError(new IllegalArgumentException("non-positive subscription request")) + var sideEffect: () => Unit = () => () + state.update { + case DownstreamState.Requesting(0L, Present(offered -> promise)) => + val accepted = Math.min(offered, n) + val nextRequested = n - accepted + val nextOfferred = offered - accepted + sideEffect = () => promise.completeDiscard(Result.success(accepted)) + DownstreamState.Requesting(nextRequested, Absent) + case DownstreamState.Requesting(requested, Absent) => + val nextRequested = Math.min(Long.MaxValue - requested, n) + requested + sideEffect = () => () + DownstreamState.Requesting(nextRequested, Absent) + case other => + sideEffect = () => () + other + } + sideEffect() + end request + + override def cancel(): Unit = + given Frame = Frame.internal + var sideEffect: () => Unit = () => () + state.update { + case DownstreamState.Requesting(_, Present(_ -> promise)) => + sideEffect = () => promise.completeDiscard(Result.fail(())) + DownstreamState.Finished + case other => + sideEffect = () => () + DownstreamState.Finished + } + sideEffect() + end cancel + + def subscribe: Unit < (IO & Ctx) = + var sideEffect: () => Unit < (IO & Ctx) = () => IO.unit + state.update { + case DownstreamState.Uninitialized => + sideEffect = () => + IO { + subscriber.onSubscribe(this) + } + DownstreamState.Requesting(0L, Absent) + case other => + sideEffect = () => IO.unit + other + } + sideEffect() + end subscribe + + private[reactivestreams] def consume( + using + tag: Tag[Emit[Chunk[V]]], + frame: Frame, + safepoint: Safepoint + ): Fiber[Nothing, StreamFinishState] < (IO & Ctx) = + def consumeStream: StreamFinishState < (Abort[Nothing] & Async & Ctx) = + ArrowEffect.handleState(tag, 0: (Long | StreamFinishState), stream.emit.unit)( + handle = + [C] => + (input, state, cont) => + // Handle the input chunk + if input.nonEmpty then + // Input chunk contains values that we need to feed the subscriber + Loop[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState), Abort[Nothing] & Async & Ctx]( + input, + state + ) { + (curChunk, curState) => + curState match + case leftOver: Long => + if curChunk.isEmpty then + // We finish the current chunk, go next + Loop.done[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState)]( + Ack.Continue() -> leftOver + ) + else + if leftOver > 0 then + // Some requests left from last loop, feed them + val taken = Math.min(curChunk.size, leftOver) + val nextLeftOver = leftOver - taken + curChunk.take(taken.toInt).foreach { value => + subscriber.onNext(value) + } + // Loop the rest + Loop.continue(curChunk.drop(taken.toInt), nextLeftOver) + else + for + // We signal that we can `offer` "curChunk.size" elements + // then we wait until subscriber picks up that offer + acceptedResult <- offer(curChunk.size) + outcome = acceptedResult match + // Subscriber requests "accepted" elements + case Result.Success(accepted) => + val taken = Math.min(curChunk.size, accepted) + val nextLeftOver = accepted - taken + curChunk.take(taken.toInt).foreach { value => + subscriber.onNext(value) + } + // Loop the rest + IO(Loop.continue(curChunk.drop(taken.toInt), nextLeftOver)) + case Result.Error(e) => + e match + case t: Throwable => + Abort.panic(t) + .andThen(Loop.done[ + Chunk[V], + Long | StreamFinishState, + (Ack, Long | StreamFinishState) + ]( + Ack.Stop -> StreamFinishState.StreamCanceled + )) + case _: Unit => + IO(Loop.done[ + Chunk[V], + Long | StreamFinishState, + (Ack, Long | StreamFinishState) + ]( + Ack.Stop -> StreamFinishState.StreamCanceled + )) + yield outcome + end for + end if + case finishState: StreamFinishState => + Loop.done[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState)]( + Ack.Stop -> finishState + ) + }.map { case (ack, state) => + state -> cont(ack) + } + else + // The input chunk is empty, we go next + state -> cont(Ack.Continue()) + , + done = (state, _) => + state match + case _: Long => StreamFinishState.StreamComplete + case finishState: StreamFinishState => finishState + ) + + boundary { (trace, context) => + val fiber = Fiber.fromTask(IOTask(consumeStream, safepoint.copyTrace(trace), context)) + fiber.unsafe.onComplete { + case Result.Success(StreamFinishState.StreamComplete) => subscriber.onComplete() + case Result.Panic(e) => subscriber.onError(e) + case _ => () + } + fiber + } + end consume + +end StreamSubscription + +object StreamSubscription: + + private[reactivestreams] enum StreamFinishState derives CanEqual: + case StreamComplete, StreamCanceled + end StreamFinishState + + def subscribe[V, Ctx]( + stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] + )( + using + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamSubscription[V, Ctx] < (IO & Ctx & Resource) = + for + subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber)) + _ <- subscription.subscribe + _ <- Resource.acquireRelease(subscription.consume)(_.interrupt.unit) + yield subscription + + object Unsafe: + def subscribe[V, Ctx]( + stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] + )( + subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit + )( + using + allowance: AllowUnsafe, + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamSubscription[V, Ctx] = + val subscription = new StreamSubscription[V, Ctx](stream, subscriber) + subscribeCallback(subscription.subscribe.andThen(subscription.consume)) + subscription + end subscribe + end Unsafe + +end StreamSubscription diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala new file mode 100644 index 000000000..4463bd4b1 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala @@ -0,0 +1,33 @@ +package kyo.interop + +import kyo.* +import kyo.interop.reactivestreams.* +import kyo.kernel.Boundary +import org.reactivestreams.* +import org.reactivestreams.FlowAdapters + +package object reactivestreams: + def fromPublisher[T]( + publisher: Publisher[T], + bufferSize: Int + )( + using + Frame, + Tag[T] + ): Stream[T, Async] < IO = + for + subscriber <- StreamSubscriber[T](bufferSize) + _ <- IO(publisher.subscribe(subscriber)) + yield subscriber.stream + + def subscribeToStream[T, Ctx]( + stream: Stream[T, Ctx], + subscriber: Subscriber[? >: T] + )( + using + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[T] + ): Subscription < (Resource & IO & Ctx) = + StreamSubscription.subscribe(stream, subscriber)(using boundary, frame, tag) +end reactivestreams diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/Test.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/Test.scala new file mode 100644 index 000000000..439a59764 --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/Test.scala @@ -0,0 +1,28 @@ +package kyo + +import kyo.internal.BaseKyoTest +import kyo.kernel.Platform +import org.scalatest.NonImplicitAssertions +import org.scalatest.freespec.AsyncFreeSpec +import scala.concurrent.ExecutionContext +import scala.concurrent.Future + +abstract class Test extends AsyncFreeSpec with BaseKyoTest[Async & Resource] with NonImplicitAssertions: + + def run(v: Future[Assertion] < (Async & Resource)): Future[Assertion] = + import AllowUnsafe.embrace.danger + Abort.run[Any](v) + .map(_.fold(e => throw new Exception("Test failed with " + e))(identity)) + .pipe(Resource.run) + .pipe(Async.run) + .map(_.toFuture) + .map(_.flatten) + .pipe(IO.Unsafe.run) + .eval + end run + + type Assertion = org.scalatest.compatible.Assertion + def success = succeed + + override given executionContext: ExecutionContext = Platform.executionContext +end Test diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala new file mode 100644 index 000000000..62b9959cd --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala @@ -0,0 +1,53 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.interop.reactivestreams.StreamSubscription +import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription + +final class CancellationTest extends Test: + final class Sub[A](b: AtomicBoolean.Unsafe) extends Subscriber[A]: + import AllowUnsafe.embrace.danger + def onNext(t: A) = b.set(true) + def onComplete() = b.set(true) + def onError(e: Throwable) = b.set(true) + def onSubscribe(s: Subscription) = () + end Sub + + val stream: Stream[Int, Any] = Stream.range(0, 5, 1) + + val attempts = 1000 + + def testStreamSubscription(clue: String)(program: Subscription => Unit): Unit < IO = + Loop(attempts) { index => + if index <= 0 then + Loop.done + else + IO.Unsafe { + val flag = AtomicBoolean.Unsafe.init(false) + StreamSubscription.Unsafe.subscribe(stream, new Sub(flag)) { fiber => + discard(IO.Unsafe.run(fiber).eval) + } + }.map { subscription => + program(subscription) + }.andThen(Loop.continue(index - 1)) + end if + } + + "after subscription is canceled request must be NOOPs" in runJVM { + testStreamSubscription(clue = "onNext was called after the subscription was canceled") { sub => + sub.cancel() + sub.request(1) + sub.request(1) + sub.request(1) + }.map(_ => assert(true)) + } + + "after subscription is canceled additional cancellations must be NOOPs" in runJVM { + testStreamSubscription(clue = "onComplete was called after the subscription was canceled") { + sub => + sub.cancel() + sub.cancel() + }.map(_ => assert(true)) + } +end CancellationTest diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala new file mode 100644 index 000000000..c12ec394e --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala @@ -0,0 +1,81 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.Duration +import kyo.Emit.Ack +import kyo.interop.reactivestreams.* + +final class PublisherToSubscriberTest extends Test: + import PublisherToSubscriberTest.* + + private def randomStream: Stream[Int, Async] < IO = + IO(Stream + .range(0, 1 << 10, 1, BufferSize) + .map { int => + Random + .use(_.nextInt(10)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen(int) + }) + + "should have the same output as input" in runJVM { + for + stream <- randomStream + publisher <- StreamPublisher[Int, Async](stream) + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + (isSame, _) <- subscriber.stream + .runFold(true -> 0) { case ((acc, expected), cur) => + Random + .use(_.nextInt(10)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen((acc && (expected == cur)) -> (expected + 1)) + } + yield assert(isSame) + } + + "should propagate errors downstream" in runJVM { + for + inputStream <- IO { + Stream.range(0, 10, 1, 1).map { int => + if int < 5 then + Async.sleep(Duration.fromUnits(10, Duration.Units.Millis)).andThen(int) + else + Abort.panic(TestError) + } + } + publisher <- StreamPublisher[Int, Async](inputStream) + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + result <- Abort.run[Throwable](subscriber.stream.runDiscard) + yield result match + case Result.Error(e: Throwable) => assert(e == TestError) + case _ => assert(false) + end for + } + + "should cancel upstream if downstream completes" in runJVM { + def emit(ack: Ack, cur: Int, stopPromise: Fiber.Promise[Nothing, Unit]): Ack < (Emit[Chunk[Int]] & IO) = + ack match + case Ack.Stop => stopPromise.completeDiscard(Result.success(())).andThen(Ack.Stop) + case Ack.Continue(_) => Emit.andMap(Chunk(cur))(emit(_, cur + 1, stopPromise)) + end emit + + for + stopPromise <- Fiber.Promise.init[Nothing, Unit] + stream <- IO(Stream(Emit.andMap(Chunk.empty[Int])(emit(_, 0, stopPromise)))) + publisher <- StreamPublisher[Int, IO](stream) + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + _ <- subscriber.stream.take(10).runDiscard + _ <- stopPromise.get + yield assert(true) + end for + } +end PublisherToSubscriberTest + +object PublisherToSubscriberTest: + type TestError = TestError.type + object TestError extends Exception("BOOM") + private val BufferSize = 1 << 4 +end PublisherToSubscriberTest diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala new file mode 100644 index 000000000..be0322e62 --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala @@ -0,0 +1,50 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.Emit.Ack +import kyo.interop.reactivestreams.StreamPublisher +import org.reactivestreams.tck.PublisherVerification +import org.reactivestreams.tck.TestEnvironment +import org.scalatestplus.testng.* + +final class StreamPublisherTest extends PublisherVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + given Frame = Frame.internal + + private def createStream(n: Int = 1) = + if n <= 0 then + Stream.empty[Int] + else + val chunkSize = Math.sqrt(n).floor.intValue + Stream.range(0, n, 1, chunkSize).map { int => + Random + .use(_.nextInt(50)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen(int) + } + + override def createPublisher(n: Long): StreamPublisher[Int, Async] = + if n > Int.MaxValue then + null + else + StreamPublisher.Unsafe( + createStream(n.toInt), + subscribeCallback = fiber => + discard(IO.Unsafe.run(Abort.run(Async.runAndBlock(Duration.Infinity)(fiber))).eval) + ) + end if + end createPublisher + + override def createFailedPublisher(): StreamPublisher[Int, Async] = + StreamPublisher.Unsafe( + createStream(), + subscribeCallback = fiber => + val asynced = Async.runAndBlock(Duration.Infinity)(fiber) + val aborted = Abort.run(asynced) + val ioed = IO.Unsafe.run(aborted).eval + ioed match + case Result.Success(fiber) => discard(fiber.unsafe.interrupt()) + case _ => () + ) + end createFailedPublisher +end StreamPublisherTest diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala new file mode 100644 index 000000000..f2ce89ca7 --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala @@ -0,0 +1,91 @@ +package kyo.interop.reactivestreams + +import java.lang.Thread +import java.util.concurrent.atomic.AtomicInteger +import kyo.* +import kyo.Result.* +import kyo.interop.reactivestreams.StreamSubscriber +import kyo.interop.reactivestreams.StreamSubscriber.* +import org.reactivestreams.* +import org.reactivestreams.tck.SubscriberBlackboxVerification +import org.reactivestreams.tck.SubscriberWhiteboxVerification +import org.reactivestreams.tck.SubscriberWhiteboxVerification.SubscriberPuppet +import org.reactivestreams.tck.SubscriberWhiteboxVerification.WhiteboxSubscriberProbe +import org.reactivestreams.tck.TestEnvironment +import org.scalatestplus.testng.* + +class StreamSubscriberTest extends SubscriberWhiteboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + private val counter = new AtomicInteger() + + def createSubscriber( + p: SubscriberWhiteboxVerification.WhiteboxSubscriberProbe[Int] + ): Subscriber[Int] = + IO.Unsafe.run { + StreamSubscriber[Int](bufferSize = 1).map(s => new WhiteboxSubscriber(s, p)) + }.eval + + def createElement(i: Int): Int = counter.getAndIncrement +end StreamSubscriberTest + +final class WhiteboxSubscriber[V]( + sub: StreamSubscriber[V], + probe: WhiteboxSubscriberProbe[V] +)( + using Tag[V] +) extends Subscriber[V]: + import AllowUnsafe.embrace.danger + + def onError(t: Throwable): Unit = + sub.onError(t) + probe.registerOnError(t) + end onError + + def onSubscribe(s: Subscription): Unit = + sub.onSubscribe(s) + probe.registerOnSubscribe( + new SubscriberPuppet: + override def triggerRequest(elements: Long): Unit = + val computation: Unit < IO = Loop(elements) { remaining => + if remaining <= 0 then + Loop.done + else + sub.request.map { accepted => + Loop.continue(remaining - accepted) + } + } + IO.Unsafe.run(computation).eval + end triggerRequest + + override def signalCancel(): Unit = + s.cancel() + end signalCancel + ) + end onSubscribe + + def onComplete(): Unit = + sub.onComplete() + probe.registerOnComplete() + end onComplete + + def onNext(a: V): Unit = + sub.onNext(a) + probe.registerOnNext(a) + end onNext +end WhiteboxSubscriber + +final class SubscriberBlackboxSpec extends SubscriberBlackboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + private val counter = new AtomicInteger() + + def createSubscriber(): StreamSubscriber[Int] = + IO.Unsafe.run { + StreamSubscriber[Int](bufferSize = 1) + }.eval + + override def triggerRequest(s: Subscriber[? >: Int]): Unit = + val computation: Long < IO = s.asInstanceOf[StreamSubscriber[Int]].request + discard(IO.Unsafe.run(computation).eval) + + def createElement(i: Int): Int = counter.incrementAndGet() +end SubscriberBlackboxSpec