Skip to content

Commit

Permalink
SPARK-3883: Refactored methods to resolve Akka address and made it po…
Browse files Browse the repository at this point in the history
…ssible to easily configure multiple communication layers for SSL
  • Loading branch information
jacek-lewandowski committed Feb 2, 2015
1 parent 72b2541 commit 90a8762
Show file tree
Hide file tree
Showing 18 changed files with 114 additions and 73 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/HttpServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private[spark] class HttpServer(
private def doStart(startPort: Int): (Server, Int) = {
val server = new Server()

val connector = securityManager.sslOptions.createJettySslContextFactory()
val connector = securityManager.fileServerSSLOptions.createJettySslContextFactory()
.map(new SslSocketConnector(_)).getOrElse(new SocketConnector)

connector.setMaxIdleTime(60 * 1000)
Expand Down Expand Up @@ -159,7 +159,7 @@ private[spark] class HttpServer(
if (server == null) {
throw new ServerStateException("Server is not started")
} else {
val scheme = if (securityManager.sslOptions.enabled) "https" else "http"
val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http"
s"$scheme://${Utils.localIpAddress}:$port"
}
}
Expand Down
22 changes: 15 additions & 7 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,25 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
)
}

val sslOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None)
logDebug(s"SSLConfiguration: $sslOptions")
// the default SSL configuration - it will be used by all communication layers unless overwritten
private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None)

val (sslSocketFactory, hostnameVerifier) = if (sslOptions.enabled) {
// SSL configuration for different communication layers - they can override the default
// configuration at a specified namespace. The namespace *must* start with spark.ssl.
val fileServerSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.fs", Some(defaultSSLOptions))
val akkaSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.akka", Some(defaultSSLOptions))

logDebug(s"SSLConfiguration for file server: $fileServerSSLOptions")
logDebug(s"SSLConfiguration for Akka: $akkaSSLOptions")

val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) {
val trustStoreManagers =
for (trustStore <- sslOptions.trustStore) yield {
val input = Files.asByteSource(sslOptions.trustStore.get).openStream()
for (trustStore <- fileServerSSLOptions.trustStore) yield {
val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream()

try {
val ks = KeyStore.getInstance(KeyStore.getDefaultType)
ks.load(input, sslOptions.trustStorePassword.get.toCharArray)
ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray)

val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
tmf.init(ks)
Expand All @@ -232,7 +240,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
}: TrustManager
})

val sslContext = SSLContext.getInstance(sslOptions.protocol.getOrElse("Default"))
val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.getOrElse("Default"))
sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null)

val hostVerifier = new HostnameVerifier {
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/deploy/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
val timeout = AkkaUtils.askTimeout(conf)

override def preStart() = {
masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master, conf))
masterActor = context.actorSelection(
Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system)))

context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])

Expand Down Expand Up @@ -161,7 +162,7 @@ object Client {
"driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))

// Verify driverArgs.master is a valid url so that we can use it in ClientActor safely
Master.toAkkaUrl(driverArgs.master, conf)
Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))

actorSystem.awaitTermination()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[spark] class AppClient(
conf: SparkConf)
extends Logging {

val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, conf))
val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))

val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
Expand Down Expand Up @@ -107,8 +107,9 @@ private[spark] class AppClient(
def changeMaster(url: String) {
// activeMasterUrl is a valid Spark url since we receive it from master.
activeMasterUrl = url
master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl, conf))
masterAddress = Master.toAkkaAddress(activeMasterUrl, conf)
master = context.actorSelection(
Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem)))
masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem))
}

private def isPossibleMaster(remoteUrl: Address) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -860,19 +860,19 @@ private[spark] object Master extends Logging {
*
* @throws SparkException if the url is invalid
*/
def toAkkaUrl(sparkUrl: String, conf: SparkConf): String = {
def toAkkaUrl(sparkUrl: String, protocol: String): String = {
val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
AkkaUtils.address(systemName, host, port, actorName, conf)
AkkaUtils.address(protocol, systemName, host, port, actorName)
}

/**
* Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`.
*
* @throws SparkException if the url is invalid
*/
def toAkkaAddress(sparkUrl: String, conf: SparkConf): Address = {
def toAkkaAddress(sparkUrl: String, protocol: String): Address = {
val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
Address(AkkaUtils.protocol(conf), systemName, host, port)
Address(protocol, systemName, host, port)
}

def startSystemAndActor(
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ private[spark] class Worker(
var masterAddress: Address = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
val akkaUrl = AkkaUtils.address(actorSystemName, host, port, actorName, conf)
val akkaUrl = AkkaUtils.address(
AkkaUtils.protocol(context.system),
actorSystemName,
host,
port,
actorName)
@volatile var registered = false
@volatile var connected = false
val workerId = generateWorkerId()
Expand Down Expand Up @@ -174,8 +179,9 @@ private[spark] class Worker(
// activeMasterUrl it's a valid Spark url since we receive it from master.
activeMasterUrl = url
activeMasterWebUiUrl = uiUrl
master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl, conf))
masterAddress = Master.toAkkaAddress(activeMasterUrl, conf)
master = context.actorSelection(
Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system)))
masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system))
connected = true
// Cancel any outstanding re-registration attempts because we found a new master
registrationRetryTimer.foreach(_.cancel())
Expand Down Expand Up @@ -540,7 +546,7 @@ private[spark] object Worker extends Logging {
val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
conf = conf, securityManager = securityMgr)
val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, conf))
val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ private[spark] class SimrSchedulerBackend(
super.start()

val driverUrl = AkkaUtils.address(
AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
sc.conf.get("spark.driver.host"),
sc.conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME,
sc.conf)
CoarseGrainedSchedulerBackend.ACTOR_NAME)

val conf = SparkHadoopUtil.get.newConfiguration(sc.conf)
val fs = FileSystem.get(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ private[spark] class SparkDeploySchedulerBackend(

// The endpoint for executors to talk to us
val driverUrl = AkkaUtils.address(
AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME,
conf)
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}",
"{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ private[spark] class CoarseMesosSchedulerBackend(
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
val driverUrl = AkkaUtils.address(
AkkaUtils.protocol(sc.env.actorSystem),
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME,
conf)
CoarseGrainedSchedulerBackend.ACTOR_NAME)

val uri = conf.get("spark.executor.uri", null)
if (uri == null) {
Expand Down
27 changes: 18 additions & 9 deletions core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ package org.apache.spark.util
import scala.collection.JavaConversions.mapAsJavaMap
import scala.concurrent.Await
import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.util.Try

import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem}
import akka.pattern.ask

import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}

import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, SSLOptions}
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}

/**
* Various utility classes for working with Akka.
Expand Down Expand Up @@ -91,7 +92,8 @@ private[spark] object AkkaUtils extends Logging {
val secureCookie = if (isAuthOn) secretKey else ""
logDebug(s"In createActorSystem, requireCookie is: $requireCookie")

val akkaSslConfig = securityManager.sslOptions.createAkkaConfig.getOrElse(ConfigFactory.empty())
val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig
.getOrElse(ConfigFactory.empty())

val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String])
.withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString(
Expand Down Expand Up @@ -216,7 +218,7 @@ private[spark] object AkkaUtils extends Logging {
val driverHost: String = conf.get("spark.driver.host", "localhost")
val driverPort: Int = conf.getInt("spark.driver.port", 7077)
Utils.checkHost(driverHost, "Expected hostname")
val url = address(driverActorSystemName, driverHost, driverPort, name, conf)
val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name)
val timeout = AkkaUtils.lookupTimeout(conf)
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
Expand All @@ -230,26 +232,33 @@ private[spark] object AkkaUtils extends Logging {
actorSystem: ActorSystem): ActorRef = {
val executorActorSystemName = SparkEnv.executorActorSystemName
Utils.checkHost(host, "Expected hostname")
val url = address(executorActorSystemName, host, port, name, conf)
val url = address(protocol(actorSystem), executorActorSystemName, host, port, name)
val timeout = AkkaUtils.lookupTimeout(conf)
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
}

def protocol(conf: SparkConf): String = {
if (conf.getBoolean("spark.ssl.enabled", defaultValue = false)) {
def protocol(actorSystem: ActorSystem): String = {
protocol(Try {
actorSystem.settings.config.getBoolean("akka.remote.netty.tcp.enable-ssl")
}.getOrElse(false))
}

def protocol(ssl: Boolean = false): String = {
if (ssl) {
"akka.ssl.tcp"
} else {
"akka.tcp"
}
}

def address(
protocol: String,
systemName: String,
host: String,
port: Any,
actorName: String,
conf: SparkConf): String = {
s"${protocol(conf)}://$systemName@$host:$port/user/$actorName"
actorName: String): String = {
s"$protocol://$systemName@$host:$port/user/$actorName"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class MapOutputTrackerSuite extends FunSuite {
securityManager = new SecurityManager(conf))
val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf))
AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
val timeout = AkkaUtils.lookupTimeout(conf)
slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)

Expand Down
36 changes: 25 additions & 11 deletions core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,32 @@ class SecurityManagerSuite extends FunSuite {

val securityManager = new SecurityManager(conf)

assert(securityManager.sslOptions.enabled === true)
assert(securityManager.fileServerSSLOptions.enabled === true)
assert(securityManager.akkaSSLOptions.enabled === true)

assert(securityManager.sslSocketFactory.isDefined === true)
assert(securityManager.hostnameVerifier.isDefined === true)

assert(securityManager.sslOptions.trustStore.isDefined === true)
assert(securityManager.sslOptions.trustStore.get.getName === "truststore")
assert(securityManager.sslOptions.keyStore.isDefined === true)
assert(securityManager.sslOptions.keyStore.get.getName === "keystore")
assert(securityManager.sslOptions.trustStorePassword === Some("password"))
assert(securityManager.sslOptions.keyStorePassword === Some("password"))
assert(securityManager.sslOptions.keyPassword === Some("password"))
assert(securityManager.sslOptions.protocol === Some("TLSv1"))
assert(securityManager.sslOptions.enabledAlgorithms ===
assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true)
assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore")
assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true)
assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore")
assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password"))
assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password"))
assert(securityManager.fileServerSSLOptions.keyPassword === Some("password"))
assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1"))
assert(securityManager.fileServerSSLOptions.enabledAlgorithms ===
Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA"))

assert(securityManager.akkaSSLOptions.trustStore.isDefined === true)
assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore")
assert(securityManager.akkaSSLOptions.keyStore.isDefined === true)
assert(securityManager.akkaSSLOptions.keyStore.get.getName === "keystore")
assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password"))
assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password"))
assert(securityManager.akkaSSLOptions.keyPassword === Some("password"))
assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1"))
assert(securityManager.akkaSSLOptions.enabledAlgorithms ===
Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA"))
}

Expand All @@ -155,7 +168,8 @@ class SecurityManagerSuite extends FunSuite {

val securityManager = new SecurityManager(conf)

assert(securityManager.sslOptions.enabled === false)
assert(securityManager.fileServerSSLOptions.enabled === false)
assert(securityManager.akkaSSLOptions.enabled === false)
assert(securityManager.sslSocketFactory.isDefined === false)
assert(securityManager.hostnameVerifier.isDefined === false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,46 @@ package org.apache.spark.deploy.master
import akka.actor.Address
import org.scalatest.FunSuite

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.{SSLOptions, SparkConf, SparkException}

class MasterSuite extends FunSuite {

test("toAkkaUrl") {
val conf = new SparkConf(loadDefaults = false)
val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", conf)
val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp")
assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl)
}

test("toAkkaUrl with SSL") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.ssl.enabled", "true")
val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", conf)
val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp")
assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl)
}

test("toAkkaUrl: a typo url") {
val conf = new SparkConf(loadDefaults = false)
val e = intercept[SparkException] {
Master.toAkkaUrl("spark://1.2. 3.4:1234", conf)
Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp")
}
assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
}

test("toAkkaAddress") {
val conf = new SparkConf(loadDefaults = false)
val address = Master.toAkkaAddress("spark://1.2.3.4:1234", conf)
val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp")
assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address)
}

test("toAkkaAddress with SSL") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.ssl.enabled", "true")
val address = Master.toAkkaAddress("spark://1.2.3.4:1234", conf)
val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp")
assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address)
}

test("toAkkaAddress: a typo url") {
val conf = new SparkConf(loadDefaults = false)
val e = intercept[SparkException] {
Master.toAkkaAddress("spark://1.2. 3.4:1234", conf)
Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp")
}
assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
}
Expand Down
Loading

0 comments on commit 90a8762

Please sign in to comment.