diff --git a/build.sbt b/build.sbt index 36490ac34b..ebaa132997 100644 --- a/build.sbt +++ b/build.sbt @@ -124,6 +124,6 @@ lazy val zhttpTest = (project in file("zio-http-test")) lazy val example = (project in file("./example")) .settings(stdSettings("example")) .settings(publishSetting(false)) - .settings(runSettings("example.FileStreaming")) + .settings(runSettings("example.Main")) .settings(libraryDependencies ++= Seq(`jwt-core`)) .dependsOn(zhttp) diff --git a/zio-http/src/main/scala/zhttp/service/HttpRuntime.scala b/zio-http/src/main/scala/zhttp/service/HttpRuntime.scala index 05b67735e2..2d3e0888ab 100644 --- a/zio-http/src/main/scala/zhttp/service/HttpRuntime.scala +++ b/zio-http/src/main/scala/zhttp/service/HttpRuntime.scala @@ -1,9 +1,9 @@ package zhttp.service import io.netty.channel.{ChannelHandlerContext, EventLoopGroup => JEventLoopGroup} -import io.netty.util.concurrent.{EventExecutor, Future} +import io.netty.util.concurrent.{EventExecutor, Future, GenericFutureListener} +import zio._ import zio.internal.Executor -import zio.{Exit, Runtime, URIO, ZIO} import scala.collection.mutable import scala.concurrent.{ExecutionContext => JExecutionContext} @@ -14,16 +14,23 @@ import scala.jdk.CollectionConverters._ * cancel the execution when the channel closes. */ final class HttpRuntime[+R](strategy: HttpRuntime.Strategy[R]) { - def unsafeRun(ctx: ChannelHandlerContext)(program: ZIO[R, Throwable, Any]): Unit = { + val rtm = strategy.getRuntime(ctx) + + // Close the connection if the program fails + // When connection closes, interrupt the program + rtm .unsafeRunAsync(for { fiber <- program.fork - _ <- ZIO.effect { - ctx.channel().closeFuture.addListener((_: Future[_ <: Void]) => rtm.unsafeRunAsync_(fiber.interrupt): Unit) + close <- UIO { + val close = closeListener(rtm, fiber) + ctx.channel().closeFuture.addListener(close) + close } _ <- fiber.join + _ <- UIO(ctx.channel().closeFuture().removeListener(close)) } yield ()) { case Exit.Success(_) => () case Exit.Failure(cause) => @@ -31,18 +38,39 @@ final class HttpRuntime[+R](strategy: HttpRuntime.Strategy[R]) { case None => () case Some(_) => System.err.println(cause.prettyPrint) } - ctx.close() + if (ctx.channel().isOpen) ctx.close() } } + + private def closeListener(rtm: Runtime[Any], fiber: Fiber.Runtime[_, _]): GenericFutureListener[Future[_ >: Void]] = + (_: Future[_ >: Void]) => rtm.unsafeRunAsync_(fiber.interrupt): Unit } object HttpRuntime { + def dedicated[R](group: JEventLoopGroup): URIO[R, HttpRuntime[R]] = + Strategy.dedicated(group).map(runtime => new HttpRuntime[R](runtime)) + + def default[R]: URIO[R, HttpRuntime[R]] = + Strategy.default().map(runtime => new HttpRuntime[R](runtime)) + + def sticky[R](group: JEventLoopGroup): URIO[R, HttpRuntime[R]] = + Strategy.sticky(group).map(runtime => new HttpRuntime[R](runtime)) + sealed trait Strategy[R] { def getRuntime(ctx: ChannelHandlerContext): Runtime[R] } object Strategy { + def dedicated[R](group: JEventLoopGroup): ZIO[R, Nothing, Strategy[R]] = + ZIO.runtime[R].map(runtime => Dedicated(runtime, group)) + + def default[R](): ZIO[R, Nothing, Strategy[R]] = + ZIO.runtime[R].map(runtime => Default(runtime)) + + def sticky[R](group: JEventLoopGroup): ZIO[R, Nothing, Strategy[R]] = + ZIO.runtime[R].map(runtime => Group(runtime, group)) + case class Default[R](runtime: Runtime[R]) extends Strategy[R] { override def getRuntime(ctx: ChannelHandlerContext): Runtime[R] = runtime } @@ -73,23 +101,5 @@ object HttpRuntime { override def getRuntime(ctx: ChannelHandlerContext): Runtime[R] = localRuntime.getOrElse(ctx.executor(), runtime) } - - def sticky[R](group: JEventLoopGroup): ZIO[R, Nothing, Strategy[R]] = - ZIO.runtime[R].map(runtime => Group(runtime, group)) - - def default[R](): ZIO[R, Nothing, Strategy[R]] = - ZIO.runtime[R].map(runtime => Default(runtime)) - - def dedicated[R](group: JEventLoopGroup): ZIO[R, Nothing, Strategy[R]] = - ZIO.runtime[R].map(runtime => Dedicated(runtime, group)) } - - def sticky[R](group: JEventLoopGroup): URIO[R, HttpRuntime[R]] = - Strategy.sticky(group).map(runtime => new HttpRuntime[R](runtime)) - - def dedicated[R](group: JEventLoopGroup): URIO[R, HttpRuntime[R]] = - Strategy.dedicated(group).map(runtime => new HttpRuntime[R](runtime)) - - def default[R]: URIO[R, HttpRuntime[R]] = - Strategy.default().map(runtime => new HttpRuntime[R](runtime)) } diff --git a/zio-http/src/main/scala/zhttp/service/Server.scala b/zio-http/src/main/scala/zhttp/service/Server.scala index cc2137a754..711f957734 100644 --- a/zio-http/src/main/scala/zhttp/service/Server.scala +++ b/zio-http/src/main/scala/zhttp/service/Server.scala @@ -215,7 +215,7 @@ object Server { for { channelFactory <- ZManaged.access[ServerChannelFactory](_.get) eventLoopGroup <- ZManaged.access[EventLoopGroup](_.get) - zExec <- HttpRuntime.default[R].toManaged_ + zExec <- HttpRuntime.sticky[R](eventLoopGroup).toManaged_ reqHandler = settings.app.compile(zExec, settings) respHandler = ServerResponseHandler(zExec, settings, ServerTimeGenerator.make) init = ServerChannelInitializer(zExec, settings, reqHandler, respHandler)