Skip to content

Commit

Permalink
Use intermediate Channel for accepting sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
armanbilge committed Nov 22, 2022
1 parent 54d0189 commit 9e6a59b
Showing 1 changed file with 64 additions and 41 deletions.
105 changes: 64 additions & 41 deletions io/jvm-native/src/main/scala/fs2/io/net/SocketGroupPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ import java.nio.channels.{
}
import java.nio.channels.AsynchronousChannelGroup
import cats.syntax.all._
import cats.effect.syntax.all._
import cats.effect.kernel.{Async, Resource}
import com.comcast.ip4s.{Host, IpAddress, Port, SocketAddress}
import fs2.concurrent.Channel

private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type =>
private[fs2] def unsafe[F[_]: Async](channelGroup: AsynchronousChannelGroup): SocketGroup[F] =
Expand Down Expand Up @@ -84,73 +86,94 @@ private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type =>
options: List[SocketOption]
): Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F]])] = {

val setup: Resource[F, AsynchronousServerSocketChannel] =
val setup: Resource[
F,
(AsynchronousServerSocketChannel, Channel[F, Either[Throwable, AsynchronousSocketChannel]])
] =
Resource.eval(address.traverse(_.resolve[F])).flatMap { addr =>
Resource
.make(
Async[F].delay(
AsynchronousServerSocketChannel.open(channelGroup)
)
)(sch => Async[F].delay(if (sch.isOpen) sch.close()))
.evalTap(ch =>
.evalTap { sch =>
Async[F].delay(
ch.bind(
sch.bind(
new InetSocketAddress(
addr.map(_.toInetAddress).orNull,
port.map(_.value).getOrElse(0)
)
)
)
)
}

def acceptIncoming(
sch: AsynchronousServerSocketChannel
): Stream[F, Socket[F]] = {
def go: Stream[F, Socket[F]] = {
def acceptChannel: F[AsynchronousSocketChannel] =
Async[F].async[AsynchronousSocketChannel] { cb =>
Async[F]
.delay {
sch.accept(
null,
new CompletionHandler[AsynchronousSocketChannel, Void] {
def completed(ch: AsynchronousSocketChannel, attachment: Void): Unit =
cb(Right(ch))
def failed(rsn: Throwable, attachment: Void): Unit =
cb(Left(rsn))
}
.mproduct { sch =>
def acceptChannel: F[AsynchronousSocketChannel] =
Async[F].async[AsynchronousSocketChannel] { cb =>
Async[F]
.delay {
sch.accept(
null,
new CompletionHandler[AsynchronousSocketChannel, Void] {
def completed(ch: AsynchronousSocketChannel, attachment: Void): Unit =
cb(Right(ch))
def failed(rsn: Throwable, attachment: Void): Unit =
cb(Left(rsn))
}
)
}
)
.as(Some(Async[F].delay(sch.close())))
}
.as(Some(Async[F].delay(sch.close())))
}

def setOpts(ch: AsynchronousSocketChannel) =
Async[F].delay {
options.foreach(o => ch.setOption(o.key, o.value))
Resource
.make(Channel.synchronous[F, Either[Throwable, AsynchronousSocketChannel]]) {
accepted =>
accepted.close *>
accepted.stream
.foreach(_.traverse_(ch => Async[F].delay(ch.close())))
.compile
.drain
}
.flatTap { accepted =>
Stream
.repeatEval(acceptChannel.attempt)
.through(accepted.sendAll)
.compile
.drain
.background
}
}

Stream.eval(acceptChannel.attempt).flatMap {
}

def acceptIncoming(sch: AsynchronousServerSocketChannel)(
incoming: Stream[F, Either[Throwable, AsynchronousSocketChannel]]
): Stream[F, Socket[F]] = {
def setOpts(ch: AsynchronousSocketChannel) =
Async[F].delay {
options.foreach(o => ch.setOption(o.key, o.value))
}

incoming
.flatMap {
case Left(_) => Stream.empty[F]
case Right(accepted) =>
Stream.resource(Socket.forAsync(accepted).evalTap(_ => setOpts(accepted)))
} ++ go
}

go.handleErrorWith {
case err: AsynchronousCloseException =>
Stream.eval(Async[F].delay(sch.isOpen)).flatMap { isOpen =>
if (isOpen) Stream.raiseError[F](err)
else Stream.empty
}
case err => Stream.raiseError[F](err)
}
}
.handleErrorWith {
case err: AsynchronousCloseException =>
Stream.eval(Async[F].delay(sch.isOpen)).flatMap { isOpen =>
if (isOpen) Stream.raiseError[F](err)
else Stream.empty
}
case err => Stream.raiseError[F](err)
}
}

setup.map { sch =>
setup.map { case (sch, incoming) =>
val jLocalAddress = sch.getLocalAddress.asInstanceOf[java.net.InetSocketAddress]
val localAddress = SocketAddress.fromInetSocketAddress(jLocalAddress)
(localAddress, acceptIncoming(sch))
(localAddress, incoming.stream.through(acceptIncoming(sch)(_)))
}
}
}
Expand Down

0 comments on commit 9e6a59b

Please sign in to comment.