diff --git a/project/Dependencies.scala b/project/Dependencies.scala index e03bb21135..168e27d964 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -15,6 +15,7 @@ object Dependencies { val netty = Seq( "netty-codec-http", + "netty-handler-proxy", "netty-transport-native-epoll", "netty-transport-native-kqueue", ).map { name => diff --git a/zio-http/src/main/scala/zhttp/http/Proxy.scala b/zio-http/src/main/scala/zhttp/http/Proxy.scala new file mode 100644 index 0000000000..f79c2e7c89 --- /dev/null +++ b/zio-http/src/main/scala/zhttp/http/Proxy.scala @@ -0,0 +1,55 @@ +package zhttp.http +import io.netty.handler.proxy.HttpProxyHandler +import zhttp.http.middleware.Auth.Credentials + +import java.net.InetSocketAddress + +/** + * Represents the connection to the forward proxy before running the request + * + * @param url: + * url address of the proxy server + * @param credentials: + * credentials for the proxy server. Encodes credentials with basic auth and + * put under the 'proxy-authorization' header + * @param headers: + * headers for the request to the proxy server + */ +final case class Proxy( + url: URL, + credentials: Option[Credentials] = None, + headers: Headers = Headers.empty, +) { self => + + def withUrl(url: URL): Proxy = self.copy(url = url) + def withCredentials(credentials: Credentials): Proxy = self.copy(credentials = Some(credentials)) + def withHeaders(headers: Headers): Proxy = self.copy(headers = headers) + + /** + * Converts a Proxy to [io.netty.handler.proxy.HttpProxyHandler] + */ + private[zhttp] def encode: Option[HttpProxyHandler] = credentials.fold(unauthorizedProxy)(authorizedProxy) + + private def authorizedProxy(credentials: Credentials): Option[HttpProxyHandler] = for { + proxyAddress <- buildProxyAddress + uname = credentials.uname + upassword = credentials.upassword + encodedHeaders = headers.encode + } yield new HttpProxyHandler(proxyAddress, uname, upassword, encodedHeaders) + + private def unauthorizedProxy: Option[HttpProxyHandler] = for { + proxyAddress <- buildProxyAddress + encodedHeaders = headers.encode + } yield { + new HttpProxyHandler(proxyAddress, encodedHeaders) + } + + private def buildProxyAddress: Option[InetSocketAddress] = for { + proxyHost <- url.host + proxyPort <- url.port + } yield new InetSocketAddress(proxyHost, proxyPort) +} + +object Proxy { + val empty: Proxy = Proxy(URL.empty) +} diff --git a/zio-http/src/main/scala/zhttp/service/Client.scala b/zio-http/src/main/scala/zhttp/service/Client.scala index 3ce1a1026b..16fa34be78 100644 --- a/zio-http/src/main/scala/zhttp/service/Client.scala +++ b/zio-http/src/main/scala/zhttp/service/Client.scala @@ -10,6 +10,7 @@ import io.netty.channel.{ } import io.netty.handler.codec.http._ import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler +import io.netty.handler.proxy.HttpProxyHandler import zhttp.http._ import zhttp.service import zhttp.service.Client.Config @@ -71,6 +72,7 @@ final case class Client[R](rtm: HttpRuntime[R], cf: JChannelFactory[Channel], el val isWebSocket = req.url.scheme.exists(_.isWebSocket) val isSSL = req.url.scheme.exists(_.isSecure) + val isProxy = clientConfig.proxy.isDefined val initializer = new ChannelInitializer[Channel]() { override def initChannel(ch: Channel): Unit = { @@ -78,6 +80,19 @@ final case class Client[R](rtm: HttpRuntime[R], cf: JChannelFactory[Channel], el val pipeline = ch.pipeline() val sslOption: ClientSSLOptions = clientConfig.ssl.getOrElse(ClientSSLOptions.DefaultSSL) + // Adding proxy handler + if (isProxy) { + val handler: HttpProxyHandler = + clientConfig.proxy + .flatMap(_.encode) + .getOrElse(new HttpProxyHandler(new InetSocketAddress(host, port))) + + pipeline.addLast( + PROXY_HANDLER, + handler, + ) + } + // If a https or wss request is made we need to add the ssl handler at the starting of the pipeline. if (isSSL) pipeline.addLast(SSL_HANDLER, ClientSSLHandler.ssl(sslOption).newHandler(ch.alloc, host, port)) @@ -167,9 +182,15 @@ object Client { } yield res } - case class Config(socketApp: Option[SocketApp[Any]] = None, ssl: Option[ClientSSLOptions] = None) { self => + case class Config( + socketApp: Option[SocketApp[Any]] = None, + ssl: Option[ClientSSLOptions] = None, + proxy: Option[Proxy] = None, + ) { + self => def withSSL(ssl: ClientSSLOptions): Config = self.copy(ssl = Some(ssl)) def withSocketApp(socketApp: SocketApp[Any]): Config = self.copy(socketApp = Some(socketApp)) + def withProxy(proxy: Proxy): Config = self.copy(proxy = Some(proxy)) } object Config { diff --git a/zio-http/src/main/scala/zhttp/service/package.scala b/zio-http/src/main/scala/zhttp/service/package.scala index 508e632ffe..aea33b39b1 100644 --- a/zio-http/src/main/scala/zhttp/service/package.scala +++ b/zio-http/src/main/scala/zhttp/service/package.scala @@ -33,6 +33,7 @@ package object service extends Logging { private[service] val WEB_SOCKET_CLIENT_PROTOCOL_HANDLER = "WEB_SOCKET_CLIENT_PROTOCOL_HANDLER" private[service] val HTTP_REQUEST_DECOMPRESSION = "HTTP_REQUEST_DECOMPRESSION" private[service] val LOW_LEVEL_LOGGING = "LOW_LEVEL_LOGGING" + private[service] val PROXY_HANDLER = "PROXY_HANDLER" private[zhttp] val HTTP_CONTENT_HANDLER = "HTTP_CONTENT_HANDLER" } diff --git a/zio-http/src/test/scala/zhttp/http/ProxySpec.scala b/zio-http/src/test/scala/zhttp/http/ProxySpec.scala new file mode 100644 index 0000000000..8a149ef154 --- /dev/null +++ b/zio-http/src/test/scala/zhttp/http/ProxySpec.scala @@ -0,0 +1,40 @@ +package zhttp.http + +import zhttp.http.middleware.Auth.Credentials +import zio.test.Assertion.{equalTo, isNone, isNull, isSome} +import zio.test._ + +object ProxySpec extends DefaultRunnableSpec { + private val validUrl = URL.fromString("http://localhost:8123").toOption.getOrElse(URL.empty) + + override def spec = suite("Proxy")( + suite("Authenticated Proxy") { + test("successfully encode valid proxy") { + val username = "unameTest" + val password = "upassTest" + val proxy = Proxy(validUrl, Some(Credentials(username, password))) + val encoded = proxy.encode + + assert(encoded.map(_.username()))(isSome(equalTo(username))) && + assert(encoded.map(_.password()))(isSome(equalTo(password))) && + assert(encoded.map(_.authScheme()))(isSome(equalTo("basic"))) + } + + test("fail to encode invalid proxy") { + val proxy = Proxy(URL.empty) + val encoded = proxy.encode + + assert(encoded.map(_.username()))(isNone) + } + } + suite("Unauthenticated proxy") { + test("successfully encode valid proxy") { + val proxy = Proxy(validUrl) + val encoded = proxy.encode + + assert(encoded)(isSome) && + assert(encoded.map(_.username()))(isSome(isNull)) && + assert(encoded.map(_.password()))(isSome(isNull)) && + assert(encoded.map(_.authScheme()))(isSome(equalTo("none"))) + } + }, + ) +} diff --git a/zio-http/src/test/scala/zhttp/service/ClientSpec.scala b/zio-http/src/test/scala/zhttp/service/ClientSpec.scala index 17c367b121..647c67c655 100644 --- a/zio-http/src/test/scala/zhttp/service/ClientSpec.scala +++ b/zio-http/src/test/scala/zhttp/service/ClientSpec.scala @@ -1,8 +1,10 @@ package zhttp.service - import zhttp.http._ +import zhttp.http.middleware.Auth.Credentials import zhttp.internal.{DynamicServer, HttpRunnableSpec} +import zhttp.service.Client.Config import zhttp.service.server._ +import zio.ZIO import zio.duration.durationInt import zio.test.Assertion._ import zio.test.TestAspect.{sequential, timeout} @@ -43,6 +45,57 @@ object ClientSpec extends HttpRunnableSpec { testM("handle connection failure") { val res = Client.request("http://localhost:1").either assertM(res)(isLeft(isSubtype[ConnectException](anything))) + } + + testM("handle proxy connection failure") { + val res = + for { + validServerPort <- ZIO.accessM[DynamicServer](_.get.port) + serverUrl <- ZIO.fromEither(URL.fromString(s"http://localhost:$validServerPort")) + proxyUrl <- ZIO.fromEither(URL.fromString("http://localhost:0001")) + out <- Client.request( + Request(url = serverUrl), + Config().withProxy(Proxy(proxyUrl)), + ) + } yield out + assertM(res.either)(isLeft(isSubtype[ConnectException](anything))) + } + + testM("proxy respond Ok") { + val res = + for { + port <- ZIO.accessM[DynamicServer](_.get.port) + url <- ZIO.fromEither(URL.fromString(s"http://localhost:$port")) + id <- DynamicServer.deploy(Http.ok) + proxy = Proxy.empty.withUrl(url).withHeaders(Headers(DynamicServer.APP_ID, id)) + out <- Client.request( + Request(url = url), + Config().withProxy(proxy), + ) + } yield out + assertM(res.either)(isRight) + } + + testM("proxy respond Ok for auth server") { + val proxyAuthApp = Http.collect[Request] { case req => + val proxyAuthHeaderName = HeaderNames.proxyAuthorization.toString + req.headers.toList.collectFirst { case (`proxyAuthHeaderName`, _) => + Response.ok + }.getOrElse(Response.status(Status.Forbidden)) + } + + val res = + for { + port <- ZIO.accessM[DynamicServer](_.get.port) + url <- ZIO.fromEither(URL.fromString(s"http://localhost:$port")) + id <- DynamicServer.deploy(proxyAuthApp) + proxy = Proxy.empty + .withUrl(url) + .withHeaders(Headers(DynamicServer.APP_ID, id)) + .withCredentials(Credentials("test", "test")) + out <- Client.request( + Request(url = url), + Config().withProxy(proxy), + ) + } yield out + assertM(res.either)(isRight) } }