Skip to content

Commit

Permalink
[SPARK-4277] Support external shuffle service on executor
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Nov 6, 2014
1 parent b41a39e commit 258417c
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 14 deletions.
14 changes: 3 additions & 11 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,15 +343,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with
*/
def getSecretKey(): String = secretKey

override def getSaslUser(appId: String): String = {
val myAppId = sparkConf.getAppId
require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
getSaslUser()
}

override def getSecretKey(appId: String): String = {
val myAppId = sparkConf.getAppId
require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
getSecretKey()
}
// Default SecurityManager only has a single secret key, so ignore appId.
override def getSaslUser(appId: String): String = getSaslUser()
override def getSecretKey(appId: String): String = getSecretKey()
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ private[spark] class Worker(
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]

// The shuffle service is not actually started unless configured.
var shuffleService = new WorkerShuffleService(conf, securityMgr)

val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else host
Expand Down Expand Up @@ -154,6 +157,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
Expand Down Expand Up @@ -419,6 +423,7 @@ private[spark] class Worker(
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
Expand All @@ -441,7 +446,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
workDir: String,
workerNumber: Option[Int] = None): (ActorSystem, Int) = {

// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.deploy.worker

import org.apache.spark.{Logging, SparkConf, SecurityManager}
import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslRpcHandler
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler

/**
* Provides a server from which Executors can read shuffle files (rather than reading directly from
* each other), to provide uninterrupted access to the files in the face of executors being turned
* off or killed.
*
* Optionally requires SASL authentication in order to read. See [[SecurityManager]].
*/
class WorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging {

private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
private val useSasl: Boolean = securityManager.isAuthenticationEnabled()

private val transportConf = SparkTransportConf.fromSparkConf(sparkConf)
private val blockHandler = new ExternalShuffleBlockHandler()
private val transportContext: TransportContext = {
val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
new TransportContext(transportConf, handler)
}

private var server: TransportServer = _

/** Starts the external shuffle service if the user has configured us to. */
def startIfEnabled() {
if (enabled) {
require(server == null, "Shuffle server already started")
logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
server = transportContext.createServer(port)
}
}

def stop() {
if (enabled && server != null) {
server.close()
server = null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
private[this] var currentResult: FetchResult = null
@volatile private[this] var currentResult: FetchResult = null

/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public void encode(ByteBuf buf) {

public static SaslMessage decode(ByteBuf buf) {
if (buf.readByte() != TAG_BYTE) {
throw new IllegalStateException("Expected SaslMessage, received something else");
throw new IllegalStateException("Expected SaslMessage, received something else"
+ " (maybe your client does not have SASL enabled?)");
}

int idLength = buf.readInt();
Expand Down

0 comments on commit 258417c

Please sign in to comment.