diff --git a/src/main/java/io/vertx/core/net/KeyStoreOptionsBase.java b/src/main/java/io/vertx/core/net/KeyStoreOptionsBase.java index 8929ab24d41..9782433456f 100644 --- a/src/main/java/io/vertx/core/net/KeyStoreOptionsBase.java +++ b/src/main/java/io/vertx/core/net/KeyStoreOptionsBase.java @@ -21,6 +21,7 @@ import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509KeyManager; import java.security.KeyStore; +import java.util.Objects; import java.util.function.Function; import java.util.function.Supplier; @@ -233,4 +234,21 @@ public Function trustManagerMapper(Vertx vertx) throws E @Override public abstract KeyStoreOptionsBase copy(); + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj != null && obj.getClass() == getClass()) { + KeyStoreOptionsBase that = (KeyStoreOptionsBase) obj; + return Objects.equals(provider, that.provider) && + Objects.equals(type, that.type) && + Objects.equals(password, that.password) && + Objects.equals(path, that.path) && + Objects.equals(value, that.value) && + Objects.equals(alias, that.alias) && + Objects.equals(aliasPassword, that.aliasPassword); + } + return false; + } } diff --git a/src/main/java/io/vertx/core/net/PemKeyCertOptions.java b/src/main/java/io/vertx/core/net/PemKeyCertOptions.java index 2895161ac3e..9dd5e2a81c0 100644 --- a/src/main/java/io/vertx/core/net/PemKeyCertOptions.java +++ b/src/main/java/io/vertx/core/net/PemKeyCertOptions.java @@ -25,6 +25,7 @@ import java.security.KeyStore; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.function.Function; /** @@ -385,6 +386,21 @@ public PemKeyCertOptions addCertValue(Buffer certValue) { return this; } + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj != null && obj.getClass() == getClass()) { + PemKeyCertOptions that = (PemKeyCertOptions) obj; + return Objects.equals(keyPaths, that.keyPaths) && + Objects.equals(keyValues, that.keyValues) && + Objects.equals(certPaths, that.certPaths) && + Objects.equals(certValues, that.certValues); + } + return false; + } + @Override public PemKeyCertOptions copy() { return new PemKeyCertOptions(this); diff --git a/src/main/java/io/vertx/core/net/PemTrustOptions.java b/src/main/java/io/vertx/core/net/PemTrustOptions.java index 788302541a5..08fde5712b3 100644 --- a/src/main/java/io/vertx/core/net/PemTrustOptions.java +++ b/src/main/java/io/vertx/core/net/PemTrustOptions.java @@ -175,6 +175,18 @@ public Function trustManagerMapper(Vertx vertx) throws E return helper != null ? helper::getTrustMgr : null; } + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj != null && obj.getClass() == getClass()) { + PemTrustOptions that = (PemTrustOptions) obj; + return Objects.equals(certPaths, that.certPaths) && Objects.equals(certValues, that.certValues); + } + return false; + } + @Override public PemTrustOptions copy() { return new PemTrustOptions(this); diff --git a/src/main/java/io/vertx/core/net/SSLOptions.java b/src/main/java/io/vertx/core/net/SSLOptions.java index 08f1ca588c0..826519d19f7 100644 --- a/src/main/java/io/vertx/core/net/SSLOptions.java +++ b/src/main/java/io/vertx/core/net/SSLOptions.java @@ -324,6 +324,25 @@ public SSLOptions removeEnabledSecureTransportProtocol(String protocol) { return this; } + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof SSLOptions) { + SSLOptions that = (SSLOptions) obj; + return sslHandshakeTimeoutUnit.toNanos(sslHandshakeTimeout) == that.sslHandshakeTimeoutUnit.toNanos(sslHandshakeTimeout) && + Objects.equals(keyCertOptions, that.keyCertOptions) && + Objects.equals(trustOptions, that.trustOptions) && + Objects.equals(enabledCipherSuites, that.enabledCipherSuites) && + Objects.equals(crlPaths, that.crlPaths) && + Objects.equals(crlValues, that.crlValues) && + useAlpn == that.useAlpn && + Objects.equals(enabledSecureTransportProtocols, that.enabledSecureTransportProtocols); + } + return false; + } + /** * Convert to JSON * diff --git a/src/main/java/io/vertx/core/net/impl/NetClientImpl.java b/src/main/java/io/vertx/core/net/impl/NetClientImpl.java index 2f759e79295..10c384d7f42 100644 --- a/src/main/java/io/vertx/core/net/impl/NetClientImpl.java +++ b/src/main/java/io/vertx/core/net/impl/NetClientImpl.java @@ -70,7 +70,7 @@ public class NetClientImpl implements MetricsProvider, NetClient, Closeable { private final VertxInternal vertx; private final NetClientOptions options; private final SSLHelper sslHelper; - private final AtomicReference> sslChannelProvider = new AtomicReference<>(); + private Future sslChannelProvider; private final ChannelGroup channelGroup; private final TCPMetrics metrics; private final CloseFuture closeFuture; @@ -180,9 +180,21 @@ public Metrics getMetrics() { @Override public Future updateSSLOptions(SSLOptions options) { - Future fut = sslHelper.buildChannelProvider(new SSLOptions(options), vertx.getOrCreateContext()); - fut.onSuccess(v -> sslChannelProvider.set(fut)); - return fut.mapEmpty(); + Future fut; + ContextInternal ctx = vertx.getOrCreateContext(); + synchronized (this) { + fut = sslHelper.updateSslContext(new SSLOptions(options), ctx); + sslChannelProvider = fut; + } + return fut.transform(ar -> { + if (ar.failed()) { + return ctx.failedFuture(ar.cause()); + } else if (ar.succeeded() && ar.result().error() != null) { + return ctx.failedFuture(ar.result().error()); + } else { + return ctx.succeededFuture(); + } + }); } @Override @@ -245,21 +257,17 @@ public void connectInternal(ProxyOptions proxyOptions, if (closeFuture.isClosed()) { connectHandler.fail(new IllegalStateException("Client is closed")); } else { - Future fut; - while (true) { - fut = sslChannelProvider.get(); + Future fut; + synchronized (NetClientImpl.this) { + fut = sslChannelProvider; if (fut == null) { - fut = sslHelper.buildChannelProvider(options.getSslOptions(), context); - if (sslChannelProvider.compareAndSet(null, fut)) { - break; - } - } else { - break; + fut = sslHelper.updateSslContext(options.getSslOptions(), context); + sslChannelProvider = fut; } } fut.onComplete(ar -> { if (ar.succeeded()) { - connectInternal2(proxyOptions, remoteAddress, peerAddress, ar.result(), serverName, ssl, useAlpn, registerWriteHandlers, connectHandler, context, remainingAttempts); + connectInternal2(proxyOptions, remoteAddress, peerAddress, ar.result().sslChannelProvider(), serverName, ssl, useAlpn, registerWriteHandlers, connectHandler, context, remainingAttempts); } else { connectHandler.fail(ar.cause()); } diff --git a/src/main/java/io/vertx/core/net/impl/SSLHelper.java b/src/main/java/io/vertx/core/net/impl/SSLHelper.java index d102f0d2064..1fb4fc0abed 100755 --- a/src/main/java/io/vertx/core/net/impl/SSLHelper.java +++ b/src/main/java/io/vertx/core/net/impl/SSLHelper.java @@ -110,6 +110,7 @@ public static SSLEngineOptions resolveEngineOptions(SSLEngineOptions engineOptio private Function keyManagerFactoryMapper; private Function trustManagerMapper; private List crls; + private Future cachedProvider; public SSLHelper(TCPSSLOptions options, List applicationProtocols) { this.sslEngineOptions = options.getSslEngineOptions(); @@ -123,6 +124,17 @@ public SSLHelper(TCPSSLOptions options, List applicationProtocols) { this.applicationProtocols = applicationProtocols; } + private static class CachedProvider { + final SSLOptions options; + final SslChannelProvider sslChannelProvider; + final Throwable failure; + CachedProvider(SSLOptions options, SslChannelProvider sslChannelProvider, Throwable failure) { + this.options = options; + this.sslChannelProvider = sslChannelProvider; + this.failure = failure; + } + } + private class EngineConfig { private final SSLOptions sslOptions; @@ -151,6 +163,44 @@ SslContextProvider sslContextProvider() { } } + /** + * Update cached options. This method ensures updates are serialized a nd performed when options is different + * (based on {@code equals}). Updates only happen when transforming {@code options} to a {@link SslChannelProvider} + * succeeds. + * + * @param options the options to use + * @param ctx the vertx context + * @return a future of the resolved channel provider + */ + public Future updateSslContext(SSLOptions options, ContextInternal ctx) { + synchronized (this) { + if (cachedProvider == null) { + cachedProvider = this.buildChannelProvider(options, ctx).map(a -> new CachedProvider(options, a, null)); + } else { + cachedProvider = cachedProvider.transform(prev -> { + if (prev.succeeded() && prev.result().options.equals(options)) { + return Future.succeededFuture(prev.result()); + } else { + return this + .buildChannelProvider(options, ctx) + .transform(ar -> { + if (ar.succeeded()) { + return ctx.succeededFuture(new CachedProvider(options, ar.result(), null)); + } else { + if (prev.succeeded()) { + return ctx.succeededFuture(new CachedProvider(prev.result().options, prev.result().sslChannelProvider, ar.cause())); + } else { + return ctx.failedFuture(prev.cause()); + } + } + }); + } + }); + } + return cachedProvider.map(c -> new SslContextUpdate(c.sslChannelProvider, c.failure)); + } + } + /** * Initialize the helper, this loads and validates the configuration. * diff --git a/src/main/java/io/vertx/core/net/impl/SslContextUpdate.java b/src/main/java/io/vertx/core/net/impl/SslContextUpdate.java new file mode 100644 index 00000000000..8369efd11e9 --- /dev/null +++ b/src/main/java/io/vertx/core/net/impl/SslContextUpdate.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2011-2023 Contributors to the Eclipse Foundation + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 + * which is available at https://www.apache.org/licenses/LICENSE-2.0. + * + * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + */ +package io.vertx.core.net.impl; + +/** + * Result of the SslContext update operation. + */ +public class SslContextUpdate { + + private final SslChannelProvider sslChannelProvider; + private final Throwable error; + + SslContextUpdate(SslChannelProvider sslChannelProvider, Throwable error) { + this.sslChannelProvider = sslChannelProvider; + this.error = error; + } + + /** + * @return the latest and freshest {@link SslChannelProvider} + */ + public SslChannelProvider sslChannelProvider() { + return sslChannelProvider; + } + + /** + * @return the optional error of the update operation + */ + public Throwable error() { + return error; + } +} diff --git a/src/main/java/io/vertx/core/net/impl/TCPServerBase.java b/src/main/java/io/vertx/core/net/impl/TCPServerBase.java index 9b0ea6323d2..9ea90d086ec 100644 --- a/src/main/java/io/vertx/core/net/impl/TCPServerBase.java +++ b/src/main/java/io/vertx/core/net/impl/TCPServerBase.java @@ -39,7 +39,6 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; /** @@ -66,7 +65,7 @@ public abstract class TCPServerBase implements Closeable, MetricsProvider { // Main private SSLHelper sslHelper; - private AtomicReference sslChannelProvider; + private volatile Future sslChannelProvider; private ServerChannelLoadBalancer channelBalancer; private Future bindFuture; private Set servers; @@ -77,12 +76,15 @@ public TCPServerBase(VertxInternal vertx, NetServerOptions options) { this.vertx = vertx; this.options = new NetServerOptions(options); this.creatingContext = vertx.getContext(); - this.sslChannelProvider = new AtomicReference<>(); } public SslContextProvider sslContextProvider() { - SslChannelProvider ref = sslChannelProvider.get(); - return ref != null ? ref.sslContextProvider() : null; + SslContextUpdate update = sslChannelProvider.result(); + if (update != null) { + return update.sslChannelProvider().sslContextProvider(); + } else { + return null; + } } public int actualPort() { @@ -97,11 +99,23 @@ protected SSLHelper createSSLHelper() { } public Future updateSSLOptions(SSLOptions options) { - return sslHelper.buildChannelProvider(new SSLOptions(options), listenContext).andThen(ar -> { - if (ar.succeeded()) { - TCPServerBase.this.sslChannelProvider.set(ar.result()); - } - }).mapEmpty(); + TCPServerBase server = actualServer; + if (server != null && server != this) { + return server.updateSSLOptions(options); + } else { + ContextInternal ctx = vertx.getOrCreateContext(); + Future update = sslHelper.updateSslContext(new SSLOptions(options), ctx); + sslChannelProvider = update; + return update.transform(ar -> { + if (ar.failed()) { + return ctx.failedFuture(ar.cause()); + } else if (ar.succeeded() && ar.result().error() != null) { + return ctx.failedFuture(ar.result().error()); + } else { + return ctx.succeededFuture(); + } + }); + } } public Future bind(SocketAddress address) { @@ -151,7 +165,7 @@ private synchronized Future listen(SocketAddress localAddress, ContextI bindFuture = promise; sslHelper = createSSLHelper(); childHandler = childHandler(listenContext, localAddress); - worker = ch -> childHandler.accept(ch, sslChannelProvider.get()); + worker = ch -> childHandler.accept(ch, sslChannelProvider.result().sslChannelProvider()); servers = new HashSet<>(); servers.add(this); channelBalancer = new ServerChannelLoadBalancer(vertx.getAcceptorEventLoopGroup().next()); @@ -164,12 +178,9 @@ private synchronized Future listen(SocketAddress localAddress, ContextI listenContext.addCloseHook(this); // Initialize SSL before binding - sslHelper.buildChannelProvider(options.getSslOptions(), listenContext).onComplete(ar -> { + sslChannelProvider = sslHelper.updateSslContext(options.getSslOptions(), listenContext).onComplete(ar -> { if (ar.succeeded()) { - // - sslChannelProvider.set(ar.result()); - // Socket bind channelBalancer.addWorker(eventLoop, worker); ServerBootstrap bootstrap = new ServerBootstrap(); @@ -227,7 +238,7 @@ private synchronized Future listen(SocketAddress localAddress, ContextI metrics = main.metrics; sslChannelProvider = main.sslChannelProvider; childHandler = childHandler(listenContext, localAddress); - worker = ch -> childHandler.accept(ch, sslChannelProvider.get()); + worker = ch -> childHandler.accept(ch, sslChannelProvider.result().sslChannelProvider()); actualServer.servers.add(this); actualServer.channelBalancer.addWorker(eventLoop, worker); listenContext.addCloseHook(this); diff --git a/src/test/java/io/vertx/core/http/HttpTLSTest.java b/src/test/java/io/vertx/core/http/HttpTLSTest.java index 4f3ea472107..b6ae2a7a85e 100755 --- a/src/test/java/io/vertx/core/http/HttpTLSTest.java +++ b/src/test/java/io/vertx/core/http/HttpTLSTest.java @@ -1640,7 +1640,7 @@ public void testHAProxy() throws Exception { } @Test - public void testReloadSSLOptions() throws Exception { + public void testUpdateSSLOptions() throws Exception { server = createHttpServer(createBaseServerOptions().setSsl(true).setKeyCertOptions(Cert.SERVER_JKS.get())) .requestHandler(req -> { req.response().end("Hello World"); @@ -1664,6 +1664,61 @@ public void testReloadSSLOptions() throws Exception { await(); } + @Test + public void testUpdateWithInvalidSSLOptions() throws Exception { + server = createHttpServer(createBaseServerOptions().setSsl(true).setKeyCertOptions(Cert.SERVER_JKS.get())) + .requestHandler(req -> { + req.response().end("Hello World"); + }); + startServer(testAddress); + client = createHttpClient(new HttpClientOptions().setKeepAlive(false).setSsl(true).setTrustOptions(Trust.SERVER_JKS.get())); + Future last = server.updateSSLOptions(new SSLOptions().setKeyCertOptions(new JksOptions().setValue(TestUtils.randomBuffer(20)).setPassword("invalid"))); + last.onComplete(onFailure(err -> { + client + .request(requestOptions) + .compose(req -> req.send().compose(HttpClientResponse::body)) + .onComplete(onSuccess(body -> { + assertEquals("Hello World", body.toString()); + testComplete(); + })); + })); + await(); + } + + @Test + public void testConcurrentUpdateSSLOptions() throws Exception { + server = createHttpServer(createBaseServerOptions().setSsl(true).setKeyCertOptions(Cert.SERVER_JKS.get())) + .requestHandler(req -> { + req.response().end("Hello World"); + }); + startServer(testAddress); + client = createHttpClient(new HttpClientOptions().setKeepAlive(false).setSsl(true).setTrustOptions(Trust.SERVER_JKS_ROOT_CA.get())); + List list = Arrays.asList( + Cert.SERVER_PKCS12.get(), + Cert.SERVER_PEM.get(), + Cert.SERVER_PEM.get(), + Cert.SERVER_JKS_ROOT_CA.get()); + AtomicInteger seq = new AtomicInteger(); + Future last = null; + for (int i = 0;i < list.size();i++) { + int val = i; + last = server.updateSSLOptions(new SSLOptions().setKeyCertOptions(list.get(i))); + last.onComplete(onSuccess(v -> { + assertEquals(val, seq.getAndIncrement()); + })); + } + last.onComplete(onSuccess(v -> { + client + .request(requestOptions) + .compose(req -> req.send().compose(HttpClientResponse::body)) + .onComplete(onSuccess(body -> { + assertEquals("Hello World", body.toString()); + testComplete(); + })); + })); + await(); + } + @Test public void testEngineUseEventLoopThread() throws Exception { testUseThreadPool(false, false);