From d3933f95ca7092bbd6a228d4408070f887c5c95e Mon Sep 17 00:00:00 2001 From: HollandDM Date: Wed, 27 Nov 2024 21:26:25 +0700 Subject: [PATCH] reactivestream/flow interop --- .gitignore | 1 + build.sbt | 21 +- kyo-core/jvm/src/main/scala/kyo/Path.scala | 11 +- .../main/scala/kyo/interop/flow/package.scala | 33 +++ .../reactive-streams/StreamPublisher.scala | 124 +++++++++ .../reactive-streams/StreamSubscriber.scala | 205 ++++++++++++++ .../reactive-streams/StreamSubscription.scala | 255 ++++++++++++++++++ .../interop/reactive-streams/package.scala | 33 +++ .../src/test/scala/kyo/interop/Test.scala | 28 ++ .../reactive-streams/CancellationTest.scala | 53 ++++ .../PublisherToSubscriberTest.scala | 81 ++++++ .../StreamPublisherTest.scala | 50 ++++ .../StreamSubscriberTest.scala | 91 +++++++ .../src/main/scala/kyo/PlatformBackend.scala | 7 +- .../sttp/client3/HttpClientKyoBackend.scala | 80 +++--- .../sttp/client3/KyoBodyFromHttpClient.scala | 104 +++++++ .../sttp/client3/KyoBodyToHttpClient.scala | 25 ++ .../src/test/scala/kyo/RequestsLiveTest.scala | 10 +- .../shared/src/main/scala/kyo/Requests.scala | 14 +- .../src/test/scala/kyo/RequestsTest.scala | 4 +- 20 files changed, 1173 insertions(+), 57 deletions(-) create mode 100644 kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala create mode 100644 kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala create mode 100644 kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala create mode 100644 kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala create mode 100644 kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala create mode 100644 kyo-reactive-streams/shared/src/test/scala/kyo/interop/Test.scala create mode 100644 kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala create mode 100644 kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala create mode 100644 kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala create mode 100644 kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala create mode 100644 kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyFromHttpClient.scala create mode 100644 kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyToHttpClient.scala diff --git a/.gitignore b/.gitignore index ca2708f55..51a929e0d 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ jmh-result.json *.jfr *.json *.gpg +test-output \ No newline at end of file diff --git a/build.sbt b/build.sbt index 64bb08a71..dd98d28f9 100644 --- a/build.sbt +++ b/build.sbt @@ -96,6 +96,7 @@ lazy val kyoJVM = project `kyo-stats-registry`.jvm, `kyo-stats-otel`.jvm, `kyo-cache`.jvm, + `kyo-reactive-streams`.jvm, `kyo-sttp`.jvm, `kyo-tapir`.jvm, `kyo-caliban`.jvm, @@ -306,6 +307,22 @@ lazy val `kyo-cache` = ) .jvmSettings(mimaCheck(false)) +lazy val `kyo-reactive-streams` = + crossProject(JVMPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Full) + .in(file("kyo-reactive-streams")) + .dependsOn(`kyo-core`) + .settings( + `kyo-settings`, + libraryDependencies ++= Seq( + "org.reactivestreams" % "reactive-streams" % "1.0.4", + "org.reactivestreams" % "reactive-streams-tck" % "1.0.4" % Test, + "org.scalatestplus" %% "testng-7-5" % "3.2.17.0" % Test + ) + ) + .jvmSettings(mimaCheck(false)) + lazy val `kyo-sttp` = crossProject(JSPlatform, JVMPlatform, NativePlatform) .withoutSuffixFor(JVMPlatform) @@ -314,7 +331,9 @@ lazy val `kyo-sttp` = .dependsOn(`kyo-core`) .settings( `kyo-settings`, - libraryDependencies += "com.softwaremill.sttp.client3" %%% "core" % "3.10.1" + libraryDependencies ++= Seq( + "com.softwaremill.sttp.client3" %%% "core" % "3.10.1", + ) ) .jsSettings(`js-settings`) .nativeSettings(`native-settings`) diff --git a/kyo-core/jvm/src/main/scala/kyo/Path.scala b/kyo-core/jvm/src/main/scala/kyo/Path.scala index c76d17f35..9ab752b65 100644 --- a/kyo-core/jvm/src/main/scala/kyo/Path.scala +++ b/kyo-core/jvm/src/main/scala/kyo/Path.scala @@ -152,7 +152,7 @@ class Path private (val path: List[String]) derives CanEqual: end if } - private def readLoop[A, ReadTpe, Res]( + private[kyo] def readLoop[A, ReadTpe, Res]( acquire: Res < IO, release: Res => Unit < Async, readOnce: Res => Maybe[ReadTpe] < IO, @@ -361,14 +361,17 @@ object Path: case s: String => List(s) case p: Path => p.path } - val javaPath = if flattened.isEmpty then Paths.get("") else Paths.get(flattened.head, flattened.tail*) - val normalizedPath = javaPath.normalize().toString - new Path(if normalizedPath.isEmpty then Nil else normalizedPath.split(File.separator).toList) + val javaPath = if flattened.isEmpty then Paths.get("") else Paths.get(flattened.head, flattened.tail*) + fromJavaPath(javaPath) end apply def apply(path: Part*): Path = apply(path.toList) + def fromJavaPath(path: JPath): Path = + val normalizedPath = path.normalize().toString + new Path(if normalizedPath.isEmpty then Nil else normalizedPath.split(File.separator).toList) + case class BasePaths( cache: Path, config: Path, diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala new file mode 100644 index 000000000..2be446b82 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/flow/package.scala @@ -0,0 +1,33 @@ +package kyo.interop + +import java.util.concurrent.Flow.* +import kyo.* +import kyo.interop.reactivestreams +import kyo.kernel.Boundary +import org.reactivestreams.FlowAdapters + +package object flow: + inline def fromPublisher[T]( + publisher: Publisher[T], + bufferSize: Int + )( + using + Frame, + Tag[T] + ): Stream[T, Async] < IO = reactivestreams.fromPublisher(FlowAdapters.toPublisher(publisher), bufferSize) + + def subscribeToStream[T, Ctx]( + stream: Stream[T, Ctx], + subscriber: Subscriber[? >: T] + )( + using + Boundary[Ctx, IO], + Frame, + Tag[T] + ): Subscription < (Resource & IO & Ctx) = + reactivestreams.subscribeToStream(stream, FlowAdapters.toSubscriber(subscriber)).map { subscription => + new Subscription: + override def request(n: Long): Unit = subscription.request(n) + override def cancel(): Unit = subscription.cancel() + } +end flow diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala new file mode 100644 index 000000000..d34205364 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamPublisher.scala @@ -0,0 +1,124 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.interop.reactivestreams.StreamSubscription.StreamFinishState +import kyo.kernel.Boundary +import kyo.kernel.ContextEffect.Isolated +import kyo.kernel.Safepoint +import kyo.scheduler.IOTask +import org.reactivestreams.* + +abstract class StreamPublisher[V, Ctx] private ( + stream: Stream[V, Ctx] +) extends Publisher[V]: + + protected def bind(subscriber: Subscriber[? >: V]): Unit + + override def subscribe(subscriber: Subscriber[? >: V]): Unit = + if isNull(subscriber) then + throw new NullPointerException("Subscriber must not be null.") + else + bind(subscriber) + end subscribe + +end StreamPublisher + +object StreamPublisher: + def apply[V, Ctx]( + stream: Stream[V, Ctx], + capacity: Int = Int.MaxValue + )( + using + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamPublisher[V, Ctx] < (Resource & IO & Ctx) = + inline def interruptPanic = Result.Panic(Fiber.Interrupted(frame)) + + def discardSubscriber(subscriber: Subscriber[? >: V]): Unit = + subscriber.onSubscribe(new Subscription: + override def request(n: Long): Unit = () + override def cancel(): Unit = () + ) + subscriber.onComplete() + end discardSubscriber + + def consumeChannel( + channel: Channel[Subscriber[? >: V]], + supervisorPromise: Fiber.Promise[Nothing, Unit] + ): Unit < (Async & Ctx) = + Loop(()) { _ => + channel.closed.map { + if _ then + Loop.done + else + val result = Abort.run[Closed] { + for + subscriber <- channel.take + subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber)) + fiber <- subscription.subscribe.andThen(subscription.consume) + _ <- supervisorPromise.onComplete(_ => discard(fiber.interrupt(interruptPanic))) + yield () + } + result.map { + case Result.Success(_) => Loop.continue(()) + case _ => Loop.done + } + } + } + + IO.Unsafe { + for + channel <- + Resource.acquireRelease(Channel.init[Subscriber[? >: V]](capacity))( + _.close.map(_.foreach(_.foreach(discardSubscriber(_)))) + ) + publisher <- IO.Unsafe { + new StreamPublisher[V, Ctx](stream): + override protected def bind( + subscriber: Subscriber[? >: V] + ): Unit = + channel.unsafe.offer(subscriber) match + case Result.Success(true) => () + case _ => discardSubscriber(subscriber) + } + supervisorPromise <- Fiber.Promise.init[Nothing, Unit] + _ <- Resource.acquireRelease(boundary((trace, context) => + Fiber.fromTask(IOTask(consumeChannel(channel, supervisorPromise), trace, context)) + ))( + _.interrupt.map(discard(_)) + ) + yield publisher + end for + } + end apply + + object Unsafe: + def apply[V, Ctx]( + stream: Stream[V, Ctx], + subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit + )( + using + allowance: AllowUnsafe, + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamPublisher[V, Ctx] = + new StreamPublisher[V, Ctx](stream): + override protected def bind( + subscriber: Subscriber[? >: V] + ): Unit = + discard(StreamSubscription.Unsafe.subscribe( + stream, + subscriber + )( + subscribeCallback + )( + using + allowance, + boundary, + frame, + tag + )) + end Unsafe +end StreamPublisher diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala new file mode 100644 index 000000000..a9d9832b8 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscriber.scala @@ -0,0 +1,205 @@ +package kyo.interop.reactivestreams + +import StreamSubscriber.* +import kyo.* +import kyo.Emit.Ack +import org.reactivestreams.* + +final class StreamSubscriber[V] private ( + bufferSize: Int +)( + using + allowance: AllowUnsafe, + frame: Frame +) extends Subscriber[V]: + + private enum UpstreamState derives CanEqual: + case Uninitialized extends UpstreamState + case WaitForRequest(subscription: Subscription, items: Chunk[V], remaining: Int) extends UpstreamState + case Finished(reason: Maybe[Throwable], leftOver: Chunk[V]) extends UpstreamState + end UpstreamState + + private val state = AtomicRef.Unsafe.init( + UpstreamState.Uninitialized -> Maybe.empty[Fiber.Promise.Unsafe[Nothing, Unit]] + ) + + private def throwIfNull[A](b: A): Unit = if isNull(b) then throw new NullPointerException() + + override def onSubscribe(subscription: Subscription): Unit = + throwIfNull(subscription) + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.Uninitialized, maybePromise) => + // Notify if someone wait + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.WaitForRequest(subscription, Chunk.empty, 0) -> Absent + case other => + // wrong state, cancel incoming subscription + sideEffect = () => subscription.cancel() + other + } + sideEffect() + end onSubscribe + + override def onNext(item: V): Unit = + throwIfNull(item) + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.WaitForRequest(subscription, items.append(item), remaining - 1) -> Absent + case other => + sideEffect = () => () + other + } + sideEffect() + end onNext + + override def onError(throwable: Throwable): Unit = + throwIfNull(throwable) + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.Finished(Maybe(throwable), items) -> Absent + case other => + sideEffect = () => () + other + } + sideEffect() + end onError + + override def onComplete(): Unit = + var sideEffect: () => Unit = () => () + state.update { + case (UpstreamState.WaitForRequest(_, items, _), maybePromise) => + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.Finished(Absent, items) -> Absent + case other => + sideEffect = () => () + other + } + sideEffect() + end onComplete + + private[interop] def await: Boolean < Async = + var sideEffect: () => (Boolean < Async) = () => IO(false) + state.update { + case (UpstreamState.Uninitialized, Absent) => + val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() + sideEffect = () => promise.safe.use(_ => false) + UpstreamState.Uninitialized -> Present(promise) + case s @ (UpstreamState.Uninitialized, Present(promise)) => + sideEffect = () => promise.safe.use(_ => false) + s + case s @ (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => + if items.isEmpty then + if remaining == 0 then + sideEffect = () => IO(true) + UpstreamState.WaitForRequest(subscription, items, remaining) -> Absent + else + val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]() + sideEffect = () => promise.safe.use(_ => false) + UpstreamState.WaitForRequest(subscription, items, remaining) -> Present(promise) + else + sideEffect = () => IO(false) + s + case s @ (UpstreamState.Finished(_, _), maybePromise) => + sideEffect = () => + maybePromise match + case Present(promise) => promise.safe.use(_ => false) + case Absent => IO(false) + s + case other => + sideEffect = () => IO(false) + other + } + sideEffect() + end await + + private[interop] def request: Long < IO = + var sideEffect: () => Long < IO = () => IO(0L) + state.update { + case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) => + sideEffect = () => IO(subscription.request(bufferSize)).andThen(bufferSize.toLong) + UpstreamState.WaitForRequest(subscription, items, remaining + bufferSize) -> maybePromise + case other => + sideEffect = () => IO(0L) + other + } + sideEffect() + end request + + private[interop] def poll: Result[Throwable | SubscriberDone, Chunk[V]] < IO = + var sideEffect: () => (Result[Throwable | SubscriberDone, Chunk[V]] < IO) = () => IO(Result.success(Chunk.empty)) + state.update { + case (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) => + sideEffect = () => IO(Result.success(items)) + UpstreamState.WaitForRequest(subscription, Chunk.empty, remaining) -> Absent + case s @ (UpstreamState.Finished(reason, leftOver), Absent) => + if leftOver.isEmpty then + sideEffect = () => + IO { + reason match + case Present(error) => Result.fail(error) + case Absent => Result.fail(SubscriberDone) + + } + s + else + sideEffect = () => IO(Result.success(leftOver)) + UpstreamState.Finished(reason, Chunk.empty) -> Absent + case other => + sideEffect = () => IO(Result.success(Chunk.empty)) + other + } + sideEffect() + end poll + + private[interop] def interupt: Unit < IO = + var sideEffect: () => (Unit < IO) = () => IO.unit + state.update { + case (UpstreamState.Uninitialized, maybePromise) => + // Notify if someone wait + sideEffect = () => maybePromise.foreach(_.completeDiscard(Result.success(()))) + UpstreamState.Finished(Absent, Chunk.empty) -> Absent + case (UpstreamState.WaitForRequest(subscription, _, _), Absent) => + sideEffect = () => IO(subscription.cancel()) + UpstreamState.Finished(Absent, Chunk.empty) -> Absent + case other => + sideEffect = () => IO.unit + other + } + sideEffect() + end interupt + + private[interop] def emit(ack: Ack)(using Tag[V]): Ack < (Emit[Chunk[V]] & Async) = + ack match + case Ack.Stop => interupt.andThen(Ack.Stop) + case Ack.Continue(_) => + await.map { + if _ then + request.andThen(Ack.Continue()) + else + poll.map { + case Result.Success(nextChunk) => Emit(nextChunk) + case Result.Error(e: Throwable) => Abort.panic(e) + case _ => Ack.Stop + } + }.map(emit) + + def stream(using Tag[V]): Stream[V, Async] = Stream(Emit.andMap(Chunk.empty)(emit)) + +end StreamSubscriber + +object StreamSubscriber: + + abstract private[reactivestreams] class SubscriberDone + private[reactivestreams] case object SubscriberDone extends SubscriberDone + + def apply[V]( + bufferSize: Int + )( + using Frame + ): StreamSubscriber[V] < IO = IO.Unsafe(new StreamSubscriber(bufferSize)) +end StreamSubscriber diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala new file mode 100644 index 000000000..c3605aa23 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/StreamSubscription.scala @@ -0,0 +1,255 @@ +package kyo.interop.reactivestreams + +import StreamSubscription.* +import kyo.* +import kyo.Emit.Ack +import kyo.interop.reactivestreams.StreamSubscription.StreamFinishState +import kyo.kernel.ArrowEffect +import kyo.kernel.Boundary +import kyo.kernel.Safepoint +import kyo.scheduler.IOPromise +import kyo.scheduler.IOTask +import org.reactivestreams.* + +final class StreamSubscription[V, Ctx] private[reactivestreams] ( + private val stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] +)( + using + allowance: AllowUnsafe, + boundary: Boundary[Ctx, IO], + tag: Tag[V], + frame: Frame +) extends Subscription: + + private enum DownstreamState derives CanEqual: + case Uninitialized extends DownstreamState + case Requesting( + requested: Long, + maybePromise: Maybe[(Long, Fiber.Promise.Unsafe[Unit, Long])] + ) extends DownstreamState + case Finished extends DownstreamState + end DownstreamState + + private val state = AtomicRef.Unsafe.init(DownstreamState.Uninitialized) + + private def offer(n: Long)(using Frame): Result[Unit, Long] < Async = + var sideEffect: () => Result[Unit, Long] < Async = () => null.asInstanceOf + state.update { + case DownstreamState.Requesting(0L, Absent) => + // No one requested, accumulate offerring and wait + val promise = Fiber.Promise.Unsafe.init[Unit, Long]() + sideEffect = () => promise.safe.getResult + DownstreamState.Requesting(0L, Present(n -> promise)) + case DownstreamState.Requesting(requested, Absent) => + // Someone requested, we offer right away + val accepted = Math.min(requested, n) + val nextRequested = requested - accepted + sideEffect = () => IO(Result.success(accepted)) + DownstreamState.Requesting(nextRequested, Absent) + case DownstreamState.Finished => + // Downstream cancelled + sideEffect = () => IO(Result.fail(())) + DownstreamState.Finished + case other => + sideEffect = () => IO(Result.success(0L)) + other + } + sideEffect() + end offer + + override def request(n: Long): Unit = + if n <= 0 then subscriber.onError(new IllegalArgumentException("non-positive subscription request")) + var sideEffect: () => Unit = () => () + state.update { + case DownstreamState.Requesting(0L, Present(offered -> promise)) => + val accepted = Math.min(offered, n) + val nextRequested = n - accepted + val nextOfferred = offered - accepted + sideEffect = () => promise.completeDiscard(Result.success(accepted)) + DownstreamState.Requesting(nextRequested, Absent) + case DownstreamState.Requesting(requested, Absent) => + val nextRequested = Math.min(Long.MaxValue - requested, n) + requested + sideEffect = () => () + DownstreamState.Requesting(nextRequested, Absent) + case other => + sideEffect = () => () + other + } + sideEffect() + end request + + override def cancel(): Unit = + given Frame = Frame.internal + var sideEffect: () => Unit = () => () + state.update { + case DownstreamState.Requesting(_, Present(_ -> promise)) => + sideEffect = () => promise.completeDiscard(Result.fail(())) + DownstreamState.Finished + case other => + sideEffect = () => () + DownstreamState.Finished + } + sideEffect() + end cancel + + def subscribe: Unit < (IO & Ctx) = + var sideEffect: () => Unit < (IO & Ctx) = () => IO.unit + state.update { + case DownstreamState.Uninitialized => + sideEffect = () => + IO { + subscriber.onSubscribe(this) + } + DownstreamState.Requesting(0L, Absent) + case other => + sideEffect = () => IO.unit + other + } + sideEffect() + end subscribe + + private[reactivestreams] def consume( + using + tag: Tag[Emit[Chunk[V]]], + frame: Frame, + safepoint: Safepoint + ): Fiber[Nothing, StreamFinishState] < (IO & Ctx) = + def consumeStream: StreamFinishState < (Abort[Nothing] & Async & Ctx) = + ArrowEffect.handleState(tag, 0: (Long | StreamFinishState), stream.emit.unit)( + handle = + [C] => + (input, state, cont) => + // Handle the input chunk + if input.nonEmpty then + // Input chunk contains values that we need to feed the subscriber + Loop[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState), Abort[Nothing] & Async & Ctx]( + input, + state + ) { + (curChunk, curState) => + curState match + case leftOver: Long => + if curChunk.isEmpty then + // We finish the current chunk, go next + Loop.done[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState)]( + Ack.Continue() -> leftOver + ) + else + if leftOver > 0 then + // Some requests left from last loop, feed them + val taken = Math.min(curChunk.size, leftOver) + val nextLeftOver = leftOver - taken + curChunk.take(taken.toInt).foreach { value => + subscriber.onNext(value) + } + // Loop the rest + Loop.continue(curChunk.drop(taken.toInt), nextLeftOver) + else + for + // We signal that we can `offer` "curChunk.size" elements + // then we wait until subscriber picks up that offer + acceptedResult <- offer(curChunk.size) + outcome = acceptedResult match + // Subscriber requests "accepted" elements + case Result.Success(accepted) => + val taken = Math.min(curChunk.size, accepted) + val nextLeftOver = accepted - taken + curChunk.take(taken.toInt).foreach { value => + subscriber.onNext(value) + } + // Loop the rest + IO(Loop.continue(curChunk.drop(taken.toInt), nextLeftOver)) + case Result.Error(e) => + e match + case t: Throwable => + Abort.panic(t) + .andThen(Loop.done[ + Chunk[V], + Long | StreamFinishState, + (Ack, Long | StreamFinishState) + ]( + Ack.Stop -> StreamFinishState.StreamCanceled + )) + case _: Unit => + IO(Loop.done[ + Chunk[V], + Long | StreamFinishState, + (Ack, Long | StreamFinishState) + ]( + Ack.Stop -> StreamFinishState.StreamCanceled + )) + yield outcome + end for + end if + case finishState: StreamFinishState => + Loop.done[Chunk[V], Long | StreamFinishState, (Ack, Long | StreamFinishState)]( + Ack.Stop -> finishState + ) + }.map { case (ack, state) => + state -> cont(ack) + } + else + // The input chunk is empty, we go next + state -> cont(Ack.Continue()) + , + done = (state, _) => + state match + case _: Long => StreamFinishState.StreamComplete + case finishState: StreamFinishState => finishState + ) + + boundary { (trace, context) => + val fiber = Fiber.fromTask(IOTask(consumeStream, safepoint.copyTrace(trace), context)) + fiber.unsafe.onComplete { + case Result.Success(StreamFinishState.StreamComplete) => subscriber.onComplete() + case Result.Panic(e) => subscriber.onError(e) + case _ => () + } + fiber + } + end consume + +end StreamSubscription + +object StreamSubscription: + + private[reactivestreams] enum StreamFinishState derives CanEqual: + case StreamComplete, StreamCanceled + end StreamFinishState + + def subscribe[V, Ctx]( + stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] + )( + using + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamSubscription[V, Ctx] < (IO & Ctx & Resource) = + for + subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber)) + _ <- subscription.subscribe + _ <- Resource.acquireRelease(subscription.consume)(_.interrupt.unit) + yield subscription + + object Unsafe: + def subscribe[V, Ctx]( + stream: Stream[V, Ctx], + subscriber: Subscriber[? >: V] + )( + subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit + )( + using + allowance: AllowUnsafe, + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[V] + ): StreamSubscription[V, Ctx] = + val subscription = new StreamSubscription[V, Ctx](stream, subscriber) + subscribeCallback(subscription.subscribe.andThen(subscription.consume)) + subscription + end subscribe + end Unsafe + +end StreamSubscription diff --git a/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala new file mode 100644 index 000000000..4463bd4b1 --- /dev/null +++ b/kyo-reactive-streams/shared/src/main/scala/kyo/interop/reactive-streams/package.scala @@ -0,0 +1,33 @@ +package kyo.interop + +import kyo.* +import kyo.interop.reactivestreams.* +import kyo.kernel.Boundary +import org.reactivestreams.* +import org.reactivestreams.FlowAdapters + +package object reactivestreams: + def fromPublisher[T]( + publisher: Publisher[T], + bufferSize: Int + )( + using + Frame, + Tag[T] + ): Stream[T, Async] < IO = + for + subscriber <- StreamSubscriber[T](bufferSize) + _ <- IO(publisher.subscribe(subscriber)) + yield subscriber.stream + + def subscribeToStream[T, Ctx]( + stream: Stream[T, Ctx], + subscriber: Subscriber[? >: T] + )( + using + boundary: Boundary[Ctx, IO], + frame: Frame, + tag: Tag[T] + ): Subscription < (Resource & IO & Ctx) = + StreamSubscription.subscribe(stream, subscriber)(using boundary, frame, tag) +end reactivestreams diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/Test.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/Test.scala new file mode 100644 index 000000000..439a59764 --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/Test.scala @@ -0,0 +1,28 @@ +package kyo + +import kyo.internal.BaseKyoTest +import kyo.kernel.Platform +import org.scalatest.NonImplicitAssertions +import org.scalatest.freespec.AsyncFreeSpec +import scala.concurrent.ExecutionContext +import scala.concurrent.Future + +abstract class Test extends AsyncFreeSpec with BaseKyoTest[Async & Resource] with NonImplicitAssertions: + + def run(v: Future[Assertion] < (Async & Resource)): Future[Assertion] = + import AllowUnsafe.embrace.danger + Abort.run[Any](v) + .map(_.fold(e => throw new Exception("Test failed with " + e))(identity)) + .pipe(Resource.run) + .pipe(Async.run) + .map(_.toFuture) + .map(_.flatten) + .pipe(IO.Unsafe.run) + .eval + end run + + type Assertion = org.scalatest.compatible.Assertion + def success = succeed + + override given executionContext: ExecutionContext = Platform.executionContext +end Test diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala new file mode 100644 index 000000000..62b9959cd --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/CancellationTest.scala @@ -0,0 +1,53 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.interop.reactivestreams.StreamSubscription +import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription + +final class CancellationTest extends Test: + final class Sub[A](b: AtomicBoolean.Unsafe) extends Subscriber[A]: + import AllowUnsafe.embrace.danger + def onNext(t: A) = b.set(true) + def onComplete() = b.set(true) + def onError(e: Throwable) = b.set(true) + def onSubscribe(s: Subscription) = () + end Sub + + val stream: Stream[Int, Any] = Stream.range(0, 5, 1) + + val attempts = 1000 + + def testStreamSubscription(clue: String)(program: Subscription => Unit): Unit < IO = + Loop(attempts) { index => + if index <= 0 then + Loop.done + else + IO.Unsafe { + val flag = AtomicBoolean.Unsafe.init(false) + StreamSubscription.Unsafe.subscribe(stream, new Sub(flag)) { fiber => + discard(IO.Unsafe.run(fiber).eval) + } + }.map { subscription => + program(subscription) + }.andThen(Loop.continue(index - 1)) + end if + } + + "after subscription is canceled request must be NOOPs" in runJVM { + testStreamSubscription(clue = "onNext was called after the subscription was canceled") { sub => + sub.cancel() + sub.request(1) + sub.request(1) + sub.request(1) + }.map(_ => assert(true)) + } + + "after subscription is canceled additional cancellations must be NOOPs" in runJVM { + testStreamSubscription(clue = "onComplete was called after the subscription was canceled") { + sub => + sub.cancel() + sub.cancel() + }.map(_ => assert(true)) + } +end CancellationTest diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala new file mode 100644 index 000000000..c12ec394e --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/PublisherToSubscriberTest.scala @@ -0,0 +1,81 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.Duration +import kyo.Emit.Ack +import kyo.interop.reactivestreams.* + +final class PublisherToSubscriberTest extends Test: + import PublisherToSubscriberTest.* + + private def randomStream: Stream[Int, Async] < IO = + IO(Stream + .range(0, 1 << 10, 1, BufferSize) + .map { int => + Random + .use(_.nextInt(10)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen(int) + }) + + "should have the same output as input" in runJVM { + for + stream <- randomStream + publisher <- StreamPublisher[Int, Async](stream) + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + (isSame, _) <- subscriber.stream + .runFold(true -> 0) { case ((acc, expected), cur) => + Random + .use(_.nextInt(10)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen((acc && (expected == cur)) -> (expected + 1)) + } + yield assert(isSame) + } + + "should propagate errors downstream" in runJVM { + for + inputStream <- IO { + Stream.range(0, 10, 1, 1).map { int => + if int < 5 then + Async.sleep(Duration.fromUnits(10, Duration.Units.Millis)).andThen(int) + else + Abort.panic(TestError) + } + } + publisher <- StreamPublisher[Int, Async](inputStream) + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + result <- Abort.run[Throwable](subscriber.stream.runDiscard) + yield result match + case Result.Error(e: Throwable) => assert(e == TestError) + case _ => assert(false) + end for + } + + "should cancel upstream if downstream completes" in runJVM { + def emit(ack: Ack, cur: Int, stopPromise: Fiber.Promise[Nothing, Unit]): Ack < (Emit[Chunk[Int]] & IO) = + ack match + case Ack.Stop => stopPromise.completeDiscard(Result.success(())).andThen(Ack.Stop) + case Ack.Continue(_) => Emit.andMap(Chunk(cur))(emit(_, cur + 1, stopPromise)) + end emit + + for + stopPromise <- Fiber.Promise.init[Nothing, Unit] + stream <- IO(Stream(Emit.andMap(Chunk.empty[Int])(emit(_, 0, stopPromise)))) + publisher <- StreamPublisher[Int, IO](stream) + subscriber <- StreamSubscriber[Int](BufferSize) + _ = publisher.subscribe(subscriber) + _ <- subscriber.stream.take(10).runDiscard + _ <- stopPromise.get + yield assert(true) + end for + } +end PublisherToSubscriberTest + +object PublisherToSubscriberTest: + type TestError = TestError.type + object TestError extends Exception("BOOM") + private val BufferSize = 1 << 4 +end PublisherToSubscriberTest diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala new file mode 100644 index 000000000..be0322e62 --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamPublisherTest.scala @@ -0,0 +1,50 @@ +package kyo.interop.reactivestreams + +import kyo.* +import kyo.Emit.Ack +import kyo.interop.reactivestreams.StreamPublisher +import org.reactivestreams.tck.PublisherVerification +import org.reactivestreams.tck.TestEnvironment +import org.scalatestplus.testng.* + +final class StreamPublisherTest extends PublisherVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + given Frame = Frame.internal + + private def createStream(n: Int = 1) = + if n <= 0 then + Stream.empty[Int] + else + val chunkSize = Math.sqrt(n).floor.intValue + Stream.range(0, n, 1, chunkSize).map { int => + Random + .use(_.nextInt(50)) + .map(millis => Async.sleep(Duration.fromUnits(millis, Duration.Units.Millis))) + .andThen(int) + } + + override def createPublisher(n: Long): StreamPublisher[Int, Async] = + if n > Int.MaxValue then + null + else + StreamPublisher.Unsafe( + createStream(n.toInt), + subscribeCallback = fiber => + discard(IO.Unsafe.run(Abort.run(Async.runAndBlock(Duration.Infinity)(fiber))).eval) + ) + end if + end createPublisher + + override def createFailedPublisher(): StreamPublisher[Int, Async] = + StreamPublisher.Unsafe( + createStream(), + subscribeCallback = fiber => + val asynced = Async.runAndBlock(Duration.Infinity)(fiber) + val aborted = Abort.run(asynced) + val ioed = IO.Unsafe.run(aborted).eval + ioed match + case Result.Success(fiber) => discard(fiber.unsafe.interrupt()) + case _ => () + ) + end createFailedPublisher +end StreamPublisherTest diff --git a/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala new file mode 100644 index 000000000..f2ce89ca7 --- /dev/null +++ b/kyo-reactive-streams/shared/src/test/scala/kyo/interop/reactive-streams/StreamSubscriberTest.scala @@ -0,0 +1,91 @@ +package kyo.interop.reactivestreams + +import java.lang.Thread +import java.util.concurrent.atomic.AtomicInteger +import kyo.* +import kyo.Result.* +import kyo.interop.reactivestreams.StreamSubscriber +import kyo.interop.reactivestreams.StreamSubscriber.* +import org.reactivestreams.* +import org.reactivestreams.tck.SubscriberBlackboxVerification +import org.reactivestreams.tck.SubscriberWhiteboxVerification +import org.reactivestreams.tck.SubscriberWhiteboxVerification.SubscriberPuppet +import org.reactivestreams.tck.SubscriberWhiteboxVerification.WhiteboxSubscriberProbe +import org.reactivestreams.tck.TestEnvironment +import org.scalatestplus.testng.* + +class StreamSubscriberTest extends SubscriberWhiteboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + private val counter = new AtomicInteger() + + def createSubscriber( + p: SubscriberWhiteboxVerification.WhiteboxSubscriberProbe[Int] + ): Subscriber[Int] = + IO.Unsafe.run { + StreamSubscriber[Int](bufferSize = 1).map(s => new WhiteboxSubscriber(s, p)) + }.eval + + def createElement(i: Int): Int = counter.getAndIncrement +end StreamSubscriberTest + +final class WhiteboxSubscriber[V]( + sub: StreamSubscriber[V], + probe: WhiteboxSubscriberProbe[V] +)( + using Tag[V] +) extends Subscriber[V]: + import AllowUnsafe.embrace.danger + + def onError(t: Throwable): Unit = + sub.onError(t) + probe.registerOnError(t) + end onError + + def onSubscribe(s: Subscription): Unit = + sub.onSubscribe(s) + probe.registerOnSubscribe( + new SubscriberPuppet: + override def triggerRequest(elements: Long): Unit = + val computation: Unit < IO = Loop(elements) { remaining => + if remaining <= 0 then + Loop.done + else + sub.request.map { accepted => + Loop.continue(remaining - accepted) + } + } + IO.Unsafe.run(computation).eval + end triggerRequest + + override def signalCancel(): Unit = + s.cancel() + end signalCancel + ) + end onSubscribe + + def onComplete(): Unit = + sub.onComplete() + probe.registerOnComplete() + end onComplete + + def onNext(a: V): Unit = + sub.onNext(a) + probe.registerOnNext(a) + end onNext +end WhiteboxSubscriber + +final class SubscriberBlackboxSpec extends SubscriberBlackboxVerification[Int](new TestEnvironment(1000L)), TestNGSuiteLike: + import AllowUnsafe.embrace.danger + private val counter = new AtomicInteger() + + def createSubscriber(): StreamSubscriber[Int] = + IO.Unsafe.run { + StreamSubscriber[Int](bufferSize = 1) + }.eval + + override def triggerRequest(s: Subscriber[? >: Int]): Unit = + val computation: Long < IO = s.asInstanceOf[StreamSubscriber[Int]].request + discard(IO.Unsafe.run(computation).eval) + + def createElement(i: Int): Int = counter.incrementAndGet() +end SubscriberBlackboxSpec diff --git a/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala b/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala index 5e984c6f6..1211112c8 100644 --- a/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala +++ b/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala @@ -2,15 +2,16 @@ package kyo import java.net.http.HttpClient import kyo.Requests.Backend +import kyo.capabilities.KyoStreams import kyo.internal.KyoSttpMonad import sttp.capabilities.WebSockets import sttp.client3.* object PlatformBackend: - def apply(backend: SttpBackend[KyoSttpMonad.M, WebSockets])(using Frame): Backend = + def apply(backend: SttpBackend[KyoSttpMonad.M, KyoStreams & WebSockets])(using Frame): Backend = new Backend: - def send[A: Flat](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, KyoStreams & WebSockets]) = r.send(backend) def apply(client: HttpClient)(using Frame): Backend = @@ -19,6 +20,6 @@ object PlatformBackend: val default = new Backend: val b = HttpClientKyoBackend() - def send[A: Flat](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, KyoStreams & WebSockets]) = r.send(b) end PlatformBackend diff --git a/kyo-sttp/jvm/src/main/scala/sttp/client3/HttpClientKyoBackend.scala b/kyo-sttp/jvm/src/main/scala/sttp/client3/HttpClientKyoBackend.scala index 06a5a22cc..e01e189fd 100644 --- a/kyo-sttp/jvm/src/main/scala/sttp/client3/HttpClientKyoBackend.scala +++ b/kyo-sttp/jvm/src/main/scala/sttp/client3/HttpClientKyoBackend.scala @@ -4,17 +4,25 @@ import java.io.InputStream import java.io.UnsupportedEncodingException import java.net.http.HttpClient import java.net.http.HttpRequest +import java.net.http.HttpRequest.BodyPublisher +import java.net.http.HttpRequest.BodyPublishers import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers +import java.nio.ByteBuffer import java.util.concurrent.Executor +import java.util.concurrent.Flow.Publisher import java.util.zip.GZIPInputStream import java.util.zip.InflaterInputStream +import java.util as ju import kyo.* +import kyo.capabilities.KyoStreams import kyo.internal.KyoSttpMonad import kyo.internal.KyoSttpMonad.* +import kyo.interop.Adapters +import scala.collection.mutable.ArrayBuffer import sttp.capabilities.WebSockets import sttp.client3.HttpClientBackend.EncodingHandler -import sttp.client3.HttpClientFutureBackend.InputStreamEncodingHandler +import sttp.client3.internal.* import sttp.client3.internal.NoStreams import sttp.client3.internal.emptyInputStream import sttp.client3.internal.httpclient.* @@ -27,8 +35,8 @@ class HttpClientKyoBackend private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: InputStreamEncodingHandler -) extends HttpClientAsyncBackend[M, Nothing, WebSockets, InputStream, InputStream]( + customEncodingHandler: EncodingHandler[KyoStreams.BinaryStream] +) extends HttpClientAsyncBackend[M, KyoStreams, KyoStreams & WebSockets, Publisher[ju.List[ByteBuffer]], KyoStreams.BinaryStream]( client, KyoSttpMonad, closeClient, @@ -36,25 +44,11 @@ class HttpClientKyoBackend private ( customEncodingHandler ): - override val streams: NoStreams = NoStreams - - override protected val bodyToHttpClient = - new BodyToHttpClient[KyoSttpMonad.M, Nothing]: - override val streams: NoStreams = NoStreams - override given monad: MonadError[KyoSttpMonad.M] = KyoSttpMonad - override def streamToPublisher(stream: Nothing) = - stream - - override protected val bodyFromHttpClient = - new InputStreamBodyFromHttpClient[KyoSttpMonad.M, Nothing]: - override def inputStreamToStream(is: InputStream) = - KyoSttpMonad.error(new IllegalStateException("Streaming is not supported")) - override val streams: NoStreams = NoStreams - override given monad: MonadError[KyoSttpMonad.M] = KyoSttpMonad - override def compileWebSocketPipe( - ws: WebSocket[KyoSttpMonad.M], - pipe: streams.Pipe[WebSocketFrame.Data[?], WebSocketFrame] - ) = pipe + override val streams: KyoStreams = KyoStreams + + override protected val bodyToHttpClient = new KyoBodyToHttpClient + + override protected val bodyFromHttpClient = new KyoBodyFromHttpClient override protected def createSimpleQueue[A] = Channel.init[A](Int.MaxValue).map(new KyoSimpleQueue[A](_)) @@ -62,18 +56,26 @@ class HttpClientKyoBackend private ( override protected def createSequencer = Meter.initMutex.map(new KyoSequencer(_)) - override protected def standardEncoding: (InputStream, String) => InputStream = { - case (body, "gzip") => new GZIPInputStream(body) - case (body, "deflate") => new InflaterInputStream(body) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") + override protected def standardEncoding: (KyoStreams.BinaryStream, String) => KyoStreams.BinaryStream = { + case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") } - override protected def createBodyHandler: HttpResponse.BodyHandler[InputStream] = - BodyHandlers.ofInputStream() - - override protected def bodyHandlerBodyToBody(p: InputStream): InputStream = p - - override protected def emptyBody(): InputStream = emptyInputStream() + override protected def createBodyHandler: HttpResponse.BodyHandler[Publisher[ju.List[ByteBuffer]]] = + BodyHandlers.ofPublisher() + + override protected def bodyHandlerBodyToBody(p: Publisher[ju.List[ByteBuffer]]): KyoStreams.BinaryStream = + Adapters.publisherToStream(p).mapChunk { chunkList => + val builder = ArrayBuffer.newBuilder[Byte] + chunkList.foreach { list => + val iterator = list.iterator() + while iterator.hasNext() do + val bytes = iterator.next().safeRead() + builder ++= bytes + } + Chunk.from(builder.result().toArray) + } + + override protected def emptyBody(): KyoStreams.BinaryStream = Stream.empty[Byte] end HttpClientKyoBackend object HttpClientKyoBackend: @@ -84,8 +86,8 @@ object HttpClientKyoBackend: client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: InputStreamEncodingHandler - ): SttpBackend[KyoSttpMonad.M, WebSockets] = + customEncodingHandler: EncodingHandler[KyoStreams.BinaryStream] + ): SttpBackend[KyoSttpMonad.M, KyoStreams & WebSockets] = new FollowRedirectsBackend( new HttpClientKyoBackend( client, @@ -98,9 +100,9 @@ object HttpClientKyoBackend: def apply( options: SttpBackendOptions = SttpBackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: InputStreamEncodingHandler = PartialFunction.empty, + customEncodingHandler: EncodingHandler[KyoStreams.BinaryStream] = PartialFunction.empty, executor: Option[Executor] = Some(r => r.run()) - ): SttpBackend[KyoSttpMonad.M, WebSockets] = + ): SttpBackend[KyoSttpMonad.M, KyoStreams & WebSockets] = HttpClientKyoBackend( HttpClientBackend.defaultClient(options, executor), closeClient = false, @@ -111,8 +113,8 @@ object HttpClientKyoBackend: def usingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: InputStreamEncodingHandler = PartialFunction.empty - ): SttpBackend[KyoSttpMonad.M, WebSockets] = + customEncodingHandler: EncodingHandler[KyoStreams.BinaryStream] = PartialFunction.empty + ): SttpBackend[KyoSttpMonad.M, KyoStreams & WebSockets] = HttpClientKyoBackend( client, closeClient = false, @@ -120,6 +122,6 @@ object HttpClientKyoBackend: customEncodingHandler ) - def stub: SttpBackendStub[KyoSttpMonad.M, WebSockets] = + def stub: SttpBackendStub[KyoSttpMonad.M, KyoStreams & WebSockets] = SttpBackendStub(KyoSttpMonad) end HttpClientKyoBackend diff --git a/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyFromHttpClient.scala b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyFromHttpClient.scala new file mode 100644 index 000000000..e88ed006e --- /dev/null +++ b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyFromHttpClient.scala @@ -0,0 +1,104 @@ +package sttp.client3 + +import kyo.* +import kyo.Emit.Ack +import kyo.Result.Success +import kyo.capabilities.* +import kyo.capabilities.KyoStreams +import kyo.internal.KyoSttpMonad +import kyo.internal.KyoSttpMonad.* +import kyo.sink +import sttp.capabilities.Streams +import sttp.client3.internal.BodyFromResponseAs +import sttp.client3.internal.SttpFile +import sttp.client3.internal.httpclient.BodyFromHttpClient +import sttp.client3.ws.GotAWebSocketException +import sttp.client3.ws.NotAWebSocketException +import sttp.model.ResponseMetadata +import sttp.monad.MonadError +import sttp.ws.WebSocket +import sttp.ws.WebSocketClosed +import sttp.ws.WebSocketFrame + +final class KyoBodyFromHttpClient extends BodyFromHttpClient[KyoSttpMonad.M, KyoStreams, KyoStreams.BinaryStream]: + override val streams: KyoStreams = KyoStreams + override given monad: MonadError[KyoSttpMonad.M] = KyoSttpMonad + + override protected def bodyFromResponseAs + : BodyFromResponseAs[KyoSttpMonad.M, KyoStreams.BinaryStream, WebSocket[KyoSttpMonad.M], KyoStreams.BinaryStream] = + new BodyFromResponseAs[KyoSttpMonad.M, KyoStreams.BinaryStream, WebSocket[KyoSttpMonad.M], KyoStreams.BinaryStream]: + override protected def withReplayableBody( + response: KyoStreams.BinaryStream, + replayableBody: Either[Array[Byte], SttpFile] + ): KyoSttpMonad.M[KyoStreams.BinaryStream] = + replayableBody match + case Left(byteArray) => Stream.init(Chunk.from(byteArray)) + case Right(file) => IO(Path.fromJavaPath(file.toPath)).map(_.readBytesStream) + override protected def regularIgnore(response: KyoStreams.BinaryStream): KyoSttpMonad.M[Unit] = + Resource.run(response.runDiscard) + + override protected def regularAsByteArray( + response: KyoStreams.BinaryStream + ): KyoSttpMonad.M[Array[Byte]] = + Resource.run(response.run.map(_.toList.toArray)) + + override protected def regularAsFile( + response: KyoStreams.BinaryStream, + file: SttpFile + ): KyoSttpMonad.M[SttpFile] = IO(Path.fromJavaPath(file.toPath)) + .map(path => Resource.run(response.sink(path))) + .map(_ => file) + + override protected def regularAsStream( + response: KyoStreams.BinaryStream + ): KyoSttpMonad.M[(KyoStreams.BinaryStream, () => KyoSttpMonad.M[Unit])] = + IO(response, () => Resource.run(response.runDiscard)) + end regularAsStream + + override protected def handleWS[T]( + responseAs: WebSocketResponseAs[T, ?], + meta: ResponseMetadata, + ws: WebSocket[KyoSttpMonad.M] + ): KyoSttpMonad.M[T] = bodyFromWs(responseAs, ws, meta) + + override protected def cleanupWhenNotAWebSocket( + response: KyoStreams.BinaryStream, + e: NotAWebSocketException + ): KyoSttpMonad.M[Unit] = Resource.run(response.runDiscard) + + override protected def cleanupWhenGotWebSocket( + response: WebSocket[KyoSttpMonad.M], + e: GotAWebSocketException + ): KyoSttpMonad.M[Unit] = response.close() + end bodyFromResponseAs + + override def compileWebSocketPipe( + ws: WebSocket[KyoSttpMonad.M], + pipe: KyoStreams.Pipe[WebSocketFrame.Data[?], WebSocketFrame] + ): KyoSttpMonad.M[Unit] = + def receiveFrame: Result[WebSocketClosed, WebSocketFrame] < Async = + Abort.run(Abort.catching[WebSocketClosed](ws.receive())) + + def emitFromWebSocket: Ack < (Emit[Chunk[WebSocketFrame.Data[?]]] & Async) = + Loop[Unit, Ack, Emit[Chunk[WebSocketFrame.Data[?]]] & Async](()) { _ => + receiveFrame.map { + case Success(WebSocketFrame.Close(_, _)) => Loop.done[Unit, Ack](Ack.Stop) + case Success(WebSocketFrame.Ping(payload)) => + ws.send(WebSocketFrame.Pong(payload)).andThen(Loop.continue[Ack]) + case Success(WebSocketFrame.Pong(_)) => Loop.continue[Ack] + case Success(in: WebSocketFrame.Data[?]) => Emit.andMap(Chunk(in)) { + case Ack.Stop => Loop.done[Unit, Ack](Ack.Stop) + case _ => Loop.continue[Ack] + } + case _ => Loop.done[Unit, Ack](Ack.Stop) + } + } + + val pipeComputation: Unit < (Async & Resource) = pipe(Stream(emitFromWebSocket)) + .runForeach(dataFrame => ws.send(dataFrame)) + .andThen(Resource.ensure(ws.close())) + + Resource.run(pipeComputation) + end compileWebSocketPipe + +end KyoBodyFromHttpClient diff --git a/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyToHttpClient.scala b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyToHttpClient.scala new file mode 100644 index 000000000..a9a38c8c8 --- /dev/null +++ b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoBodyToHttpClient.scala @@ -0,0 +1,25 @@ +package sttp.client3 + +import java.net.http.HttpRequest.BodyPublisher +import java.net.http.HttpRequest.BodyPublishers +import java.nio.ByteBuffer +import kyo.* +import kyo.capabilities.KyoStreams +import kyo.internal.KyoSttpMonad +import kyo.internal.KyoSttpMonad.* +import kyo.interop.Adapters +import sttp.client3.internal.httpclient.BodyToHttpClient +import sttp.monad.MonadError + +final class KyoBodyToHttpClient extends BodyToHttpClient[KyoSttpMonad.M, KyoStreams]: + override val streams: KyoStreams = KyoStreams + override given monad: MonadError[KyoSttpMonad.M] = KyoSttpMonad + override def streamToPublisher(stream: KyoStreams.BinaryStream): KyoSttpMonad.M[BodyPublisher] = + val byteBufferStream = stream.mapChunk { chunk => + Chunk(ByteBuffer.wrap(chunk.toArray)) + } + Adapters.streamToPublisher(byteBufferStream).map { publisher => + BodyPublishers.fromPublisher(publisher) + } + end streamToPublisher +end KyoBodyToHttpClient diff --git a/kyo-sttp/jvm/src/test/scala/kyo/RequestsLiveTest.scala b/kyo-sttp/jvm/src/test/scala/kyo/RequestsLiveTest.scala index 31bfe6b03..4283e4ac5 100644 --- a/kyo-sttp/jvm/src/test/scala/kyo/RequestsLiveTest.scala +++ b/kyo-sttp/jvm/src/test/scala/kyo/RequestsLiveTest.scala @@ -1,5 +1,7 @@ package kyo +import kyo.capabilities.KyoStreams +import scala.concurrent.duration import scala.util.* import sttp.client3.* @@ -7,19 +9,19 @@ class RequestsLiveTest extends Test: "requests" - { "live" - { - "success" in run { + "success" in runJVM { for port <- startTestServer("/ping", Success("pong")) r <- Requests(_.get(uri"http://localhost:$port/ping")) yield assert(r == "pong") } - "failure" in run { + "failure" in runJVM { for port <- startTestServer("/ping", Failure(new Exception)) r <- Abort.run(Requests(_.get(uri"http://localhost:$port/ping"))) yield assert(r.isFail) } - "race" in run { + "race" in runJVM { val n = 1000 for port <- startTestServer("/ping", Success("pong")) @@ -33,7 +35,7 @@ class RequestsLiveTest extends Test: private def startTestServer( endpointPath: String, response: Try[String], - port: Int = 8000 + port: Int = 54323 ): Int < (IO & Resource) = IO { diff --git a/kyo-sttp/shared/src/main/scala/kyo/Requests.scala b/kyo-sttp/shared/src/main/scala/kyo/Requests.scala index 44f737551..cf631f574 100644 --- a/kyo-sttp/shared/src/main/scala/kyo/Requests.scala +++ b/kyo-sttp/shared/src/main/scala/kyo/Requests.scala @@ -1,5 +1,7 @@ package kyo +import kyo.capabilities.KyoStreams +import sttp.capabilities.WebSockets import sttp.client3.* /** Represents a failed HTTP request. @@ -36,7 +38,7 @@ object Requests: * @return * The response wrapped in an effect */ - def send[A: Flat](r: Request[A, Any]): Response[A] < (Async & Abort[FailedRequest]) + def send[A: Flat](r: Request[A, KyoStreams & WebSockets]): Response[A] < (Async & Abort[FailedRequest]) /** Wraps the Backend with a meter * @@ -47,7 +49,7 @@ object Requests: */ def withMeter(m: Meter)(using Frame): Backend = new Backend: - def send[A: Flat](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, KyoStreams & WebSockets]) = Abort.run(m.run(self.send(r))).map(r => Abort.get(r.mapFail(FailedRequest(_)))) end Backend @@ -69,7 +71,7 @@ object Requests: local.let(b)(v) /** Type alias for a basic request */ - type BasicRequest = RequestT[Empty, Either[FailedRequest, String], Any] + type BasicRequest = RequestT[Empty, Either[FailedRequest, String], KyoStreams & WebSockets] /** A basic request with error handling */ val basicRequest: BasicRequest = sttp.client3.basicRequest.mapResponse { @@ -88,7 +90,9 @@ object Requests: * @return * The response body wrapped in an effect */ - def apply[E, A](f: BasicRequest => Request[Either[E, A], Any])(using Frame): A < (Async & Abort[FailedRequest | E]) = + def apply[E, A](f: BasicRequest => Request[Either[E, A], KyoStreams & WebSockets])(using + Frame + ): A < (Async & Abort[FailedRequest | E]) = request(f(basicRequest)) /** Sends an HTTP request @@ -102,7 +106,7 @@ object Requests: * @return * The response body wrapped in an effect */ - def request[E, A](req: Request[Either[E, A], Any])(using Frame): A < (Async & Abort[FailedRequest | E]) = + def request[E, A](req: Request[Either[E, A], KyoStreams & WebSockets])(using Frame): A < (Async & Abort[FailedRequest | E]) = local.use(_.send(req)).map { r => Abort.get(r.body) } diff --git a/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala b/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala index f0fd6c80b..3852899f1 100644 --- a/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala +++ b/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala @@ -1,13 +1,15 @@ package kyo +import kyo.capabilities.KyoStreams import scala.util.* +import sttp.capabilities.WebSockets import sttp.client3.* class RequestsTest extends Test: class TestBackend extends Requests.Backend: var calls = 0 - def send[A: Flat](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, KyoStreams & WebSockets]) = calls += 1 Response.ok(Right("mocked")).asInstanceOf[Response[A]] end TestBackend