Skip to content

Commit

Permalink
[kyo-reactive-stream]: fix feedbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
HollandDM committed Dec 13, 2024
1 parent 9ab73ed commit 2a28269
Show file tree
Hide file tree
Showing 17 changed files with 1,019 additions and 646 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ jmh-result.json
*.jfr
*.json
*.gpg
test-output
test-output
.DS_Store
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package kyo.interop.flow

import java.util.concurrent.Flow.*
import kyo.*
import kyo.interop.flow.StreamSubscription.StreamFinishState
import kyo.kernel.Boundary
import scala.annotation.nowarn

abstract private[kyo] class StreamPublisher[V, Ctx](
stream: Stream[V, Ctx]
) extends Publisher[V]:

protected def bind(subscriber: Subscriber[? >: V]): Unit

override def subscribe(subscriber: Subscriber[? >: V]): Unit =
if isNull(subscriber) then
throw new NullPointerException("Subscriber must not be null.")
else
bind(subscriber)
end subscribe

end StreamPublisher

object StreamPublisher:

def apply[V, Ctx](
stream: Stream[V, Ctx],
capacity: Int = Int.MaxValue
)(
using
Boundary[Ctx, IO],
Frame,
Tag[Emit[Chunk[V]]],
Tag[Poll[Chunk[V]]]
): StreamPublisher[V, Ctx] < (Resource & IO & Ctx) =
inline def interruptPanic = Result.Panic(Fiber.Interrupted(scala.compiletime.summonInline[Frame]))

def discardSubscriber(subscriber: Subscriber[? >: V]): Unit =
subscriber.onSubscribe(new Subscription:
override def request(n: Long): Unit = ()
override def cancel(): Unit = ()
)
subscriber.onComplete()
end discardSubscriber

def consumeChannel(
channel: Channel[Subscriber[? >: V]],
supervisor: Fiber.Promise[Nothing, Unit]
): Unit < (Async & Ctx) =
Loop(()) { _ =>
Abort.recover[Closed](_ => Loop.done) {
for
subscriber <- channel.take
subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber))
fiber <- subscription.subscribe.andThen(subscription.consume)
_ <- supervisor.onComplete(_ => discard(fiber.interrupt(interruptPanic)))
yield Loop.continue(())
}
}

for
channel <-
Resource.acquireRelease(Channel.init[Subscriber[? >: V]](capacity))(
_.close.map(_.foreach(_.foreach(discardSubscriber(_))))
)
publisher <- IO.Unsafe {
new StreamPublisher[V, Ctx](stream):
override protected def bind(
subscriber: Subscriber[? >: V]
): Unit =
channel.unsafe.offer(subscriber) match
case Result.Success(true) => ()
case _ => discardSubscriber(subscriber)
}
supervisor <- Resource.acquireRelease(Fiber.Promise.init[Nothing, Unit])(_.interrupt.map(discard(_)))
_ <- Resource.acquireRelease(Async._run(consumeChannel(channel, supervisor)))(_.interrupt.map(discard(_)))
yield publisher
end for
end apply

object Unsafe:
@nowarn("msg=anonymous")
inline def apply[V, Ctx](
stream: Stream[V, Ctx],
subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit
)(
using
AllowUnsafe,
Frame,
Tag[Emit[Chunk[V]]],
Tag[Poll[Chunk[V]]]
): StreamPublisher[V, Ctx] =
new StreamPublisher[V, Ctx](stream):
override protected def bind(
subscriber: Subscriber[? >: V]
): Unit =
discard(StreamSubscription.Unsafe._subscribe(
stream,
subscriber
)(
subscribeCallback
))
end Unsafe
end StreamPublisher
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
package kyo.interop.flow

import StreamSubscriber.*
import java.util.concurrent.Flow.*
import kyo.*
import scala.annotation.tailrec

final private[kyo] class StreamSubscriber[V](
bufferSize: Int,
strategy: EmitStrategy
)(
using AllowUnsafe
) extends Subscriber[V]:

private enum UpstreamState derives CanEqual:
case Uninitialized extends UpstreamState
case WaitForRequest(subscription: Subscription, items: Chunk[V], remaining: Int) extends UpstreamState
case Finished(reason: Maybe[Throwable], leftOver: Chunk[V]) extends UpstreamState
end UpstreamState

private val state = AtomicRef.Unsafe.init(
UpstreamState.Uninitialized -> Maybe.empty[Fiber.Promise.Unsafe[Nothing, Unit]]
)

private inline def throwIfNull[A](b: A): Unit = if isNull(b) then throw new NullPointerException()

override def onSubscribe(subscription: Subscription): Unit =
throwIfNull(subscription)
@tailrec def handleSubscribe(): Unit =
val curState = state.get()
curState match
case (UpstreamState.Uninitialized, maybePromise) =>
val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty, 0) -> Absent
if state.compareAndSet(curState, nextState) then
maybePromise.foreach(_.completeDiscard(Result.success(())))
else
handleSubscribe()
end if
case other =>
if state.compareAndSet(curState, other) then
subscription.cancel()
else
handleSubscribe()
end match
end handleSubscribe
handleSubscribe()
end onSubscribe

override def onNext(item: V): Unit =
throwIfNull(item)
@tailrec def handleNext(): Unit =
val curState = state.get()
curState match
case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) =>
if (strategy == EmitStrategy.Eager) || (strategy == EmitStrategy.Buffer && remaining == 1) then
val nextState = UpstreamState.WaitForRequest(subscription, items.append(item), remaining - 1) -> Absent
if state.compareAndSet(curState, nextState) then
maybePromise.foreach(_.completeDiscard(Result.success(())))
else
handleNext()
end if
else
val nextState = UpstreamState.WaitForRequest(subscription, items.append(item), remaining - 1) -> maybePromise
if !state.compareAndSet(curState, nextState) then handleNext()
case other =>
if !state.compareAndSet(curState, other) then handleNext()
end match
end handleNext
handleNext()
end onNext

override def onError(throwable: Throwable): Unit =
throwIfNull(throwable)
@tailrec def handleError(): Unit =
val curState = state.get()
curState match
case (UpstreamState.WaitForRequest(_, items, _), maybePromise) =>
val nextState = UpstreamState.Finished(Maybe(throwable), items) -> Absent
if state.compareAndSet(curState, nextState) then
maybePromise.foreach(_.completeDiscard(Result.success(())))
else
handleError()
end if
case other =>
if !state.compareAndSet(curState, other) then handleError()
end match
end handleError
handleError()
end onError

override def onComplete(): Unit =
@tailrec def handleComplete(): Unit =
val curState = state.get()
curState match
case (UpstreamState.WaitForRequest(_, items, _), maybePromise) =>
val nextState = UpstreamState.Finished(Absent, items) -> Absent
if state.compareAndSet(curState, nextState) then
maybePromise.foreach(_.completeDiscard(Result.success(())))
else
handleComplete()
end if
case other =>
if !state.compareAndSet(curState, other) then handleComplete()
end match
end handleComplete
handleComplete()
end onComplete

private[interop] def await(using Frame): Boolean < Async =
@tailrec def handleAwait(): Boolean < Async =
val curState = state.get()
curState match
case (UpstreamState.Uninitialized, Absent) =>
val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]()
val nextState = UpstreamState.Uninitialized -> Present(promise)
if state.compareAndSet(curState, nextState) then
promise.safe.use(_ => false)
else
handleAwait()
end if
case s @ (UpstreamState.Uninitialized, Present(promise)) =>
if state.compareAndSet(curState, s) then
promise.safe.use(_ => false)
else
handleAwait()
case s @ (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) =>
if items.isEmpty then
if remaining == 0 then
val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty[V], 0) -> Absent
if state.compareAndSet(curState, nextState) then
IO(true)
else
handleAwait()
end if
else
val promise = Fiber.Promise.Unsafe.init[Nothing, Unit]()
val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty[V], remaining) -> Present(promise)
if state.compareAndSet(curState, nextState) then
promise.safe.use(_ => false)
else
handleAwait()
end if
else
if state.compareAndSet(curState, s) then
IO(false)
else
handleAwait()
case other =>
if state.compareAndSet(curState, other) then
IO(false)
else
handleAwait()
end match
end handleAwait
IO(handleAwait())
end await

private[interop] def request(using Frame): Long < IO =
@tailrec def handleRequest(): Long < IO =
val curState = state.get()
curState match
case (UpstreamState.WaitForRequest(subscription, items, remaining), maybePromise) =>
val nextState = UpstreamState.WaitForRequest(subscription, items, remaining + bufferSize) -> maybePromise
if state.compareAndSet(curState, nextState) then
IO(subscription.request(bufferSize)).andThen(bufferSize.toLong)
else
handleRequest()
end if
case other =>
if state.compareAndSet(curState, other) then
IO(0L)
else
handleRequest()
end match
end handleRequest
IO(handleRequest())
end request

private[interop] def poll(using Frame): Result[Throwable | SubscriberDone, Chunk[V]] < IO =
@tailrec def handlePoll(): Result[Throwable | SubscriberDone, Chunk[V]] < IO =
val curState = state.get()
curState match
case (UpstreamState.WaitForRequest(subscription, items, remaining), Absent) =>
val nextState = UpstreamState.WaitForRequest(subscription, Chunk.empty, remaining) -> Absent
if state.compareAndSet(curState, nextState) then
IO(Result.success(items))
else
handlePoll()
end if
case s @ (UpstreamState.Finished(reason, leftOver), Absent) =>
if leftOver.isEmpty then
if state.compareAndSet(curState, s) then
IO {
reason match
case Present(error) => Result.fail(error)
case Absent => Result.fail(SubscriberDone)
}
else
handlePoll()
else
val nextState = UpstreamState.Finished(reason, Chunk.empty) -> Absent
if state.compareAndSet(curState, nextState) then
IO(Result.success(leftOver))
else
handlePoll()
end if
end if
case other =>
if state.compareAndSet(curState, other) then
IO(Result.success(Chunk.empty))
else
handlePoll()
end match
end handlePoll
IO(handlePoll())
end poll

private[interop] def interupt(using Frame): Unit < IO =
@tailrec def handleInterupt(): Unit < IO =
val curState = state.get()
curState match
case (UpstreamState.Uninitialized, maybePromise) =>
val nextState = UpstreamState.Finished(Absent, Chunk.empty) -> Absent
if state.compareAndSet(curState, nextState) then
IO(maybePromise.foreach(_.completeDiscard(Result.success(()))))
else
handleInterupt()
end if
case (UpstreamState.WaitForRequest(subscription, _, _), Absent) =>
val nextState = UpstreamState.Finished(Absent, Chunk.empty) -> Absent
if state.compareAndSet(curState, nextState) then
IO(subscription.cancel())
else
handleInterupt()
end if
case other =>
if state.compareAndSet(curState, other) then
IO.unit
else
handleInterupt()
end match
end handleInterupt
IO(handleInterupt())
end interupt

private[interop] def emit(using Frame, Tag[Emit[Chunk[V]]]): Ack < (Emit[Chunk[V]] & Async) =
Emit.andMap(Chunk.empty) { ack =>
Loop(ack) {
case Ack.Stop => interupt.andThen(Loop.done(Ack.Stop))
case Ack.Continue(_) =>
await
.map {
case true => request.andThen(Ack.Continue())
case false => poll.map {
case Result.Success(nextChunk) => Emit(nextChunk)
case Result.Error(e: Throwable) => Abort.panic(e)
case _ => Ack.Stop
}
}
.map(Loop.continue(_))
}
}

def stream(using Frame, Tag[Emit[Chunk[V]]]): Stream[V, Async] = Stream(emit)

end StreamSubscriber

object StreamSubscriber:

abstract private[flow] class SubscriberDone
private[flow] case object SubscriberDone extends SubscriberDone

enum EmitStrategy derives CanEqual:
case Eager // Emit value to downstream stream as soon as the subscriber receives one
case Buffer // Subscriber buffers received values and emit them only when reaching bufferSize
end EmitStrategy

def apply[V](
bufferSize: Int,
strategy: EmitStrategy = EmitStrategy.Eager
)(
using Frame
): StreamSubscriber[V] < IO = IO.Unsafe(new StreamSubscriber(bufferSize, strategy))
end StreamSubscriber
Loading

0 comments on commit 2a28269

Please sign in to comment.