From 8e4f6bf3ba8f11d3eb081714330d131d216f4299 Mon Sep 17 00:00:00 2001 From: Amit Kumar Singh Date: Tue, 1 Mar 2022 20:11:49 +0530 Subject: [PATCH] Feature: Request Streaming (#1048) * introduce `Incoming` and `Outgoing` inHttpData * streaming support * benchmark disable objectAggregator * cleanup * refactor * cleanup + PR comments * cleanup + PR comments * cleanup + PR comments * refactor: rename variable * memory leak * refactor: Handler now extends ChannelInboundHandlerAdapter * refactor: remove unused methods from UnsafeChannel * remove bodyAsCharSequenceStream operator * refactor: remove unnecessary methods on HttpData * refactor: re-implement `bodyAsStream` * refactor: remove unsafe modification of pipeline from HttpData * refactor: rename HttpData types * fix 2.12 build * refactor: remove type param * PR comment * PR comment * refaector: simplify releaseRequest * refactor: reorder methods in ServerResponseHandler * refactor: make methods final * refactor: rename HttpData traits * add `bodyAsByteArray` and derive `body` and `bodyAsString` from it. * add test: should throw error for HttpData.Incoming * Introduce `useAggregator` method on settings and use it everywhere * remove sharable from `ServerResponseHandler` * Update zio-http/src/main/scala/zhttp/http/Request.scala * refactor: remove unnecessary pattern matching * throw exception on unknown message type * simplify test * refactor: change order of ContentHandler. Move it before the RequestHandler * test: update test structure * refactor: move pattern match logic to WebSocketUpgrade * revert addBefore Change because of degrade in performance (#1089) * fix static server issue with streaming * take case of auto read if body is not used * autoRead when needed * Update zio-http/src/main/scala/zhttp/service/RequestBodyHandler.scala * Update zio-http/src/main/scala/zhttp/http/Response.scala * Update zio-http/src/main/scala/zhttp/http/Response.scala * remove test which is not used * Update zio-http/src/main/scala/zhttp/service/Handler.scala Co-authored-by: Shrey Mehta <36622672+smehta91@users.noreply.github.com> * Update zio-http/src/main/scala/zhttp/service/Handler.scala Co-authored-by: Shrey Mehta <36622672+smehta91@users.noreply.github.com> * style: fmt * exclude Head in 404 check Co-authored-by: Tushar Mathur Co-authored-by: Shrey Mehta <36622672+smehta91@users.noreply.github.com> --- .../src/main/scala/zhttp/http/HttpData.scala | 120 +++++++++++----- .../src/main/scala/zhttp/http/Request.scala | 27 +++- .../src/main/scala/zhttp/http/Response.scala | 16 ++- .../main/scala/zhttp/service/Handler.scala | 131 +++++++++++++----- .../zhttp/service/RequestBodyHandler.scala | 22 +++ .../src/main/scala/zhttp/service/Server.scala | 41 +++--- .../main/scala/zhttp/service/package.scala | 1 + .../server/ServerChannelInitializer.scala | 3 +- .../service/server/WebSocketUpgrade.scala | 24 +++- .../handlers/ServerResponseHandler.scala | 87 +++++++----- .../test/scala/zhttp/service/ServerSpec.scala | 46 ++++-- .../zhttp/service/StaticServerSpec.scala | 26 ++-- 12 files changed, 380 insertions(+), 164 deletions(-) create mode 100644 zio-http/src/main/scala/zhttp/service/RequestBodyHandler.scala diff --git a/zio-http/src/main/scala/zhttp/http/HttpData.scala b/zio-http/src/main/scala/zhttp/http/HttpData.scala index 54770a763a..6959942d3e 100644 --- a/zio-http/src/main/scala/zhttp/http/HttpData.scala +++ b/zio-http/src/main/scala/zhttp/http/HttpData.scala @@ -1,9 +1,11 @@ package zhttp.http import io.netty.buffer.{ByteBuf, Unpooled} +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.http.{HttpContent, LastHttpContent} import zio.blocking.Blocking.Service.live.effectBlocking import zio.stream.ZStream -import zio.{Chunk, Task, UIO} +import zio.{Chunk, Task, UIO, ZIO} import java.io.FileInputStream import java.nio.charset.Charset @@ -16,7 +18,7 @@ sealed trait HttpData { self => /** * Returns true if HttpData is a stream */ - def isChunked: Boolean = self match { + final def isChunked: Boolean = self match { case HttpData.BinaryStream(_) => true case _ => false } @@ -24,30 +26,22 @@ sealed trait HttpData { self => /** * Returns true if HttpData is empty */ - def isEmpty: Boolean = self match { + final def isEmpty: Boolean = self match { case HttpData.Empty => true case _ => false } - def toByteBuf: Task[ByteBuf] = { + final def toByteBuf: Task[ByteBuf] = { self match { - case HttpData.Text(text, charset) => UIO(Unpooled.copiedBuffer(text, charset)) - case HttpData.BinaryChunk(data) => UIO(Unpooled.copiedBuffer(data.toArray)) - case HttpData.BinaryByteBuf(data) => UIO(data) - case HttpData.Empty => UIO(Unpooled.EMPTY_BUFFER) - case HttpData.BinaryStream(stream) => - stream - .asInstanceOf[ZStream[Any, Throwable, ByteBuf]] - .fold(Unpooled.compositeBuffer())((c, b) => c.addComponent(b)) - case HttpData.RandomAccessFile(raf) => - effectBlocking { - val fis = new FileInputStream(raf().getFD) - val fileContent: Array[Byte] = new Array[Byte](raf().length().toInt) - fis.read(fileContent) - Unpooled.copiedBuffer(fileContent) - } + case self: HttpData.Incoming => self.encode + case self: HttpData.Outgoing => self.encode } } + + final def toByteBufStream: ZStream[Any, Throwable, ByteBuf] = self match { + case self: HttpData.Incoming => self.encodeAsStream + case self: HttpData.Outgoing => ZStream.fromEffect(self.encode) + } } object HttpData { @@ -68,33 +62,89 @@ object HttpData { def fromChunk(data: Chunk[Byte]): HttpData = BinaryChunk(data) /** - * Helper to create HttpData from Stream of bytes + * Helper to create HttpData from contents of a file */ - def fromStream(stream: ZStream[Any, Throwable, Byte]): HttpData = - HttpData.BinaryStream(stream.mapChunks(chunks => Chunk(Unpooled.copiedBuffer(chunks.toArray)))) + def fromFile(file: => java.io.File): HttpData = { + RandomAccessFile(() => new java.io.RandomAccessFile(file, "r")) + } /** * Helper to create HttpData from Stream of string */ - def fromStream(stream: ZStream[Any, Throwable, String], charset: Charset = HTTP_CHARSET): HttpData = - HttpData.BinaryStream(stream.map(str => Unpooled.copiedBuffer(str, charset))) + def fromStream(stream: ZStream[Any, Throwable, CharSequence], charset: Charset = HTTP_CHARSET): HttpData = + HttpData.BinaryStream(stream.map(str => Unpooled.wrappedBuffer(str.toString.getBytes(charset)))) /** - * Helper to create HttpData from String + * Helper to create HttpData from Stream of bytes */ - def fromString(text: String, charset: Charset = HTTP_CHARSET): HttpData = Text(text, charset) + def fromStream(stream: ZStream[Any, Throwable, Byte]): HttpData = + HttpData.BinaryStream(stream.mapChunks(chunks => Chunk(Unpooled.wrappedBuffer(chunks.toArray)))) /** - * Helper to create HttpData from contents of a file + * Helper to create HttpData from String */ - def fromFile(file: => java.io.File): HttpData = { - RandomAccessFile(() => new java.io.RandomAccessFile(file, "r")) + def fromString(text: String, charset: Charset = HTTP_CHARSET): HttpData = Text(text, charset) + + private[zhttp] sealed trait Outgoing extends HttpData { self => + def encode: ZIO[Any, Throwable, ByteBuf] = + self match { + case HttpData.Text(text, charset) => UIO(Unpooled.copiedBuffer(text, charset)) + case HttpData.BinaryChunk(data) => UIO(Unpooled.copiedBuffer(data.toArray)) + case HttpData.BinaryByteBuf(data) => UIO(data) + case HttpData.Empty => UIO(Unpooled.EMPTY_BUFFER) + case HttpData.BinaryStream(stream) => + stream + .asInstanceOf[ZStream[Any, Throwable, ByteBuf]] + .fold(Unpooled.compositeBuffer())((c, b) => c.addComponent(b)) + case HttpData.RandomAccessFile(raf) => + effectBlocking { + val fis = new FileInputStream(raf().getFD) + val fileContent: Array[Byte] = new Array[Byte](raf().length().toInt) + fis.read(fileContent) + Unpooled.copiedBuffer(fileContent) + } + } + } + + private[zhttp] final class UnsafeContent(private val httpContent: HttpContent) extends AnyVal { + def content: ByteBuf = httpContent.content() + + def isLast: Boolean = httpContent.isInstanceOf[LastHttpContent] + } + + private[zhttp] final class UnsafeChannel(private val ctx: ChannelHandlerContext) extends AnyVal { + def read(): Unit = ctx.read(): Unit + } + + private[zhttp] final case class Incoming(unsafeRun: (UnsafeChannel => UnsafeContent => Unit) => Unit) + extends HttpData { + def encode: ZIO[Any, Nothing, ByteBuf] = for { + body <- ZIO.effectAsync[Any, Nothing, ByteBuf](cb => + unsafeRun(ch => { + val buffer = Unpooled.compositeBuffer() + msg => { + buffer.addComponent(true, msg.content) + if (msg.isLast) cb(UIO(buffer)) else ch.read() + } + }), + ) + } yield body + + def encodeAsStream: ZStream[Any, Nothing, ByteBuf] = ZStream + .effectAsync[Any, Nothing, ByteBuf](cb => + unsafeRun(ch => + msg => { + cb(ZIO.succeed(Chunk(msg.content))) + if (msg.isLast) cb(ZIO.fail(None)) else ch.read() + }, + ), + ) } - private[zhttp] final case class Text(text: String, charset: Charset) extends HttpData - private[zhttp] final case class BinaryChunk(data: Chunk[Byte]) extends HttpData - private[zhttp] final case class BinaryByteBuf(data: ByteBuf) extends HttpData - private[zhttp] final case class BinaryStream(stream: ZStream[Any, Throwable, ByteBuf]) extends HttpData - private[zhttp] final case class RandomAccessFile(unsafeGet: () => java.io.RandomAccessFile) extends HttpData - private[zhttp] case object Empty extends HttpData + private[zhttp] final case class Text(text: String, charset: Charset) extends Outgoing + private[zhttp] final case class BinaryChunk(data: Chunk[Byte]) extends Outgoing + private[zhttp] final case class BinaryByteBuf(data: ByteBuf) extends Outgoing + private[zhttp] final case class BinaryStream(stream: ZStream[Any, Throwable, ByteBuf]) extends Outgoing + private[zhttp] final case class RandomAccessFile(unsafeGet: () => java.io.RandomAccessFile) extends Outgoing + private[zhttp] case object Empty extends Outgoing } diff --git a/zio-http/src/main/scala/zhttp/http/Request.scala b/zio-http/src/main/scala/zhttp/http/Request.scala index 144c6c8407..22a01ba7e0 100644 --- a/zio-http/src/main/scala/zhttp/http/Request.scala +++ b/zio-http/src/main/scala/zhttp/http/Request.scala @@ -3,6 +3,7 @@ package zhttp.http import io.netty.buffer.{ByteBuf, ByteBufUtil} import io.netty.handler.codec.http.{DefaultFullHttpRequest, HttpRequest} import zhttp.http.headers.HeaderExtension +import zio.stream.ZStream import zio.{Chunk, Task, UIO} import java.net.InetAddress @@ -33,17 +34,33 @@ trait Request extends HeaderExtension[Request] { self => */ def data: HttpData + final def bodyAsByteArray: Task[Array[Byte]] = + bodyAsByteBuf.flatMap(buf => Task(ByteBufUtil.getBytes(buf)).ensuring(UIO(buf.release(buf.refCnt())))) + /** * Decodes the content of request as a Chunk of Bytes */ - def body: Task[Chunk[Byte]] = - bodyAsByteBuf.flatMap(buf => Task(Chunk.fromArray(ByteBufUtil.getBytes(buf)))) + final def body: Task[Chunk[Byte]] = + bodyAsByteArray.map(Chunk.fromArray) /** * Decodes the content of request as string */ - def bodyAsString: Task[String] = - bodyAsByteBuf.flatMap(buf => Task(buf.toString(charset))) + final def bodyAsString: Task[String] = + bodyAsByteArray.map(new String(_, charset)) + + /** + * Decodes the content of request as stream of bytes + */ + final def bodyAsStream: ZStream[Any, Throwable, Byte] = data.toByteBufStream + .mapM[Any, Throwable, Chunk[Byte]] { buf => + Task { + val bytes = Chunk.fromArray(ByteBufUtil.getBytes(buf)) + buf.release(buf.refCnt()) + bytes + } + } + .flattenChunks /** * Gets all the headers in the Request @@ -95,7 +112,7 @@ trait Request extends HeaderExtension[Request] { self => */ def url: URL - private[zhttp] def bodyAsByteBuf: Task[ByteBuf] = data.toByteBuf + private[zhttp] final def bodyAsByteBuf: Task[ByteBuf] = data.toByteBuf } object Request { diff --git a/zio-http/src/main/scala/zhttp/http/Response.scala b/zio-http/src/main/scala/zhttp/http/Response.scala index 91e90e016d..b913f30457 100644 --- a/zio-http/src/main/scala/zhttp/http/Response.scala +++ b/zio-http/src/main/scala/zhttp/http/Response.scala @@ -75,12 +75,16 @@ final case class Response private ( val jHeaders = self.headers.encode val jContent = self.data match { - case HttpData.Text(text, charset) => Unpooled.wrappedBuffer(text.getBytes(charset)) - case HttpData.BinaryChunk(data) => Unpooled.copiedBuffer(data.toArray) - case HttpData.BinaryByteBuf(data) => data - case HttpData.BinaryStream(_) => null - case HttpData.Empty => Unpooled.EMPTY_BUFFER - case HttpData.RandomAccessFile(_) => null + case HttpData.Incoming(_) => null + case data: HttpData.Outgoing => + data match { + case HttpData.Text(text, charset) => Unpooled.wrappedBuffer(text.getBytes(charset)) + case HttpData.BinaryChunk(data) => Unpooled.copiedBuffer(data.toArray) + case HttpData.BinaryByteBuf(data) => data + case HttpData.BinaryStream(_) => null + case HttpData.Empty => Unpooled.EMPTY_BUFFER + case HttpData.RandomAccessFile(_) => null + } } val hasContentLength = jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) diff --git a/zio-http/src/main/scala/zhttp/service/Handler.scala b/zio-http/src/main/scala/zhttp/service/Handler.scala index d144d05757..2ebb360c22 100644 --- a/zio-http/src/main/scala/zhttp/service/Handler.scala +++ b/zio-http/src/main/scala/zhttp/service/Handler.scala @@ -16,54 +16,115 @@ private[zhttp] final case class Handler[R]( runtime: HttpRuntime[R], config: Server.Config[R, Throwable], serverTimeGenerator: ServerTime, -) extends SimpleChannelInboundHandler[FullHttpRequest](false) +) extends SimpleChannelInboundHandler[HttpObject](false) with WebSocketUpgrade[R] with ServerResponseHandler[R] { self => - override def channelRead0(ctx: Ctx, jReq: FullHttpRequest): Unit = { - jReq.touch("server.Handler-channelRead0") + override def channelRead0(ctx: Ctx, msg: HttpObject): Unit = { + implicit val iCtx: ChannelHandlerContext = ctx - try - ( - unsafeRun( - jReq, - app, - new Request { - override def method: Method = Method.fromHttpMethod(jReq.method()) + msg match { + case jReq: FullHttpRequest => + jReq.touch("server.Handler-channelRead0") + try + unsafeRun( + jReq, + app, + new Request { + override def method: Method = Method.fromHttpMethod(jReq.method()) + + override def url: URL = URL.fromString(jReq.uri()).getOrElse(null) + + override def headers: Headers = Headers.make(jReq.headers()) + + override def remoteAddress: Option[InetAddress] = { + ctx.channel().remoteAddress() match { + case m: InetSocketAddress => Some(m.getAddress) + case _ => None + } + } - override def url: URL = URL.fromString(jReq.uri()).getOrElse(null) + override def data: HttpData = HttpData.fromByteBuf(jReq.content()) - override def headers: Headers = Headers.make(jReq.headers()) + /** + * Gets the HttpRequest + */ + override def unsafeEncode = jReq + }, + ) + catch { + case throwable: Throwable => + writeResponse( + Response + .fromHttpError(HttpError.InternalServerError(cause = Some(throwable))) + .withConnection(HeaderValues.close), + jReq, + ): Unit + } + case jReq: HttpRequest => + if (canHaveBody(jReq)) { + ctx.channel().config().setAutoRead(false): Unit + } + try + unsafeRun( + jReq, + app, + new Request { + override def data: HttpData = HttpData.Incoming(callback => + ctx + .pipeline() + .addAfter(HTTP_REQUEST_HANDLER, HTTP_CONTENT_HANDLER, new RequestBodyHandler(callback)): Unit, + ) + + override def headers: Headers = Headers.make(jReq.headers()) + + override def method: Method = Method.fromHttpMethod(jReq.method()) + + override def remoteAddress: Option[InetAddress] = { + ctx.channel().remoteAddress() match { + case m: InetSocketAddress => Some(m.getAddress) + case _ => None + } + } - override def unsafeEncode: HttpRequest = jReq + override def url: URL = URL.fromString(jReq.uri()).getOrElse(null) + + /** + * Gets the HttpRequest + */ + override def unsafeEncode = jReq + }, + ) + catch { + case throwable: Throwable => + writeResponse( + Response + .fromHttpError(HttpError.InternalServerError(cause = Some(throwable))) + .withConnection(HeaderValues.close), + jReq, + ): Unit + } + + case msg: HttpContent => + ctx.fireChannelRead(msg): Unit + + case _ => + throw new IllegalStateException(s"Unexpected message type: ${msg.getClass.getName}") - override def remoteAddress: Option[InetAddress] = { - ctx.channel().remoteAddress() match { - case m: InetSocketAddress => Some(m.getAddress) - case _ => None - } - } - - override def data: HttpData = HttpData.fromByteBuf(jReq.content()) - }, - ), - ) - catch { - case throwable: Throwable => - writeResponse( - Response - .fromHttpError(HttpError.InternalServerError(cause = Some(throwable))) - .withConnection(HeaderValues.close), - jReq, - ): Unit } + + } + + private def canHaveBody(req: HttpRequest): Boolean = req.method() match { + case HttpMethod.GET | HttpMethod.HEAD | HttpMethod.OPTIONS | HttpMethod.TRACE => false + case _ => true } /** * Executes http apps */ private def unsafeRun[A]( - jReq: FullHttpRequest, + jReq: HttpRequest, http: Http[R, Throwable, A, Response], a: A, )(implicit ctx: Ctx): Unit = { @@ -86,7 +147,7 @@ private[zhttp] final case class Handler[R]( }, res => - if (self.isWebSocket(res)) UIO(self.upgradeToWebSocket(ctx, jReq, res)) + if (self.isWebSocket(res)) UIO(self.upgradeToWebSocket(jReq, res)) else { for { _ <- ZIO { @@ -99,7 +160,7 @@ private[zhttp] final case class Handler[R]( case HExit.Success(res) => if (self.isWebSocket(res)) { - self.upgradeToWebSocket(ctx, jReq, res) + self.upgradeToWebSocket(jReq, res) } else { writeResponse(res, jReq): Unit } diff --git a/zio-http/src/main/scala/zhttp/service/RequestBodyHandler.scala b/zio-http/src/main/scala/zhttp/service/RequestBodyHandler.scala new file mode 100644 index 0000000000..27d395a71c --- /dev/null +++ b/zio-http/src/main/scala/zhttp/service/RequestBodyHandler.scala @@ -0,0 +1,22 @@ +package zhttp.service +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} +import io.netty.handler.codec.http.{HttpContent, LastHttpContent} +import zhttp.http.HttpData.{UnsafeChannel, UnsafeContent} + +final class RequestBodyHandler(val callback: UnsafeChannel => UnsafeContent => Unit) + extends SimpleChannelInboundHandler[HttpContent](false) { self => + + private var onMessage: UnsafeContent => Unit = _ + + override def channelRead0(ctx: ChannelHandlerContext, msg: HttpContent): Unit = { + self.onMessage(new UnsafeContent(msg)) + if (msg.isInstanceOf[LastHttpContent]) { + ctx.channel().pipeline().remove(self): Unit + } + } + + override def handlerAdded(ctx: ChannelHandlerContext): Unit = { + self.onMessage = callback(new UnsafeChannel(ctx)) + ctx.read(): Unit + } +} diff --git a/zio-http/src/main/scala/zhttp/service/Server.scala b/zio-http/src/main/scala/zhttp/service/Server.scala index 0e80aa18b4..5200ccb2a0 100644 --- a/zio-http/src/main/scala/zhttp/service/Server.scala +++ b/zio-http/src/main/scala/zhttp/service/Server.scala @@ -21,7 +21,6 @@ sealed trait Server[-R, +E] { self => private def settings[R1 <: R, E1 >: E](s: Config[R1, E1] = Config()): Config[R1, E1] = self match { case Concat(self, other) => other.settings(self.settings(s)) case LeakDetection(level) => s.copy(leakDetectionLevel = level) - case MaxRequestSize(size) => s.copy(maxRequestSize = size) case Error(errorHandler) => s.copy(error = Some(errorHandler)) case Ssl(sslOption) => s.copy(sslOption = sslOption) case App(app) => s.copy(app = app) @@ -32,6 +31,7 @@ sealed trait Server[-R, +E] { self => case ConsolidateFlush(enabled) => s.copy(consolidateFlush = enabled) case UnsafeChannelPipeline(init) => s.copy(channelInitializer = init) case RequestDecompression(enabled, strict) => s.copy(requestDecompression = (enabled, strict)) + case ObjectAggregator(maxRequestSize) => s.copy(objectAggregator = maxRequestSize) } def make(implicit @@ -49,12 +49,6 @@ sealed trait Server[-R, +E] { self => def startDefault[R1 <: Has[_] with R](implicit ev: E <:< Throwable): ZIO[R1, Throwable, Nothing] = start.provideSomeLayer[R1](EventLoopGroup.auto(0) ++ ServerChannelFactory.auto) - /** - * Creates a new server with the maximum size of the request specified in - * bytes. - */ - def withMaxRequestSize(size: Int): Server[R, E] = Concat(self, Server.MaxRequestSize(size)) - /** * Creates a new server listening on the provided port. */ @@ -138,12 +132,18 @@ sealed trait Server[-R, +E] { self => */ def withRequestDecompression(enabled: Boolean, strict: Boolean): Server[R, E] = Concat(self, RequestDecompression(enabled, strict)) + + /** + * Creates a new server with HttpObjectAggregator with the specified max size + * of the aggregated content. + */ + def withObjectAggregator(maxRequestSize: Int = Int.MaxValue): Server[R, E] = + Concat(self, ObjectAggregator(maxRequestSize)) } object Server { private[zhttp] final case class Config[-R, +E]( leakDetectionLevel: LeakDetectionLevel = LeakDetectionLevel.SIMPLE, - maxRequestSize: Int = 4 * 1024, // 4 kilo bytes error: Option[Throwable => ZIO[R, Nothing, Unit]] = None, sslOption: ServerSSLOptions = null, @@ -156,7 +156,10 @@ object Server { flowControl: Boolean = true, channelInitializer: ChannelPipeline => Unit = null, requestDecompression: (Boolean, Boolean) = (false, false), - ) + objectAggregator: Int = -1, + ) { + def useAggregator: Boolean = objectAggregator >= 0 + } /** * Holds server start information. @@ -165,7 +168,6 @@ object Server { private final case class Concat[R, E](self: Server[R, E], other: Server[R, E]) extends Server[R, E] private final case class LeakDetection(level: LeakDetectionLevel) extends UServer - private final case class MaxRequestSize(size: Int) extends UServer private final case class Error[R](errorHandler: Throwable => ZIO[R, Nothing, Unit]) extends Server[R, Nothing] private final case class Ssl(sslOptions: ServerSSLOptions) extends UServer private final case class Address(address: InetSocketAddress) extends UServer @@ -176,9 +178,9 @@ object Server { private final case class FlowControl(enabled: Boolean) extends UServer private final case class UnsafeChannelPipeline(init: ChannelPipeline => Unit) extends UServer private final case class RequestDecompression(enabled: Boolean, strict: Boolean) extends UServer + private final case class ObjectAggregator(maxRequestSize: Int) extends UServer def app[R, E](http: HttpApp[R, E]): Server[R, E] = Server.App(http) - def maxRequestSize(size: Int): UServer = Server.MaxRequestSize(size) def port(port: Int): UServer = Server.Address(new InetSocketAddress(port)) def bind(port: Int): UServer = Server.Address(new InetSocketAddress(port)) def bind(hostname: String, port: Int): UServer = Server.Address(new InetSocketAddress(hostname, port)) @@ -188,14 +190,15 @@ object Server { def ssl(sslOptions: ServerSSLOptions): UServer = Server.Ssl(sslOptions) def acceptContinue: UServer = Server.AcceptContinue(true) def requestDecompression(strict: Boolean): UServer = Server.RequestDecompression(enabled = true, strict = strict) - def unsafePipeline(pipeline: ChannelPipeline => Unit): UServer = UnsafeChannelPipeline(pipeline) - val disableFlowControl: UServer = Server.FlowControl(false) - val disableLeakDetection: UServer = LeakDetection(LeakDetectionLevel.DISABLED) - val simpleLeakDetection: UServer = LeakDetection(LeakDetectionLevel.SIMPLE) - val advancedLeakDetection: UServer = LeakDetection(LeakDetectionLevel.ADVANCED) - val paranoidLeakDetection: UServer = LeakDetection(LeakDetectionLevel.PARANOID) - val disableKeepAlive: UServer = Server.KeepAlive(false) - val consolidateFlush: UServer = ConsolidateFlush(true) + val disableFlowControl: UServer = Server.FlowControl(false) + val disableLeakDetection: UServer = LeakDetection(LeakDetectionLevel.DISABLED) + val simpleLeakDetection: UServer = LeakDetection(LeakDetectionLevel.SIMPLE) + val advancedLeakDetection: UServer = LeakDetection(LeakDetectionLevel.ADVANCED) + val paranoidLeakDetection: UServer = LeakDetection(LeakDetectionLevel.PARANOID) + val disableKeepAlive: UServer = Server.KeepAlive(false) + val consolidateFlush: UServer = ConsolidateFlush(true) + def unsafePipeline(pipeline: ChannelPipeline => Unit): UServer = UnsafeChannelPipeline(pipeline) + def enableObjectAggregator(maxRequestSize: Int = Int.MaxValue): UServer = ObjectAggregator(maxRequestSize) /** * Creates a server from a http app. diff --git a/zio-http/src/main/scala/zhttp/service/package.scala b/zio-http/src/main/scala/zhttp/service/package.scala index 7b0eb82207..342361f34d 100644 --- a/zio-http/src/main/scala/zhttp/service/package.scala +++ b/zio-http/src/main/scala/zhttp/service/package.scala @@ -21,6 +21,7 @@ package object service { private[service] val CLIENT_INBOUND_HANDLER = "CLIENT_INBOUND_HANDLER" private[service] val WEB_SOCKET_CLIENT_PROTOCOL_HANDLER = "WEB_SOCKET_CLIENT_PROTOCOL_HANDLER" private[service] val HTTP_REQUEST_DECOMPRESSION = "HTTP_REQUEST_DECOMPRESSION" + private[zhttp] val HTTP_CONTENT_HANDLER = "HTTP_CONTENT_HANDLER" type ChannelFactory = Has[JChannelFactory[Channel]] type EventLoopGroup = Has[JEventLoopGroup] diff --git a/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala b/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala index 15c837bcbf..90bf6c3703 100644 --- a/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala +++ b/zio-http/src/main/scala/zhttp/service/server/ServerChannelInitializer.scala @@ -50,7 +50,8 @@ final case class ServerChannelInitializer[R]( // ObjectAggregator // Always add ObjectAggregator - pipeline.addLast(HTTP_OBJECT_AGGREGATOR, new HttpObjectAggregator(cfg.maxRequestSize)) + if (cfg.useAggregator) + pipeline.addLast(HTTP_OBJECT_AGGREGATOR, new HttpObjectAggregator(cfg.objectAggregator)) // ExpectContinueHandler // Add expect continue handler is settings is true diff --git a/zio-http/src/main/scala/zhttp/service/server/WebSocketUpgrade.scala b/zio-http/src/main/scala/zhttp/service/server/WebSocketUpgrade.scala index 2d02e6431b..b487a1f42e 100644 --- a/zio-http/src/main/scala/zhttp/service/server/WebSocketUpgrade.scala +++ b/zio-http/src/main/scala/zhttp/service/server/WebSocketUpgrade.scala @@ -6,6 +6,8 @@ import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler import zhttp.http.{Response, Status} import zhttp.service.{HttpRuntime, WEB_SOCKET_HANDLER, WebSocketAppHandler} +import scala.annotation.tailrec + /** * Module to switch protocol to websockets */ @@ -19,15 +21,23 @@ trait WebSocketUpgrade[R] { self: ChannelHandler => * Checks if the response requires to switch protocol to websocket. Returns * true if it can, otherwise returns false */ - final def upgradeToWebSocket(ctx: ChannelHandlerContext, jReq: FullHttpRequest, res: Response): Unit = { + @tailrec + final def upgradeToWebSocket(jReq: HttpRequest, res: Response)(implicit ctx: ChannelHandlerContext): Unit = { val app = res.attribute.socketApp - ctx - .channel() - .pipeline() - .addLast(new WebSocketServerProtocolHandler(app.get.protocol.serverBuilder.build())) - .addLast(WEB_SOCKET_HANDLER, new WebSocketAppHandler(runtime, app.get)) - ctx.channel().eventLoop().submit(() => ctx.fireChannelRead(jReq)): Unit + jReq match { + case jReq: FullHttpRequest => + ctx + .channel() + .pipeline() + .addLast(new WebSocketServerProtocolHandler(app.get.protocol.serverBuilder.build())) + .addLast(WEB_SOCKET_HANDLER, new WebSocketAppHandler(runtime, app.get)) + ctx.channel().eventLoop().submit(() => ctx.fireChannelRead(jReq)): Unit + case jReq: HttpRequest => + val fullRequest = new DefaultFullHttpRequest(jReq.protocolVersion(), jReq.method(), jReq.uri()) + fullRequest.headers().setAll(jReq.headers()) + self.upgradeToWebSocket(fullRequest, res) + } } } diff --git a/zio-http/src/main/scala/zhttp/service/server/content/handlers/ServerResponseHandler.scala b/zio-http/src/main/scala/zhttp/service/server/content/handlers/ServerResponseHandler.scala index 28db13c94a..14286c178e 100644 --- a/zio-http/src/main/scala/zhttp/service/server/content/handlers/ServerResponseHandler.scala +++ b/zio-http/src/main/scala/zhttp/service/server/content/handlers/ServerResponseHandler.scala @@ -1,39 +1,26 @@ package zhttp.service.server.content.handlers import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.{ChannelHandlerContext, DefaultFileRegion} import io.netty.handler.codec.http._ import zhttp.http.{HttpData, Response} import zhttp.service.server.ServerTime -import zhttp.service.{ChannelFuture, HttpRuntime} +import zhttp.service.{ChannelFuture, HttpRuntime, Server} import zio.stream.ZStream import zio.{UIO, ZIO} import java.io.RandomAccessFile -@Sharable private[zhttp] trait ServerResponseHandler[R] { - def serverTime: ServerTime - val rt: HttpRuntime[R] - type Ctx = ChannelHandlerContext + val rt: HttpRuntime[R] + val config: Server.Config[R, Throwable] - def writeResponse(msg: Response, jReq: FullHttpRequest)(implicit ctx: Ctx): Unit = { + def serverTime: ServerTime + def writeResponse(msg: Response, jReq: HttpRequest)(implicit ctx: Ctx): Unit = { ctx.write(encodeResponse(msg)) - msg.data match { - case HttpData.BinaryStream(stream) => - rt.unsafeRun(ctx) { - writeStreamContent(stream).ensuring(UIO(releaseRequest(jReq))) - } - case HttpData.RandomAccessFile(raf) => - unsafeWriteFileContent(raf()) - releaseRequest(jReq) - case _ => - ctx.flush() - releaseRequest(jReq) - } + writeData(msg.data.asInstanceOf[HttpData.Outgoing], jReq) () } @@ -65,9 +52,53 @@ private[zhttp] trait ServerResponseHandler[R] { /** * Releases the FullHttpRequest safely. */ - private def releaseRequest(jReq: FullHttpRequest): Unit = { - if (jReq.refCnt() > 0) { - jReq.release(jReq.refCnt()): Unit + private def releaseRequest(jReq: HttpRequest)(implicit ctx: Ctx): Unit = { + jReq match { + case jReq: FullHttpRequest if jReq.refCnt() > 0 => jReq.release(jReq.refCnt()): Unit + case _ => () + } + } + + /** + * Writes file content to the Channel. Does not use Chunked transfer encoding + */ + private def unsafeWriteFileContent(raf: RandomAccessFile)(implicit ctx: ChannelHandlerContext): Unit = { + val fileLength = raf.length() + // Write the content. + ctx.write(new DefaultFileRegion(raf.getChannel, 0, fileLength)) + // Write the end marker. + ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT): Unit + } + + /** + * Writes data on the channel + */ + private def writeData(data: HttpData.Outgoing, jReq: HttpRequest)(implicit ctx: Ctx): Unit = { + data match { + case HttpData.BinaryStream(stream) => + rt.unsafeRun(ctx) { + writeStreamContent(stream).ensuring(UIO { + releaseRequest(jReq) + if (!config.useAggregator && !ctx.channel().config().isAutoRead) { + ctx.channel().config().setAutoRead(true) + ctx.read(): Unit + } // read next HttpContent + }) + } + case HttpData.RandomAccessFile(raf) => + unsafeWriteFileContent(raf()) + releaseRequest(jReq) + if (!config.useAggregator && !ctx.channel().config().isAutoRead) { + ctx.channel().config().setAutoRead(true) + ctx.read(): Unit + } // read next HttpContent + case _ => + ctx.flush() + releaseRequest(jReq) + if (!config.useAggregator && !ctx.channel().config().isAutoRead) { + ctx.channel().config().setAutoRead(true) + ctx.read(): Unit + } // read next HttpContent } } @@ -82,16 +113,4 @@ private[zhttp] trait ServerResponseHandler[R] { _ <- ChannelFuture.unit(ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)) } yield () } - - /** - * Writes file content to the Channel. Does not use Chunked transfer encoding - */ - - private def unsafeWriteFileContent(raf: RandomAccessFile)(implicit ctx: ChannelHandlerContext): Unit = { - val fileLength = raf.length() - // Write the content. - ctx.write(new DefaultFileRegion(raf.getChannel, 0, fileLength)) - // Write the end marker. - ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT): Unit - } } diff --git a/zio-http/src/test/scala/zhttp/service/ServerSpec.scala b/zio-http/src/test/scala/zhttp/service/ServerSpec.scala index d467c07598..ddb3ef1840 100644 --- a/zio-http/src/test/scala/zhttp/service/ServerSpec.scala +++ b/zio-http/src/test/scala/zhttp/service/ServerSpec.scala @@ -23,7 +23,9 @@ object ServerSpec extends HttpRunnableSpec { private val env = EventLoopGroup.nio() ++ ChannelFactory.nio ++ ServerChannelFactory.nio ++ DynamicServer.live - private val app = serve(DynamicServer.app, Some(Server.requestDecompression(true))) + private val app = + serve(DynamicServer.app, Some(Server.requestDecompression(true) ++ Server.enableObjectAggregator(4096))) + private val appWithReqStreaming = serve(DynamicServer.app, None) def dynamicAppSpec = suite("DynamicAppSpec") { suite("success") { @@ -146,7 +148,7 @@ object ServerSpec extends HttpRunnableSpec { } } - def responseSpec = suite("ResponseSpec") { + def responseSpec = suite("ResponseSpec") { testM("data") { checkAllM(nonEmptyContent) { case (string, data) => val res = Http.fromData(data).deploy.bodyAsString.run() @@ -240,14 +242,40 @@ object ServerSpec extends HttpRunnableSpec { } } } + def requestBodySpec = suite("RequestBodySpec") { + testM("POST Request stream") { + val app: Http[Any, Throwable, Request, Response] = Http.collect[Request] { case req => + Response(data = HttpData.fromStream(req.bodyAsStream)) + } + checkAllM(Gen.alphaNumericString) { c => + assertM(app.deploy.bodyAsString.run(path = !!, method = Method.POST, content = HttpData.fromString(c)))( + equalTo(c), + ) + } + } + } + + def serverErrorSpec = suite("ServerErrorSpec") { + val app = Http.fail(new Error("SERVER_ERROR")) + testM("status is 500") { + val res = app.deploy.status.run() + assertM(res)(equalTo(Status.INTERNAL_SERVER_ERROR)) + } + + testM("content is set") { + val res = app.deploy.bodyAsString.run() + assertM(res)(containsString("SERVER_ERROR")) + } + + testM("header is set") { + val res = app.deploy.headers.run().map(_.headerValue("Content-Length")) + assertM(res)(isSome(anything)) + } + } override def spec = - suiteM("Server") { - app - .as( - List(dynamicAppSpec, responseSpec, requestSpec), - ) - .useNow - }.provideCustomLayerShared(env) @@ timeout(30 seconds) + suite("Server") { + val spec = dynamicAppSpec + responseSpec + requestSpec + requestBodySpec + serverErrorSpec + suiteM("app without request streaming") { app.as(List(spec)).useNow } + + suiteM("app with request streaming") { appWithReqStreaming.as(List(spec)).useNow } + }.provideCustomLayerShared(env) @@ timeout(30 seconds) @@ sequential } diff --git a/zio-http/src/test/scala/zhttp/service/StaticServerSpec.scala b/zio-http/src/test/scala/zhttp/service/StaticServerSpec.scala index 788c204f23..4876a1d8a4 100644 --- a/zio-http/src/test/scala/zhttp/service/StaticServerSpec.scala +++ b/zio-http/src/test/scala/zhttp/service/StaticServerSpec.scala @@ -30,6 +30,18 @@ object StaticServerSpec extends HttpRunnableSpec { private val app = serve { nonZIO ++ staticApp } def nonZIOSpec = suite("NonZIOSpec") { + val methodGenWithoutHEAD: Gen[Any, Method] = Gen.fromIterable( + List( + Method.OPTIONS, + Method.GET, + Method.POST, + Method.PUT, + Method.PATCH, + Method.DELETE, + Method.TRACE, + Method.CONNECT, + ), + ) testM("200 response") { checkAllM(HttpGen.method) { method => val actual = status(method, !! / "HExitSuccess") @@ -37,25 +49,13 @@ object StaticServerSpec extends HttpRunnableSpec { } } + testM("500 response") { - val methodGenWithoutHEAD: Gen[Any, Method] = Gen.fromIterable( - List( - Method.OPTIONS, - Method.GET, - Method.POST, - Method.PUT, - Method.PATCH, - Method.DELETE, - Method.TRACE, - Method.CONNECT, - ), - ) checkAllM(methodGenWithoutHEAD) { method => val actual = status(method, !! / "HExitFailure") assertM(actual)(equalTo(Status.INTERNAL_SERVER_ERROR)) } } + testM("404 response ") { - checkAllM(HttpGen.method) { method => + checkAllM(methodGenWithoutHEAD) { method => val actual = status(method, !! / "A") assertM(actual)(equalTo(Status.NOT_FOUND)) }