Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Refactor: Merge client and server Request" #915

Merged
merged 1 commit into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions zio-http/src/main/scala/zhttp/http/Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ object Request {
method: Method = Method.GET,
url: URL = URL.root,
headers: Headers = Headers.empty,
data: HttpData = HttpData.Empty,
remoteAddress: Option[InetAddress] = None,
data: HttpData = HttpData.Empty,
): Request = {
val m = method
val u = url
Expand All @@ -121,7 +121,7 @@ object Request {
remoteAddress: Option[InetAddress],
content: HttpData = HttpData.empty,
): UIO[Request] =
UIO(Request(method, url, headers, content, remoteAddress))
UIO(Request(method, url, headers, remoteAddress, content))

/**
* Lift request to TypedRequest with option to extract params
Expand Down
63 changes: 48 additions & 15 deletions zio-http/src/main/scala/zhttp/service/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,42 @@ package zhttp.service

import io.netty.bootstrap.Bootstrap
import io.netty.buffer.{ByteBuf, ByteBufUtil}
import io.netty.channel.{Channel, ChannelFactory => JChannelFactory, EventLoopGroup => JEventLoopGroup}
import io.netty.handler.codec.http.{FullHttpRequest, HttpVersion}
import io.netty.channel.{
Channel,
ChannelFactory => JChannelFactory,
ChannelHandlerContext,
EventLoopGroup => JEventLoopGroup,
}
import io.netty.handler.codec.http.HttpVersion
import zhttp.http.URL.Location
import zhttp.http._
import zhttp.http.headers.HeaderExtension
import zhttp.service
import zhttp.service.Client.ClientResponse
import zhttp.service.Client.{ClientRequest, ClientResponse}
import zhttp.service.client.ClientSSLHandler.ClientSSLOptions
import zhttp.service.client.{ClientChannelInitializer, ClientInboundHandler}
import zio.{Chunk, Promise, Task, ZIO}

import java.net.InetSocketAddress
import java.net.{InetAddress, InetSocketAddress}

final case class Client(rtm: HttpRuntime[Any], cf: JChannelFactory[Channel], el: JEventLoopGroup)
extends HttpMessageCodec {
def request(
request: Request,
request: Client.ClientRequest,
sslOption: ClientSSLOptions = ClientSSLOptions.DefaultSSL,
): Task[Client.ClientResponse] =
for {
promise <- Promise.make[Throwable, Client.ClientResponse]
jReq <- encodeClientParams(HttpVersion.HTTP_1_1, request)
_ <- Task(asyncRequest(request, jReq, promise, sslOption)).catchAll(cause => promise.fail(cause))
_ <- Task(asyncRequest(request, promise, sslOption)).catchAll(cause => promise.fail(cause))
res <- promise.await
} yield res

private def asyncRequest(
req: Request,
jReq: FullHttpRequest,
req: ClientRequest,
promise: Promise[Throwable, ClientResponse],
sslOption: ClientSSLOptions,
): Unit = {
val jReq = encodeClientParams(HttpVersion.HTTP_1_1, req)
try {
val hand = ClientInboundHandler(rtm, jReq, promise)
val host = req.url.host
Expand Down Expand Up @@ -107,42 +111,71 @@ object Client {
method: Method,
url: URL,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(Request(method, url))
request(ClientRequest(method, url))

def request(
method: Method,
url: URL,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(Request(method, url), sslOptions)
request(ClientRequest(method, url), sslOptions)

def request(
method: Method,
url: URL,
headers: Headers,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(Request(method, url, headers), sslOptions)
request(ClientRequest(method, url, headers), sslOptions)

def request(
method: Method,
url: URL,
headers: Headers,
content: HttpData,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
request(Request(method, url, headers, content, None))
request(ClientRequest(method, url, headers, content))

def request(
req: Request,
req: ClientRequest,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
make.flatMap(_.request(req))

def request(
req: Request,
req: ClientRequest,
sslOptions: ClientSSLOptions,
): ZIO[EventLoopGroup with ChannelFactory, Throwable, ClientResponse] =
make.flatMap(_.request(req, sslOptions))

final case class ClientRequest(
method: Method,
url: URL,
getHeaders: Headers = Headers.empty,
data: HttpData = HttpData.empty,
private val channelContext: ChannelHandlerContext = null,
) extends HeaderExtension[ClientRequest] { self =>

def getBodyAsString: Option[String] = data match {
case HttpData.Text(text, _) => Some(text)
case HttpData.BinaryChunk(data) => Some(new String(data.toArray, HTTP_CHARSET))
case HttpData.BinaryByteBuf(data) => Some(data.toString(HTTP_CHARSET))
case _ => Option.empty
}

def remoteAddress: Option[InetAddress] = {
if (channelContext != null && channelContext.channel().remoteAddress().isInstanceOf[InetSocketAddress])
Some(channelContext.channel().remoteAddress().asInstanceOf[InetSocketAddress].getAddress)
else
None
}

/**
* Updates the headers using the provided function
*/
override def updateHeaders(update: Headers => Headers): ClientRequest =
self.copy(getHeaders = update(self.getHeaders))
}

final case class ClientResponse(status: Status, headers: Headers, private[zhttp] val buffer: ByteBuf)
extends HeaderExtension[ClientResponse] { self =>

Expand Down
58 changes: 31 additions & 27 deletions zio-http/src/main/scala/zhttp/service/EncodeClientParams.scala
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
package zhttp.service

import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.{DefaultFullHttpRequest, FullHttpRequest, HttpHeaderNames, HttpVersion}
import zhttp.http.Request
import zio.Task
import zhttp.http.HTTP_CHARSET
trait EncodeClientParams {

/**
* Converts client params to JFullHttpRequest
*/
def encodeClientParams(jVersion: HttpVersion, req: Request): Task[FullHttpRequest] = req.getBodyAsByteBuf.map {
content =>
val method = req.method.asHttpMethod
val url = req.url

// As per the spec, the path should contain only the relative path.
// Host and port information should be in the headers.
val path = url.relative.encode

val encodedReqHeaders = req.getHeaders.encode

val headers = url.host match {
case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value)
case None => encodedReqHeaders
}

val writerIndex = content.writerIndex()
if (writerIndex != 0) {
headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString())
}
// TODO: we should also add a default user-agent req header as some APIs might reject requests without it.
val jReq = new DefaultFullHttpRequest(jVersion, method, path, content)
jReq.headers().set(headers)

jReq
def encodeClientParams(jVersion: HttpVersion, req: Client.ClientRequest): FullHttpRequest = {
val method = req.method.asHttpMethod
val url = req.url

// As per the spec, the path should contain only the relative path.
// Host and port information should be in the headers.
val path = url.relative.encode

val content = req.getBodyAsString match {
case Some(text) => Unpooled.copiedBuffer(text, HTTP_CHARSET)
case None => Unpooled.EMPTY_BUFFER
}

val encodedReqHeaders = req.getHeaders.encode

val headers = url.host match {
case Some(value) => encodedReqHeaders.set(HttpHeaderNames.HOST, value)
case None => encodedReqHeaders
}

val writerIndex = content.writerIndex()
if (writerIndex != 0) {
headers.set(HttpHeaderNames.CONTENT_LENGTH, writerIndex.toString())
}
// TODO: we should also add a default user-agent req header as some APIs might reject requests without it.
val jReq = new DefaultFullHttpRequest(jVersion, method, path, content)
jReq.headers().set(headers)

jReq
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package zhttp.http

import io.netty.handler.codec.http.{HttpHeaderNames, HttpVersion}
import zhttp.internal.HttpGen
import zhttp.service.{Client, EncodeClientParams}
import zio.random.Random
import zio.test.Assertion._
import zio.test._

object EncodeClientRequestSpec extends DefaultRunnableSpec with EncodeClientParams {

val anyClientParam: Gen[Random with Sized, Client.ClientRequest] = HttpGen.clientRequest(
HttpGen.httpData(
Gen.listOf(Gen.alphaNumericString),
),
)

val clientParamWithAbsoluteUrl = HttpGen.clientRequest(
dataGen = HttpGen.httpData(
Gen.listOf(Gen.alphaNumericString),
),
urlGen = HttpGen.genAbsoluteURL,
)

def clientParamWithFiniteData(size: Int): Gen[Random with Sized, Client.ClientRequest] = HttpGen.clientRequest(
for {
content <- Gen.alphaNumericStringBounded(size, size)
data <- Gen.fromIterable(List(HttpData.fromString(content)))
} yield data,
)

def spec = suite("EncodeClientParams") {
testM("method") {
check(anyClientParam) { params =>
val req = encodeClientParams(HttpVersion.HTTP_1_1, params)
assert(req.method())(equalTo(params.method.asHttpMethod))
}
} +
testM("method on HttpData.File") {
check(HttpGen.clientParamsForFileHttpData()) { params =>
val req = encodeClientParams(HttpVersion.HTTP_1_1, params)
assert(req.method())(equalTo(params.method.asHttpMethod))
}
} +
suite("uri") {
testM("uri") {
check(anyClientParam) { params =>
val req = encodeClientParams(HttpVersion.HTTP_1_1, params)
assert(req.uri())(equalTo(params.url.relative.encode))
}
} +
testM("uri on HttpData.File") {
check(HttpGen.clientParamsForFileHttpData()) { params =>
val req = encodeClientParams(HttpVersion.HTTP_1_1, params)
assert(req.uri())(equalTo(params.url.relative.encode))
}
}
} +
testM("content-length") {
check(clientParamWithFiniteData(5)) { params =>
val req = encodeClientParams(HttpVersion.HTTP_1_1, params)
assert(req.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).toLong)(equalTo(5L))
}
} +
testM("host header") {
check(anyClientParam) { params =>
val req = encodeClientParams(HttpVersion.HTTP_1_1, params)
val hostHeader = HttpHeaderNames.HOST
assert(Option(req.headers().get(hostHeader)))(equalTo(params.url.host))
}
} +
testM("host header when absolute url") {
check(clientParamWithAbsoluteUrl) { params =>
val req = encodeClientParams(HttpVersion.HTTP_1_1, params)
val reqHeaders = req.headers()
val hostHeader = HttpHeaderNames.HOST

assert(reqHeaders.getAll(hostHeader).size)(equalTo(1)) &&
assert(Option(reqHeaders.get(hostHeader)))(equalTo(params.url.host))
}
}
}
}
85 changes: 0 additions & 85 deletions zio-http/src/test/scala/zhttp/http/EncodeRequestSpec.scala

This file was deleted.

Loading