diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index af5a91c5f9458..dedae380a8cf6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -96,9 +96,9 @@ object SparkSubmit { */ private def kill(args: SparkSubmitArguments): Unit = { val client = new StandaloneRestClient - val response = client.killDriver(args.master, args.driverToKill) + val response = client.killSubmission(args.master, args.driverToKill) response match { - case k: KillDriverResponse => handleRestResponse(k) + case k: KillSubmissionResponse => handleRestResponse(k) case r => handleUnexpectedRestResponse(r) } } @@ -109,9 +109,9 @@ object SparkSubmit { */ private def requestStatus(args: SparkSubmitArguments): Unit = { val client = new StandaloneRestClient - val response = client.requestDriverStatus(args.master, args.driverToRequestStatusFor) + val response = client.requestSubmissionStatus(args.master, args.driverToRequestStatusFor) response match { - case s: DriverStatusResponse => handleRestResponse(s) + case s: SubmissionStatusResponse => handleRestResponse(s) case r => handleUnexpectedRestResponse(r) } } @@ -135,9 +135,9 @@ object SparkSubmit { if (args.isStandaloneCluster && args.isRestEnabled) { printStream.println("Running Spark using the REST application submission protocol.") val client = new StandaloneRestClient - val response = client.submitDriver(args) + val response = client.createSubmission(args) response match { - case s: SubmitDriverResponse => handleRestResponse(s) + case s: CreateSubmissionResponse => handleRestResponse(s) case r => handleUnexpectedRestResponse(r) } } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala deleted file mode 100644 index cd6fc40a0c114..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A request to query the status of a driver in the REST application submission protocol. - */ -class DriverStatusRequest extends SubmitRestProtocolRequest { - var driverId: String = null - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(driverId, "driverId") - } -} - -/** - * A response to the [[DriverStatusRequest]] in the REST application submission protocol. - */ -class DriverStatusResponse extends SubmitRestProtocolResponse { - var driverId: String = null - - // standalone cluster mode only - var driverState: String = null - var workerId: String = null - var workerHostPort: String = null - - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(driverId, "driverId") - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala deleted file mode 100644 index a057c742bae11..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A request to kill a driver in the REST application submission protocol. - */ -class KillDriverRequest extends SubmitRestProtocolRequest { - var driverId: String = null - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(driverId, "driverId") - } -} - -/** - * A response to the [[KillDriverRequest]] in the REST application submission protocol. - */ -class KillDriverResponse extends SubmitRestProtocolResponse { - var driverId: String = null - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(driverId, "driverId") - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index df7319235c653..a7251b5cf97bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -17,100 +17,170 @@ package org.apache.spark.deploy.rest -import java.net.URL +import java.io.{FileNotFoundException, DataOutputStream} +import java.net.{HttpURLConnection, URL} -import org.apache.spark.{SPARK_VERSION => sparkVersion} +import scala.io.Source + +import com.google.common.base.Charsets + +import org.apache.spark.{Logging, SparkException, SPARK_VERSION => sparkVersion} import org.apache.spark.deploy.SparkSubmitArguments /** - * A client that submits applications to the standalone Master using the REST protocol - * This client is intended to communicate with the [[StandaloneRestServer]]. Cluster mode only. + * A client that submits applications to the standalone Master using a REST protocol. + * This client is intended to communicate with the [[StandaloneRestServer]] and is + * currently used for cluster mode only. + * + * The specific request sent to the server depends on the action as follows: + * (1) submit - POST to http://.../submissions/create + * (2) kill - POST http://.../submissions/kill/[submissionId] + * (3) status - GET http://.../submissions/status/[submissionId] + * + * In the case of (1), parameters are posted in the HTTP body in the form of JSON fields. + * Otherwise, the URL fully specifies the intended action of the client. */ -private[spark] class StandaloneRestClient extends SubmitRestClient { +private[spark] class StandaloneRestClient extends Logging { import StandaloneRestClient._ /** - * Request that the REST server submit a driver specified by the provided arguments. + * Submit an application specified by the provided arguments. * - * If the driver was successfully submitted, this polls the status of the driver that was - * just submitted and reports it to the user. Otherwise, if the submission was unsuccessful, - * this reports failure and logs an error message provided by the REST server. + * If the submission was successful, poll the status of the submission and report + * it to the user. Otherwise, report the error message provided by the server. */ - override def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolResponse = { + def createSubmission(args: SparkSubmitArguments): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to launch a driver in ${args.master}.") validateSubmitArgs(args) - val response = super.submitDriver(args) - val submitResponse = response match { - case s: SubmitDriverResponse => s - case _ => return response - } - // Report status of submitted driver to user - val submitSuccess = submitResponse.success.toBoolean - if (submitSuccess) { - val driverId = submitResponse.driverId - if (driverId != null) { - logInfo(s"Driver successfully submitted as $driverId. Polling driver state...") - pollSubmittedDriverStatus(args.master, driverId) - } else { - logError("Application successfully submitted, but driver ID was not provided!") - } - } else { - val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") - logError("Application submission failed" + failMessage) + val master = args.master + val url = getSubmitUrl(master) + val request = constructSubmitRequest(args) + val response = postJson(url, request.toJson) + response match { + case s: CreateSubmissionResponse => reportSubmissionStatus(master, s) + case _ => // unexpected type, let upstream caller handle it } - submitResponse + response } - /** Request that the REST server kill the specified driver. */ - override def killDriver(master: String, driverId: String): SubmitRestProtocolResponse = { + /** Request that the server kill the specified submission. */ + def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill submission $submissionId in $master.") validateMaster(master) - super.killDriver(master, driverId) + post(getKillUrl(master, submissionId)) } - /** Request the status of the specified driver from the REST server. */ - override def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolResponse = { + /** Request the status of a submission from the server. */ + def requestSubmissionStatus(master: String, submissionId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request for the status of submission $submissionId in $master.") validateMaster(master) - super.requestDriverStatus(master, driverId) + get(getStatusUrl(master, submissionId)) + } + + /** Send a GET request to the specified URL. */ + private def get(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending GET request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + readResponse(conn) + } + + /** Send a POST request to the specified URL. */ + private def post(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + readResponse(conn) + } + + /** Send a POST request with the given JSON as the body to the specified URL. */ + private def postJson(url: URL, json: String): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url:\n$json") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + out.write(json.getBytes(Charsets.UTF_8)) + out.close() + readResponse(conn) } /** - * Poll the status of the driver that was just submitted and log it. - * This retries up to a fixed number of times before giving up. + * Read the response from the given connection. + * + * The response is expected to represent a [[SubmitRestProtocolResponse]] in the form of JSON. + * Additionally, this validates the response to ensure that it is properly constructed. + * If the response represents an error, report the message from the server. */ - private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { - (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => - val response = requestDriverStatus(master, driverId) - val statusResponse = response match { - case s: DriverStatusResponse => s - case _ => return + private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { + try { + val responseJson = Source.fromInputStream(connection.getInputStream).mkString + logDebug(s"Response from the REST server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + // The response should have already been validated on the server. + // In case this is not true, validate it ourselves to avoid potential NPEs. + try { + response.validate() + } catch { + case e: SubmitRestProtocolException => + throw new SubmitRestProtocolException("Malformed response received from server", e) } - val statusSuccess = statusResponse.success.toBoolean - if (statusSuccess) { - val driverState = Option(statusResponse.driverState) - val workerId = Option(statusResponse.workerId) - val workerHostPort = Option(statusResponse.workerHostPort) - val exception = Option(statusResponse.message) - // Log driver state, if present - driverState match { - case Some(state) => logInfo(s"State of driver $driverId is now $state.") - case _ => logError(s"State of driver $driverId was not found!") - } - // Log worker node, if present - (workerId, workerHostPort) match { - case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") - case _ => - } - // Log exception stack trace, if present - exception.foreach { e => logError(e) } - return + // If the response is an error, log the message + // Otherwise, simply return the response + response match { + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + case response: SubmitRestProtocolResponse => + response + case unexpected => + throw new SubmitRestProtocolException( + s"Unexpected message received from server:\n$unexpected") } - Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) + } catch { + case e: FileNotFoundException => + throw new SparkException(s"Unable to connect to server ${connection.getURL}", e) + } + } + + /** Return the REST URL for creating a new submission. */ + private def getSubmitUrl(master: String): URL = { + val baseUrl = master.stripPrefix("spark://") + new URL(s"http://$baseUrl/submissions/create") + } + + /** Return the REST URL for killing an existing submission. */ + private def getKillUrl(master: String, submissionId: String): URL = { + val baseUrl = master.stripPrefix("spark://") + new URL(s"http://$baseUrl/submissions/kill/$submissionId") + } + + /** Return the REST URL for requesting the status of an existing submission. */ + private def getStatusUrl(master: String, submissionId: String): URL = { + val baseUrl = master.stripPrefix("spark://") + new URL(s"http://$baseUrl/submissions/status/$submissionId") + } + + /** Throw an exception if this is not standalone mode. */ + private def validateMaster(master: String): Unit = { + if (!master.startsWith("spark://")) { + throw new IllegalArgumentException("This REST client is only supported in standalone mode.") } - logError(s"Error: Master did not recognize driver $driverId.") } - /** Construct a submit driver request message. */ - protected override def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequest = { - val message = new SubmitDriverRequest + /** Throw an exception if this is not standalone cluster mode. */ + private def validateSubmitArgs(args: SparkSubmitArguments): Unit = { + if (!args.isStandaloneCluster) { + throw new IllegalArgumentException( + "This REST client is only supported in standalone cluster mode.") + } + } + + /** Construct a message that captures the specified parameters for submitting an application. */ + private def constructSubmitRequest(args: SparkSubmitArguments): CreateSubmissionRequest = { + val message = new CreateSubmissionRequest message.clientSparkVersion = sparkVersion message.appName = args.name message.appResource = args.primaryResource @@ -130,48 +200,61 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { sys.env.foreach { case (k, v) => if (k.startsWith("SPARK_")) { message.setEnvironmentVariable(k, v) } } + message.validate() message } - /** Construct a kill driver request message. */ - protected override def constructKillRequest( - master: String, - driverId: String): KillDriverRequest = { - val k = new KillDriverRequest - k.clientSparkVersion = sparkVersion - k.driverId = driverId - k - } - - /** Construct a driver status request message. */ - protected override def constructStatusRequest( - master: String, - driverId: String): DriverStatusRequest = { - val d = new DriverStatusRequest - d.clientSparkVersion = sparkVersion - d.driverId = driverId - d - } - - /** Extract the URL portion of the master address. */ - protected override def getHttpUrl(master: String): URL = { - validateMaster(master) - new URL("http://" + master.stripPrefix("spark://")) - } - - /** Throw an exception if this is not standalone mode. */ - private def validateMaster(master: String): Unit = { - if (!master.startsWith("spark://")) { - throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + /** Report the status of a newly created submission. */ + private def reportSubmissionStatus(master: String, submitResponse: CreateSubmissionResponse): Unit = { + val submitSuccess = submitResponse.success.toBoolean + if (submitSuccess) { + val submissionId = submitResponse.submissionId + if (submissionId != null) { + logInfo(s"Driver successfully submitted as $submissionId. Polling driver state...") + pollSubmissionStatus(master, submissionId) + } else { + logError("Application successfully submitted, but driver ID was not provided!") + } + } else { + val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") + logError("Application submission failed" + failMessage) } } - /** Throw an exception if this is not standalone cluster mode. */ - private def validateSubmitArgs(args: SparkSubmitArguments): Unit = { - if (!args.isStandaloneCluster) { - throw new IllegalArgumentException( - "This REST client is only supported in standalone cluster mode.") + /** + * Poll the status of the specified submission and log it. + * This retries up to a fixed number of times before giving up. + */ + private def pollSubmissionStatus(master: String, submissionId: String): Unit = { + (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => + val response = requestSubmissionStatus(master, submissionId) + val statusResponse = response match { + case s: SubmissionStatusResponse => s + case _ => return // unexpected type, let upstream caller handle it + } + val statusSuccess = statusResponse.success.toBoolean + if (statusSuccess) { + val driverState = Option(statusResponse.driverState) + val workerId = Option(statusResponse.workerId) + val workerHostPort = Option(statusResponse.workerHostPort) + val exception = Option(statusResponse.message) + // Log driver state, if present + driverState match { + case Some(state) => logInfo(s"State of driver $submissionId is now $state.") + case _ => logError(s"State of driver $submissionId was not found!") + } + // Log worker node, if present + (workerId, workerHostPort) match { + case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case _ => + } + // Log exception stack trace, if present + exception.foreach { e => logError(e) } + return + } + Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) } + logError(s"Error: Master did not recognize submission $submissionId.") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index de0f701fd3258..88f1734dadd97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -17,12 +17,19 @@ package org.apache.spark.deploy.rest -import java.io.File +import java.io.{DataOutputStream, File} +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest} + +import scala.io.Source import akka.actor.ActorRef +import com.google.common.base.Charsets +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.apache.spark.{SPARK_VERSION => sparkVersion} -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.{AkkaUtils, Utils} import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ @@ -30,64 +37,188 @@ import org.apache.spark.deploy.master.Master /** * A server that responds to requests submitted by the [[StandaloneRestClient]]. - * This is intended to be embedded in the standalone Master. Cluster mode only + * This is intended to be embedded in the standalone Master and used in cluster mode only. */ -private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) - extends SubmitRestServer(host, requestedPort, master.conf) { - protected override val handler = new StandaloneRestServerHandler(master) +private[spark] class StandaloneRestServer( + master: Master, + host: String, + requestedPort: Int) + extends Logging { + + private var _server: Option[Server] = None + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, master.conf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Set up the mapping from contexts to the appropriate servlets: + * (1) submit requests should be directed to /create + * (2) kill requests should be directed to /kill + * (3) status requests should be directed to /status + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, requestedPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/submissions") + mainHandler.addServlet(new ServletHolder(new KillRequestServlet(master)), "/kill/*") + mainHandler.addServlet(new ServletHolder(new StatusRequestServlet(master)), "/status/*") + mainHandler.addServlet(new ServletHolder(new SubmitRequestServlet(master)), "/create") + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } } /** - * A handler for requests submitted to the standalone - * Master via the REST application submission protocol. + * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. */ -private[spark] class StandaloneRestServerHandler( - conf: SparkConf, - masterActor: ActorRef, - masterUrl: String) - extends SubmitRestServerHandler { +private[spark] abstract class StandaloneRestServlet(master: Master) + extends HttpServlet with Logging { + + protected val conf: SparkConf = master.conf + protected val masterActor: ActorRef = master.self + protected val masterUrl: String = master.masterUrl + protected val askTimeout = AkkaUtils.askTimeout(conf) + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def handleResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + try { + val message = validateResponse(responseMessage) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.setStatus(HttpServletResponse.SC_OK) + val content = message.toJson.getBytes(Charsets.UTF_8) + val out = new DataOutputStream(responseServlet.getOutputStream) + out.write(content) + out.close() + } catch { + case e: Exception => + logError("Exception encountered when handling response.", e) + } + } - private val askTimeout = AkkaUtils.askTimeout(conf) + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Exception): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } - def this(master: Master) = { - this(master.conf, master.self, master.masterUrl) + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e } - /** Handle a request to submit a driver. */ - protected override def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse = { - val driverDescription = buildDriverDescription(request) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) - val s = new SubmitDriverResponse - s.serverSparkVersion = sparkVersion - s.message = response.message - s.success = response.success.toString - s.driverId = response.driverId.orNull - s + /** + * Validate the response message to ensure that it is correctly constructed. + * If it is, simply return the response as is. Otherwise, return an error response + * to propagate the exception back to the client. + */ + private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = { + try { + response.validate() + response + } catch { + case e: Exception => + handleError("Internal server error: " + formatException(e)) + } } +} + +/** + * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. + */ +private[spark] class KillRequestServlet(master: Master) extends StandaloneRestServlet(master) { - /** Handle a request to kill a driver. */ - protected override def handleKill(request: KillDriverRequest): KillDriverResponse = { - val driverId = request.driverId + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + try { + val submissionId = request.getPathInfo.stripPrefix("/") + val responseMessage = + if (submissionId.nonEmpty) { + handleKill(submissionId) + } else { + handleError("Submission ID is missing in kill request") + } + handleResponse(responseMessage, response) + } catch { + case e: Exception => + logError("Exception encountered when handling kill request", e) + } + } + + private def handleKill(submissionId: String): KillSubmissionResponse = { val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( - DeployMessages.RequestKillDriver(driverId), masterActor, askTimeout) - val k = new KillDriverResponse + DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion k.message = response.message - k.driverId = driverId + k.submissionId = submissionId k.success = response.success.toString k } +} + +/** + * A servlet for handling status requests passed to the [[StandaloneRestServer]]. + */ +private[spark] class StatusRequestServlet(master: Master) extends StandaloneRestServlet(master) { - /** Handle a request for a driver's status. */ - protected override def handleStatus(request: DriverStatusRequest): DriverStatusResponse = { - val driverId = request.driverId + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + try { + val submissionId = request.getPathInfo.stripPrefix("/") + val responseMessage = + if (submissionId.nonEmpty) { + handleStatus(submissionId) + } else { + handleError("Submission ID is missing in status request") + } + handleResponse(responseMessage, response) + } catch { + case e: Exception => + logError("Exception encountered when handling status request", e) + } + } + + private def handleStatus(submissionId: String): SubmissionStatusResponse = { val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( - DeployMessages.RequestDriverStatus(driverId), masterActor, askTimeout) + DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } - val d = new DriverStatusResponse + val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion - d.driverId = driverId + d.submissionId = submissionId d.success = response.found.toString d.driverState = response.state.map(_.toString).orNull d.workerId = response.workerId.orNull @@ -95,14 +226,73 @@ private[spark] class StandaloneRestServerHandler( d.message = message.orNull d } +} + +/** + * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. + */ +private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRestServlet(master) { + + /** + * Submit an application to the Master with parameters specified in the request message. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If this is successful, return an appropriate response to the client indicating so. + * Otherwise, return error instead. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + try { + val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + .asInstanceOf[SubmitRestProtocolRequest] + val responseMessage = handleSubmit(requestMessage) + response.setContentType("application/json") + response.setCharacterEncoding("utf-8") + response.setStatus(HttpServletResponse.SC_OK) + val content = responseMessage.toJson.getBytes(Charsets.UTF_8) + val out = new DataOutputStream(response.getOutputStream) + out.write(content) + out.close() + } catch { + case e: Exception => logError("Exception while handling request", e) + } + } + + private def handleSubmit(request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + try { + request.validate() + request match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success.toString + submitResponse.submissionId = response.driverId.orNull + submitResponse + case unexpected => handleError( + s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + } + } catch { + case e: Exception => handleError(formatException(e)) + } + } /** * Build a driver description from the fields specified in the submit request. * - * This does not currently consider fields used by python applications since - * python is not supported in standalone cluster mode yet. + * This involves constructing a command that takes into account memory, java options, + * classpath and other settings to launch the driver. This does not currently consider + * fields used by python applications since python is not supported in standalone + * cluster mode yet. */ - private def buildDriverDescription(request: SubmitDriverRequest): DriverDescription = { + private def buildDriverDescription(request: CreateSubmissionRequest): DriverDescription = { // Required fields, including the main class because python is not yet supported val appName = request.appName val appResource = request.appResource diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala deleted file mode 100644 index 9af29a41e2288..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -import java.io.{DataOutputStream, FileNotFoundException} -import java.net.{HttpURLConnection, URL} - -import scala.io.Source - -import com.google.common.base.Charsets - -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.deploy.SparkSubmitArguments - -/** - * An abstract client that submits applications using the REST protocol. - * This client is intended to communicate with the [[SubmitRestServer]]. - */ -private[spark] abstract class SubmitRestClient extends Logging { - - /** Request that the REST server submit a driver using the provided arguments. */ - def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request to launch a driver in ${args.master}.") - val url = getHttpUrl(args.master) - val request = constructSubmitRequest(args) - val response = sendHttp(url, request) - handleResponse(response) - } - - /** Request that the REST server kill the specified driver. */ - def killDriver(master: String, driverId: String): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request to kill driver $driverId in $master.") - val url = getHttpUrl(master) - val request = constructKillRequest(master, driverId) - val response = sendHttp(url, request) - handleResponse(response) - } - - /** Request the status of the specified driver from the REST server. */ - def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request for the status of driver $driverId in $master.") - val url = getHttpUrl(master) - val request = constructStatusRequest(master, driverId) - val response = sendHttp(url, request) - handleResponse(response) - } - - /** Return the HTTP URL of the REST server that corresponds to the given master URL. */ - protected def getHttpUrl(master: String): URL - - // Construct the appropriate type of message based on the request type - protected def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequest - protected def constructKillRequest(master: String, driverId: String): KillDriverRequest - protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequest - - /** - * Send the provided request in an HTTP message to the given URL. - * This assumes that both the request and the response use the JSON format. - * Return the response received from the REST server. - */ - private def sendHttp(url: URL, request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { - try { - val conn = url.openConnection().asInstanceOf[HttpURLConnection] - conn.setRequestMethod("POST") - conn.setRequestProperty("Content-Type", "application/json") - conn.setRequestProperty("charset", "utf-8") - conn.setDoOutput(true) - request.validate() - val requestJson = request.toJson - logDebug(s"Sending the following request to the REST server:\n$requestJson") - val out = new DataOutputStream(conn.getOutputStream) - out.write(requestJson.getBytes(Charsets.UTF_8)) - out.close() - val responseJson = Source.fromInputStream(conn.getInputStream).mkString - logDebug(s"Response from the REST server:\n$responseJson") - SubmitRestProtocolMessage.fromJson(responseJson).asInstanceOf[SubmitRestProtocolResponse] - } catch { - case e: FileNotFoundException => - throw new SparkException(s"Unable to connect to REST server $url", e) - } - } - - /** Validate the response and log any error messages provided by the server. */ - private def handleResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = { - try { - response.validate() - response match { - case e: ErrorResponse => logError(s"Server responded with error:\n${e.message}") - case _ => - } - } catch { - case e: SubmitRestProtocolException => - throw new SubmitRestProtocolException("Malformed response received from server", e) - } - response - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index ffb1bb164d0ff..372af58dea352 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -131,46 +131,9 @@ abstract class SubmitRestProtocolMessage { } /** - * An abstract request sent from the client in the REST application submission protocol. + * Helper methods to process serialized [[SubmitRestProtocolMessage]]s. */ -abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { - var clientSparkVersion: String = null - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(clientSparkVersion, "clientSparkVersion") - } -} - -/** - * An abstract response sent from the server in the REST application submission protocol. - */ -abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { - var serverSparkVersion: String = null - var success: String = null - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(serverSparkVersion, "serverSparkVersion") - assertFieldIsSet(success, "success") - assertFieldIsBoolean(success, "success") - } -} - -/** - * An error response message used in the REST application submission protocol. - */ -class ErrorResponse extends SubmitRestProtocolResponse { - - // request was unsuccessful - success = "false" - - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(message, "message") - assert(!success.toBoolean, s"The 'success' field must be false in $messageType.") - } -} - -object SubmitRestProtocolMessage { +private[spark] object SubmitRestProtocolMessage { private val packagePrefix = this.getClass.getPackage.getName private val mapper = new ObjectMapper() .registerModule(DefaultScalaModule) @@ -181,16 +144,12 @@ object SubmitRestProtocolMessage { * If the action field is not found, throw a [[SubmitRestMissingFieldException]]. */ def parseAction(json: String): String = { - parseField(json, "action").getOrElse { - throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json") - } - } - - /** Parse the value of the specified field from the given JSON. */ - def parseField(json: String, field: String): Option[String] = { parse(json).asInstanceOf[JObject].obj - .find { case (f, _) => f == field } + .find { case (f, _) => f == "action" } .map { case (_, v) => v.asInstanceOf[JString].s } + .getOrElse { + throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json") + } } /** @@ -204,7 +163,7 @@ object SubmitRestProtocolMessage { val className = parseAction(json) val clazz = Class.forName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) - fromJson(json, clazz) + mapper.readValue(json, clazz) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala similarity index 84% rename from core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala rename to core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 4eb7efc2555b4..01f51b9a0d904 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -20,12 +20,23 @@ package org.apache.spark.deploy.rest import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.fasterxml.jackson.annotation.{JsonProperty, JsonIgnore, JsonInclude} +import com.fasterxml.jackson.annotation.JsonProperty + +/** + * An abstract request sent from the client in the REST application submission protocol. + */ +private[spark] abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + var clientSparkVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(clientSparkVersion, "clientSparkVersion") + } +} /** * A request to submit a driver in the REST application submission protocol. */ -class SubmitDriverRequest extends SubmitRestProtocolRequest { +private[spark] class CreateSubmissionRequest extends SubmitRestProtocolRequest { var appName: String = null var appResource: String = null var mainClass: String = null @@ -68,10 +79,3 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { assertFieldIsNumeric(totalExecutorCores, "totalExecutorCores") } } - -/** - * A response to the [[SubmitDriverRequest]] in the REST application submission protocol. - */ -class SubmitDriverResponse extends SubmitRestProtocolResponse { - var driverId: String = null -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala new file mode 100644 index 0000000000000..421aaf6a13c42 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * An abstract response sent from the server in the REST application submission protocol. + */ +private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + var serverSparkVersion: String = null + var success: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(serverSparkVersion, "serverSparkVersion") + assertFieldIsSet(success, "success") + assertFieldIsBoolean(success, "success") + } +} + +/** + * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. + */ +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null +} + +/** + * A response to a kill request in the REST application submission protocol. + */ +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + } +} + +/** + * A response to a status request in the REST application submission protocol. + */ +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + var driverState: String = null + var workerId: String = null + var workerHostPort: String = null + + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + } +} + +/** + * An error response message used in the REST application submission protocol. + */ +private[spark] class ErrorResponse extends SubmitRestProtocolResponse { + + // request was unsuccessful + success = "false" + + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(message, "message") + assert(!success.toBoolean, s"The 'success' field must be false in $messageType.") + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala deleted file mode 100644 index 6c2a3a159da8e..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -import java.io.DataOutputStream -import java.net.InetSocketAddress -import javax.servlet.http.{HttpServletRequest, HttpServletResponse} - -import scala.io.Source - -import com.google.common.base.Charsets -import org.eclipse.jetty.server.{Request, Server} -import org.eclipse.jetty.server.handler.AbstractHandler -import org.eclipse.jetty.util.thread.QueuedThreadPool - -import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging, SparkConf} -import org.apache.spark.util.Utils - -/** - * An abstract server that responds to requests submitted by the - * [[SubmitRestClient]] in the REST application submission protocol. - */ -private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, conf: SparkConf) - extends Logging { - - protected val handler: SubmitRestServerHandler - private var _server: Option[Server] = None - - /** Start the server and return the bound port. */ - def start(): Int = { - val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, conf) - _server = Some(server) - logInfo(s"Started REST server for submitting applications on port $boundPort") - boundPort - } - - def stop(): Unit = { - _server.foreach(_.stop()) - } - - private def doStart(startPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(host, requestedPort)) - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - server.setHandler(handler) - server.start() - val boundPort = server.getConnectors()(0).getLocalPort - (server, boundPort) - } -} - -/** - * An abstract handler for requests submitted via the REST application submission protocol. - * This represents the main handler used in the [[SubmitRestServer]]. - */ -private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { - protected def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse - protected def handleKill(request: KillDriverRequest): KillDriverResponse - protected def handleStatus(request: DriverStatusRequest): DriverStatusResponse - - /** - * Handle a request submitted by the [[SubmitRestClient]]. - * This assumes that both the request and the response use the JSON format. - */ - override def handle( - target: String, - baseRequest: Request, - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - try { - val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) - .asInstanceOf[SubmitRestProtocolRequest] - val responseMessage = constructResponseMessage(requestMessage) - response.setContentType("application/json") - response.setCharacterEncoding("utf-8") - response.setStatus(HttpServletResponse.SC_OK) - val content = responseMessage.toJson.getBytes(Charsets.UTF_8) - val out = new DataOutputStream(response.getOutputStream) - out.write(content) - out.close() - baseRequest.setHandled(true) - } catch { - case e: Exception => logError("Exception while handling request", e) - } - } - - /** - * Construct the appropriate response message based on the type of the request message. - * If an exception is thrown, construct an error message instead. - */ - private def constructResponseMessage( - request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { - // Validate the request message to ensure that it is correctly constructed. If the request - // is sent via the SubmitRestClient, it should have already been validated remotely. In case - // this is not true, do it again here to guard against potential NPEs. If validation fails, - // send an error message back to the sender. - val response = - try { - request.validate() - request match { - case submit: SubmitDriverRequest => handleSubmit(submit) - case kill: KillDriverRequest => handleKill(kill) - case status: DriverStatusRequest => handleStatus(status) - case unexpected => handleError( - s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") - } - } catch { - case e: Exception => handleError(formatException(e)) - } - // Validate the response message to ensure that it is correctly constructed. If it is not, - // propagate the exception back to the client and signal that it is a server error. - try { - response.validate() - } catch { - case e: Exception => - return handleError("Internal server error: " + formatException(e)) - } - response - } - - /** Construct an error message to signal the fact that an exception has been thrown. */ - private def handleError(message: String): ErrorResponse = { - val e = new ErrorResponse - e.serverSparkVersion = sparkVersion - e.message = message - e - } - - /** Return a human readable String representation of the exception. */ - protected def formatException(e: Exception): String = { - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"$e\n$stackTraceString" - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index fa994118883f3..7052526c13bc9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -53,13 +53,13 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val resultsFile = File.createTempFile("test-submit", ".txt") val numbers = Seq(1, 2, 3) val size = 500 - val driverId = submitApplication(resultsFile, numbers, size) - waitUntilFinished(driverId) + val submissionId = submitApplication(resultsFile, numbers, size) + waitUntilFinished(submissionId) validateResult(resultsFile, numbers, size) } test("kill empty driver") { - val response = client.killDriver(masterRestUrl, "driver-that-does-not-exist") + val response = client.killSubmission(masterRestUrl, "driver-that-does-not-exist") val killResponse = getKillResponse(response) val killSuccess = killResponse.success assert(killSuccess === "false") @@ -69,12 +69,12 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val resultsFile = File.createTempFile("test-kill", ".txt") val numbers = Seq(1, 2, 3) val size = 500 - val driverId = submitApplication(resultsFile, numbers, size) - val response = client.killDriver(masterRestUrl, driverId) + val submissionId = submitApplication(resultsFile, numbers, size) + val response = client.killSubmission(masterRestUrl, submissionId) val killResponse = getKillResponse(response) val killSuccess = killResponse.success - waitUntilFinished(driverId) - val response2 = client.requestDriverStatus(masterRestUrl, driverId) + waitUntilFinished(submissionId) + val response2 = client.requestSubmissionStatus(masterRestUrl, submissionId) val statusResponse = getStatusResponse(response2) val statusSuccess = statusResponse.success val driverState = statusResponse.driverState @@ -86,7 +86,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B } test("request status for empty driver") { - val response = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") + val response = client.requestSubmissionStatus(masterRestUrl, "driver-that-does-not-exist") val statusResponse = getStatusResponse(response) val statusSuccess = statusResponse.success assert(statusSuccess === "false") @@ -129,52 +129,52 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) SparkSubmit.prepareSubmitEnvironment(args) - val response = client.submitDriver(args) + val response = client.createSubmission(args) val submitResponse = getSubmitResponse(response) - val driverId = submitResponse.driverId - assert(driverId != null, "Application submission was unsuccessful!") - driverId + val submissionId = submitResponse.submissionId + assert(submissionId != null, "Application submission was unsuccessful!") + submissionId } /** Wait until the given driver has finished running up to the specified timeout. */ - private def waitUntilFinished(driverId: String, maxSeconds: Int = 30): Unit = { + private def waitUntilFinished(submissionId: String, maxSeconds: Int = 30): Unit = { var finished = false val expireTime = System.currentTimeMillis + maxSeconds * 1000 while (!finished) { - val response = client.requestDriverStatus(masterRestUrl, driverId) + val response = client.requestSubmissionStatus(masterRestUrl, submissionId) val statusResponse = getStatusResponse(response) val driverState = statusResponse.driverState finished = driverState != DriverState.SUBMITTED.toString && driverState != DriverState.RUNNING.toString if (System.currentTimeMillis > expireTime) { - fail(s"Driver $driverId did not finish within $maxSeconds seconds.") + fail(s"Driver $submissionId did not finish within $maxSeconds seconds.") } } } /** Return the response as a submit driver response, or fail with error otherwise. */ - private def getSubmitResponse(response: SubmitRestProtocolResponse): SubmitDriverResponse = { + private def getSubmitResponse(response: SubmitRestProtocolResponse): CreateSubmissionResponse = { response match { - case s: SubmitDriverResponse => s + case s: CreateSubmissionResponse => s case e: ErrorResponse => fail(s"Server returned error: ${e.message}") case r => fail(s"Expected submit response. Actual: ${r.toJson}") } } /** Return the response as a kill driver response, or fail with error otherwise. */ - private def getKillResponse(response: SubmitRestProtocolResponse): KillDriverResponse = { + private def getKillResponse(response: SubmitRestProtocolResponse): KillSubmissionResponse = { response match { - case k: KillDriverResponse => k + case k: KillSubmissionResponse => k case e: ErrorResponse => fail(s"Server returned error: ${e.message}") case r => fail(s"Expected kill response. Actual: ${r.toJson}") } } /** Return the response as a driver status response, or fail with error otherwise. */ - private def getStatusResponse(response: SubmitRestProtocolResponse): DriverStatusResponse = { + private def getStatusResponse(response: SubmitRestProtocolResponse): SubmissionStatusResponse = { response match { - case s: DriverStatusResponse => s + case s: SubmissionStatusResponse => s case e: ErrorResponse => fail(s"Server returned error: ${e.message}") case r => fail(s"Expected status response. Actual: ${r.toJson}") } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 97158a2cecda8..1adde1233508a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -94,8 +94,8 @@ class SubmitRestProtocolSuite extends FunSuite { assert(newResponse.message === null) } - test("SubmitDriverRequest") { - val message = new SubmitDriverRequest + test("CreateSubmissionRequest") { + val message = new CreateSubmissionRequest intercept[SubmitRestProtocolException] { message.validate() } message.clientSparkVersion = "1.2.3" message.appName = "SparkPie" @@ -142,7 +142,7 @@ class SubmitRestProtocolSuite extends FunSuite { // test JSON val json = message.toJson assertJsonEquals(json, submitDriverRequestJson) - val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverRequest]) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionRequest]) assert(newMessage.clientSparkVersion === "1.2.3") assert(newMessage.appName === "SparkPie") assert(newMessage.appResource === "honey-walnut-cherry.jar") @@ -163,11 +163,11 @@ class SubmitRestProtocolSuite extends FunSuite { assert(newMessage.environmentVariables === message.environmentVariables) } - test("SubmitDriverResponse") { - val message = new SubmitDriverResponse + test("CreateSubmissionResponse") { + val message = new CreateSubmissionResponse intercept[SubmitRestProtocolException] { message.validate() } message.serverSparkVersion = "1.2.3" - message.driverId = "driver_123" + message.submissionId = "driver_123" message.success = "true" message.validate() // bad fields @@ -177,31 +177,17 @@ class SubmitRestProtocolSuite extends FunSuite { // test JSON val json = message.toJson assertJsonEquals(json, submitDriverResponseJson) - val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverResponse]) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionResponse]) assert(newMessage.serverSparkVersion === "1.2.3") - assert(newMessage.driverId === "driver_123") + assert(newMessage.submissionId === "driver_123") assert(newMessage.success === "true") } - test("KillDriverRequest") { - val message = new KillDriverRequest - intercept[SubmitRestProtocolException] { message.validate() } - message.clientSparkVersion = "1.2.3" - message.driverId = "driver_123" - message.validate() - // test JSON - val json = message.toJson - assertJsonEquals(json, killDriverRequestJson) - val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverRequest]) - assert(newMessage.clientSparkVersion === "1.2.3") - assert(newMessage.driverId === "driver_123") - } - - test("KillDriverResponse") { - val message = new KillDriverResponse + test("KillSubmissionResponse") { + val message = new KillSubmissionResponse intercept[SubmitRestProtocolException] { message.validate() } message.serverSparkVersion = "1.2.3" - message.driverId = "driver_123" + message.submissionId = "driver_123" message.success = "true" message.validate() // bad fields @@ -211,31 +197,17 @@ class SubmitRestProtocolSuite extends FunSuite { // test JSON val json = message.toJson assertJsonEquals(json, killDriverResponseJson) - val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverResponse]) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillSubmissionResponse]) assert(newMessage.serverSparkVersion === "1.2.3") - assert(newMessage.driverId === "driver_123") + assert(newMessage.submissionId === "driver_123") assert(newMessage.success === "true") } - test("DriverStatusRequest") { - val message = new DriverStatusRequest - intercept[SubmitRestProtocolException] { message.validate() } - message.clientSparkVersion = "1.2.3" - message.driverId = "driver_123" - message.validate() - // test JSON - val json = message.toJson - assertJsonEquals(json, driverStatusRequestJson) - val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusRequest]) - assert(newMessage.clientSparkVersion === "1.2.3") - assert(newMessage.driverId === "driver_123") - } - - test("DriverStatusResponse") { - val message = new DriverStatusResponse + test("SubmissionStatusResponse") { + val message = new SubmissionStatusResponse intercept[SubmitRestProtocolException] { message.validate() } message.serverSparkVersion = "1.2.3" - message.driverId = "driver_123" + message.submissionId = "driver_123" message.success = "true" message.validate() // optional fields @@ -249,9 +221,9 @@ class SubmitRestProtocolSuite extends FunSuite { // test JSON val json = message.toJson assertJsonEquals(json, driverStatusResponseJson) - val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusResponse]) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmissionStatusResponse]) assert(newMessage.serverSparkVersion === "1.2.3") - assert(newMessage.driverId === "driver_123") + assert(newMessage.submissionId === "driver_123") assert(newMessage.driverState === "RUNNING") assert(newMessage.success === "true") assert(newMessage.workerId === "worker_123") @@ -295,7 +267,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val submitDriverRequestJson = """ |{ - | "action" : "SubmitDriverRequest", + | "action" : "CreateSubmissionRequest", | "appArgs" : ["two slices","a hint of cinnamon"], | "appName" : "SparkPie", | "appResource" : "honey-walnut-cherry.jar", @@ -320,48 +292,30 @@ class SubmitRestProtocolSuite extends FunSuite { private val submitDriverResponseJson = """ |{ - | "action" : "SubmitDriverResponse", - | "driverId" : "driver_123", + | "action" : "CreateSubmissionResponse", | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", | "success" : "true" |} """.stripMargin - private val killDriverRequestJson = - """ - |{ - | "action" : "KillDriverRequest", - | "clientSparkVersion" : "1.2.3", - | "driverId" : "driver_123" - |} - """.stripMargin - private val killDriverResponseJson = """ |{ - | "action" : "KillDriverResponse", - | "driverId" : "driver_123", + | "action" : "KillSubmissionResponse", | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", | "success" : "true" |} """.stripMargin - private val driverStatusRequestJson = - """ - |{ - | "action" : "DriverStatusRequest", - | "clientSparkVersion" : "1.2.3", - | "driverId" : "driver_123" - |} - """.stripMargin - private val driverStatusResponseJson = """ |{ - | "action" : "DriverStatusResponse", - | "driverId" : "driver_123", + | "action" : "SubmissionStatusResponse", | "driverState" : "RUNNING", | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", | "success" : "true", | "workerHostPort" : "1.2.3.4:7780", | "workerId" : "worker_123"