Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Merge client and server Request #894

Merged
merged 3 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions zio-http/src/main/scala/zhttp/http/Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ object Request {
method: Method = Method.GET,
url: URL = URL.root,
headers: Headers = Headers.empty,
remoteAddress: Option[InetAddress] = None,
data: HttpData = HttpData.Empty,
remoteAddress: Option[InetAddress] = None,
): Request = {
val m = method
val u = url
Expand All @@ -121,7 +121,7 @@ object Request {
remoteAddress: Option[InetAddress],
content: HttpData = HttpData.empty,
): UIO[Request] =
UIO(Request(method, url, headers, remoteAddress, content))
UIO(Request(method, url, headers, content, remoteAddress))

/**
* Lift request to TypedRequest with option to extract params
Expand Down
63 changes: 15 additions & 48 deletions zio-http/src/main/scala/zhttp/service/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,38 @@ package zhttp.service

import io.netty.bootstrap.Bootstrap
import io.netty.buffer.{ByteBuf, ByteBufUtil}
import io.netty.channel.{
Channel,
ChannelFactory => JChannelFactory,
ChannelHandlerContext,
EventLoopGroup => JEventLoopGroup,
}
import io.netty.handler.codec.http.HttpVersion
import io.netty.channel.{Channel, ChannelFactory => JChannelFactory, EventLoopGroup => JEventLoopGroup}
import io.netty.handler.codec.http.{FullHttpRequest, HttpVersion}
import zhttp.http.URL.Location
import zhttp.http._
import zhttp.http.headers.HeaderExtension
import zhttp.service
import zhttp.service.Client.{ClientRequest, ClientResponse}
import zhttp.service.Client.ClientResponse
import zhttp.service.client.ClientSSLHandler.ClientSSLOptions
import zhttp.service.client.{ClientChannelInitializer, ClientInboundHandler}
import zio.{Chunk, Promise, Task, ZIO}

import java.net.{InetAddress, InetSocketAddress}
import java.net.InetSocketAddress

final case class Client(rtm: HttpRuntime[Any], cf: JChannelFactory[Channel], el: JEventLoopGroup)
extends HttpMessageCodec {
def request(
request: Client.ClientRequest,
request: Request,
sslOption: ClientSSLOptions = ClientSSLOptions.DefaultSSL,
): Task[Client.ClientResponse] =
for {
promise <- Promise.make[Throwable, Client.ClientResponse]
_ <- Task(asyncRequest(request, promise, sslOption)).catchAll(cause => promise.fail(cause))
jReq <- encodeClientParams(HttpVersion.HTTP_1_1, request)
_ <- Task(asyncRequest(request, jReq, promise, sslOption)).catchAll(cause => promise.fail(cause))
res <- promise.await
} yield res

private def asyncRequest(
req: ClientRequest,
req: Request,
jReq: FullHttpRequest,
promise: Promise[Throwable, ClientResponse],
sslOption: ClientSSLOptions,
): Unit = {
val jReq = encodeClientParams(HttpVersion.HTTP_1_1, req)
try {
val hand = ClientInboundHandler(rtm, jReq, promise)
val host = req.url.host
Expand Down Expand Up @@ -111,71 +107,42 @@ object Client {
method: Method,
url: URL,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url))
request(Request(method, url))

def request(
method: Method,
url: URL,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url), sslOptions)
request(Request(method, url), sslOptions)

def request(
method: Method,
url: URL,
headers: Headers,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url, headers), sslOptions)
request(Request(method, url, headers), sslOptions)

def request(
method: Method,
url: URL,
headers: Headers,
content: HttpData,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(ClientRequest(method, url, headers, content))
request(Request(method, url, headers, content, None))

def request(
req: ClientRequest,
req: Request,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
make.flatMap(_.request(req))

def request(
req: ClientRequest,
req: Request,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
make.flatMap(_.request(req, sslOptions))

final case class ClientRequest(
method: Method,
url: URL,
getHeaders: Headers = Headers.empty,
data: HttpData = HttpData.empty,
private val channelContext: ChannelHandlerContext = null,
) extends HeaderExtension[ClientRequest] { self =>

def getBodyAsString: Option[String] = data match {
case HttpData.Text(text, _) => Some(text)
case HttpData.BinaryChunk(data) => Some(new String(data.toArray, HTTP_CHARSET))
case HttpData.BinaryByteBuf(data) => Some(data.toString(HTTP_CHARSET))
case _ => Option.empty
}

def remoteAddress: Option[InetAddress] = {
if (channelContext != null && channelContext.channel().remoteAddress().isInstanceOf[InetSocketAddress])
Some(channelContext.channel().remoteAddress().asInstanceOf[InetSocketAddress].getAddress)
else
None
}

/**
* Updates the headers using the provided function
*/
override def updateHeaders(update: Headers => Headers): ClientRequest =
self.copy(getHeaders = update(self.getHeaders))
}

final case class ClientResponse(status: Status, headers: Headers, private[zhttp] val buffer: ByteBuf)
extends HeaderExtension[ClientResponse] { self =>

Expand Down
58 changes: 27 additions & 31 deletions zio-http/src/main/scala/zhttp/service/EncodeClientParams.scala
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
package zhttp.service

import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.{DefaultFullHttpRequest, FullHttpRequest, HttpHeaderNames, HttpVersion}
import zhttp.http.HTTP_CHARSET
import zhttp.http.Request
import zio.Task
trait EncodeClientParams {

/**
* Converts client params to JFullHttpRequest
*/
def encodeClientParams(jVersion: HttpVersion, req: Client.ClientRequest): FullHttpRequest = {
val method = req.method.asHttpMethod
val url = req.url

// As per the spec, the path should contain only the relative path.
// Host and port information should be in the headers.
val path = url.relative.encode

val content = req.getBodyAsString match {
case Some(text) => Unpooled.copiedBuffer(text, HTTP_CHARSET)
case None => Unpooled.EMPTY_BUFFER
}

val encodedReqHeaders = req.getHeaders.encode

val headers = url.host match {
case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value)
case None => encodedReqHeaders
}

val writerIndex = content.writerIndex()
if (writerIndex != 0) {
headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString())
}
// TODO: we should also add a default user-agent req header as some APIs might reject requests without it.
val jReq = new DefaultFullHttpRequest(jVersion, method, path, content)
jReq.headers().set(headers)

jReq
def encodeClientParams(jVersion: HttpVersion, req: Request): Task[FullHttpRequest] = req.getBodyAsByteBuf.map {
content =>
val method = req.method.asHttpMethod
val url = req.url

// As per the spec, the path should contain only the relative path.
// Host and port information should be in the headers.
val path = url.relative.encode

val encodedReqHeaders = req.getHeaders.encode

val headers = url.host match {
case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value)
case None => encodedReqHeaders
}

val writerIndex = content.writerIndex()
if (writerIndex != 0) {
headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString())
}
// TODO: we should also add a default user-agent req header as some APIs might reject requests without it.
val jReq = new DefaultFullHttpRequest(jVersion, method, path, content)
jReq.headers().set(headers)

jReq
}
}

This file was deleted.

85 changes: 85 additions & 0 deletions zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package zhttp.http

import io.netty.handler.codec.http.{HttpHeaderNames, HttpVersion}
import zhttp.internal.HttpGen
import zhttp.service.EncodeClientParams
import zio.random.Random
import zio.test.Assertion._
import zio.test._

object EncodeRequestSpec extends DefaultRunnableSpec with EncodeClientParams {

val anyClientParam: Gen[Random with Sized, Request] = HttpGen.clientRequest(
Copy link
Member

@girdharshubham girdharshubham Jan 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#nitpick This as well?
candidate for HttpGen?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can keep it here for now, as this test is the only user of this Gen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already there, the file has been renamed only.

HttpGen.httpData(
Gen.listOf(Gen.alphaNumericString),
),
)

val clientParamWithAbsoluteUrl = HttpGen.clientRequest(
dataGen = HttpGen.httpData(
Gen.listOf(Gen.alphaNumericString),
),
urlGen = HttpGen.genAbsoluteURL,
)

def clientParamWithFiniteData(size: Int): Gen[Random with Sized, Request] = HttpGen.clientRequest(
Copy link
Member

@girdharshubham girdharshubham Jan 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#nitpick Should we keep it HttpGen?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can keep it here for now, as this test is the only user of this Gen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already there, the file has been renamed only.

for {
content <- Gen.alphaNumericStringBounded(size, size)
data <- Gen.fromIterable(List(HttpData.fromString(content)))
} yield data,
)

def spec = suite("EncodeClientParams") {
testM("method") {
checkM(anyClientParam) { params =>
val method = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.method())
assertM(method)(equalTo(params.method.asHttpMethod))
}
} +
testM("method on HttpData.File") {
checkM(HttpGen.clientParamsForFileHttpData()) { params =>
val method = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.method())
assertM(method)(equalTo(params.method.asHttpMethod))
}
} +
suite("uri") {
testM("uri") {
checkM(anyClientParam) { params =>
val uri = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.uri())
assertM(uri)(equalTo(params.url.relative.encode))
}
} +
testM("uri on HttpData.File") {
checkM(HttpGen.clientParamsForFileHttpData()) { params =>
val uri = encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.uri())
assertM(uri)(equalTo(params.url.relative.encode))
}
}
} +
testM("content-length") {
checkM(clientParamWithFiniteData(5)) { params =>
val len = encodeClientParams(HttpVersion.HTTP_1_1, params).map(
_.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).toLong,
)
assertM(len)(equalTo(5L))
}
} +
testM("host header") {
checkM(anyClientParam) { params =>
val hostHeader = HttpHeaderNames.HOST
val headers = encodeClientParams(HttpVersion.HTTP_1_1, params).map(h => Option(h.headers().get(hostHeader)))
assertM(headers)(equalTo(params.url.host))
}
} +
testM("host header when absolute url") {
checkM(clientParamWithAbsoluteUrl) { params =>
val hostHeader = HttpHeaderNames.HOST
for {
reqHeaders <- encodeClientParams(HttpVersion.HTTP_1_1, params).map(_.headers())
} yield assert(reqHeaders.getAll(hostHeader).size)(equalTo(1)) && assert(Option(reqHeaders.get(hostHeader)))(
equalTo(params.url.host),
)
}
}
}
}
Loading