Skip to content

Commit

Permalink
refactor: merge Request for Client and Server
Browse files Browse the repository at this point in the history
  • Loading branch information
tusharmath committed Jan 26, 2022
1 parent dc0d80f commit 55089b8
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 189 deletions.
4 changes: 2 additions & 2 deletions zio-http/src/main/scala/zhttp/http/Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ object Request {
method: Method = Method.GET,
url: URL = URL.root,
headers: Headers = Headers.empty,
remoteAddress: Option[InetAddress] = None,
data: HttpData = HttpData.Empty,
remoteAddress: Option[InetAddress] = None,
): Request = {
val m = method
val u = url
Expand All @@ -121,7 +121,7 @@ object Request {
remoteAddress: Option[InetAddress],
content: HttpData = HttpData.empty,
): UIO[Request] =
UIO(Request(method, url, headers, remoteAddress, content))
UIO(Request(method, url, headers, content, remoteAddress))

/**
* Lift request to TypedRequest with option to extract params
Expand Down
63 changes: 15 additions & 48 deletions zio-http/src/main/scala/zhttp/service/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,38 @@ package zhttp.service

import io.netty.bootstrap.Bootstrap
import io.netty.buffer.{ByteBuf, ByteBufUtil}
import io.netty.channel.{
Channel,
ChannelFactory => JChannelFactory,
ChannelHandlerContext,
EventLoopGroup => JEventLoopGroup,
}
import io.netty.handler.codec.http.HttpVersion
import io.netty.channel.{Channel, ChannelFactory => JChannelFactory, EventLoopGroup => JEventLoopGroup}
import io.netty.handler.codec.http.{FullHttpRequest, HttpVersion}
import zhttp.http.URL.Location
import zhttp.http._
import zhttp.http.headers.HeaderExtension
import zhttp.service
import zhttp.service.Client.{ClientRequest, ClientResponse}
import zhttp.service.Client.ClientResponse
import zhttp.service.client.ClientSSLHandler.ClientSSLOptions
import zhttp.service.client.{ClientChannelInitializer, ClientInboundHandler}
import zio.{Chunk, Promise, Task, ZIO}

import java.net.{InetAddress, InetSocketAddress}
import java.net.InetSocketAddress

final case class Client(rtm: HttpRuntime[Any], cf: JChannelFactory[Channel], el: JEventLoopGroup)
extends HttpMessageCodec {
def request(
request: Client.ClientRequest,
request: Request,
sslOption: ClientSSLOptions = ClientSSLOptions.DefaultSSL,
): Task[Client.ClientResponse] =
for {
promise <- Promise.make[Throwable, Client.ClientResponse]
_ <- Task(asyncRequest(request, promise, sslOption)).catchAll(cause => promise.fail(cause))
jReq <- encodeClientParams(HttpVersion.HTTP_1_1, request)
_ <- Task(asyncRequest(request, jReq, promise, sslOption)).catchAll(cause => promise.fail(cause))
res <- promise.await
} yield res

private def asyncRequest(
req: ClientRequest,
req: Request,
jReq: FullHttpRequest,
promise: Promise[Throwable, ClientResponse],
sslOption: ClientSSLOptions,
): Unit = {
val jReq = encodeClientParams(HttpVersion.HTTP_1_1, req)
try {
val hand = ClientInboundHandler(rtm, jReq, promise)
val host = req.url.host
Expand Down Expand Up @@ -111,71 +107,42 @@ object Client {
method: Method,
url: URL,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url))
request(Request(method, url))

def request(
method: Method,
url: URL,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url), sslOptions)
request(Request(method, url), sslOptions)

def request(
method: Method,
url: URL,
headers: Headers,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url, headers), sslOptions)
request(Request(method, url, headers), sslOptions)

def request(
method: Method,
url: URL,
headers: Headers,
content: HttpData,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url, headers, content))
request(Request(method, url, headers, content, None))

def request(
req: ClientRequest,
req: Request,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
make.flatMap(_.request(req))

def request(
req: ClientRequest,
req: Request,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
make.flatMap(_.request(req, sslOptions))

final case class ClientRequest(
method: Method,
url: URL,
getHeaders: Headers = Headers.empty,
data: HttpData = HttpData.empty,
private val channelContext: ChannelHandlerContext = null,
) extends HeaderExtension[ClientRequest] { self =>

def getBodyAsString: Option[String] = data match {
case HttpData.Text(text, _) => Some(text)
case HttpData.BinaryChunk(data) => Some(new String(data.toArray, HTTP_CHARSET))
case HttpData.BinaryByteBuf(data) => Some(data.toString(HTTP_CHARSET))
case _ => Option.empty
}

def remoteAddress: Option[InetAddress] = {
if (channelContext != null && channelContext.channel().remoteAddress().isInstanceOf[InetSocketAddress])
Some(channelContext.channel().remoteAddress().asInstanceOf[InetSocketAddress].getAddress)
else
None
}

/**
* Updates the headers using the provided function
*/
override def updateHeaders(update: Headers => Headers): ClientRequest =
self.copy(getHeaders = update(self.getHeaders))
}

final case class ClientResponse(status: Status, headers: Headers, private[zhttp] val buffer: ByteBuf)
extends HeaderExtension[ClientResponse] { self =>

Expand Down
63 changes: 32 additions & 31 deletions zio-http/src/main/scala/zhttp/service/EncodeClientParams.scala
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
package zhttp.service

import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.{DefaultFullHttpRequest, FullHttpRequest, HttpHeaderNames, HttpVersion}
import zhttp.http.HTTP_CHARSET
import zhttp.http.Request
import zio.Task
trait EncodeClientParams {

/**
* Converts client params to JFullHttpRequest
*/
def encodeClientParams(jVersion: HttpVersion, req: Client.ClientRequest): FullHttpRequest = {
val method = req.method.asHttpMethod
val url = req.url

// As per the spec, the path should contain only the relative path.
// Host and port information should be in the headers.
val path = url.relative.asString

val content = req.getBodyAsString match {
case Some(text) => Unpooled.copiedBuffer(text, HTTP_CHARSET)
case None => Unpooled.EMPTY_BUFFER
}

val encodedReqHeaders = req.getHeaders.encode

val headers = url.host match {
case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value)
case None => encodedReqHeaders
}

val writerIndex = content.writerIndex()
if (writerIndex != 0) {
headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString())
}
// TODO: we should also add a default user-agent req header as some APIs might reject requests without it.
val jReq = new DefaultFullHttpRequest(jVersion, method, path, content)
jReq.headers().set(headers)

jReq
def encodeClientParams(jVersion: HttpVersion, req: Request): Task[FullHttpRequest] = req.getBodyAsByteBuf.map {
content =>
val method = req.method.asHttpMethod
val url = req.url

// As per the spec, the path should contain only the relative path.
// Host and port information should be in the headers.
val path = url.relative.asString

// val content = req.getBodyAsString match {
// case Some(text) => Unpooled.copiedBuffer(text, HTTP_CHARSET)
// case None => Unpooled.EMPTY_BUFFER
// }

val encodedReqHeaders = req.getHeaders.encode

val headers = url.host match {
case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value)
case None => encodedReqHeaders
}

val writerIndex = content.writerIndex()
if (writerIndex != 0) {
headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString())
}
// TODO: we should also add a default user-agent req header as some APIs might reject requests without it.
val jReq = new DefaultFullHttpRequest(jVersion, method, path, content)
jReq.headers().set(headers)

jReq
}
}
83 changes: 0 additions & 83 deletions zio-http/src/test/scala/zhttp/http/EncodeClientRequestSpec.scala

This file was deleted.

85 changes: 85 additions & 0 deletions zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package zhttp.http

import io.netty.handler.codec.http.{HttpHeaderNames, HttpVersion}
import zhttp.internal.HttpGen
import zhttp.service.EncodeClientParams
import zio.random.Random
import zio.test.Assertion._
import zio.test._

object EncodeRequestSpec extends DefaultRunnableSpec with EncodeClientParams {

val anyClientParam: Gen[Random with Sized, Request] = HttpGen.clientRequest(
HttpGen.httpData(
Gen.listOf(Gen.alphaNumericString),
),
)

val clientParamWithAbsoluteUrl = HttpGen.clientRequest(
dataGen = HttpGen.httpData(
Gen.listOf(Gen.alphaNumericString),
),
urlGen = HttpGen.genAbsoluteURL,
)

def clientParamWithFiniteData(size: Int): Gen[Random with Sized, Request] = HttpGen.clientRequest(
for {
content <- Gen.alphaNumericStringBounded(size, size)
data <- Gen.fromIterable(List(HttpData.fromString(content)))
} yield data,
)

def spec = suite("EncodeClientParams") {
testM("method") {
checkM(anyClientParam) { params =>
val method = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.method())
assertM(method)(equalTo(params.method.asHttpMethod))
}
} +
testM("method on HttpData.File") {
checkM(HttpGen.clientParamsForFileHttpData()) { params =>
val method = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.method())
assertM(method)(equalTo(params.method.asHttpMethod))
}
} +
suite("uri") {
testM("uri") {
checkM(anyClientParam) { params =>
val uri = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.uri())
assertM(uri)(equalTo(params.url.relative.asString))
}
} +
testM("uri on HttpData.File") {
checkM(HttpGen.clientParamsForFileHttpData()) { params =>
val uri = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.uri())
assertM(uri)(equalTo(params.url.relative.asString))
}
}
} +
testM("content-length") {
checkM(clientParamWithFiniteData(5)) { params =>
val len = encodeClientParams(HttpVersion.HTTP_1_1, params).map(
_.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).toLong,
)
assertM(len)(equalTo(5L))
}
} +
testM("host header") {
checkM(anyClientParam) { params =>
val hostHeader = HttpHeaderNames.HOST
val headers = encodeClientParams(HttpVersion.HTTP_1_1, params).map(h => Option(h.headers().get(hostHeader)))
assertM(headers)(equalTo(params.url.host))
}
} +
testM("host header when absolute url") {
checkM(clientParamWithAbsoluteUrl) { params =>
val hostHeader = HttpHeaderNames.HOST
for {
reqHeaders <- encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.headers())
} yield assert(reqHeaders.getAll(hostHeader).size)(equalTo(1)) && assert(Option(reqHeaders.get(hostHeader)))(
equalTo(params.url.host),
)
}
}
}
}
Loading

0 comments on commit 55089b8

Please sign in to comment.