Skip to content

Commit

Permalink
Abstract the success field to the general response
Browse files Browse the repository at this point in the history
This was common basically across all response messages.
  • Loading branch information
Andrew Or committed Jan 30, 2015
1 parent 6c57b4b commit b2fef8b
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 53 deletions.
8 changes: 6 additions & 2 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,19 @@ object SparkSubmit {
* Standalone cluster mode only.
*/
private def kill(args: SparkSubmitArguments): Unit = {
new StandaloneRestClient().killDriver(args.master, args.driverToKill)
val client = new StandaloneRestClient
val response = client.killDriver(args.master, args.driverToKill)
printStream.println(response.toJson)
}

/**
* Request the status of an existing driver using the REST application submission protocol.
* Standalone cluster mode only.
*/
private def requestStatus(args: SparkSubmitArguments): Unit = {
new StandaloneRestClient().requestDriverStatus(args.master, args.driverToRequestStatusFor)
val client = new StandaloneRestClient
val response = client.requestDriverStatus(args.master, args.driverToRequestStatusFor)
printStream.println(response.toJson)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.jar.JarFile

import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.deploy.SparkSubmitAction.SparkSubmitAction
import org.apache.spark.deploy.SparkSubmitAction._
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -52,6 +52,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()

// Standalone cluster mode only
Expand All @@ -62,17 +63,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St

private val restEnabledKey = "spark.submit.rest.enabled"

def action: SparkSubmitAction = {
(driverToKill, driverToRequestStatusFor) match {
case (null, null) => SparkSubmitAction.SUBMIT
case (_, null) => SparkSubmitAction.KILL
case (null, _) => SparkSubmitAction.REQUEST_STATUS
case _ => SparkSubmit.printErrorAndExit(
"Requested to both kill and request status for a driver. Choose only one.")
null // never reached
}
}

/** Default properties present in the currently defined defaults file. */
lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
Expand Down Expand Up @@ -189,14 +179,17 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
if (name == null && primaryResource != null) {
name = Utils.stripDirectory(primaryResource)
}

// Action should be SUBMIT unless otherwise specified
action = Option(action).getOrElse(SUBMIT)
}

/** Ensure that required fields exists. Call this only once all defaults are loaded. */
private def validateArguments(): Unit = {
action match {
case SparkSubmitAction.SUBMIT => validateSubmitArguments()
case SparkSubmitAction.KILL => validateKillArguments()
case SparkSubmitAction.REQUEST_STATUS => validateStatusRequestArguments()
case SUBMIT => validateSubmitArguments()
case KILL => validateKillArguments()
case REQUEST_STATUS => validateStatusRequestArguments()
}
}

Expand Down Expand Up @@ -379,10 +372,18 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St

case ("--kill") :: value :: tail =>
driverToKill = value
if (action != null) {
SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.")
}
action = KILL
parse(tail)

case ("--status") :: value :: tail =>
driverToRequestStatusFor = value
if (action != null) {
SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.")
}
action = REQUEST_STATUS
parse(tail)

case ("--supervise") :: tail =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ package org.apache.spark.deploy.rest
*/
class DriverStatusRequest extends SubmitRestProtocolRequest {
private val driverId = new SubmitRestProtocolField[String]("driverId")

def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,23 @@ package org.apache.spark.deploy.rest
*/
class DriverStatusResponse extends SubmitRestProtocolResponse {
private val driverId = new SubmitRestProtocolField[String]("driverId")
private val success = new SubmitRestProtocolField[Boolean]("success")
// standalone cluster mode only
private val driverState = new SubmitRestProtocolField[String]("driverState")
private val workerId = new SubmitRestProtocolField[String]("workerId")
private val workerHostPort = new SubmitRestProtocolField[String]("workerHostPort")

def getDriverId: String = driverId.toString
def getSuccess: String = success.toString
def getDriverState: String = driverState.toString
def getWorkerId: String = workerId.toString
def getWorkerHostPort: String = workerHostPort.toString

def setDriverId(s: String): this.type = setField(driverId, s)
def setSuccess(s: String): this.type = setBooleanField(success, s)
def setDriverState(s: String): this.type = setField(driverState, s)
def setWorkerId(s: String): this.type = setField(workerId, s)
def setWorkerHostPort(s: String): this.type = setField(workerHostPort, s)

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(success)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ package org.apache.spark.deploy.rest
* An error response message used in the REST application submission protocol.
*/
class ErrorResponse extends SubmitRestProtocolResponse {
// request was unsuccessful
setSuccess("false")

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(message)
assert(!getSuccess.toBoolean, "The 'success' field cannot be true in an error response.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ package org.apache.spark.deploy.rest
*/
class KillDriverRequest extends SubmitRestProtocolRequest {
private val driverId = new SubmitRestProtocolField[String]("driverId")

def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,10 @@ package org.apache.spark.deploy.rest
*/
class KillDriverResponse extends SubmitRestProtocolResponse {
private val driverId = new SubmitRestProtocolField[String]("driverId")
private val success = new SubmitRestProtocolField[Boolean]("success")

def getDriverId: String = driverId.toString
def getSuccess: String = success.toString

def setDriverId(s: String): this.type = setField(driverId, s)
def setSuccess(s: String): this.type = setBooleanField(success, s)

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(success)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,7 @@ package org.apache.spark.deploy.rest
* A response to the [[SubmitDriverRequest]] in the REST application submission protocol.
*/
class SubmitDriverResponse extends SubmitRestProtocolResponse {
private val success = new SubmitRestProtocolField[Boolean]("success")
private val driverId = new SubmitRestProtocolField[String]("driverId")

def getSuccess: String = success.toString
def getDriverId: String = driverId.toString

def setSuccess(s: String): this.type = setBooleanField(success, s)
def setDriverId(s: String): this.type = setField(driverId, s)

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(success)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
val url = getHttpUrl(args.master)
val request = constructSubmitRequest(args)
val response = sendHttp(url, request)
handleResponse(response)
validateResponse(response)
}

/** Request that the REST server kill the specified driver. */
Expand All @@ -48,7 +48,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
val url = getHttpUrl(master)
val request = constructKillRequest(master, driverId)
val response = sendHttp(url, request)
handleResponse(response)
validateResponse(response)
}

/** Request the status of the specified driver from the REST server. */
Expand All @@ -57,7 +57,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
val url = getHttpUrl(master)
val request = constructStatusRequest(master, driverId)
val response = sendHttp(url, request)
handleResponse(response)
validateResponse(response)
}

/** Return the HTTP URL of the REST server that corresponds to the given master URL. */
Expand Down Expand Up @@ -95,14 +95,10 @@ private[spark] abstract class SubmitRestClient extends Logging {
}
}

/** Validate the response and log any error messages produced by the server. */
private def handleResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
/** Validate the response... */
private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
try {
response.validate()
response match {
case error: ErrorResponse => logError(s"Server returned error:\n${error.getMessage}")
case _ =>
}
} catch {
case e: SubmitRestProtocolException =>
throw new SubmitRestProtocolException("Malformed response received from server", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,20 @@ abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage {
*/
abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
protected override val sparkVersion = new SubmitRestProtocolField[String]("server_spark_version")
def getServerSparkVersion: String = sparkVersion.toString
def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s)
private val success = new SubmitRestProtocolField[Boolean]("success")

override def getSparkVersion: String = getServerSparkVersion
def getServerSparkVersion: String = sparkVersion.toString
def getSuccess: String = success.toString

override def setSparkVersion(s: String) = setServerSparkVersion(s)
def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s)
def setSuccess(s: String): this.type = setBooleanField(success, s)

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(success)
}
}

object SubmitRestProtocolMessage {
Expand Down

0 comments on commit b2fef8b

Please sign in to comment.