From f5d5c0e78b63f28abc4f459449266fbea667d02c Mon Sep 17 00:00:00 2001 From: Ryan Emerson Date: Thu, 11 Jul 2024 14:57:53 +0100 Subject: [PATCH] Initial impl. Tests needed --- src/org/jgroups/stack/GossipRouter.java | 1 + src/org/jgroups/util/FileWatcher.java | 105 +++++++++++++++ .../jgroups/util/ReloadingX509KeyManager.java | 101 ++++++++++++++ .../util/ReloadingX509TrustManager.java | 94 +++++++++++++ src/org/jgroups/util/SslContextFactory.java | 123 +++++++++++------- src/org/jgroups/util/TLS.java | 7 +- 6 files changed, 384 insertions(+), 47 deletions(-) create mode 100644 src/org/jgroups/util/FileWatcher.java create mode 100644 src/org/jgroups/util/ReloadingX509KeyManager.java create mode 100644 src/org/jgroups/util/ReloadingX509TrustManager.java diff --git a/src/org/jgroups/stack/GossipRouter.java b/src/org/jgroups/stack/GossipRouter.java index 03b195a19fa..3eac7430a7b 100644 --- a/src/org/jgroups/stack/GossipRouter.java +++ b/src/org/jgroups/stack/GossipRouter.java @@ -954,6 +954,7 @@ public static void main(String[] args) throws Exception { String type=""; if(tls.enabled()) { tls.init(); + tls.setWatcher(new FileWatcher()); SSLContext context=tls.createContext(); SocketFactory socket_factory=tls.createSocketFactory(context); router.socketFactory(socket_factory); diff --git a/src/org/jgroups/util/FileWatcher.java b/src/org/jgroups/util/FileWatcher.java new file mode 100644 index 00000000000..511889f9939 --- /dev/null +++ b/src/org/jgroups/util/FileWatcher.java @@ -0,0 +1,105 @@ +package org.jgroups.util; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +import org.jgroups.logging.Log; +import org.jgroups.logging.LogFactory; + +public class FileWatcher implements Runnable, AutoCloseable { + + static final Log LOG = LogFactory.getLog(FileWatcher.class); + + public static final int SLEEP = 2_000; + private final Thread thread; + private final ConcurrentHashMap watched; + private boolean running = true; + + public FileWatcher() { + watched = new ConcurrentHashMap<>(); + thread = new Thread(this, "FileWatcher"); + Runtime.getRuntime().addShutdownHook(new Thread(this::stop)); + thread.start(); + } + + public void unwatch(Path path) { + watched.remove(path); + LOG.debug("Unwatched %s", path); + } + + public void watch(Path path, Consumer callback) { + watched.compute(path, (k, w) -> { + if (w == null) { + w = new Watched(); + try { + w.lastModified = Files.getLastModifiedTime(path).toMillis(); + } catch (FileNotFoundException | NoSuchFileException e) { + w.lastModified = -1; + LOG.debug("File not found %s", path); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + w.watchers.add(callback); + return w; + }); + LOG.debug("Watching %s", path); + } + + @Override + public void run() { + while (running) { + try { + Thread.sleep(SLEEP); + } catch (InterruptedException e) { + return; + } + if (!running) { + return; + } + for (Map.Entry e : watched.entrySet()) { + Watched w = e.getValue(); + try { + long lastModified = Files.getLastModifiedTime(e.getKey()).toMillis(); + if (w.lastModified < lastModified) { + w.lastModified = lastModified; + for (Consumer c : w.watchers) { + c.accept(e.getKey()); + } + } + } catch (FileNotFoundException | NoSuchFileException ex) { + w.lastModified = -1; + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } + } + + public void stop() { + running = false; + try { + thread.join(); + } catch (InterruptedException e) { + // Ignore + } + } + + @Override + public void close() { + stop(); + } + + static class Watched { + long lastModified; + List> watchers = new ArrayList<>(2); + } +} diff --git a/src/org/jgroups/util/ReloadingX509KeyManager.java b/src/org/jgroups/util/ReloadingX509KeyManager.java new file mode 100644 index 00000000000..ceca604c729 --- /dev/null +++ b/src/org/jgroups/util/ReloadingX509KeyManager.java @@ -0,0 +1,101 @@ +package org.jgroups.util; + +import java.io.Closeable; +import java.io.IOException; +import java.net.Socket; +import java.nio.file.Path; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedKeyManager; + +import org.jgroups.logging.Log; +import org.jgroups.logging.LogFactory; + +/** + * A {@link X509ExtendedKeyManager} which uses a @{@link FileWatcher} to check for changes. + */ +public final class ReloadingX509KeyManager extends X509ExtendedKeyManager implements Closeable { + + static final Log LOG = LogFactory.getLog(ReloadingX509KeyManager.class); + + private final AtomicReference manager; + private final Path path; + private final Function action; + private final FileWatcher watcher; + private Instant lastLoaded; + + public ReloadingX509KeyManager(FileWatcher watcher, Path path, Function action) { + Objects.requireNonNull(watcher, "watcher must be non-null"); + Objects.requireNonNull(path, "path must be non-null"); + Objects.requireNonNull(action, "action must be non-null"); + + this.manager = new AtomicReference<>(); + this.watcher = watcher; + this.path = path; + this.action = action; + reload(this.path); + watcher.watch(path, this::reload); + } + + private void reload(Path path) { + manager.set(action.apply(path)); + lastLoaded = Instant.now(); + LOG.debug("Loaded '%s'", path); + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return manager.get().getClientAliases(keyType, issuers); + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return manager.get().chooseClientAlias(keyType, issuers, socket); + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return manager.get().getServerAliases(keyType, issuers); + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return manager.get().chooseServerAlias(keyType, issuers, socket); + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + return manager.get().getCertificateChain(alias); + } + + @Override + public PrivateKey getPrivateKey(String alias) { + return manager.get().getPrivateKey(alias); + } + + @Override + public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { + return manager.get().chooseEngineClientAlias(keyType, issuers, engine); + } + + @Override + public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { + return manager.get().chooseEngineServerAlias(keyType, issuers, engine); + } + + public Instant lastLoaded() { + return lastLoaded; + } + + @Override + public void close() throws IOException { + watcher.unwatch(path); + } +} diff --git a/src/org/jgroups/util/ReloadingX509TrustManager.java b/src/org/jgroups/util/ReloadingX509TrustManager.java new file mode 100644 index 00000000000..fc24b1f964f --- /dev/null +++ b/src/org/jgroups/util/ReloadingX509TrustManager.java @@ -0,0 +1,94 @@ +package org.jgroups.util; + +import java.io.Closeable; +import java.io.IOException; +import java.net.Socket; +import java.nio.file.Path; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedTrustManager; + +import org.jgroups.logging.Log; +import org.jgroups.logging.LogFactory; + +/** + * A {@link X509ExtendedTrustManager} which uses a @{@link FileWatcher} to check for changes. + */ +public class ReloadingX509TrustManager extends X509ExtendedTrustManager implements Closeable { + + static final Log LOG = LogFactory.getLog(ReloadingX509TrustManager.class); + + private final AtomicReference manager; + private final Path path; + private final Function action; + private final FileWatcher watcher; + private Instant lastLoaded; + + public ReloadingX509TrustManager(FileWatcher watcher, Path path, Function action) { + Objects.requireNonNull(watcher, "watcher must be non-null"); + Objects.requireNonNull(path, "path must be non-null"); + Objects.requireNonNull(action, "action must be non-null"); + this.manager = new AtomicReference<>(); + this.path = path; + this.action = action; + this.watcher = watcher; + reload(this.path); + watcher.watch(this.path, this::reload); + } + + private void reload(Path path) { + manager.set(action.apply(path)); + lastLoaded = Instant.now(); + LOG.debug("Loaded '%s'", path); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + manager.get().checkClientTrusted(chain, authType); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + manager.get().checkServerTrusted(chain, authType); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return manager.get().getAcceptedIssuers(); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { + manager.get().checkClientTrusted(chain, authType, socket); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { + manager.get().checkServerTrusted(chain, authType, socket); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) throws CertificateException { + manager.get().checkClientTrusted(chain, authType, engine); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) throws CertificateException { + manager.get().checkServerTrusted(chain, authType, engine); + } + + public Instant lastLoaded() { + return lastLoaded; + } + + @Override + public void close() throws IOException { + watcher.unwatch(path); + } +} diff --git a/src/org/jgroups/util/SslContextFactory.java b/src/org/jgroups/util/SslContextFactory.java index 0514315613c..1f0539b581e 100644 --- a/src/org/jgroups/util/SslContextFactory.java +++ b/src/org/jgroups/util/SslContextFactory.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.security.GeneralSecurityException; import java.security.KeyStore; @@ -22,6 +23,8 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509ExtendedTrustManager; import org.jgroups.logging.Log; import org.jgroups.logging.LogFactory; @@ -50,6 +53,7 @@ public class SslContextFactory { private String sslProtocol = DEFAULT_SSL_PROTOCOL; private ClassLoader classLoader; private String providerName; + private FileWatcher watcher; public SslContextFactory() { } @@ -133,6 +137,11 @@ public SslContextFactory classLoader(ClassLoader classLoader) { return this; } + public SslContextFactory watcher(FileWatcher watcher) { + this.watcher = watcher; + return this; + } + public SSLContext getContext() { try { SSLContext sslContext; @@ -154,66 +163,90 @@ public SSLContext getContext() { public void initializeContext(SSLContext sslContext) { try { - KeyManager[] keyManagers = null; + KeyManager[] kms = null; if (keyStoreFileName != null || keyStore != null) { - KeyManagerFactory kmf = getKeyManagerFactory(); - keyManagers = kmf.getKeyManagers(); + if (keyStoreFileName != null && watcher != null) { + kms = new KeyManager[]{new ReloadingX509KeyManager(watcher, Path.of(keyStoreFileName), p -> getKeyManager())}; + } else { + kms = new KeyManager[]{getKeyManager()}; + } } - TrustManager[] trustManagers = null; + TrustManager[] tms = null; if (trustStoreFileName != null || trustStore != null) { - TrustManagerFactory tmf = getTrustManagerFactory(); - trustManagers = tmf.getTrustManagers(); + if (trustStoreFileName != null && watcher != null) { + tms = new TrustManager[]{new ReloadingX509TrustManager(watcher, Path.of(trustStoreFileName), p -> getTrustManager())}; + } else { + tms = new TrustManager[]{getTrustManager()}; + } } - sslContext.init(keyManagers, trustManagers, null); + sslContext.init(kms, tms, null); } catch (Exception e) { throw new RuntimeException("Could not initialize SSL", e); } } - public KeyManagerFactory getKeyManagerFactory() throws IOException, GeneralSecurityException { - Provider provider; - KeyStore ks = keyStore != null ? keyStore : null; - if (ks == null) { - String type = keyStoreType != null ? keyStoreType : DEFAULT_KEYSTORE_TYPE; - provider = findProvider(this.providerName, KeyStore.class.getSimpleName(), type); - ks = provider != null ? KeyStore.getInstance(type, provider) : KeyStore.getInstance(type); - loadKeyStore(ks, keyStoreFileName, keyStorePassword, classLoader); - } else { - provider = ks.getProvider(); - } - if (keyAlias != null) { - if (ks.containsAlias(keyAlias) && ks.isKeyEntry(keyAlias)) { - KeyStore.PasswordProtection passParam = new KeyStore.PasswordProtection(keyStorePassword); - KeyStore.Entry entry = ks.getEntry(keyAlias, passParam); - // Recreate the keystore with just one key - ks = provider != null ? KeyStore.getInstance(keyStoreType, provider) : KeyStore.getInstance(keyStoreType); - ks.load(null, null); - ks.setEntry(keyAlias, entry, passParam); + private X509ExtendedKeyManager getKeyManager() { + try { + Provider provider; + KeyStore ks = keyStore != null ? keyStore : null; + if (ks == null) { + String type = keyStoreType != null ? keyStoreType : DEFAULT_KEYSTORE_TYPE; + provider = findProvider(this.providerName, KeyStore.class.getSimpleName(), type); + ks = provider != null ? KeyStore.getInstance(type, provider) : KeyStore.getInstance(type); + loadKeyStore(ks, keyStoreFileName, keyStorePassword, classLoader); } else { - throw new RuntimeException("No alias '" + keyAlias + "' in key store '" + keyStoreFileName + "'"); + provider = keyStore.getProvider(); } + if (keyAlias != null) { + if (ks.containsAlias(keyAlias) && ks.isKeyEntry(keyAlias)) { + KeyStore.PasswordProtection passParam = new KeyStore.PasswordProtection(keyStorePassword); + KeyStore.Entry entry = ks.getEntry(keyAlias, passParam); + // Recreate the keystore with just one key + ks = provider != null ? KeyStore.getInstance(keyStoreType, provider) : KeyStore.getInstance(keyStoreType); + ks.load(null, null); + ks.setEntry(keyAlias, entry, passParam); + } else { + throw new RuntimeException(String.format("The alias '%s' does not exist in the key store '%s'", keyAlias, keyStoreFileName)); + } + } + String algorithm = KeyManagerFactory.getDefaultAlgorithm(); + provider = findProvider(this.providerName, KeyManagerFactory.class.getSimpleName(), algorithm); + KeyManagerFactory kmf = provider != null ? KeyManagerFactory.getInstance(algorithm, provider) : KeyManagerFactory.getInstance(algorithm); + kmf.init(ks, keyStorePassword); + for (KeyManager km : kmf.getKeyManagers()) { + if (km instanceof X509ExtendedKeyManager) { + return (X509ExtendedKeyManager) km; + } + } + throw new GeneralSecurityException("Could not obtain an X509ExtendedKeyManager"); + } catch (GeneralSecurityException | IOException e) { + throw new RuntimeException("Error while initializing SSL context", e); } - String algorithm = KeyManagerFactory.getDefaultAlgorithm(); - provider = findProvider(this.providerName, KeyManagerFactory.class.getSimpleName(), algorithm); - KeyManagerFactory kmf = provider != null ? KeyManagerFactory.getInstance(algorithm, provider) : KeyManagerFactory.getInstance(algorithm); - kmf.init(ks, keyStorePassword); - return kmf; } - public TrustManagerFactory getTrustManagerFactory() throws IOException, GeneralSecurityException { - Provider provider; - KeyStore ts = trustStore != null ? trustStore : null; - if (ts == null) { - String type = trustStoreType != null ? trustStoreType : DEFAULT_KEYSTORE_TYPE; - provider = findProvider(this.providerName, KeyStore.class.getSimpleName(), type); - ts = provider != null ? KeyStore.getInstance(type, provider) : KeyStore.getInstance(type); - loadKeyStore(ts, trustStoreFileName, trustStorePassword, classLoader); + private X509ExtendedTrustManager getTrustManager() { + try { + Provider provider; + KeyStore ts = trustStore != null ? trustStore : null; + if (ts == null) { + String type = trustStoreType != null ? trustStoreType : DEFAULT_KEYSTORE_TYPE; + provider = findProvider(this.providerName, KeyStore.class.getSimpleName(), type); + ts = provider != null ? KeyStore.getInstance(type, provider) : KeyStore.getInstance(type); + loadKeyStore(ts, trustStoreFileName, trustStorePassword, classLoader); + } + String algorithm = KeyManagerFactory.getDefaultAlgorithm(); + provider = findProvider(this.providerName, TrustManagerFactory.class.getSimpleName(), algorithm); + TrustManagerFactory tmf = provider != null ? TrustManagerFactory.getInstance(algorithm, provider) : TrustManagerFactory.getInstance(algorithm); + tmf.init(ts); + for (TrustManager tm : tmf.getTrustManagers()) { + if (tm instanceof X509ExtendedTrustManager) { + return (X509ExtendedTrustManager) tm; + } + } + throw new GeneralSecurityException("Could not obtain an X509TrustManager"); + } catch (GeneralSecurityException | IOException e) { + throw new RuntimeException("Error while initializing SSL context", e); } - String algorithm = KeyManagerFactory.getDefaultAlgorithm(); - provider = findProvider(this.providerName, TrustManagerFactory.class.getSimpleName(), algorithm); - TrustManagerFactory tmf = provider != null ? TrustManagerFactory.getInstance(algorithm, provider) : TrustManagerFactory.getInstance(algorithm); - tmf.init(ts); - return tmf; } private static void loadKeyStore(KeyStore ks, String keyStoreFileName, char[] keyStorePassword, ClassLoader classLoader) throws IOException, GeneralSecurityException { diff --git a/src/org/jgroups/util/TLS.java b/src/org/jgroups/util/TLS.java index 25e56e801cf..9641ec57b4c 100644 --- a/src/org/jgroups/util/TLS.java +++ b/src/org/jgroups/util/TLS.java @@ -64,7 +64,7 @@ public class TLS implements Lifecycle { converter=SniMatcherConverter.class) protected List sni_matchers=new ArrayList<>(); - + protected FileWatcher watcher; public boolean enabled() {return enabled;} @@ -109,6 +109,8 @@ public class TLS implements Lifecycle { public List getSniMatchers() {return sni_matchers;} public TLS setSniMatchers(List s) {this.sni_matchers=s; return this;} + public FileWatcher getWatcher() {return watcher;} + public void setWatcher(FileWatcher watcher) {this.watcher = watcher;} @Override public void init() throws Exception { @@ -137,7 +139,8 @@ public SSLContext createContext() { .keyAlias(keystore_alias) .trustStoreFileName(truststore_path) .trustStorePassword(truststore_password) - .trustStoreType(truststore_type); + .trustStoreType(truststore_type) + .watcher(watcher); return sslContextFactory.getContext(); }