From bc7094e168f4a905ab09d7059c828d63e43265e3 Mon Sep 17 00:00:00 2001 From: Hasnain Lakhani Date: Fri, 13 Oct 2023 11:46:10 -0500 Subject: [PATCH] [SPARK-45426][CORE] Add support for a ReloadingX509TrustManager ### What changes were proposed in this pull request? This adds in support for trust store reloading, mirroring the Hadoop implementation (see source comments for a link). I believe reusing the existing code instead of adding a dependency is fine license wise (see https://github.com/apache/spark/pull/42685/files#r1333667328) ### Why are the changes needed? This helps us refresh trust stores without needing downtime ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests (also copied from upstream) ``` build/sbt > project network-common > testOnly org.apache.spark.network.ssl.ReloadingX509TrustManagerSuite ``` The rest of the changes and integration were tested as part of https://github.com/apache/spark/pull/42685 ### Was this patch authored or co-authored using generative AI tooling? No Closes #43249 from hasnain-db/spark-tls-reloading. Authored-by: Hasnain Lakhani Signed-off-by: Mridul Muralidharan gmail.com> --- .../ssl/ReloadingX509TrustManager.java | 226 +++++++++++++ .../ssl/ReloadingX509TrustManagerSuite.java | 317 ++++++++++++++++++ .../spark/network/ssl/SslSampleConfigs.java | 16 +- 3 files changed, 555 insertions(+), 4 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java b/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java new file mode 100644 index 0000000000000..4c39a5d2a3de2 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.ssl; + +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.concurrent.atomic.AtomicReference; + +import com.google.common.annotations.VisibleForTesting; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * A {@link TrustManager} implementation that reloads its configuration when + * the truststore file on disk changes. + * This implementation is based off of the + * org.apache.hadoop.security.ssl.ReloadingX509TrustManager class in the Apache Hadoop Encrypted + * Shuffle implementation. + * + * @see Hadoop MapReduce Next Generation - Encrypted Shuffle + */ +public final class ReloadingX509TrustManager + implements X509TrustManager, Runnable { + + private final Logger logger = LoggerFactory.getLogger(ReloadingX509TrustManager.class); + + private final String type; + private final File file; + // The file being pointed to by `file` if it's a link + private String canonicalPath; + private final String password; + private long lastLoaded; + private final long reloadInterval; + @VisibleForTesting + protected volatile int reloadCount; + @VisibleForTesting + protected volatile int needsReloadCheckCounts; + private final AtomicReference trustManagerRef; + + private volatile boolean running; + private Thread reloader; + + /** + * Creates a reloadable trustmanager. The trustmanager reloads itself + * if the underlying trustore file has changed. + * + * @param type type of truststore file, typically 'jks'. + * @param trustStore the truststore file. + * @param password password of the truststore file. + * @param reloadInterval interval to check if the truststore file has + * changed, in milliseconds. + * @throws IOException thrown if the truststore could not be initialized due + * to an IO error. + * @throws GeneralSecurityException thrown if the truststore could not be + * initialized due to a security error. + */ + public ReloadingX509TrustManager( + String type, File trustStore, String password, long reloadInterval) + throws IOException, GeneralSecurityException { + this.type = type; + this.file = trustStore; + this.canonicalPath = this.file.getCanonicalPath(); + this.password = password; + this.trustManagerRef = new AtomicReference(); + this.trustManagerRef.set(loadTrustManager()); + this.reloadInterval = reloadInterval; + this.reloadCount = 0; + this.needsReloadCheckCounts = 0; + } + + /** + * Starts the reloader thread. + */ + public void init() { + reloader = new Thread(this, "Truststore reloader thread"); + reloader.setDaemon(true); + running = true; + reloader.start(); + } + + /** + * Stops the reloader thread. + */ + public void destroy() throws InterruptedException { + running = false; + reloader.interrupt(); + reloader.join(); + } + + /** + * Returns the reload check interval. + * + * @return the reload check interval, in milliseconds. + */ + public long getReloadInterval() { + return reloadInterval; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + X509TrustManager tm = trustManagerRef.get(); + if (tm != null) { + tm.checkClientTrusted(chain, authType); + } else { + throw new CertificateException("Unknown client chain certificate: " + + chain[0].toString() + ". Please ensure the correct trust store is specified in the config"); + } + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + X509TrustManager tm = trustManagerRef.get(); + if (tm != null) { + tm.checkServerTrusted(chain, authType); + } else { + throw new CertificateException("Unknown server chain certificate: " + + chain[0].toString() + ". Please ensure the correct trust store is specified in the config"); + } + } + + private static final X509Certificate[] EMPTY = new X509Certificate[0]; + + @Override + public X509Certificate[] getAcceptedIssuers() { + X509Certificate[] issuers = EMPTY; + X509TrustManager tm = trustManagerRef.get(); + if (tm != null) { + issuers = tm.getAcceptedIssuers(); + } + return issuers; + } + + boolean needsReload() throws IOException { + boolean reload = true; + File latestCanonicalFile = file.getCanonicalFile(); + if (file.exists() && latestCanonicalFile.exists()) { + // `file` can be a symbolic link. We need to reload if it points to another file, + // or if the file has been modified + if (latestCanonicalFile.getPath().equals(canonicalPath) && + latestCanonicalFile.lastModified() == lastLoaded) { + reload = false; + } + } else { + lastLoaded = 0; + } + return reload; + } + + X509TrustManager loadTrustManager() + throws IOException, GeneralSecurityException { + X509TrustManager trustManager = null; + KeyStore ks = KeyStore.getInstance(type); + File latestCanonicalFile = file.getCanonicalFile(); + canonicalPath = latestCanonicalFile.getPath(); + lastLoaded = latestCanonicalFile.lastModified(); + try (FileInputStream in = new FileInputStream(latestCanonicalFile)) { + ks.load(in, password.toCharArray()); + logger.debug("Loaded truststore '" + file + "'"); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + TrustManager[] trustManagers = trustManagerFactory.getTrustManagers(); + for (TrustManager trustManager1 : trustManagers) { + if (trustManager1 instanceof X509TrustManager) { + trustManager = (X509TrustManager) trustManager1; + break; + } + } + return trustManager; + } + + @Override + public void run() { + while (running) { + try { + Thread.sleep(reloadInterval); + } catch (InterruptedException e) { + //NOP + } + try { + if (running && needsReload()) { + try { + trustManagerRef.set(loadTrustManager()); + this.reloadCount += 1; + } catch (Exception ex) { + logger.warn( + "Could not load truststore (keep using existing one) : " + ex.toString(), + ex + ); + } + } + } catch (IOException ex) { + logger.warn("Could not check whether truststore needs reloading: " + ex.toString(), ex); + } + needsReloadCheckCounts++; + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java new file mode 100644 index 0000000000000..7e2cc38e70b34 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.ssl; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.junit.jupiter.api.Assertions.*; + +import static org.apache.spark.network.ssl.SslSampleConfigs.*; + +public class ReloadingX509TrustManagerSuite { + + private final Logger logger = LoggerFactory.getLogger(ReloadingX509TrustManagerSuite.class); + + /** + * Waits until reload count hits the requested value, sleeping 100ms at a time. + * If the maximum number of attempts is hit, throws a RuntimeException + * @param tm the trust manager to wait for + * @param count The count to wait for + * @param attempts The number of attempts to wait for + */ + private void waitForReloadCount(ReloadingX509TrustManager tm, int count, int attempts) + throws InterruptedException { + if (tm.reloadCount > count) { + throw new IllegalStateException( + "Passed invalid count " + count + " to waitForReloadCount, already have " + tm.reloadCount); + } + for (int i = 0; i < attempts; i++) { + if (tm.reloadCount >= count) { + return; + } + // Adapted from SystemClock.waitTillTime + long startTime = System.currentTimeMillis(); + long targetTime = startTime + 100; + long currentTime = startTime; + while (currentTime < targetTime) { + long sleepTime = Math.min(10, targetTime - currentTime); + Thread.sleep(sleepTime); + currentTime = System.currentTimeMillis(); + } + } + throw new IllegalStateException("Trust store not reloaded after " + attempts + " attempts!"); + } + + /** + * Waits until we make some number of attempts to reload, and verifies + * that the actual reload count did not change + * + * @param tm the trust manager to wait for + * @param attempts The number of attempts to wait for + */ + private void waitForNoReload(ReloadingX509TrustManager tm, int attempts) + throws InterruptedException { + int oldReloadCount = tm.reloadCount; + int checkCount = tm.needsReloadCheckCounts; + int target = checkCount + attempts; + while (checkCount < target) { + Thread.sleep(100); + checkCount = tm.needsReloadCheckCounts; + } + assertEquals(oldReloadCount, tm.reloadCount); + } + + /** + * Tests to ensure that loading a missing trust-store fails + * + * @throws Exception + */ + @Test + public void testLoadMissingTrustStore() throws Exception { + File trustStore = new File("testmissing.jks"); + assertFalse(trustStore.exists()); + + assertThrows(IOException.class, () -> { + ReloadingX509TrustManager tm = new ReloadingX509TrustManager( + KeyStore.getDefaultType(), + trustStore, + "password", + 10 + ); + try { + tm.init(); + } finally { + tm.destroy(); + } + }); + } + + /** + * Tests to ensure that loading a corrupt trust-store fails + * + * @throws Exception + */ + @Test + public void testLoadCorruptTrustStore() throws Exception { + File corruptStore = File.createTempFile("truststore-corrupt", "jks"); + corruptStore.deleteOnExit(); + OutputStream os = new FileOutputStream(corruptStore); + os.write(1); + os.close(); + + assertThrows(IOException.class, () -> { + ReloadingX509TrustManager tm = new ReloadingX509TrustManager( + KeyStore.getDefaultType(), + corruptStore, + "password", + 10 + ); + try { + tm.init(); + } finally { + tm.destroy(); + corruptStore.delete(); + } + }); + } + + /** + * Tests that we successfully reload when a file is updated + * @throws Exception + */ + @Test + public void testReload() throws Exception { + KeyPair kp = generateKeyPair("RSA"); + X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA"); + X509Certificate cert2 = generateCertificate("CN=Cert2", kp, 30, "SHA1withRSA"); + File trustStore = File.createTempFile("testreload", "jks"); + trustStore.deleteOnExit(); + createTrustStore(trustStore, "password", "cert1", cert1); + + ReloadingX509TrustManager tm = + new ReloadingX509TrustManager("jks", trustStore, "password", 1); + assertEquals(1, tm.getReloadInterval()); + assertEquals(0, tm.reloadCount); + try { + tm.init(); + assertEquals(1, tm.getAcceptedIssuers().length); + // At this point we haven't reloaded, just the initial load + assertEquals(0, tm.reloadCount); + + // Add another cert + Map certs = new HashMap(); + certs.put("cert1", cert1); + certs.put("cert2", cert2); + createTrustStore(trustStore, "password", certs); + + // Wait up to 5s until we reload + waitForReloadCount(tm, 1, 50); + + assertEquals(2, tm.getAcceptedIssuers().length); + } finally { + tm.destroy(); + trustStore.delete(); + } + } + + /** + * Tests that we keep old certs if the trust store goes missing + * + * @throws Exception + */ + @Test + public void testReloadMissingTrustStore() throws Exception { + KeyPair kp = generateKeyPair("RSA"); + X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA"); + File trustStore = new File("testmissing.jks"); + trustStore.deleteOnExit(); + assertFalse(trustStore.exists()); + createTrustStore(trustStore, "password", "cert1", cert1); + + ReloadingX509TrustManager tm = + new ReloadingX509TrustManager("jks", trustStore, "password", 1); + assertEquals(0, tm.reloadCount); + try { + tm.init(); + assertEquals(1, tm.getAcceptedIssuers().length); + X509Certificate cert = tm.getAcceptedIssuers()[0]; + trustStore.delete(); + + // Wait for up to 5s - we should *not* reload + waitForNoReload(tm, 50); + + assertEquals(1, tm.getAcceptedIssuers().length); + assertEquals(cert, tm.getAcceptedIssuers()[0]); + } finally { + tm.destroy(); + } + } + + /** + * Tests that we keep old certs if the new truststore is corrupt + * @throws Exception + */ + @Test + public void testReloadCorruptTrustStore() throws Exception { + KeyPair kp = generateKeyPair("RSA"); + X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA"); + File corruptStore = File.createTempFile("truststore-corrupt", "jks"); + corruptStore.deleteOnExit(); + createTrustStore(corruptStore, "password", "cert1", cert1); + + ReloadingX509TrustManager tm = + new ReloadingX509TrustManager("jks", corruptStore, "password", 1); + assertEquals(0, tm.reloadCount); + try { + tm.init(); + assertEquals(1, tm.getAcceptedIssuers().length); + X509Certificate cert = tm.getAcceptedIssuers()[0]; + + OutputStream os = new FileOutputStream(corruptStore); + os.write(1); + os.close(); + corruptStore.setLastModified(System.currentTimeMillis() - 1000); + + // Wait for up to 5s - we should *not* reload + waitForNoReload(tm, 50); + + assertEquals(1, tm.getAcceptedIssuers().length); + assertEquals(cert, tm.getAcceptedIssuers()[0]); + } finally { + tm.destroy(); + corruptStore.delete(); + } + } + + /** + * Tests that we successfully reload when the trust store is a symlink + * and we update the contents of the pointed-to file or we update the file it points to. + * @throws Exception + */ + @Test + public void testReloadSymlink() throws Exception { + KeyPair kp = generateKeyPair("RSA"); + X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA"); + X509Certificate cert2 = generateCertificate("CN=Cert2", kp, 30, "SHA1withRSA"); + X509Certificate cert3 = generateCertificate("CN=Cert3", kp, 30, "SHA1withRSA"); + + File trustStore1 = File.createTempFile("testreload", "jks"); + trustStore1.deleteOnExit(); + createTrustStore(trustStore1, "password", "cert1", cert1); + + File trustStore2 = File.createTempFile("testreload", "jks"); + Map certs = new HashMap(); + certs.put("cert1", cert1); + certs.put("cert2", cert2); + createTrustStore(trustStore2, "password", certs); + + File trustStoreSymlink = File.createTempFile("testreloadsymlink", "jks"); + trustStoreSymlink.delete(); + Files.createSymbolicLink(trustStoreSymlink.toPath(), trustStore1.toPath()); + + ReloadingX509TrustManager tm = + new ReloadingX509TrustManager("jks", trustStoreSymlink, "password", 1); + assertEquals(1, tm.getReloadInterval()); + assertEquals(0, tm.reloadCount); + logger.info("TRUST STORE 1 IS" + trustStore1); + logger.info("TRUST STORE 2 IS " + trustStore2); + try { + tm.init(); + assertEquals(1, tm.getAcceptedIssuers().length); + // At this point we haven't reloaded, just the initial load + assertEquals(0, tm.reloadCount); + + // Repoint to trustStore2, which has another cert + logger.info("REPOINTING SYMLINK!!!"); + trustStoreSymlink.delete(); + Files.createSymbolicLink(trustStoreSymlink.toPath(), trustStore2.toPath()); + logger.info("REPOINTED!!!"); + + // Wait up to 5s until we reload + waitForReloadCount(tm, 1, 50); + + assertEquals(2, tm.getAcceptedIssuers().length); + + // Add another cert + certs.put("cert3", cert3); + createTrustStore(trustStore2, "password", certs); + + // Wait up to 5s until we reload + waitForReloadCount(tm, 2, 50); + + assertEquals(3, tm.getAcceptedIssuers().length); + } finally { + tm.destroy(); + trustStore1.delete(); + trustStore2.delete(); + trustStoreSymlink.delete(); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java b/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java index 3c81b0af3186c..2a04d740e8ad8 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java @@ -21,6 +21,8 @@ import java.io.FileOutputStream; import java.io.IOException; import java.math.BigInteger; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; import java.security.GeneralSecurityException; import java.security.InvalidKeyException; import java.security.Key; @@ -41,9 +43,6 @@ import org.apache.spark.network.util.MapConfigProvider; -/** - * - */ public class SslSampleConfigs { public static final String keyStorePath = getAbsolutePath("/keystore"); @@ -217,9 +216,18 @@ private static KeyStore createEmptyKeyStore() private static void saveKeyStore( KeyStore ks, File keyStore, String password) throws GeneralSecurityException, IOException { - FileOutputStream out = new FileOutputStream(keyStore); + // Write the file atomically to ensure tests don't read a partial write + File tempFile = File.createTempFile("temp-key-store", "jks"); + FileOutputStream out = new FileOutputStream(tempFile); try { ks.store(out, password.toCharArray()); + out.close(); + Files.move( + tempFile.toPath(), + keyStore.toPath(), + StandardCopyOption.REPLACE_EXISTING, + StandardCopyOption.ATOMIC_MOVE + ); } finally { out.close(); }