diff --git a/zio-http/src/main/scala/zhttp/http/Request.scala b/zio-http/src/main/scala/zhttp/http/Request.scala index e78be8918b..a930ea5ed5 100644 --- a/zio-http/src/main/scala/zhttp/http/Request.scala +++ b/zio-http/src/main/scala/zhttp/http/Request.scala @@ -95,8 +95,8 @@ object Request { method: Method = Method.GET, url: URL = URL.root, headers: Headers = Headers.empty, - data: HttpData = HttpData.Empty, remoteAddress: Option[InetAddress] = None, + data: HttpData = HttpData.Empty, ): Request = { val m = method val u = url @@ -121,7 +121,7 @@ object Request { remoteAddress: Option[InetAddress], content: HttpData = HttpData.empty, ): UIO[Request] = - UIO(Request(method, url, headers, content, remoteAddress)) + UIO(Request(method, url, headers, remoteAddress, content)) /** * Lift request to TypedRequest with option to extract params diff --git a/zio-http/src/main/scala/zhttp/service/Client.scala b/zio-http/src/main/scala/zhttp/service/Client.scala index 40beb22f0f..853eab2be2 100644 --- a/zio-http/src/main/scala/zhttp/service/Client.scala +++ b/zio-http/src/main/scala/zhttp/service/Client.scala @@ -2,38 +2,42 @@ package zhttp.service import io.netty.bootstrap.Bootstrap import io.netty.buffer.{ByteBuf, ByteBufUtil} -import io.netty.channel.{Channel, ChannelFactory => JChannelFactory, EventLoopGroup => JEventLoopGroup} -import io.netty.handler.codec.http.{FullHttpRequest, HttpVersion} +import io.netty.channel.{ + Channel, + ChannelFactory => JChannelFactory, + ChannelHandlerContext, + EventLoopGroup => JEventLoopGroup, +} +import io.netty.handler.codec.http.HttpVersion import zhttp.http.URL.Location import zhttp.http._ import zhttp.http.headers.HeaderExtension import zhttp.service -import zhttp.service.Client.ClientResponse +import zhttp.service.Client.{ClientRequest, ClientResponse} import zhttp.service.client.ClientSSLHandler.ClientSSLOptions import zhttp.service.client.{ClientChannelInitializer, ClientInboundHandler} import zio.{Chunk, Promise, Task, ZIO} -import java.net.InetSocketAddress +import java.net.{InetAddress, InetSocketAddress} final case class Client(rtm: HttpRuntime[Any], cf: JChannelFactory[Channel], el: JEventLoopGroup) extends HttpMessageCodec { def request( - request: Request, + request: Client.ClientRequest, sslOption: ClientSSLOptions = ClientSSLOptions.DefaultSSL, ): Task[Client.ClientResponse] = for { promise <- Promise.make[Throwable, Client.ClientResponse] - jReq <- encodeClientParams(HttpVersion.HTTP_1_1, request) - _ <- Task(asyncRequest(request, jReq, promise, sslOption)).catchAll(cause => promise.fail(cause)) + _ <- Task(asyncRequest(request, promise, sslOption)).catchAll(cause => promise.fail(cause)) res <- promise.await } yield res private def asyncRequest( - req: Request, - jReq: FullHttpRequest, + req: ClientRequest, promise: Promise[Throwable, ClientResponse], sslOption: ClientSSLOptions, ): Unit = { + val jReq = encodeClientParams(HttpVersion.HTTP_1_1, req) try { val hand = ClientInboundHandler(rtm, jReq, promise) val host = req.url.host @@ -107,14 +111,14 @@ object Client { method: Method, url: URL, ): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] = - request(Request(method, url)) + request(ClientRequest(method, url)) def request( method: Method, url: URL, sslOptions: ClientSSLOptions, ): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] = - request(Request(method, url), sslOptions) + request(ClientRequest(method, url), sslOptions) def request( method: Method, @@ -122,7 +126,7 @@ object Client { headers: Headers, sslOptions: ClientSSLOptions, ): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] = - request(Request(method, url, headers), sslOptions) + request(ClientRequest(method, url, headers), sslOptions) def request( method: Method, @@ -130,19 +134,48 @@ object Client { headers: Headers, content: HttpData, ): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] = - request(Request(method, url, headers, content, None)) + request(ClientRequest(method, url, headers, content)) def request( - req: Request, + req: ClientRequest, ): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] = make.flatMap(_.request(req)) def request( - req: Request, + req: ClientRequest, sslOptions: ClientSSLOptions, ): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] = make.flatMap(_.request(req, sslOptions)) + final case class ClientRequest( + method: Method, + url: URL, + getHeaders: Headers = Headers.empty, + data: HttpData = HttpData.empty, + private val channelContext: ChannelHandlerContext = null, + ) extends HeaderExtension[ClientRequest] { self => + + def getBodyAsString: Option[String] = data match { + case HttpData.Text(text, _) => Some(text) + case HttpData.BinaryChunk(data) => Some(new String(data.toArray, HTTP_CHARSET)) + case HttpData.BinaryByteBuf(data) => Some(data.toString(HTTP_CHARSET)) + case _ => Option.empty + } + + def remoteAddress: Option[InetAddress] = { + if (channelContext != null && channelContext.channel().remoteAddress().isInstanceOf[InetSocketAddress]) + Some(channelContext.channel().remoteAddress().asInstanceOf[InetSocketAddress].getAddress) + else + None + } + + /** + * Updates the headers using the provided function + */ + override def updateHeaders(update: Headers => Headers): ClientRequest = + self.copy(getHeaders = update(self.getHeaders)) + } + final case class ClientResponse(status: Status, headers: Headers, private[zhttp] val buffer: ByteBuf) extends HeaderExtension[ClientResponse] { self => diff --git a/zio-http/src/main/scala/zhttp/service/EncodeClientParams.scala b/zio-http/src/main/scala/zhttp/service/EncodeClientParams.scala index ca5d4b4807..f9de469a08 100644 --- a/zio-http/src/main/scala/zhttp/service/EncodeClientParams.scala +++ b/zio-http/src/main/scala/zhttp/service/EncodeClientParams.scala @@ -1,37 +1,41 @@ package zhttp.service +import io.netty.buffer.Unpooled import io.netty.handler.codec.http.{DefaultFullHttpRequest, FullHttpRequest, HttpHeaderNames, HttpVersion} -import zhttp.http.Request -import zio.Task +import zhttp.http.HTTP_CHARSET trait EncodeClientParams { /** * Converts client params to JFullHttpRequest */ - def encodeClientParams(jVersion: HttpVersion, req: Request): Task[FullHttpRequest] = req.getBodyAsByteBuf.map { - content => - val method = req.method.asHttpMethod - val url = req.url - - // As per the spec, the path should contain only the relative path. - // Host and port information should be in the headers. - val path = url.relative.encode - - val encodedReqHeaders = req.getHeaders.encode - - val headers = url.host match { - case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value) - case None => encodedReqHeaders - } - - val writerIndex = content.writerIndex() - if (writerIndex != 0) { - headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString()) - } - // TODO: we should also add a default user-agent req header as some APIs might reject requests without it. - val jReq = new DefaultFullHttpRequest(jVersion, method, path, content) - jReq.headers().set(headers) - - jReq + def encodeClientParams(jVersion: HttpVersion, req: Client.ClientRequest): FullHttpRequest = { + val method = req.method.asHttpMethod + val url = req.url + + // As per the spec, the path should contain only the relative path. + // Host and port information should be in the headers. + val path = url.relative.encode + + val content = req.getBodyAsString match { + case Some(text) => Unpooled.copiedBuffer(text, HTTP_CHARSET) + case None => Unpooled.EMPTY_BUFFER + } + + val encodedReqHeaders = req.getHeaders.encode + + val headers = url.host match { + case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value) + case None => encodedReqHeaders + } + + val writerIndex = content.writerIndex() + if (writerIndex != 0) { + headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString()) + } + // TODO: we should also add a default user-agent req header as some APIs might reject requests without it. + val jReq = new DefaultFullHttpRequest(jVersion, method, path, content) + jReq.headers().set(headers) + + jReq } } diff --git a/zio-http/src/test/scala/zhttp/http/EncodeClientRequestSpec.scala b/zio-http/src/test/scala/zhttp/http/EncodeClientRequestSpec.scala new file mode 100644 index 0000000000..b69a33246d --- /dev/null +++ b/zio-http/src/test/scala/zhttp/http/EncodeClientRequestSpec.scala @@ -0,0 +1,83 @@ +package zhttp.http + +import io.netty.handler.codec.http.{HttpHeaderNames, HttpVersion} +import zhttp.internal.HttpGen +import zhttp.service.{Client, EncodeClientParams} +import zio.random.Random +import zio.test.Assertion._ +import zio.test._ + +object EncodeClientRequestSpec extends DefaultRunnableSpec with EncodeClientParams { + + val anyClientParam: Gen[Random with Sized, Client.ClientRequest] = HttpGen.clientRequest( + HttpGen.httpData( + Gen.listOf(Gen.alphaNumericString), + ), + ) + + val clientParamWithAbsoluteUrl = HttpGen.clientRequest( + dataGen = HttpGen.httpData( + Gen.listOf(Gen.alphaNumericString), + ), + urlGen = HttpGen.genAbsoluteURL, + ) + + def clientParamWithFiniteData(size: Int): Gen[Random with Sized, Client.ClientRequest] = HttpGen.clientRequest( + for { + content <- Gen.alphaNumericStringBounded(size, size) + data <- Gen.fromIterable(List(HttpData.fromString(content))) + } yield data, + ) + + def spec = suite("EncodeClientParams") { + testM("method") { + check(anyClientParam) { params => + val req = encodeClientParams(HttpVersion.HTTP_1_1, params) + assert(req.method())(equalTo(params.method.asHttpMethod)) + } + } + + testM("method on HttpData.File") { + check(HttpGen.clientParamsForFileHttpData()) { params => + val req = encodeClientParams(HttpVersion.HTTP_1_1, params) + assert(req.method())(equalTo(params.method.asHttpMethod)) + } + } + + suite("uri") { + testM("uri") { + check(anyClientParam) { params => + val req = encodeClientParams(HttpVersion.HTTP_1_1, params) + assert(req.uri())(equalTo(params.url.relative.encode)) + } + } + + testM("uri on HttpData.File") { + check(HttpGen.clientParamsForFileHttpData()) { params => + val req = encodeClientParams(HttpVersion.HTTP_1_1, params) + assert(req.uri())(equalTo(params.url.relative.encode)) + } + } + } + + testM("content-length") { + check(clientParamWithFiniteData(5)) { params => + val req = encodeClientParams(HttpVersion.HTTP_1_1, params) + assert(req.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).toLong)(equalTo(5L)) + } + } + + testM("host header") { + check(anyClientParam) { params => + val req = encodeClientParams(HttpVersion.HTTP_1_1, params) + val hostHeader = HttpHeaderNames.HOST + assert(Option(req.headers().get(hostHeader)))(equalTo(params.url.host)) + } + } + + testM("host header when absolute url") { + check(clientParamWithAbsoluteUrl) { params => + val req = encodeClientParams(HttpVersion.HTTP_1_1, params) + val reqHeaders = req.headers() + val hostHeader = HttpHeaderNames.HOST + + assert(reqHeaders.getAll(hostHeader).size)(equalTo(1)) && + assert(Option(reqHeaders.get(hostHeader)))(equalTo(params.url.host)) + } + } + } +} diff --git a/zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala b/zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala deleted file mode 100644 index 0dcf1c3b5e..0000000000 --- a/zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala +++ /dev/null @@ -1,85 +0,0 @@ -package zhttp.http - -import io.netty.handler.codec.http.{HttpHeaderNames, HttpVersion} -import zhttp.internal.HttpGen -import zhttp.service.EncodeClientParams -import zio.random.Random -import zio.test.Assertion._ -import zio.test._ - -object EncodeRequestSpec extends DefaultRunnableSpec with EncodeClientParams { - - val anyClientParam: Gen[Random with Sized, Request] = HttpGen.clientRequest( - HttpGen.httpData( - Gen.listOf(Gen.alphaNumericString), - ), - ) - - val clientParamWithAbsoluteUrl = HttpGen.clientRequest( - dataGen = HttpGen.httpData( - Gen.listOf(Gen.alphaNumericString), - ), - urlGen = HttpGen.genAbsoluteURL, - ) - - def clientParamWithFiniteData(size: Int): Gen[Random with Sized, Request] = HttpGen.clientRequest( - for { - content <- Gen.alphaNumericStringBounded(size, size) - data <- Gen.fromIterable(List(HttpData.fromString(content))) - } yield data, - ) - - def spec = suite("EncodeClientParams") { - testM("method") { - checkM(anyClientParam) { params => - val method = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.method()) - assertM(method)(equalTo(params.method.asHttpMethod)) - } - } + - testM("method on HttpData.File") { - checkM(HttpGen.clientParamsForFileHttpData()) { params => - val method = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.method()) - assertM(method)(equalTo(params.method.asHttpMethod)) - } - } + - suite("uri") { - testM("uri") { - checkM(anyClientParam) { params => - val uri = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.uri()) - assertM(uri)(equalTo(params.url.relative.encode)) - } - } + - testM("uri on HttpData.File") { - checkM(HttpGen.clientParamsForFileHttpData()) { params => - val uri = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.uri()) - assertM(uri)(equalTo(params.url.relative.encode)) - } - } - } + - testM("content-length") { - checkM(clientParamWithFiniteData(5)) { params => - val len = encodeClientParams(HttpVersion.HTTP_1_1, params).map( - _.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).toLong, - ) - assertM(len)(equalTo(5L)) - } - } + - testM("host header") { - checkM(anyClientParam) { params => - val hostHeader = HttpHeaderNames.HOST - val headers = encodeClientParams(HttpVersion.HTTP_1_1, params).map(h => Option(h.headers().get(hostHeader))) - assertM(headers)(equalTo(params.url.host)) - } - } + - testM("host header when absolute url") { - checkM(clientParamWithAbsoluteUrl) { params => - val hostHeader = HttpHeaderNames.HOST - for { - reqHeaders <- encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.headers()) - } yield assert(reqHeaders.getAll(hostHeader).size)(equalTo(1)) && assert(Option(reqHeaders.get(hostHeader)))( - equalTo(params.url.host), - ) - } - } - } -} diff --git a/zio-http/src/test/scala/zhttp/http/GetBodyAsStringSpec.scala b/zio-http/src/test/scala/zhttp/http/GetBodyAsStringSpec.scala index bd3da3cc7e..aad22542d1 100644 --- a/zio-http/src/test/scala/zhttp/http/GetBodyAsStringSpec.scala +++ b/zio-http/src/test/scala/zhttp/http/GetBodyAsStringSpec.scala @@ -1,6 +1,7 @@ package zhttp.http import io.netty.handler.codec.http.HttpHeaderNames +import zhttp.service.Client import zio.Chunk import zio.test.Assertion._ import zio.test._ @@ -15,26 +16,27 @@ object GetBodyAsStringSpec extends DefaultRunnableSpec { val charsetGen: Gen[Any, Charset] = Gen.fromIterable(List(UTF_8, UTF_16, UTF_16BE, UTF_16LE, US_ASCII, ISO_8859_1)) - checkM(charsetGen) { charset => - val encoded = Request( - Method.GET, - URL(Path("/")), - headers = Headers(HttpHeaderNames.CONTENT_TYPE.toString, s"text/html; charset=$charset"), - data = HttpData.BinaryChunk(Chunk.fromArray("abc".getBytes(charset))), - ).getBodyAsString + check(charsetGen) { charset => + val encoded = Client + .ClientRequest( + Method.GET, + URL(Path("/")), + getHeaders = Headers(HttpHeaderNames.CONTENT_TYPE.toString, s"text/html; charset=$charset"), + data = HttpData.BinaryChunk(Chunk.fromArray("abc".getBytes())), + ) + .getBodyAsString + val actual = Option(new String(Chunk.fromArray("abc".getBytes(charset)).toArray, charset)) - val expected = new String(Chunk.fromArray("abc".getBytes(charset)).toArray, charset) - - assertM(encoded)(equalTo(expected)) + assert(actual)(equalTo(encoded)) } } + - testM("should map bytes to default utf-8 if no charset given") { + test("should map bytes to default utf-8 if no charset given") { val data = Chunk.fromArray("abc".getBytes()) val content = HttpData.BinaryChunk(data) - val request = Request(Method.GET, URL(Path("/")), data = content) + val request = Client.ClientRequest(Method.GET, URL(Path("/")), data = content) val encoded = request.getBodyAsString - val actual = new String(data.toArray, HTTP_CHARSET) - assertM(encoded)(equalTo(actual)) + val actual = Option(new String(data.toArray, HTTP_CHARSET)) + assert(actual)(equalTo(encoded)) }, ) } diff --git a/zio-http/src/test/scala/zhttp/internal/HttpGen.scala b/zio-http/src/test/scala/zhttp/internal/HttpGen.scala index 5b06448608..996deb4d4a 100644 --- a/zio-http/src/test/scala/zhttp/internal/HttpGen.scala +++ b/zio-http/src/test/scala/zhttp/internal/HttpGen.scala @@ -3,6 +3,7 @@ package zhttp.internal import io.netty.buffer.Unpooled import zhttp.http.URL.Location import zhttp.http._ +import zhttp.service.Client.ClientRequest import zio.random.Random import zio.stream.ZStream import zio.test.{Gen, Sized} @@ -22,7 +23,7 @@ object HttpGen { url <- urlGen headers <- Gen.listOf(headerGen).map(Headers(_)) data <- dataGen - } yield Request(method, url, headers, data, None) + } yield ClientRequest(method, url, headers, data) def clientParamsForFileHttpData() = { for { @@ -30,7 +31,7 @@ object HttpGen { method <- HttpGen.method url <- HttpGen.url headers <- Gen.listOf(HttpGen.header).map(Headers(_)) - } yield Request(method, url, headers, HttpData.fromFile(file), None) + } yield ClientRequest(method, url, headers, HttpData.fromFile(file)) } def cookies: Gen[Random with Sized, Cookie] = for { @@ -118,7 +119,7 @@ object HttpGen { url <- HttpGen.url headers <- Gen.listOf(HttpGen.header).map(Headers(_)) data <- HttpGen.httpData(Gen.listOf(Gen.alphaNumericString)) - } yield Request(method, url, headers, data, None) + } yield Request(method, url, headers, None, data) def response[R](gContent: Gen[R, List[String]]): Gen[Random with Sized with R, Response] = { for { diff --git a/zio-http/src/test/scala/zhttp/internal/HttpRunnableSpec.scala b/zio-http/src/test/scala/zhttp/internal/HttpRunnableSpec.scala index 5718e794c4..1642799a76 100644 --- a/zio-http/src/test/scala/zhttp/internal/HttpRunnableSpec.scala +++ b/zio-http/src/test/scala/zhttp/internal/HttpRunnableSpec.scala @@ -21,7 +21,7 @@ import zio.{Has, Task, ZIO, ZManaged} */ abstract class HttpRunnableSpec extends DefaultRunnableSpec { self => - implicit class RunnableClientHttpSyntax[R, A](app: Http[R, Throwable, Request, A]) { + implicit class RunnableClientHttpSyntax[R, A](app: Http[R, Throwable, Client.ClientRequest, A]) { /** * Runs the deployed Http app by making a real http request to it. The method allows us to configure individual @@ -34,12 +34,11 @@ abstract class HttpRunnableSpec extends DefaultRunnableSpec { self => headers: Headers = Headers.empty, ): ZIO[R, Throwable, A] = app( - Request( + Client.ClientRequest( method, URL(path, Location.Absolute(Scheme.HTTP, "localhost", 0)), headers, HttpData.fromString(content), - None, ), ).catchAll { case Some(value) => ZIO.fail(value) @@ -59,7 +58,7 @@ abstract class HttpRunnableSpec extends DefaultRunnableSpec { self => for { port <- Http.fromZIO(DynamicServer.getPort) id <- Http.fromZIO(DynamicServer.deploy(app)) - response <- Http.fromFunctionZIO[Request] { params => + response <- Http.fromFunctionZIO[Client.ClientRequest] { params => Client.request( params .addHeader(DynamicServer.APP_ID, id) @@ -75,7 +74,7 @@ abstract class HttpRunnableSpec extends DefaultRunnableSpec { self => def deployWebSocket: HttpTestClient[SttpClient, client3.Response[Either[String, WebSocket[Task]]]] = for { id <- Http.fromZIO(DynamicServer.deploy(app)) res <- - Http.fromFunctionZIO[Request](params => + Http.fromFunctionZIO[Client.ClientRequest](params => for { port <- DynamicServer.getPort url = s"ws://localhost:$port${params.url.path.asString}" @@ -118,7 +117,7 @@ object HttpRunnableSpec { Http[ R with EventLoopGroup with ChannelFactory with DynamicServer with ServerChannelFactory, Throwable, - Request, + Client.ClientRequest, A, ] }