diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2479911dd6237..511b7ba4f74ae 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -98,16 +98,31 @@ private[spark] object SSLOptions extends Logging { * The parent directory of that location is used as a base directory to resolve relative paths * to keystore and truststore. */ - def parse(conf: SparkConf, ns: String): SSLOptions = { - val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = false) + def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) + val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) + .orElse(defaults.flatMap(_.keyStore)) + val keyStorePassword = conf.getOption(s"$ns.keyStorePassword") + .orElse(defaults.flatMap(_.keyStorePassword)) + val keyPassword = conf.getOption(s"$ns.keyPassword") + .orElse(defaults.flatMap(_.keyPassword)) + val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_)) + .orElse(defaults.flatMap(_.trustStore)) + val trustStorePassword = conf.getOption(s"$ns.trustStorePassword") + .orElse(defaults.flatMap(_.trustStorePassword)) + val protocol = conf.getOption(s"$ns.protocol") - val enabledAlgorithms = conf.get(s"$ns.enabledAlgorithms", defaultValue = "") - .split(",").map(_.trim).filter(_.nonEmpty).toSet + .orElse(defaults.flatMap(_.protocol)) + + val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms") + .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet) + .orElse(defaults.map(_.enabledAlgorithms)) + .getOrElse(Set.empty) new SSLOptions( enabled, diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index d3ab0790e5e95..923fde26c0164 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -201,7 +201,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) ) } - val sslOptions = SSLOptions.parse(sparkConf, "spark.ssl") + val sslOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) logDebug(s"SSLConfiguration: $sslOptions") val (sslSocketFactory, hostnameVerifier) = if (sslOptions.enabled) { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 01b20b5deeb7a..444a33371bd71 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -55,4 +55,69 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) } + test("test resolving property with defaults specified ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + test("test whether defaults can be overridden ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ui.ssl.enabled", "false") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ui.ssl.keyStorePassword", "12345") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + + assert(opts.enabled === false) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("12345")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === Set("ABC", "DEF")) + } + }