Skip to content

Commit

Permalink
Fix: Echo HttpData directly from Request (#1244)
Browse files Browse the repository at this point in the history
* test: add failing test for echoing raw data

* fix: fix unsafe cast issue

* test: simplify test case
  • Loading branch information
tusharmath authored May 8, 2022
1 parent 79674a9 commit f5f2883
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 25 deletions.
18 changes: 5 additions & 13 deletions zio-http/src/main/scala/zhttp/http/HttpData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,11 @@ object HttpData {
}
}

private[zhttp] final class UnsafeContent(private val httpContent: HttpContent) extends AnyVal {
def content: ByteBuf = httpContent.content()

def isLast: Boolean = httpContent.isInstanceOf[LastHttpContent]
}

private[zhttp] final class UnsafeChannel(private val ctx: ChannelHandlerContext) extends AnyVal {
def read(): Unit = ctx.read(): Unit
}

private[zhttp] final case class UnsafeAsync(unsafeRun: (UnsafeChannel => UnsafeContent => Unit) => Unit)
private[zhttp] final case class UnsafeAsync(unsafeRun: (ChannelHandlerContext => HttpContent => Unit) => Unit)
extends HttpData {

private def isLast(msg: HttpContent): Boolean = msg.isInstanceOf[LastHttpContent]

/**
* Encodes the HttpData into a ByteBuf.
*/
Expand All @@ -148,7 +140,7 @@ object HttpData {
val buffer = Unpooled.compositeBuffer()
msg => {
buffer.addComponent(true, msg.content)
if (msg.isLast) cb(UIO(buffer)) else ch.read()
if (isLast(msg)) cb(UIO(buffer)) else ch.read(): Unit
}
}),
)
Expand All @@ -163,7 +155,7 @@ object HttpData {
unsafeRun(ch =>
msg => {
cb(ZIO.succeed(Chunk(msg.content)))
if (msg.isLast) cb(ZIO.fail(None)) else ch.read()
if (isLast(msg)) cb(ZIO.fail(None)) else ch.read(): Unit
},
),
)
Expand Down
9 changes: 7 additions & 2 deletions zio-http/src/main/scala/zhttp/service/Handler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ private[zhttp] final case class Handler[R](

override def remoteAddress: Option[InetAddress] = getRemoteAddress

override def data: HttpData = HttpData.fromByteBuf(jReq.content())
override def data: HttpData = HttpData.fromByteBuf(jReq.content())

override def version: Version = Version.unsafeFromJava(jReq.protocolVersion())

/**
Expand All @@ -63,7 +64,11 @@ private[zhttp] final case class Handler[R](
HttpData.UnsafeAsync(callback =>
ctx
.pipeline()
.addAfter(HTTP_REQUEST_HANDLER, HTTP_CONTENT_HANDLER, new RequestBodyHandler(callback)): Unit,
.addAfter(
HTTP_REQUEST_HANDLER,
HTTP_CONTENT_HANDLER,
new RequestBodyHandler(callback(ctx)),
): Unit,
)

override def headers: Headers = Headers.make(jReq.headers())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
package zhttp.service
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http.{HttpContent, LastHttpContent}
import zhttp.http.HttpData.{UnsafeChannel, UnsafeContent}

final class RequestBodyHandler(val callback: UnsafeChannel => UnsafeContent => Unit)
final class RequestBodyHandler(val callback: HttpContent => Unit)
extends SimpleChannelInboundHandler[HttpContent](false) { self =>

private var onMessage: UnsafeContent => Unit = _

override def channelRead0(ctx: ChannelHandlerContext, msg: HttpContent): Unit = {
self.onMessage(new UnsafeContent(msg))
self.callback(msg)
if (msg.isInstanceOf[LastHttpContent]) {
ctx.channel().pipeline().remove(self): Unit
}
}

override def handlerAdded(ctx: ChannelHandlerContext): Unit = {
self.onMessage = callback(new UnsafeChannel(ctx))
ctx.read(): Unit
}
}
10 changes: 8 additions & 2 deletions zio-http/src/main/scala/zhttp/service/ServerResponseWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private[zhttp] final class ServerResponseWriter[R](
/**
* Writes data on the channel
*/
private def writeData(data: HttpData.Complete, jReq: HttpRequest)(implicit ctx: Ctx): Unit = {
private def writeData(data: HttpData, jReq: HttpRequest)(implicit ctx: Ctx): Unit = {
data match {

case _: HttpData.FromAsciiString => flushReleaseAndRead(jReq)
Expand All @@ -103,6 +103,12 @@ private[zhttp] final class ServerResponseWriter[R](
case HttpData.JavaFile(unsafeGet) =>
unsafeWriteFileContent(unsafeGet())
releaseAndRead(jReq)

case HttpData.UnsafeAsync(unsafeRun) =>
unsafeRun { _ => msg =>
ctx.writeAndFlush(msg)
if (!msg.isInstanceOf[LastHttpContent]) ctx.read(): Unit
}
}
}

Expand All @@ -127,7 +133,7 @@ private[zhttp] final class ServerResponseWriter[R](

def write(msg: Response, jReq: HttpRequest)(implicit ctx: Ctx): Unit = {
ctx.write(encodeResponse(msg))
writeData(msg.data.asInstanceOf[HttpData.Complete], jReq)
writeData(msg.data, jReq)
()
}

Expand Down
11 changes: 9 additions & 2 deletions zio-http/src/test/scala/zhttp/service/ServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ object ServerSpec extends HttpRunnableSpec {
private val env =
EventLoopGroup.nio() ++ ChannelFactory.nio ++ ServerChannelFactory.nio ++ DynamicServer.live

private val MaxSize = 1024 * 10
private val app =
serve(DynamicServer.app, Some(Server.requestDecompression(true) ++ Server.enableObjectAggregator(4096)))
serve(DynamicServer.app, Some(Server.requestDecompression(true) ++ Server.enableObjectAggregator(MaxSize)))
private val appWithReqStreaming = serve(DynamicServer.app, Some(Server.requestDecompression(true)))

def dynamicAppSpec = suite("DynamicAppSpec") {
Expand Down Expand Up @@ -104,6 +105,12 @@ object ServerSpec extends HttpRunnableSpec {
testM("one char") {
val res = app.deploy.bodyAsString.run(content = HttpData.fromString("1"))
assertM(res)(equalTo("1"))
} +
testM("data") {
val dataStream = ZStream.repeat("A").take(MaxSize.toLong)
val app = Http.collect[Request] { case req => Response(data = req.data) }
val res = app.deploy.bodyAsByteBuf.map(_.readableBytes()).run(content = HttpData.fromStream(dataStream))
assertM(res)(equalTo(MaxSize))
}
} +
suite("headers") {
Expand Down Expand Up @@ -299,6 +306,6 @@ object ServerSpec extends HttpRunnableSpec {
val spec = dynamicAppSpec + responseSpec + requestSpec + requestBodySpec + serverErrorSpec
suiteM("app without request streaming") { app.as(List(spec)).useNow } +
suiteM("app with request streaming") { appWithReqStreaming.as(List(spec)).useNow }
}.provideCustomLayerShared(env) @@ timeout(20 seconds)
}.provideCustomLayerShared(env) @@ timeout(10 seconds)

}

0 comments on commit f5f2883

Please sign in to comment.