Skip to content

Commit

Permalink
Feat: Server Request Decompression (#1095)
Browse files Browse the repository at this point in the history
* feat: server request decompression

* test: add test for deflate
  • Loading branch information
girdharshubham authored Feb 25, 2022
1 parent 7f9b409 commit a9ae523
Showing 6 changed files with 69 additions and 26 deletions.
40 changes: 26 additions & 14 deletions zio-http/src/main/scala/zhttp/service/Server.scala
Original file line number Diff line number Diff line change
@@ -19,18 +19,19 @@ sealed trait Server[-R, +E] { self =>
Concat(self, other)

private def settings[R1 <: R, E1 >: E](s: Config[R1, E1] = Config()): Config[R1, E1] = self match {
case Concat(self, other) => other.settings(self.settings(s))
case LeakDetection(level) => s.copy(leakDetectionLevel = level)
case MaxRequestSize(size) => s.copy(maxRequestSize = size)
case Error(errorHandler) => s.copy(error = Some(errorHandler))
case Ssl(sslOption) => s.copy(sslOption = sslOption)
case App(app) => s.copy(app = app)
case Address(address) => s.copy(address = address)
case AcceptContinue(enabled) => s.copy(acceptContinue = enabled)
case KeepAlive(enabled) => s.copy(keepAlive = enabled)
case FlowControl(enabled) => s.copy(flowControl = enabled)
case ConsolidateFlush(enabled) => s.copy(consolidateFlush = enabled)
case UnsafeChannelPipeline(init) => s.copy(channelInitializer = init)
case Concat(self, other) => other.settings(self.settings(s))
case LeakDetection(level) => s.copy(leakDetectionLevel = level)
case MaxRequestSize(size) => s.copy(maxRequestSize = size)
case Error(errorHandler) => s.copy(error = Some(errorHandler))
case Ssl(sslOption) => s.copy(sslOption = sslOption)
case App(app) => s.copy(app = app)
case Address(address) => s.copy(address = address)
case AcceptContinue(enabled) => s.copy(acceptContinue = enabled)
case KeepAlive(enabled) => s.copy(keepAlive = enabled)
case FlowControl(enabled) => s.copy(flowControl = enabled)
case ConsolidateFlush(enabled) => s.copy(consolidateFlush = enabled)
case UnsafeChannelPipeline(init) => s.copy(channelInitializer = init)
case RequestDecompression(enabled, strict) => s.copy(requestDecompression = (enabled, strict))
}

def make(implicit
@@ -129,6 +130,14 @@ sealed trait Server[-R, +E] { self =>
*/
def withUnsafeChannelPipeline(unsafePipeline: ChannelPipeline => Unit): Server[R, E] =
Concat(self, UnsafeChannelPipeline(unsafePipeline))

/**
* Creates a new server with netty's HttpContentDecompressor to decompress
* Http requests (@see <a href =
* "https://netty.io/4.1/api/io/netty/handler/codec/http/HttpContentDecompressor.html">HttpContentDecompressor</a>).
*/
def withRequestDecompression(enabled: Boolean, strict: Boolean): Server[R, E] =
Concat(self, RequestDecompression(enabled, strict))
}

object Server {
@@ -146,6 +155,7 @@ object Server {
consolidateFlush: Boolean = false,
flowControl: Boolean = true,
channelInitializer: ChannelPipeline => Unit = null,
requestDecompression: (Boolean, Boolean) = (false, false),
)

/**
@@ -165,6 +175,7 @@ object Server {
private final case class AcceptContinue(enabled: Boolean) extends UServer
private final case class FlowControl(enabled: Boolean) extends UServer
private final case class UnsafeChannelPipeline(init: ChannelPipeline => Unit) extends UServer
private final case class RequestDecompression(enabled: Boolean, strict: Boolean) extends UServer

def app[R, E](http: HttpApp[R, E]): Server[R, E] = Server.App(http)
def maxRequestSize(size: Int): UServer = Server.MaxRequestSize(size)
@@ -176,14 +187,15 @@ object Server {
def error[R](errorHandler: Throwable => ZIO[R, Nothing, Unit]): Server[R, Nothing] = Server.Error(errorHandler)
def ssl(sslOptions: ServerSSLOptions): UServer = Server.Ssl(sslOptions)
def acceptContinue: UServer = Server.AcceptContinue(true)
val disableFlowControl: UServer = Server.FlowControl(false)
def requestDecompression(strict: Boolean): UServer = Server.RequestDecompression(enabled = true, strict = strict)
def unsafePipeline(pipeline: ChannelPipeline => Unit): UServer = UnsafeChannelPipeline(pipeline)
val disableFlowControl: UServer = Server.FlowControl(false)
val disableLeakDetection: UServer = LeakDetection(LeakDetectionLevel.DISABLED)
val simpleLeakDetection: UServer = LeakDetection(LeakDetectionLevel.SIMPLE)
val advancedLeakDetection: UServer = LeakDetection(LeakDetectionLevel.ADVANCED)
val paranoidLeakDetection: UServer = LeakDetection(LeakDetectionLevel.PARANOID)
val disableKeepAlive: UServer = Server.KeepAlive(false)
val consolidateFlush: UServer = ConsolidateFlush(true)
def unsafePipeline(pipeline: ChannelPipeline => Unit): UServer = UnsafeChannelPipeline(pipeline)

/**
* Creates a server from a http app.
1 change: 1 addition & 0 deletions zio-http/src/main/scala/zhttp/service/package.scala
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ package object service {
private[service] val HTTP_SERVER_FLUSH_CONSOLIDATION = "HTTP_SERVER_FLUSH_CONSOLIDATION"
private[service] val CLIENT_INBOUND_HANDLER = "CLIENT_INBOUND_HANDLER"
private[service] val WEB_SOCKET_CLIENT_PROTOCOL_HANDLER = "WEB_SOCKET_CLIENT_PROTOCOL_HANDLER"
private[service] val HTTP_REQUEST_DECOMPRESSION = "HTTP_REQUEST_DECOMPRESSION"

type ChannelFactory = Has[JChannelFactory[Channel]]
type EventLoopGroup = Has[JEventLoopGroup]
Original file line number Diff line number Diff line change
@@ -42,6 +42,10 @@ final case class ServerChannelInitializer[R](
)
pipeline.addLast("encoder", new HttpResponseEncoder())

// HttpContentDecompressor
if (cfg.requestDecompression._1)
pipeline.addLast(HTTP_REQUEST_DECOMPRESSION, new HttpContentDecompressor(cfg.requestDecompression._2))

// TODO: See if server codec is really required

// ObjectAggregator
4 changes: 2 additions & 2 deletions zio-http/src/test/scala/zhttp/internal/HttpRunnableSpec.scala
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ abstract class HttpRunnableSpec extends DefaultRunnableSpec { self =>
def run(
path: Path = !!,
method: Method = Method.GET,
content: String = "",
content: HttpData = HttpData.empty,
headers: Headers = Headers.empty,
version: Version = Version.Http_1_1,
): ZIO[R, Throwable, A] =
@@ -38,7 +38,7 @@ abstract class HttpRunnableSpec extends DefaultRunnableSpec { self =>
url = URL(path), // url set here is overridden later via `deploy` method
method = method,
headers = headers,
data = HttpData.fromString(content),
data = content,
version = version,
),
).catchAll {
2 changes: 1 addition & 1 deletion zio-http/src/test/scala/zhttp/service/ClientSpec.scala
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ object ClientSpec extends HttpRunnableSpec {
} +
testM("echo POST request content") {
val app = Http.collectZIO[Request] { case req => req.bodyAsString.map(Response.text(_)) }
val res = app.deploy.bodyAsString.run(method = Method.POST, content = "ZIO user")
val res = app.deploy.bodyAsString.run(method = Method.POST, content = HttpData.fromString("ZIO user"))
assertM(res)(equalTo("ZIO user"))
} +
testM("empty content") {
44 changes: 35 additions & 9 deletions zio-http/src/test/scala/zhttp/service/ServerSpec.scala
Original file line number Diff line number Diff line change
@@ -4,12 +4,12 @@ import zhttp.html._
import zhttp.http._
import zhttp.internal.{DynamicServer, HttpGen, HttpRunnableSpec}
import zhttp.service.server._
import zio.ZIO
import zio.duration.durationInt
import zio.stream.ZStream
import zio.stream.{ZStream, ZTransducer}
import zio.test.Assertion._
import zio.test.TestAspect._
import zio.test._
import zio.{Chunk, ZIO}

import java.nio.file.Paths

@@ -35,7 +35,7 @@ object ServerSpec extends HttpRunnableSpec {
case _ -> !! / "HExitFailure" => HExit.fail(new RuntimeException("FAILURE"))
}

private val app = serve { nonZIO ++ staticApp ++ DynamicServer.app }
private val app = serve(nonZIO ++ staticApp ++ DynamicServer.app, Some(Server.requestDecompression(true)))

def dynamicAppSpec = suite("DynamicAppSpec") {
suite("success") {
@@ -88,15 +88,15 @@ object ServerSpec extends HttpRunnableSpec {
assertM(res)(equalTo(Status.OK))
} +
testM("body is ok") {
val res = app.deploy.bodyAsString.run(content = "ABC")
val res = app.deploy.bodyAsString.run(content = HttpData.fromString("ABC"))
assertM(res)(equalTo("ABC"))
} +
testM("empty string") {
val res = app.deploy.bodyAsString.run(content = "")
val res = app.deploy.bodyAsString.run(content = HttpData.fromString(""))
assertM(res)(equalTo(""))
} +
testM("one char") {
val res = app.deploy.bodyAsString.run(content = "1")
val res = app.deploy.bodyAsString.run(content = HttpData.fromString("1"))
assertM(res)(equalTo("1"))
}
} +
@@ -112,6 +112,32 @@ object ServerSpec extends HttpRunnableSpec {
val res = app.deploy.bodyAsString.run()
assertM(res)(equalTo("abc"))
}
} +
suite("decompression") {
val app = Http.collectZIO[Request] { case req => req.bodyAsString.map(body => Response.text(body)) }.deploy
val content = "some-text"
val stream = ZStream.fromChunk(Chunk.fromArray(content.getBytes))

testM("gzip") {
val res = for {
body <- stream.transduce(ZTransducer.gzip()).runCollect
response <- app.run(
content = HttpData.fromChunk(body),
headers = Headers.contentEncoding(HeaderValues.gzip),
)
} yield response
assertM(res.flatMap(_.bodyAsString))(equalTo(content))
} +
testM("deflate") {
val res = for {
body <- stream.transduce(ZTransducer.deflate()).runCollect
response <- app.run(
content = HttpData.fromChunk(body),
headers = Headers.contentEncoding(HeaderValues.deflate),
)
} yield response
assertM(res.flatMap(_.bodyAsString))(equalTo(content))
}
}
}

@@ -155,13 +181,13 @@ object ServerSpec extends HttpRunnableSpec {
}
testM("has content-length") {
checkAllM(Gen.alphaNumericString) { string =>
val res = app.deploy.bodyAsString.run(content = string)
val res = app.deploy.bodyAsString.run(content = HttpData.fromString(string))
assertM(res)(equalTo(string.length.toString))
}
} +
testM("POST Request.getBody") {
val app = Http.collectZIO[Request] { case req => req.body.as(Response.ok) }
val res = app.deploy.status.run(path = !!, method = Method.POST, content = "some text")
val res = app.deploy.status.run(path = !!, method = Method.POST, content = HttpData.fromString("some text"))
assertM(res)(equalTo(Status.OK))
}
}
@@ -211,7 +237,7 @@ object ServerSpec extends HttpRunnableSpec {
}
.deploy
.bodyAsString
.run(content = "abc")
.run(content = HttpData.fromString("abc"))
assertM(res)(equalTo("abc"))
} +
testM("file-streaming") {

0 comments on commit a9ae523

Please sign in to comment.