From 9350eec4f77ba2f4b4936406ff6ccd475a0c245f Mon Sep 17 00:00:00 2001 From: Rajkumar Natarajan Date: Wed, 20 Dec 2023 22:40:50 -0800 Subject: [PATCH] chore: support mtls settings --- .../src/main/scala/zio/http/SSLConfig.scala | 59 +++++++++++++++---- .../http/netty/server/ServerSSLDecoder.scala | 39 +++++++----- 2 files changed, 71 insertions(+), 27 deletions(-) diff --git a/zio-http/src/main/scala/zio/http/SSLConfig.scala b/zio-http/src/main/scala/zio/http/SSLConfig.scala index 876b59eef8..d681ae5d01 100644 --- a/zio-http/src/main/scala/zio/http/SSLConfig.scala +++ b/zio-http/src/main/scala/zio/http/SSLConfig.scala @@ -17,16 +17,32 @@ package zio.http import zio.Config -import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.http.SSLConfig._ -final case class SSLConfig(behaviour: HttpBehaviour, data: Data, provider: Provider) +sealed trait ClientAuth + +object ClientAuth { + case object Required extends ClientAuth + case object NoneClientAuth extends ClientAuth + case object Optional extends ClientAuth + +} + +final case class SSLConfig( + behaviour: HttpBehaviour, + data: Data, + provider: Provider, + clientAuth: Option[ClientAuth] = None, +) object SSLConfig { def apply(data: Data): SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, data, Provider.JDK) + new SSLConfig(HttpBehaviour.Redirect, data, Provider.JDK, None) + + def apply(data: Data, clientAuth: ClientAuth): SSLConfig = + new SSLConfig(HttpBehaviour.Redirect, data, Provider.JDK, Some(clientAuth)) val config: Config[SSLConfig] = ( @@ -38,22 +54,41 @@ object SSLConfig { } def fromFile(certPath: String, keyPath: String): SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, Data.FromFile(certPath, keyPath), Provider.JDK) + fromFile(HttpBehaviour.Redirect, certPath, keyPath) + + def fromFile(certPath: String, keyPath: String, clientAuth: ClientAuth): SSLConfig = + fromFile(HttpBehaviour.Redirect, certPath, keyPath, Some(clientAuth)) - def fromFile(behaviour: HttpBehaviour, certPath: String, keyPath: String): SSLConfig = - new SSLConfig(behaviour, Data.FromFile(certPath, keyPath), Provider.JDK) + def fromFile( + behaviour: HttpBehaviour, + certPath: String, + keyPath: String, + clientAuth: Option[ClientAuth] = None, + ): SSLConfig = + new SSLConfig(behaviour, Data.FromFile(certPath, keyPath), Provider.JDK, clientAuth) def fromResource(certPath: String, keyPath: String): SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, Data.FromResource(certPath, keyPath), Provider.JDK) + fromResource(HttpBehaviour.Redirect, certPath, keyPath, None) - def fromResource(behaviour: HttpBehaviour, certPath: String, keyPath: String): SSLConfig = - new SSLConfig(behaviour, Data.FromResource(certPath, keyPath), Provider.JDK) + def fromResource(certPath: String, keyPath: String, clientAuth: ClientAuth): SSLConfig = + fromResource(HttpBehaviour.Redirect, certPath, keyPath, Some(clientAuth)) + + def fromResource( + behaviour: HttpBehaviour, + certPath: String, + keyPath: String, + clientAuth: Option[ClientAuth] = None, + ): SSLConfig = + new SSLConfig(behaviour, Data.FromResource(certPath, keyPath), Provider.JDK, clientAuth) def generate: SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, Data.Generate, Provider.JDK) + generate(HttpBehaviour.Redirect, None) + + def generate(clientAuth: ClientAuth): SSLConfig = + generate(HttpBehaviour.Redirect, Some(clientAuth)) - def generate(behaviour: HttpBehaviour): SSLConfig = - new SSLConfig(behaviour, Data.Generate, Provider.JDK) + def generate(behaviour: HttpBehaviour, clientAuth: Option[ClientAuth] = None): SSLConfig = + new SSLConfig(behaviour, Data.Generate, Provider.JDK, clientAuth) sealed trait HttpBehaviour object HttpBehaviour { diff --git a/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala b/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala index 9ffeed426d..10b093e63c 100644 --- a/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala +++ b/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala @@ -19,11 +19,9 @@ package zio.http.netty.server import java.io.FileInputStream import java.util -import zio.stacktracer.TracingImplicits.disableAutoTrace - import zio.http.SSLConfig.{HttpBehaviour, Provider} import zio.http.netty.Names -import zio.http.{SSLConfig, Server} +import zio.http.{ClientAuth, SSLConfig, Server} import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext @@ -33,27 +31,38 @@ import io.netty.handler.ssl.ApplicationProtocolConfig.{ SelectedListenerFailureBehavior, SelectorFailureBehavior, } +import io.netty.handler.ssl._ import io.netty.handler.ssl.util.SelfSignedCertificate -import io.netty.handler.ssl.{SslContext, SslHandler, _} +import io.netty.handler.ssl.{ClientAuth => NettyClientAuth} object SSLUtil { + def getClientAuth(clientAuth: ClientAuth): NettyClientAuth = clientAuth match { + case ClientAuth.Required => NettyClientAuth.REQUIRE + case ClientAuth.Optional => NettyClientAuth.OPTIONAL + case _ => NettyClientAuth.NONE + } + implicit class SslContextBuilderOps(self: SslContextBuilder) { def toNettyProvider(sslProvider: Provider): SslProvider = sslProvider match { case Provider.OpenSSL => SslProvider.OPENSSL case Provider.JDK => SslProvider.JDK } - def buildWithDefaultOptions(sslConfig: SSLConfig): SslContext = self - .sslProvider(toNettyProvider(sslConfig.provider)) - .applicationProtocolConfig( - new ApplicationProtocolConfig( - Protocol.ALPN, - SelectorFailureBehavior.NO_ADVERTISE, - SelectedListenerFailureBehavior.ACCEPT, - ApplicationProtocolNames.HTTP_1_1, - ), - ) - .build() + def buildWithDefaultOptions(sslConfig: SSLConfig): SslContext = { + val clientAuthConfig: Option[ClientAuth] = sslConfig.clientAuth + clientAuthConfig.foreach(ca => self.clientAuth(getClientAuth(ca))) + self + .sslProvider(toNettyProvider(sslConfig.provider)) + .applicationProtocolConfig( + new ApplicationProtocolConfig( + Protocol.ALPN, + SelectorFailureBehavior.NO_ADVERTISE, + SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_1_1, + ), + ) + .build() + } } def sslConfigToSslContext(sslConfig: SSLConfig): SslContext = sslConfig.data match {