diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala index bb73c61e68bad..ec0e197cfa345 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -17,12 +17,17 @@ package org.apache.spark.deploy.rest +/** + * A request to query the status of a driver in the REST application submission protocol. + */ class DriverStatusRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String] + private val driverId = new SubmitRestProtocolField[String]("driverId") + def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) + override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(driverId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala index 6da41d09b3f22..2819ef50a75d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -17,12 +17,16 @@ package org.apache.spark.deploy.rest +/** + * A response to the [[DriverStatusRequest]] in the REST application submission protocol. + */ class DriverStatusResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String] - private val success = new SubmitRestProtocolField[Boolean] - private val driverState = new SubmitRestProtocolField[String] - private val workerId = new SubmitRestProtocolField[String] - private val workerHostPort = new SubmitRestProtocolField[String] + private val driverId = new SubmitRestProtocolField[String]("driverId") + private val success = new SubmitRestProtocolField[Boolean]("success") + // standalone cluster mode only + private val driverState = new SubmitRestProtocolField[String]("driverState") + private val workerId = new SubmitRestProtocolField[String]("workerId") + private val workerHostPort = new SubmitRestProtocolField[String]("workerHostPort") def getDriverId: String = driverId.toString def getSuccess: String = success.toString @@ -38,7 +42,7 @@ class DriverStatusResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") - assertFieldIsSet(success, "success") + assertFieldIsSet(driverId) + assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala index 0e08831e7b6a1..5388cddb070fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -17,9 +17,12 @@ package org.apache.spark.deploy.rest +/** + * An error response message used in the REST application submission protocol. + */ class ErrorResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(message, "message") + assertFieldIsSet(message) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala index 31b127876c43e..97f5dd2ba8227 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -17,12 +17,17 @@ package org.apache.spark.deploy.rest +/** + * A request to kill a driver in the REST application submission protocol. + */ class KillDriverRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String] + private val driverId = new SubmitRestProtocolField[String]("driverId") + def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) + override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(driverId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala index 107b447c3d1c4..fe68800e99800 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -17,9 +17,12 @@ package org.apache.spark.deploy.rest +/** + * A response to the [[KillDriverRequest]] in the REST application submission protocol. + */ class KillDriverResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String] - private val success = new SubmitRestProtocolField[Boolean] + private val driverId = new SubmitRestProtocolField[String]("driverId") + private val success = new SubmitRestProtocolField[Boolean]("success") def getDriverId: String = driverId.toString def getSuccess: String = success.toString @@ -29,7 +32,7 @@ class KillDriverResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") - assertFieldIsSet(success, "success") + assertFieldIsSet(driverId) + assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index b564006fd7457..6f2752c848a0e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -21,11 +21,10 @@ import java.net.URL import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.deploy.SparkSubmitArguments -import org.apache.spark.util.Utils /** - * A client that submits applications to the standalone Master using the stable REST protocol. - * This client is intended to communicate with the StandaloneRestServer. Cluster mode only. + * A client that submits applications to the standalone Master using the REST protocol + * This client is intended to communicate with the [[StandaloneRestServer]]. Cluster mode only. */ private[spark] class StandaloneRestClient extends SubmitRestClient { import StandaloneRestClient._ @@ -38,7 +37,8 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { * this reports failure and logs an error message provided by the REST server. */ override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { - val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponse] + validateSubmitArgs(args) + val submitResponse = super.submitDriver(args) val submitSuccess = submitResponse.getSuccess.toBoolean if (submitSuccess) { val driverId = submitResponse.getDriverId @@ -51,14 +51,25 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { submitResponse } + /** Request that the REST server kill the specified driver. */ + override def killDriver(master: String, driverId: String): KillDriverResponse = { + validateMaster(master) + super.killDriver(master, driverId) + } + + /** Request the status of the specified driver from the REST server. */ + override def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { + validateMaster(master) + super.requestDriverStatus(master, driverId) + } + /** - * Poll the status of the driver that was just submitted and report it. - * This retries up to a fixed number of times until giving up. + * Poll the status of the driver that was just submitted and log it. + * This retries up to a fixed number of times before giving up. */ private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => val statusResponse = requestDriverStatus(master, driverId) - .asInstanceOf[DriverStatusResponse] val statusSuccess = statusResponse.getSuccess.toBoolean if (statusSuccess) { val driverState = statusResponse.getDriverState @@ -75,13 +86,13 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { exception.foreach { e => logError(e) } return } + Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) } logError(s"Error: Master did not recognize driver $driverId.") } /** Construct a submit driver request message. */ - override protected def constructSubmitRequest( - args: SparkSubmitArguments): SubmitDriverRequest = { + protected override def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequest = { val message = new SubmitDriverRequest() .setSparkVersion(sparkVersion) .setAppName(args.name) @@ -99,12 +110,14 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setTotalExecutorCores(args.totalExecutorCores) args.childArgs.foreach(message.addAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } - // TODO: send special environment variables? + sys.env.foreach { case (k, v) => + if (k.startsWith("SPARK_")) { message.setEnvironmentVariable(k, v) } + } message } /** Construct a kill driver request message. */ - override protected def constructKillRequest( + protected override def constructKillRequest( master: String, driverId: String): KillDriverRequest = { new KillDriverRequest() @@ -113,7 +126,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { } /** Construct a driver status request message. */ - override protected def constructStatusRequest( + protected override def constructStatusRequest( master: String, driverId: String): DriverStatusRequest = { new DriverStatusRequest() @@ -121,25 +134,26 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setDriverId(driverId) } + /** Extract the URL portion of the master address. */ + protected override def getHttpUrl(master: String): URL = { + validateMaster(master) + new URL("http://" + master.stripPrefix("spark://")) + } + /** Throw an exception if this is not standalone mode. */ - override protected def validateMaster(master: String): Unit = { + private def validateMaster(master: String): Unit = { if (!master.startsWith("spark://")) { throw new IllegalArgumentException("This REST client is only supported in standalone mode.") } } - /** Throw an exception if this is not cluster deploy mode. */ - override protected def validateDeployMode(deployMode: String): Unit = { - if (deployMode != "cluster") { - throw new IllegalArgumentException("This REST client is only supported in cluster mode.") + /** Throw an exception if this is not standalone cluster mode. */ + private def validateSubmitArgs(args: SparkSubmitArguments): Unit = { + if (!args.isStandaloneCluster) { + throw new IllegalArgumentException( + "This REST client is only supported in standalone cluster mode.") } } - - /** Extract the URL portion of the master address. */ - override protected def getHttpUrl(master: String): URL = { - validateMaster(master) - new URL("http://" + master.stripPrefix("spark://")) - } } private object StandaloneRestClient { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 3fcfe189c6a10..1838647f6ed6e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -24,23 +24,22 @@ import akka.actor.ActorRef import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.SparkConf import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ -import org.apache.spark.deploy.DeployMessages import org.apache.spark.deploy.master.Master /** - * A server that responds to requests submitted by the StandaloneRestClient. - * This is intended to be embedded in the standalone Master. Cluster mode only. + * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * This is intended to be embedded in the standalone Master. Cluster mode only */ private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) extends SubmitRestServer(host, requestedPort, master.conf) { - override protected val handler = new StandaloneRestServerHandler(master) + protected override val handler = new StandaloneRestServerHandler(master) } /** - * A handler for requests submitted to the standalone Master - * via the stable application submission REST protocol. + * A handler for requests submitted to the standalone + * Master via the REST application submission protocol. */ private[spark] class StandaloneRestServerHandler( conf: SparkConf, @@ -55,8 +54,7 @@ private[spark] class StandaloneRestServerHandler( } /** Handle a request to submit a driver. */ - override protected def handleSubmit( - request: SubmitDriverRequest): SubmitDriverResponse = { + protected override def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse = { val driverDescription = buildDriverDescription(request) val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) @@ -68,8 +66,7 @@ private[spark] class StandaloneRestServerHandler( } /** Handle a request to kill a driver. */ - override protected def handleKill( - request: KillDriverRequest): KillDriverResponse = { + protected override def handleKill(request: KillDriverRequest): KillDriverResponse = { val driverId = request.getDriverId val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(driverId), masterActor, askTimeout) @@ -81,16 +78,11 @@ private[spark] class StandaloneRestServerHandler( } /** Handle a request for a driver's status. */ - override protected def handleStatus( - request: DriverStatusRequest): DriverStatusResponse = { + protected override def handleStatus(request: DriverStatusRequest): DriverStatusResponse = { val driverId = request.getDriverId val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(driverId), masterActor, askTimeout) - // Format exception nicely, if it exists - val message = response.exception.map { e => - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"Exception from the cluster:\n$e\n$stackTraceString" - } + val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } new DriverStatusResponse() .setSparkVersion(sparkVersion) .setDriverId(driverId) @@ -103,6 +95,7 @@ private[spark] class StandaloneRestServerHandler( /** * Build a driver description from the fields specified in the submit request. + * * This does not currently consider fields used by python applications since * python is not supported in standalone cluster mode yet. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala index 87132511587a3..f2154b48f7d31 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -25,21 +25,24 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.util.JsonProtocol +/** + * A request to submit a driver in the REST application submission protocol. + */ class SubmitDriverRequest extends SubmitRestProtocolRequest { - private val appName = new SubmitRestProtocolField[String] - private val appResource = new SubmitRestProtocolField[String] - private val mainClass = new SubmitRestProtocolField[String] - private val jars = new SubmitRestProtocolField[String] - private val files = new SubmitRestProtocolField[String] - private val pyFiles = new SubmitRestProtocolField[String] - private val driverMemory = new SubmitRestProtocolField[String] - private val driverCores = new SubmitRestProtocolField[Int] - private val driverExtraJavaOptions = new SubmitRestProtocolField[String] - private val driverExtraClassPath = new SubmitRestProtocolField[String] - private val driverExtraLibraryPath = new SubmitRestProtocolField[String] - private val superviseDriver = new SubmitRestProtocolField[Boolean] - private val executorMemory = new SubmitRestProtocolField[String] - private val totalExecutorCores = new SubmitRestProtocolField[Int] + private val appName = new SubmitRestProtocolField[String]("appName") + private val appResource = new SubmitRestProtocolField[String]("appResource") + private val mainClass = new SubmitRestProtocolField[String]("mainClass") + private val jars = new SubmitRestProtocolField[String]("jars") + private val files = new SubmitRestProtocolField[String]("files") + private val pyFiles = new SubmitRestProtocolField[String]("pyFiles") + private val driverMemory = new SubmitRestProtocolField[String]("driverMemory") + private val driverCores = new SubmitRestProtocolField[Int]("driverCores") + private val driverExtraJavaOptions = new SubmitRestProtocolField[String]("driverExtraJavaOptions") + private val driverExtraClassPath = new SubmitRestProtocolField[String]("driverExtraClassPath") + private val driverExtraLibraryPath = new SubmitRestProtocolField[String]("driverExtraLibraryPath") + private val superviseDriver = new SubmitRestProtocolField[Boolean]("superviseDriver") + private val executorMemory = new SubmitRestProtocolField[String]("executorMemory") + private val totalExecutorCores = new SubmitRestProtocolField[Int]("totalExecutorCores") // Special fields private val appArgs = new ArrayBuffer[String] @@ -101,30 +104,43 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { envVars ++= JsonProtocol.mapFromJson(parse(s)) } + /** Return an array of arguments to be passed to the application. */ @JsonIgnore def getAppArgs: Array[String] = appArgs.toArray + + /** Return a map of Spark properties to be passed to the application as java options. */ @JsonIgnore def getSparkProperties: Map[String, String] = sparkProperties.toMap + + /** Return a map of environment variables to be passed to the application. */ @JsonIgnore def getEnvironmentVariables: Map[String, String] = envVars.toMap + + /** Add a command line argument to be passed to the application. */ @JsonIgnore def addAppArg(s: String): this.type = { appArgs += s; this } + + /** Set a Spark property to be passed to the application as a java option. */ @JsonIgnore def setSparkProperty(k: String, v: String): this.type = { sparkProperties(k) = v; this } + + /** Set an environment variable to be passed to the application. */ @JsonIgnore def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this } + /** Serialize the given Array to a compact JSON string. */ private def arrayToJson(arr: Array[String]): String = { if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else null } + /** Serialize the given Map to a compact JSON string. */ private def mapToJson(map: Map[String, String]): String = { if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null } override def validate(): Unit = { super.validate() - assertFieldIsSet(appName, "app_name") - assertFieldIsSet(appResource, "app_resource") + assertFieldIsSet(appName) + assertFieldIsSet(appResource) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala index b1825af8ce565..a9adf3634f231 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -17,9 +17,12 @@ package org.apache.spark.deploy.rest +/** + * A response to the [[SubmitDriverRequest]] in the REST application submission protocol. + */ class SubmitDriverResponse extends SubmitRestProtocolResponse { - private val success = new SubmitRestProtocolField[Boolean] - private val driverId = new SubmitRestProtocolField[String] + private val success = new SubmitRestProtocolField[Boolean]("success") + private val driverId = new SubmitRestProtocolField[String]("driverId") def getSuccess: String = success.toString def getDriverId: String = driverId.toString @@ -29,6 +32,6 @@ class SubmitDriverResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(success, "success") + assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index eb258290bdc7b..a1be15c9fa5d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -28,14 +28,13 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.SparkSubmitArguments /** - * An abstract client that submits applications using the stable REST protocol. - * This client is intended to communicate with the SubmitRestServer. + * An abstract client that submits applications using the REST protocol. + * This client is intended to communicate with the [[SubmitRestServer]]. */ private[spark] abstract class SubmitRestClient extends Logging { - /** Request that the REST server submit a driver specified by the provided arguments. */ + /** Request that the REST server submit a driver using the provided arguments. */ def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { - validateSubmitArguments(args) val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) logInfo(s"Submitting a request to launch a driver in ${args.master}.") @@ -44,7 +43,6 @@ private[spark] abstract class SubmitRestClient extends Logging { /** Request that the REST server kill the specified driver. */ def killDriver(master: String, driverId: String): KillDriverResponse = { - validateMaster(master) val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) logInfo(s"Submitting a request to kill driver $driverId in $master.") @@ -53,7 +51,6 @@ private[spark] abstract class SubmitRestClient extends Logging { /** Request the status of the specified driver from the REST server. */ def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { - validateMaster(master) val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) logInfo(s"Submitting a request for the status of driver $driverId in $master.") @@ -68,17 +65,9 @@ private[spark] abstract class SubmitRestClient extends Logging { protected def constructKillRequest(master: String, driverId: String): KillDriverRequest protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequest - // If the provided arguments are not as expected, throw an exception - protected def validateMaster(master: String): Unit - protected def validateDeployMode(deployMode: String): Unit - protected def validateSubmitArguments(args: SparkSubmitArguments): Unit = { - validateMaster(args.master) - validateDeployMode(args.deployMode) - } - /** * Send the provided request in an HTTP message to the given URL. - * This assumes both the request and the response use the JSON format. + * This assumes that both the request and the response use the JSON format. * Return the response received from the REST server. */ private def sendHttp(url: URL, request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { @@ -96,7 +85,7 @@ private[spark] abstract class SubmitRestClient extends Logging { out.close() val responseJson = Source.fromInputStream(conn.getInputStream).mkString logDebug(s"Response from the REST server:\n$responseJson") - SubmitRestProtocolResponse.fromJson(responseJson) + SubmitRestProtocolMessage.fromJson(responseJson).asInstanceOf[SubmitRestProtocolResponse] } catch { case e: FileNotFoundException => throw new SparkException(s"Unable to connect to REST server $url", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala index 33e4fe4d5c2bb..2b52fd6bc44a7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -17,10 +17,20 @@ package org.apache.spark.deploy.rest -class SubmitRestProtocolField[T] { +/** + * A field used in [[SubmitRestProtocolMessage]]s. + */ +class SubmitRestProtocolField[T](val name: String) { protected var value: Option[T] = None + + /** Return the value or throw an [[IllegalArgumentException]] if the value is not set. */ + def getValue: T = { + value.getOrElse { + throw new IllegalAccessException(s"Value not set in field '$name'!") + } + } + def isSet: Boolean = value.isDefined - def getValue: T = value.getOrElse { throw new IllegalAccessException("Value not set!") } def getValueOption: Option[T] = value def setValue(v: T): Unit = { value = Some(v) } def clearValue(): Unit = { value = None } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 0aa72d236e1a4..2e92eb926d339 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -26,22 +26,29 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.util.Utils +/** + * An abstract message exchanged in the REST application submission protocol. + * + * This message is intended to be serialized to and deserialized from JSON in the exchange. + * Each message can either be a request or a response and consists of three common fields: + * (1) the action, which fully specifies the type of the message + * (2) the Spark version of the client / server + * (3) an optional message + */ @JsonInclude(Include.NON_NULL) @JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) @JsonPropertyOrder(alphabetic = true) abstract class SubmitRestProtocolMessage { - import SubmitRestProtocolMessage._ - private val messageType = Utils.getFormattedClassName(this) - protected val action: String = camelCaseToUnderscores(decapitalize(messageType)) - protected val sparkVersion = new SubmitRestProtocolField[String] - protected val message = new SubmitRestProtocolField[String] + protected val action: String = messageType + protected val sparkVersion: SubmitRestProtocolField[String] + protected val message = new SubmitRestProtocolField[String]("message") // Required for JSON de/serialization and not explicitly used private def getAction: String = action private def setAction(s: String): this.type = this - // Spark version implementation depends on whether this is a request or a response + // Intended for the user and not for JSON de/serialization, which expects more specific keys @JsonIgnore def getSparkVersion: String @JsonIgnore @@ -50,26 +57,37 @@ abstract class SubmitRestProtocolMessage { def getMessage: String = message.toString def setMessage(s: String): this.type = setField(message, s) + /** + * Serialize the message to JSON. + * This also ensures that the message is valid and its fields are in the expected format. + */ def toJson: String = { validate() val mapper = new ObjectMapper - val json = mapper.writeValueAsString(this) - postProcessJson(json) + pretty(parse(mapper.writeValueAsString(this))) } + /** Assert the validity of the message. */ def validate(): Unit = { assert(action != null, s"The action field is missing in $messageType!") + assertFieldIsSet(sparkVersion) } - protected def assertFieldIsSet(field: SubmitRestProtocolField[_], name: String): Unit = { - assert(field.isSet, s"The $name field is missing in $messageType!") + /** Assert that the specified field is set in this message. */ + protected def assertFieldIsSet(field: SubmitRestProtocolField[_]): Unit = { + assert(field.isSet, s"Field '${field.name}' is missing in $messageType!") } + /** Set the field to the given value, or clear the field if the value is null. */ protected def setField(field: SubmitRestProtocolField[String], value: String): this.type = { if (value == null) { field.clearValue() } else { field.setValue(value) } this } + /** + * Set the field to the given boolean value, or clear the field if the value is null. + * If the provided value does not represent a boolean, throw an exception. + */ protected def setBooleanField( field: SubmitRestProtocolField[Boolean], value: String): this.type = { @@ -77,6 +95,10 @@ abstract class SubmitRestProtocolMessage { this } + /** + * Set the field to the given numeric value, or clear the field if the value is null. + * If the provided value does not represent a numeric, throw an exception. + */ protected def setNumericField( field: SubmitRestProtocolField[Int], value: String): this.type = { @@ -84,6 +106,11 @@ abstract class SubmitRestProtocolMessage { this } + /** + * Set the field to the given memory value, or clear the field if the value is null. + * If the provided value does not represent a memory value, throw an exception. + * Valid examples of memory values include "512m", "24g", and "128000". + */ protected def setMemoryField( field: SubmitRestProtocolField[String], value: String): this.type = { @@ -91,116 +118,69 @@ abstract class SubmitRestProtocolMessage { setField(field, value) this } - - private def postProcessJson(json: String): String = { - val fields = parse(json).asInstanceOf[JObject].obj - val newFields = fields.map { case (k, v) => (camelCaseToUnderscores(k), v) } - pretty(render(JObject(newFields))) - } } +/** + * An abstract request sent from the client in the REST application submission protocol. + */ abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + protected override val sparkVersion = new SubmitRestProtocolField[String]("client_spark_version") def getClientSparkVersion: String = sparkVersion.toString def setClientSparkVersion(s: String): this.type = setField(sparkVersion, s) override def getSparkVersion: String = getClientSparkVersion override def setSparkVersion(s: String) = setClientSparkVersion(s) - override def validate(): Unit = { - super.validate() - assertFieldIsSet(sparkVersion, "client_spark_version") - } } +/** + * An abstract response sent from the server in the REST application submission protocol. + */ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + protected override val sparkVersion = new SubmitRestProtocolField[String]("server_spark_version") def getServerSparkVersion: String = sparkVersion.toString def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s) override def getSparkVersion: String = getServerSparkVersion override def setSparkVersion(s: String) = setServerSparkVersion(s) - override def validate(): Unit = { - super.validate() - assertFieldIsSet(sparkVersion, "server_spark_version") - } } object SubmitRestProtocolMessage { private val mapper = new ObjectMapper private val packagePrefix = this.getClass.getPackage.getName - def parseAction(json: String): String = { + /** Parse the value of the action field from the given JSON. */ + def parseAction(json: String): String = parseField(json, "action") + + /** Parse the value of the specified field from the given JSON. */ + def parseField(json: String, field: String): String = { parse(json).asInstanceOf[JObject].obj - .find { case (f, _) => f == "action" } + .find { case (f, _) => f == field } .map { case (_, v) => v.asInstanceOf[JString].s } .getOrElse { - throw new IllegalArgumentException(s"Could not find action field in message:\n$json") - } + throw new IllegalArgumentException(s"Could not find field '$field' in message:\n$json") + } } + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method first parses the action from the JSON and uses it to infers the message type. + * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in + * this package. Otherwise, a [[ClassNotFoundException]] will be thrown. + */ def fromJson(json: String): SubmitRestProtocolMessage = { - val action = parseAction(json) - val className = underscoresToCamelCase(action).capitalize + val className = parseAction(json) val clazz = Class.forName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method determines the type of the message from the class provided instead of + * inferring it from the action field. This is useful for deserializing JSON that + * represents custom user-defined messages. + */ def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { - val fields = parse(json).asInstanceOf[JObject].obj - val processedFields = fields.map { case (k, v) => (underscoresToCamelCase(k), v) } - val processedJson = compact(render(JObject(processedFields))) - mapper.readValue(processedJson, clazz) - } - - private def camelCaseToUnderscores(s: String): String = { - val newString = new StringBuilder - s.foreach { c => - if (c.isUpper) { - newString.append("_" + c.toLower) - } else { - newString.append(c) - } - } - newString.toString() - } - - private def underscoresToCamelCase(s: String): String = { - val newString = new StringBuilder - var capitalizeNext = false - s.foreach { c => - if (c == '_') { - capitalizeNext = true - } else { - val nextChar = if (capitalizeNext) c.toUpper else c - newString.append(nextChar) - capitalizeNext = false - } - } - newString.toString() - } - - private def decapitalize(s: String): String = { - if (s != null && s.nonEmpty) { - s(0).toLower + s.substring(1) - } else { - s - } - } -} - -object SubmitRestProtocolRequest { - def fromJson(s: String): SubmitRestProtocolRequest = { - SubmitRestProtocolMessage.fromJson(s) match { - case req: SubmitRestProtocolRequest => req - case res: SubmitRestProtocolResponse => - throw new IllegalArgumentException(s"Message was not a request:\n$s") - } - } -} - -object SubmitRestProtocolResponse { - def fromJson(s: String): SubmitRestProtocolResponse = { - SubmitRestProtocolMessage.fromJson(s) match { - case req: SubmitRestProtocolRequest => - throw new IllegalArgumentException(s"Message was not a response:\n$s") - case res: SubmitRestProtocolResponse => res - } + mapper.readValue(json, clazz) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 89a2b83d2cdee..5d3fb70f8bcc1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -32,8 +32,8 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging, SparkConf} import org.apache.spark.util.Utils /** - * An abstract server that responds to requests submitted by the SubmitRestClient - * in the stable application submission REST protocol. + * An abstract server that responds to requests submitted by the + * [[SubmitRestClient]] in the REST application submission protocol. */ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, conf: SparkConf) extends Logging { @@ -66,8 +66,8 @@ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, } /** - * An abstract handler for requests submitted via the stable application submission REST protocol. - * This represents the main handler used in the SubmitRestServer. + * An abstract handler for requests submitted via the REST application submission protocol. + * This represents the main handler used in the [[SubmitRestServer]]. */ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { protected def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse @@ -75,8 +75,8 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi protected def handleStatus(request: DriverStatusRequest): DriverStatusResponse /** - * Handle a request submitted by the SubmitRestClient. - * This assumes both the request and the response use the JSON format. + * Handle a request submitted by the [[SubmitRestClient]]. + * This assumes that both the request and the response use the JSON format. */ override def handle( target: String, @@ -85,7 +85,8 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi response: HttpServletResponse): Unit = { try { val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString - val requestMessage = SubmitRestProtocolRequest.fromJson(requestMessageJson) + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + .asInstanceOf[SubmitRestProtocolRequest] val responseMessage = constructResponseMessage(requestMessage) response.setContentType("application/json") response.setCharacterEncoding("utf-8") @@ -102,7 +103,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi /** * Construct the appropriate response message based on the type of the request message. - * If an IllegalArgumentException is thrown in the process, construct an error message instead. + * If an [[IllegalArgumentException]] is thrown, construct an error message instead. */ private def constructResponseMessage( request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { @@ -121,14 +122,15 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") } } catch { - case e: IllegalArgumentException => handleError(e.getMessage) + case e: IllegalArgumentException => handleError(formatException(e)) } // Validate the response message to ensure that it is correctly constructed. If it is not, // propagate the exception back to the client and signal that it is a server error. try { response.validate() } catch { - case e: IllegalArgumentException => handleError(s"Internal server error: ${e.getMessage}") + case e: IllegalArgumentException => + handleError("Internal server error: " + formatException(e)) } response } @@ -139,4 +141,10 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi .setSparkVersion(sparkVersion) .setMessage(message) } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Exception): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index bc9639095d1b7..4c81f5fabdc1a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -20,45 +20,11 @@ package org.apache.spark.deploy.rest import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite -class DummyRequest extends SubmitRestProtocolRequest { - private val active = new SubmitRestProtocolField[Boolean] - private val age = new SubmitRestProtocolField[Int] - private val name = new SubmitRestProtocolField[String] - - def getActive: String = active.toString - def getAge: String = age.toString - def getName: String = name.toString - - def setActive(s: String): this.type = setBooleanField(active, s) - def setAge(s: String): this.type = setNumericField(age, s) - def setName(s: String): this.type = setField(name, s) - - override def validate(): Unit = { - super.validate() - assertFieldIsSet(name, "name") - assertFieldIsSet(age, "age") - assert(age.getValue > 5, "Not old enough!") - } -} - -class DummyResponse extends SubmitRestProtocolResponse - /** - * Tests for the stable application submission REST protocol. + * Tests for the REST application submission protocol. */ class SubmitRestProtocolSuite extends FunSuite { - /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ - private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { - val trimmedJson1 = jsonString1.trim - val trimmedJson2 = jsonString2.trim - val json1 = compact(render(parse(trimmedJson1))) - val json2 = compact(render(parse(trimmedJson2))) - // Put this on a separate line to avoid printing comparison twice when test fails - val equals = json1 == json2 - assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) - } - test("get and set fields") { val request = new DummyRequest assert(request.getSparkVersion === null) @@ -319,10 +285,10 @@ class SubmitRestProtocolSuite extends FunSuite { private val dummyRequestJson = """ |{ - | "action" : "dummy_request", + | "action" : "DummyRequest", | "active" : "true", | "age" : "25", - | "client_spark_version" : "1.2.3", + | "clientSparkVersion" : "1.2.3", | "name" : "jung" |} """.stripMargin @@ -330,42 +296,42 @@ class SubmitRestProtocolSuite extends FunSuite { private val dummyResponseJson = """ |{ - | "action" : "dummy_response", - | "server_spark_version" : "3.3.4" + | "action" : "DummyResponse", + | "serverSparkVersion" : "3.3.4" |} """.stripMargin private val submitDriverRequestJson = """ |{ - | "action" : "submit_driver_request", - | "app_args" : "[\"two slices\",\"a hint of cinnamon\"]", - | "app_name" : "SparkPie", - | "app_resource" : "honey-walnut-cherry.jar", - | "client_spark_version" : "1.2.3", - | "driver_cores" : "180", - | "driver_extra_class_path" : "food-coloring.jar", - | "driver_extra_java_options" : " -Dslices=5 -Dcolor=mostly_red", - | "driver_extra_library_path" : "pickle.jar", - | "driver_memory" : "512m", - | "environment_variables" : "{\"PATH\":\"/dev/null\",\"PYTHONPATH\":\"/dev/null\"}", - | "executor_memory" : "256m", + | "action" : "SubmitDriverRequest", + | "appArgs" : "[\"two slices\",\"a hint of cinnamon\"]", + | "appName" : "SparkPie", + | "appResource" : "honey-walnut-cherry.jar", + | "clientSparkVersion" : "1.2.3", + | "driverCores" : "180", + | "driverExtraClassPath" : "food-coloring.jar", + | "driverExtraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", + | "driverExtraLibraryPath" : "pickle.jar", + | "driverMemory" : "512m", + | "environmentVariables" : "{\"PATH\":\"/dev/null\",\"PYTHONPATH\":\"/dev/null\"}", + | "executorMemory" : "256m", | "files" : "fireball.png", | "jars" : "mayonnaise.jar,ketchup.jar", - | "main_class" : "org.apache.spark.examples.SparkPie", - | "py_files" : "do-not-eat-my.py", - | "spark_properties" : "{\"spark.live.long\":\"true\",\"spark.shuffle.enabled\":\"false\"}", - | "supervise_driver" : "false", - | "total_executor_cores" : "10000" + | "mainClass" : "org.apache.spark.examples.SparkPie", + | "pyFiles" : "do-not-eat-my.py", + | "sparkProperties" : "{\"spark.live.long\":\"true\",\"spark.shuffle.enabled\":\"false\"}", + | "superviseDriver" : "false", + | "totalExecutorCores" : "10000" |} """.stripMargin private val submitDriverResponseJson = """ |{ - | "action" : "submit_driver_response", - | "driver_id" : "driver_123", - | "server_spark_version" : "1.2.3", + | "action" : "SubmitDriverResponse", + | "driverId" : "driver_123", + | "serverSparkVersion" : "1.2.3", | "success" : "true" |} """.stripMargin @@ -373,18 +339,18 @@ class SubmitRestProtocolSuite extends FunSuite { private val killDriverRequestJson = """ |{ - | "action" : "kill_driver_request", - | "client_spark_version" : "1.2.3", - | "driver_id" : "driver_123" + | "action" : "KillDriverRequest", + | "clientSparkVersion" : "1.2.3", + | "driverId" : "driver_123" |} """.stripMargin private val killDriverResponseJson = """ |{ - | "action" : "kill_driver_response", - | "driver_id" : "driver_123", - | "server_spark_version" : "1.2.3", + | "action" : "KillDriverResponse", + | "driverId" : "driver_123", + | "serverSparkVersion" : "1.2.3", | "success" : "true" |} """.stripMargin @@ -392,31 +358,64 @@ class SubmitRestProtocolSuite extends FunSuite { private val driverStatusRequestJson = """ |{ - | "action" : "driver_status_request", - | "client_spark_version" : "1.2.3", - | "driver_id" : "driver_123" + | "action" : "DriverStatusRequest", + | "clientSparkVersion" : "1.2.3", + | "driverId" : "driver_123" |} """.stripMargin private val driverStatusResponseJson = """ |{ - | "action" : "driver_status_response", - | "driver_id" : "driver_123", - | "driver_state" : "RUNNING", - | "server_spark_version" : "1.2.3", + | "action" : "DriverStatusResponse", + | "driverId" : "driver_123", + | "driverState" : "RUNNING", + | "serverSparkVersion" : "1.2.3", | "success" : "true", - | "worker_host_port" : "1.2.3.4:7780", - | "worker_id" : "worker_123" + | "workerHostPort" : "1.2.3.4:7780", + | "workerId" : "worker_123" |} """.stripMargin private val errorJson = """ |{ - | "action" : "error_response", + | "action" : "ErrorResponse", | "message" : "Field not found in submit request: X", - | "server_spark_version" : "1.2.3" + | "serverSparkVersion" : "1.2.3" |} """.stripMargin + + /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ + private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { + val trimmedJson1 = jsonString1.trim + val trimmedJson2 = jsonString2.trim + val json1 = compact(render(parse(trimmedJson1))) + val json2 = compact(render(parse(trimmedJson2))) + // Put this on a separate line to avoid printing comparison twice when test fails + val equals = json1 == json2 + assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) + } +} + +private class DummyResponse extends SubmitRestProtocolResponse +private class DummyRequest extends SubmitRestProtocolRequest { + private val active = new SubmitRestProtocolField[Boolean]("active") + private val age = new SubmitRestProtocolField[Int]("age") + private val name = new SubmitRestProtocolField[String]("name") + + def getActive: String = active.toString + def getAge: String = age.toString + def getName: String = name.toString + + def setActive(s: String): this.type = setBooleanField(active, s) + def setAge(s: String): this.type = setNumericField(age, s) + def setName(s: String): this.type = setField(name, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(name) + assertFieldIsSet(age) + assert(age.getValue > 5, "Not old enough!") + } }