Skip to content

Commit

Permalink
Integrate REST protocol in standalone mode
Browse files Browse the repository at this point in the history
This commit embeds the REST server in the standalone Master and
forces Spark submit to submit applications through the REST client.
This is the first working end-to-end implementation of a stable
submission interface in standalone cluster mode.
  • Loading branch information
Andrew Or committed Jan 19, 2015
1 parent 53e7c0e commit af9d9cb
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 95 deletions.
18 changes: 11 additions & 7 deletions core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ import org.apache.spark.util.MemoryParam
* Command-line parser for the driver client.
*/
private[spark] class ClientArguments(args: Array[String]) {
val defaultCores = 1
val defaultMemory = 512
import ClientArguments._

var cmd: String = "" // 'launch' or 'kill'
var logLevel = Level.WARN
Expand All @@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) {
var master: String = ""
var jarUrl: String = ""
var mainClass: String = ""
var supervise: Boolean = false
var memory: Int = defaultMemory
var cores: Int = defaultCores
var supervise: Boolean = DEFAULT_SUPERVISE
var memory: Int = DEFAULT_MEMORY
var cores: Int = DEFAULT_CORES
private var _driverOptions = ListBuffer[String]()
def driverOptions = _driverOptions.toSeq

Expand Down Expand Up @@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) {
|Usage: DriverClient kill <active-master> <driver-id>
|
|Options:
| -c CORES, --cores CORES Number of cores to request (default: $defaultCores)
| -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory)
| -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES)
| -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY)
| -s, --supervise Whether to restart the driver on failure
| (default: $DEFAULT_SUPERVISE)
| -v, --verbose Print more debugging output
""".stripMargin
System.err.println(usage)
Expand All @@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) {
}

object ClientArguments {
private[spark] val DEFAULT_CORES = 1
private[spark] val DEFAULT_MEMORY = 512 // MB
private[spark] val DEFAULT_SUPERVISE = false

def isValidJarUrl(s: String): Boolean = {
try {
val uri = new URI(s)
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map}

import org.apache.spark.executor.ExecutorURLClassLoader
import org.apache.spark.util.Utils
import org.apache.spark.deploy.rest.StandaloneRestClient

/**
* Main gateway of launching a Spark application.
Expand Down Expand Up @@ -72,6 +73,16 @@ object SparkSubmit {
if (appArgs.verbose) {
printStream.println(appArgs)
}

// In standalone cluster mode, use the brand new REST client to submit the application
val doingRest = appArgs.master.startsWith("spark://") && appArgs.deployMode == "cluster"
if (doingRest) {
println("Submitting driver through the REST interface.")
new StandaloneRestClient().submitDriver(appArgs)
println("Done submitting driver.")
return
}

val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
.orElse(sparkProperties.get("spark.master"))
.orElse(env.get("MASTER"))
.orNull
driverExtraClassPath = Option(driverExtraClassPath)
.orElse(sparkProperties.get("spark.driver.extraClassPath"))
.orNull
driverExtraJavaOptions = Option(driverExtraJavaOptions)
.orElse(sparkProperties.get("spark.driver.extraJavaOptions"))
.orNull
driverExtraLibraryPath = Option(driverExtraLibraryPath)
.orElse(sparkProperties.get("spark.driver.extraLibraryPath"))
.orNull
driverMemory = Option(driverMemory)
.orElse(sparkProperties.get("spark.driver.memory"))
.orElse(env.get("SPARK_DRIVER_MEMORY"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
import org.apache.spark.ui.SparkUI
Expand Down Expand Up @@ -121,6 +122,8 @@ private[spark] class Master(
throw new SparkException("spark.deploy.defaultCores must be positive")
}

val restServer = new StandaloneRestServer(this, host, 6677)

override def preStart() {
logInfo("Starting Spark master at " + masterUrl)
logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ private[spark] object KillDriverResponseField extends StandaloneRestProtocolFiel
case object MESSAGE extends KillDriverResponseField
case object MASTER extends KillDriverResponseField
case object DRIVER_ID extends KillDriverResponseField
case object DRIVER_STATE extends SubmitDriverResponseField
override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE)
case object SUCCESS extends SubmitDriverResponseField
override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS)
override val optionalFields = Seq.empty
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import com.google.common.base.Charsets

import org.apache.spark.{SPARK_VERSION => sparkVersion}
import org.apache.spark.deploy.SparkSubmitArguments
import org.apache.spark.util.Utils

/**
* A client that submits Spark applications using a stable REST protocol in standalone
Expand Down Expand Up @@ -63,6 +64,12 @@ private[spark] class StandaloneRestClient {
*/
private def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage = {
import SubmitDriverRequestField._
val driverMemory = Option(args.driverMemory)
.map { m => Utils.memoryStringToMb(m).toString }
.orNull
val executorMemory = Option(args.executorMemory)
.map { m => Utils.memoryStringToMb(m).toString }
.orNull
val message = new SubmitDriverRequestMessage()
.setField(SPARK_VERSION, sparkVersion)
.setField(MASTER, args.master)
Expand All @@ -72,19 +79,21 @@ private[spark] class StandaloneRestClient {
.setFieldIfNotNull(JARS, args.jars)
.setFieldIfNotNull(FILES, args.files)
.setFieldIfNotNull(PY_FILES, args.pyFiles)
.setFieldIfNotNull(DRIVER_MEMORY, args.driverMemory)
.setFieldIfNotNull(DRIVER_MEMORY, driverMemory)
.setFieldIfNotNull(DRIVER_CORES, args.driverCores)
.setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions)
.setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath)
.setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath)
.setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString)
.setFieldIfNotNull(EXECUTOR_MEMORY, args.executorMemory)
.setFieldIfNotNull(EXECUTOR_MEMORY, executorMemory)
.setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores)
// Set each Spark property as its own field
// TODO: Include environment variables?
args.childArgs.zipWithIndex.foreach { case (arg, i) =>
message.setFieldIfNotNull(APP_ARG(i), arg)
}
args.sparkProperties.foreach { case (k, v) =>
message.setFieldIfNotNull(SPARK_PROPERTY(k), v)
}
// TODO: set environment variables?
message.validate()
}

Expand Down Expand Up @@ -175,8 +184,8 @@ private[spark] class StandaloneRestClient {
object StandaloneRestClient {
def main(args: Array[String]): Unit = {
assert(args.length > 0)
val client = new StandaloneRestClient
client.killDriver("spark://" + args(0), "abc_driver")
//val client = new StandaloneRestClient
//client.submitDriver("spark://" + args(0))
println("Done.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,28 @@ private[spark] abstract class StandaloneRestProtocolMessage(

import StandaloneRestProtocolField._

private val fields = new mutable.HashMap[StandaloneRestProtocolField, String]
private val className = Utils.getFormattedClassName(this)
protected val fields = new mutable.HashMap[StandaloneRestProtocolField, String]

// Set the action field
fields(actionField) = action.toString

/** Return all fields currently set in this message. */
def getFields: Map[StandaloneRestProtocolField, String] = fields

/** Return the value of the given field. If the field is not present, return null. */
def getField(key: StandaloneRestProtocolField): String = getFieldOption(key).orNull

/** Return the value of the given field. If the field is not present, throw an exception. */
def getField(key: StandaloneRestProtocolField): String = {
fields.get(key).getOrElse {
def getFieldNotNull(key: StandaloneRestProtocolField): String = {
getFieldOption(key).getOrElse {
throw new IllegalArgumentException(s"Field $key is not set in message $className")
}
}

/** Return the value of the given field as an option. */
def getFieldOption(key: StandaloneRestProtocolField): Option[String] = fields.get(key)

/** Assign the given value to the field, overriding any existing value. */
def setField(key: StandaloneRestProtocolField, value: String): this.type = {
if (key == actionField) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.deploy.rest

import java.io.DataOutputStream
import java.net.InetSocketAddress
import javax.servlet.http.{HttpServletRequest, HttpServletResponse}

import scala.io.Source
Expand All @@ -26,25 +27,37 @@ import com.google.common.base.Charsets
import org.eclipse.jetty.server.{Request, Server}
import org.eclipse.jetty.server.handler.AbstractHandler

import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion}
import org.apache.spark.deploy.rest.StandaloneRestProtocolAction._
import org.apache.spark.util.Utils
import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging}
import org.apache.spark.deploy.master.Master
import org.apache.spark.util.{AkkaUtils, Utils}

/**
* A server that responds to requests submitted by the StandaloneRestClient.
*/
private[spark] class StandaloneRestServer(requestedPort: Int) {
val server = new Server(requestedPort)
server.setHandler(new StandaloneRestHandler)
private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) {
val server = new Server(new InetSocketAddress(host, requestedPort))
server.setHandler(new StandaloneRestServerHandler(master))
server.start()
server.join()
}

/**
* A Jetty handler that responds to requests submitted via the standalone REST protocol.
*/
private[spark] class StandaloneRestHandler extends AbstractHandler with Logging {
private[spark] abstract class StandaloneRestHandler(master: Master)
extends AbstractHandler with Logging {

private implicit val askTimeout = AkkaUtils.askTimeout(master.conf)

/** Handle a request to submit a driver. */
protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage
/** Handle a request to kill a driver. */
protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage
/** Handle a request for a driver's status. */
protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage

/**
* Handle a request submitted by the StandaloneRestClient.
*/
override def handle(
target: String,
baseRequest: Request,
Expand All @@ -67,74 +80,32 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging
}
}

/**
* Construct the appropriate response message based on the type of the request message.
* If an IllegalArgumentException is thrown in the process, construct an error message.
*/
private def constructResponseMessage(
request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = {
// If the request is sent via the StandaloneRestClient, it should have already been
// validated remotely. In case this is not true, validate the request here to guard
// against potential NPEs. If validation fails, return an ERROR message to the sender.
try {
request.validate()
request match {
case submit: SubmitDriverRequestMessage => handleSubmit(submit)
case kill: KillDriverRequestMessage => handleKill(kill)
case status: DriverStatusRequestMessage => handleStatus(status)
case unexpected => handleError(
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
}
} catch {
case e: IllegalArgumentException =>
return handleError(e.getMessage)
}
request match {
case submit: SubmitDriverRequestMessage => handleSubmitRequest(submit)
case kill: KillDriverRequestMessage => handleKillRequest(kill)
case status: DriverStatusRequestMessage => handleStatusRequest(status)
case unexpected => handleError(
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
// Propagate exception to user in an ErrorMessage. If the construction of the
// ErrorMessage itself throws an exception, log the exception and ignore the request.
case e: IllegalArgumentException => handleError(e.getMessage)
}
}

private def handleSubmitRequest(
request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = {
import SubmitDriverResponseField._
// TODO: Actually submit the driver
val message = "Driver is submitted successfully..."
val master = request.getField(SubmitDriverRequestField.MASTER)
val driverId = "new_driver_id"
val driverState = "SUBMITTED"
new SubmitDriverResponseMessage()
.setField(SPARK_VERSION, sparkVersion)
.setField(MESSAGE, message)
.setField(MASTER, master)
.setField(DRIVER_ID, driverId)
.setField(DRIVER_STATE, driverState)
.validate()
}

private def handleKillRequest(request: KillDriverRequestMessage): KillDriverResponseMessage = {
import KillDriverResponseField._
// TODO: Actually kill the driver
val message = "Driver is killed successfully..."
val master = request.getField(KillDriverRequestField.MASTER)
val driverId = request.getField(KillDriverRequestField.DRIVER_ID)
val driverState = "KILLED"
new KillDriverResponseMessage()
.setField(SPARK_VERSION, sparkVersion)
.setField(MESSAGE, message)
.setField(MASTER, master)
.setField(DRIVER_ID, driverId)
.setField(DRIVER_STATE, driverState)
.validate()
}

private def handleStatusRequest(
request: DriverStatusRequestMessage): DriverStatusResponseMessage = {
import DriverStatusResponseField._
// TODO: Actually look up the status of the driver
val master = request.getField(DriverStatusRequestField.MASTER)
val driverId = request.getField(DriverStatusRequestField.DRIVER_ID)
val driverState = "HEALTHY"
new DriverStatusResponseMessage()
.setField(SPARK_VERSION, sparkVersion)
.setField(MASTER, master)
.setField(DRIVER_ID, driverId)
.setField(DRIVER_STATE, driverState)
.validate()
}

/** Construct an error message to signal the fact that an exception has been thrown. */
private def handleError(message: String): ErrorMessage = {
import ErrorField._
new ErrorMessage()
Expand All @@ -144,10 +115,10 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging
}
}

object StandaloneRestServer {
def main(args: Array[String]): Unit = {
println("Hey boy I'm starting a server.")
new StandaloneRestServer(6677)
readLine()
}
}
//object StandaloneRestServer {
// def main(args: Array[String]): Unit = {
// println("Hey boy I'm starting a server.")
// new StandaloneRestServer(6677)
// readLine()
// }
//}
Loading

0 comments on commit af9d9cb

Please sign in to comment.