diff --git a/zio-http/src/main/scala/zhttp/service/Server.scala b/zio-http/src/main/scala/zhttp/service/Server.scala index dbed3eb7ab..302be0bc90 100644 --- a/zio-http/src/main/scala/zhttp/service/Server.scala +++ b/zio-http/src/main/scala/zhttp/service/Server.scala @@ -15,9 +15,6 @@ sealed trait Server[-R, +E] { self => import Server._ - def ++[R1 <: R, E1 >: E](other: Server[R1, E1]): Server[R1, E1] = - Concat(self, other) - private def settings[R1 <: R, E1 >: E](s: Config[R1, E1] = Config()): Config[R1, E1] = self match { case Concat(self, other) => other.settings(self.settings(s)) case LeakDetection(level) => s.copy(leakDetectionLevel = level) @@ -32,9 +29,12 @@ sealed trait Server[-R, +E] { self => case UnsafeChannelPipeline(init) => s.copy(channelInitializer = init) case RequestDecompression(enabled, strict) => s.copy(requestDecompression = (enabled, strict)) case ObjectAggregator(maxRequestSize) => s.copy(objectAggregator = maxRequestSize) - case UnsafeServerbootstrap(init) => s.copy(serverbootstrapInitializer = init) + case UnsafeServerBootstrap(init) => s.copy(serverBootstrapInitializer = init) } + def ++[R1 <: R, E1 >: E](other: Server[R1, E1]): Server[R1, E1] = + Concat(self, other) + def make(implicit ev: E <:< Throwable, ): ZManaged[R with EventLoopGroup with ServerChannelFactory, Throwable, Start] = @@ -51,9 +51,10 @@ sealed trait Server[-R, +E] { self => start.provideSomeLayer[R1](EventLoopGroup.auto(0) ++ ServerChannelFactory.auto) /** - * Creates a new server listening on the provided port. + * Creates a new server using a HttpServerExpectContinueHandler to send a 100 + * HttpResponse if necessary. */ - def withPort(port: Int): Server[R, E] = Concat(self, Server.Address(new InetSocketAddress(port))) + def withAcceptContinue(enable: Boolean): Server[R, E] = Concat(self, Server.AcceptContinue(enable)) /** * Creates a new server listening on the provided hostname and port. @@ -73,21 +74,17 @@ sealed trait Server[-R, +E] { self => def withBinding(inetSocketAddress: InetSocketAddress): Server[R, E] = Concat(self, Server.Address(inetSocketAddress)) /** - * Creates a new server with the errorHandler provided. - */ - def withError[R1](errorHandler: Throwable => ZIO[R1, Nothing, Unit]): Server[R with R1, E] = - Concat(self, Server.Error(errorHandler)) - - /** - * Creates a new server with the following ssl options. + * Creates a new server with FlushConsolidationHandler to control the flush + * operations in a more efficient way if enabled (@see FlushConsolidationHandler). */ - def withSsl(sslOptions: ServerSSLOptions): Server[R, E] = Concat(self, Server.Ssl(sslOptions)) + def withConsolidateFlush(enable: Boolean): Server[R, E] = Concat(self, ConsolidateFlush(enable)) /** - * Creates a new server using a HttpServerExpectContinueHandler to send a 100 - * HttpResponse if necessary. + * Creates a new server with the errorHandler provided. */ - def withAcceptContinue(enable: Boolean): Server[R, E] = Concat(self, Server.AcceptContinue(enable)) + def withError[R1](errorHandler: Throwable => ZIO[R1, Nothing, Unit]): Server[R with R1, E] = + Concat(self, Server.Error(errorHandler)) /** * Creates a new server using netty FlowControlHandler if enable (@see */ def withFlowControl(enable: Boolean): Server[R, E] = Concat(self, Server.FlowControl(enable)) + /** + * Creates a new server with netty's HttpServerKeepAliveHandler to close + * persistent connections when enable is true (@see HttpServerKeepAliveHandler). + */ + def withKeepAlive(enable: Boolean): Server[R, E] = Concat(self, KeepAlive(enable)) + /** * Creates a new server with the leak detection level provided (@see ResourceLeakDetector.Level). @@ -102,18 +106,29 @@ sealed trait Server[-R, +E] { self => def withLeakDetection(level: LeakDetectionLevel): Server[R, E] = Concat(self, LeakDetection(level)) /** - * Creates a new server with netty's HttpServerKeepAliveHandler to close - * persistent connections when enable is true (@see HttpServerKeepAliveHandler). + * Creates a new server with HttpObjectAggregator with the specified max size + * of the aggregated content. */ - def withKeepAlive(enable: Boolean): Server[R, E] = Concat(self, KeepAlive(enable)) + def withObjectAggregator(maxRequestSize: Int = Int.MaxValue): Server[R, E] = + Concat(self, ObjectAggregator(maxRequestSize)) /** - * Creates a new server with FlushConsolidationHandler to control the flush - * operations in a more efficient way if enabled (@see FlushConsolidationHandler). + * Creates a new server listening on the provided port. */ - def withConsolidateFlush(enable: Boolean): Server[R, E] = Concat(self, ConsolidateFlush(enable)) + def withPort(port: Int): Server[R, E] = Concat(self, Server.Address(new InetSocketAddress(port))) + + /** + * Creates a new server with netty's HttpContentDecompressor to decompress + * Http requests (@see HttpContentDecompressor). + */ + def withRequestDecompression(enabled: Boolean, strict: Boolean): Server[R, E] = + Concat(self, RequestDecompression(enabled, strict)) + + /** + * Creates a new server with the following ssl options. + */ + def withSsl(sslOptions: ServerSSLOptions): Server[R, E] = Concat(self, Server.Ssl(sslOptions)) /** * Creates a new server by passing a function that modifies the channel @@ -131,89 +146,64 @@ sealed trait Server[-R, +E] { self => * bootstrap is generally not advised unless you know what you are doing. */ def withUnsafeServerBootstrap(unsafeServerbootstrap: ServerBootstrap => Unit): Server[R, E] = - Concat(self, UnsafeServerbootstrap(unsafeServerbootstrap)) - - /** - * Creates a new server with netty's HttpContentDecompressor to decompress - * Http requests (@see HttpContentDecompressor). - */ - def withRequestDecompression(enabled: Boolean, strict: Boolean): Server[R, E] = - Concat(self, RequestDecompression(enabled, strict)) - - /** - * Creates a new server with HttpObjectAggregator with the specified max size - * of the aggregated content. - */ - def withObjectAggregator(maxRequestSize: Int = Int.MaxValue): Server[R, E] = - Concat(self, ObjectAggregator(maxRequestSize)) + Concat(self, UnsafeServerBootstrap(unsafeServerbootstrap)) } object Server { - private[zhttp] final case class Config[-R, +E]( - leakDetectionLevel: LeakDetectionLevel = LeakDetectionLevel.SIMPLE, - error: Option[Throwable => ZIO[R, Nothing, Unit]] = None, - sslOption: ServerSSLOptions = null, + val disableFlowControl: UServer = Server.FlowControl(false) + val disableLeakDetection: UServer = LeakDetection(LeakDetectionLevel.DISABLED) + val simpleLeakDetection: UServer = LeakDetection(LeakDetectionLevel.SIMPLE) + val advancedLeakDetection: UServer = LeakDetection(LeakDetectionLevel.ADVANCED) + val paranoidLeakDetection: UServer = LeakDetection(LeakDetectionLevel.PARANOID) + val disableKeepAlive: UServer = Server.KeepAlive(false) + val consolidateFlush: UServer = ConsolidateFlush(true) - // TODO: move app out of settings - app: HttpApp[R, E] = Http.empty, - address: InetSocketAddress = new InetSocketAddress(8080), - acceptContinue: Boolean = false, - keepAlive: Boolean = true, - consolidateFlush: Boolean = false, - flowControl: Boolean = true, - channelInitializer: ChannelPipeline => Unit = null, - requestDecompression: (Boolean, Boolean) = (false, false), - objectAggregator: Int = -1, - serverbootstrapInitializer: ServerBootstrap => Unit = null, - ) { - def useAggregator: Boolean = objectAggregator >= 0 - } + def acceptContinue: UServer = Server.AcceptContinue(true) + + def app[R, E](http: HttpApp[R, E]): Server[R, E] = Server.App(http) /** - * Holds server start information. + * Creates a server from a http app. */ - final case class Start(port: Int = 0) + def apply[R, E](http: HttpApp[R, E]): Server[R, E] = Server.App(http) + + def bind(port: Int): UServer = Server.Address(new InetSocketAddress(port)) + + def bind(hostname: String, port: Int): UServer = Server.Address(new InetSocketAddress(hostname, port)) + + def bind(inetAddress: InetAddress, port: Int): UServer = Server.Address(new InetSocketAddress(inetAddress, port)) - private final case class Concat[R, E](self: Server[R, E], other: Server[R, E]) extends Server[R, E] - private final case class LeakDetection(level: LeakDetectionLevel) extends UServer - private final case class Error[R](errorHandler: Throwable => ZIO[R, Nothing, Unit]) extends Server[R, Nothing] - private final case class Ssl(sslOptions: ServerSSLOptions) extends UServer - private final case class Address(address: InetSocketAddress) extends UServer - private final case class App[R, E](app: HttpApp[R, E]) extends Server[R, E] - private final case class KeepAlive(enabled: Boolean) extends Server[Any, Nothing] - private final case class ConsolidateFlush(enabled: Boolean) extends Server[Any, Nothing] - private final case class AcceptContinue(enabled: Boolean) extends UServer - private final case class FlowControl(enabled: Boolean) extends UServer - private final case class UnsafeChannelPipeline(init: ChannelPipeline => Unit) extends UServer - private final case class RequestDecompression(enabled: Boolean, strict: Boolean) extends UServer - private final case class ObjectAggregator(maxRequestSize: Int) extends UServer - private final case class UnsafeServerbootstrap(init: ServerBootstrap => Unit) extends UServer - - def app[R, E](http: HttpApp[R, E]): Server[R, E] = Server.App(http) - def port(port: Int): UServer = Server.Address(new InetSocketAddress(port)) - def bind(port: Int): UServer = Server.Address(new InetSocketAddress(port)) - def bind(hostname: String, port: Int): UServer = Server.Address(new InetSocketAddress(hostname, port)) - def bind(inetAddress: InetAddress, port: Int): UServer = Server.Address(new InetSocketAddress(inetAddress, port)) def bind(inetSocketAddress: InetSocketAddress): UServer = Server.Address(inetSocketAddress) + + def enableObjectAggregator(maxRequestSize: Int = Int.MaxValue): UServer = ObjectAggregator(maxRequestSize) + def error[R](errorHandler: Throwable => ZIO[R, Nothing, Unit]): Server[R, Nothing] = Server.Error(errorHandler) - def ssl(sslOptions: ServerSSLOptions): UServer = Server.Ssl(sslOptions) - def acceptContinue: UServer = Server.AcceptContinue(true) + + def make[R]( + server: Server[R, Throwable], + ): ZManaged[R with EventLoopGroup with ServerChannelFactory, Throwable, Start] = { + val settings = server.settings() + for { + channelFactory <- ZManaged.access[ServerChannelFactory](_.get) + eventLoopGroup <- ZManaged.access[EventLoopGroup](_.get) + zExec <- HttpRuntime.sticky[R](eventLoopGroup).toManaged_ + handler = new ServerResponseWriter(zExec, settings, ServerTime.make) + reqHandler = settings.app.compile(zExec, settings, handler) + init = ServerChannelInitializer(zExec, settings, reqHandler) + serverBootstrap = new ServerBootstrap().channelFactory(channelFactory).group(eventLoopGroup) + chf <- ZManaged.effect(serverBootstrap.childHandler(init).bind(settings.address)) + _ <- ChannelFuture.asManaged(chf) + port <- ZManaged.effect(chf.channel().localAddress().asInstanceOf[InetSocketAddress].getPort) + } yield { + ResourceLeakDetector.setLevel(settings.leakDetectionLevel.jResourceLeakDetectionLevel) + Start(port) + } + } + + def port(port: Int): UServer = Server.Address(new InetSocketAddress(port)) + def requestDecompression(strict: Boolean): UServer = Server.RequestDecompression(enabled = true, strict = strict) - val disableFlowControl: UServer = Server.FlowControl(false) - val disableLeakDetection: UServer = LeakDetection(LeakDetectionLevel.DISABLED) - val simpleLeakDetection: UServer = LeakDetection(LeakDetectionLevel.SIMPLE) - val advancedLeakDetection: UServer = LeakDetection(LeakDetectionLevel.ADVANCED) - val paranoidLeakDetection: UServer = LeakDetection(LeakDetectionLevel.PARANOID) - val disableKeepAlive: UServer = Server.KeepAlive(false) - val consolidateFlush: UServer = ConsolidateFlush(true) - def unsafePipeline(pipeline: ChannelPipeline => Unit): UServer = UnsafeChannelPipeline(pipeline) - def enableObjectAggregator(maxRequestSize: Int = Int.MaxValue): UServer = ObjectAggregator(maxRequestSize) - def unsafeServerbootstrap(serverBootstrap: ServerBootstrap => Unit): UServer = UnsafeServerbootstrap(serverBootstrap) - /** - * Creates a server from a http app. - */ - def apply[R, E](http: HttpApp[R, E]): Server[R, E] = Server.App(http) + def ssl(sslOptions: ServerSSLOptions): UServer = Server.Ssl(sslOptions) /** * Launches the app on the provided port. @@ -251,24 +241,60 @@ object Server { .useForever .provideSomeLayer[R](EventLoopGroup.auto(0) ++ ServerChannelFactory.auto) - def make[R]( - server: Server[R, Throwable], - ): ZManaged[R with EventLoopGroup with ServerChannelFactory, Throwable, Start] = { - val settings = server.settings() - for { - channelFactory <- ZManaged.access[ServerChannelFactory](_.get) - eventLoopGroup <- ZManaged.access[EventLoopGroup](_.get) - zExec <- HttpRuntime.sticky[R](eventLoopGroup).toManaged_ - handler = new ServerResponseWriter(zExec, settings, ServerTime.make) - reqHandler = settings.app.compile(zExec, settings, handler) - init = ServerChannelInitializer(zExec, settings, reqHandler) - serverBootstrap = new ServerBootstrap().channelFactory(channelFactory).group(eventLoopGroup) - chf <- ZManaged.effect(serverBootstrap.childHandler(init).bind(settings.address)) - _ <- ChannelFuture.asManaged(chf) - port <- ZManaged.effect(chf.channel().localAddress().asInstanceOf[InetSocketAddress].getPort) - } yield { - ResourceLeakDetector.setLevel(settings.leakDetectionLevel.jResourceLeakDetectionLevel) - Start(port) - } + def unsafePipeline(pipeline: ChannelPipeline => Unit): UServer = UnsafeChannelPipeline(pipeline) + + def unsafeServerBootstrap(serverBootstrap: ServerBootstrap => Unit): UServer = UnsafeServerBootstrap(serverBootstrap) + + /** + * Holds server start information. + */ + final case class Start(port: Int = 0) + + private[zhttp] final case class Config[-R, +E]( + leakDetectionLevel: LeakDetectionLevel = LeakDetectionLevel.SIMPLE, + error: Option[Throwable => ZIO[R, Nothing, Unit]] = None, + sslOption: ServerSSLOptions = null, + + // TODO: move app out of settings + app: HttpApp[R, E] = Http.empty, + address: InetSocketAddress = new InetSocketAddress(8080), + acceptContinue: Boolean = false, + keepAlive: Boolean = true, + consolidateFlush: Boolean = false, + flowControl: Boolean = true, + channelInitializer: ChannelPipeline => Unit = null, + requestDecompression: (Boolean, Boolean) = (false, false), + objectAggregator: Int = -1, + serverBootstrapInitializer: ServerBootstrap => Unit = null, + ) { + def useAggregator: Boolean = objectAggregator >= 0 } + + private final case class Concat[R, E](self: Server[R, E], other: Server[R, E]) extends Server[R, E] + + private final case class LeakDetection(level: LeakDetectionLevel) extends UServer + + private final case class Error[R](errorHandler: Throwable => ZIO[R, Nothing, Unit]) extends Server[R, Nothing] + + private final case class Ssl(sslOptions: ServerSSLOptions) extends UServer + + private final case class Address(address: InetSocketAddress) extends UServer + + private final case class App[R, E](app: HttpApp[R, E]) extends Server[R, E] + + private final case class KeepAlive(enabled: Boolean) extends Server[Any, Nothing] + + private final case class ConsolidateFlush(enabled: Boolean) extends Server[Any, Nothing] + + private final case class AcceptContinue(enabled: Boolean) extends UServer + + private final case class FlowControl(enabled: Boolean) extends UServer + + private final case class UnsafeChannelPipeline(init: ChannelPipeline => Unit) extends UServer + + private final case class RequestDecompression(enabled: Boolean, strict: Boolean) extends UServer + + private final case class ObjectAggregator(maxRequestSize: Int) extends UServer + + private final case class UnsafeServerBootstrap(init: ServerBootstrap => Unit) extends UServer }