Skip to content

Commit

Permalink
Use specific HTTP response codes on error
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Or committed Feb 4, 2015
1 parent f98660b commit 792e112
Showing 1 changed file with 106 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest}

import scala.io.Source

import akka.actor.ActorRef
import com.google.common.base.Charsets
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
Expand All @@ -38,11 +37,12 @@ import org.apache.spark.deploy.master.Master
/**
* A server that responds to requests submitted by the [[StandaloneRestClient]].
* This is intended to be embedded in the standalone Master and used in cluster mode only.
*
* When an error occurs, this server sends an error response with an appropriate message
* back to the client. If the construction of this error message itself is faulty, the
* server indicates internal error through the response code.
*/
private[spark] class StandaloneRestServer(
master: Master,
host: String,
requestedPort: Int)
private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int)
extends Logging {

import StandaloneRestServer._
Expand Down Expand Up @@ -98,34 +98,36 @@ private object StandaloneRestServer {
/**
* An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
*/
private[spark] abstract class StandaloneRestServlet(master: Master)
extends HttpServlet with Logging {
private[spark] abstract class StandaloneRestServlet extends HttpServlet with Logging {

protected val conf: SparkConf = master.conf
protected val masterActor: ActorRef = master.self
protected val masterUrl: String = master.masterUrl
protected val askTimeout = AkkaUtils.askTimeout(conf)
/** Service a request. If an exception is thrown in the process, indicate server error. */
protected override def service(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
try {
super.service(request, response)
} catch {
case e: Exception =>
logError("Exception while handling request", e)
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
}
}

/**
* Serialize the given response message to JSON and send it through the response servlet.
* This validates the response before sending it to ensure it is properly constructed.
*/
protected def handleResponse(
protected def sendResponse(
responseMessage: SubmitRestProtocolResponse,
responseServlet: HttpServletResponse): Unit = {
try {
val message = validateResponse(responseMessage)
responseServlet.setContentType("application/json")
responseServlet.setCharacterEncoding("utf-8")
responseServlet.setStatus(HttpServletResponse.SC_OK)
val content = message.toJson.getBytes(Charsets.UTF_8)
val out = new DataOutputStream(responseServlet.getOutputStream)
out.write(content)
out.close()
} catch {
case e: Exception =>
logError("Exception encountered when handling response.", e)
}
val message = validateResponse(responseMessage, responseServlet)
responseServlet.setContentType("application/json")
responseServlet.setCharacterEncoding("utf-8")
responseServlet.setStatus(HttpServletResponse.SC_OK)
val content = message.toJson.getBytes(Charsets.UTF_8)
val out = new DataOutputStream(responseServlet.getOutputStream)
out.write(content)
out.close()
}

/** Return a human readable String representation of the exception. */
Expand All @@ -147,12 +149,15 @@ private[spark] abstract class StandaloneRestServlet(master: Master)
* If it is, simply return the response as is. Otherwise, return an error response
* to propagate the exception back to the client.
*/
private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
private def validateResponse(
responseMessage: SubmitRestProtocolResponse,
responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
try {
response.validate()
response
responseMessage.validate()
responseMessage
} catch {
case e: Exception =>
responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
handleError("Internal server error: " + formatException(e))
}
}
Expand All @@ -161,7 +166,8 @@ private[spark] abstract class StandaloneRestServlet(master: Master)
/**
* A servlet for handling kill requests passed to the [[StandaloneRestServer]].
*/
private[spark] class KillRequestServlet(master: Master) extends StandaloneRestServlet(master) {
private[spark] class KillRequestServlet(master: Master) extends StandaloneRestServlet {
private val askTimeout = AkkaUtils.askTimeout(master.conf)

/**
* If a submission ID is specified in the URL, have the Master kill the corresponding
Expand All @@ -170,24 +176,20 @@ private[spark] class KillRequestServlet(master: Master) extends StandaloneRestSe
protected override def doPost(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
try {
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleKill(submissionId)
} else {
handleError("Submission ID is missing in kill request")
}
handleResponse(responseMessage, response)
} catch {
case e: Exception =>
logError("Exception encountered when handling kill request", e)
}
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleKill(submissionId)
} else {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in kill request")
}
sendResponse(responseMessage, response)
}

private def handleKill(submissionId: String): KillSubmissionResponse = {
val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
DeployMessages.RequestKillDriver(submissionId), master.self, askTimeout)
val k = new KillSubmissionResponse
k.serverSparkVersion = sparkVersion
k.message = response.message
Expand All @@ -200,7 +202,8 @@ private[spark] class KillRequestServlet(master: Master) extends StandaloneRestSe
/**
* A servlet for handling status requests passed to the [[StandaloneRestServer]].
*/
private[spark] class StatusRequestServlet(master: Master) extends StandaloneRestServlet(master) {
private[spark] class StatusRequestServlet(master: Master) extends StandaloneRestServlet {
private val askTimeout = AkkaUtils.askTimeout(master.conf)

/**
* If a submission ID is specified in the URL, request the status of the corresponding
Expand All @@ -209,24 +212,20 @@ private[spark] class StatusRequestServlet(master: Master) extends StandaloneRest
protected override def doGet(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
try {
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleStatus(submissionId)
} else {
handleError("Submission ID is missing in status request")
}
handleResponse(responseMessage, response)
} catch {
case e: Exception =>
logError("Exception encountered when handling status request", e)
}
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleStatus(submissionId)
} else {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in status request")
}
sendResponse(responseMessage, response)
}

private def handleStatus(submissionId: String): SubmissionStatusResponse = {
val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
DeployMessages.RequestDriverStatus(submissionId), master.self, askTimeout)
val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
val d = new SubmissionStatusResponse
d.serverSparkVersion = sparkVersion
Expand All @@ -243,7 +242,8 @@ private[spark] class StatusRequestServlet(master: Master) extends StandaloneRest
/**
* A servlet for handling submit requests passed to the [[StandaloneRestServer]].
*/
private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRestServlet(master) {
private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRestServlet {
private val askTimeout = AkkaUtils.askTimeout(master.conf)

/**
* Submit an application to the Master with parameters specified in the request message.
Expand All @@ -253,46 +253,48 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
* Otherwise, return error instead.
*/
protected override def doPost(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
try {
val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString
val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
.asInstanceOf[SubmitRestProtocolRequest]
val responseMessage = handleSubmit(requestMessage)
response.setContentType("application/json")
response.setCharacterEncoding("utf-8")
response.setStatus(HttpServletResponse.SC_OK)
val content = responseMessage.toJson.getBytes(Charsets.UTF_8)
val out = new DataOutputStream(response.getOutputStream)
out.write(content)
out.close()
} catch {
case e: Exception => logError("Exception while handling request", e)
}
requestServlet: HttpServletRequest,
responseServlet: HttpServletResponse): Unit = {
val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString
val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
.asInstanceOf[SubmitRestProtocolRequest]
val responseMessage = handleSubmit(requestMessage, responseServlet)
responseServlet.setContentType("application/json")
responseServlet.setCharacterEncoding("utf-8")
responseServlet.setStatus(HttpServletResponse.SC_OK)
val content = responseMessage.toJson.getBytes(Charsets.UTF_8)
val out = new DataOutputStream(responseServlet.getOutputStream)
out.write(content)
out.close()
}

private def handleSubmit(request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = {
private def handleSubmit(
requestMessage: SubmitRestProtocolRequest,
responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
// The response should have already been validated on the client.
// In case this is not true, validate it ourselves to avoid potential NPEs.
try {
request.validate()
request match {
case submitRequest: CreateSubmissionRequest =>
val driverDescription = buildDriverDescription(submitRequest)
val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout)
val submitResponse = new CreateSubmissionResponse
submitResponse.serverSparkVersion = sparkVersion
submitResponse.message = response.message
submitResponse.success = response.success.toString
submitResponse.submissionId = response.driverId.orNull
submitResponse
case unexpected => handleError(
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
}
requestMessage.validate()
} catch {
case e: Exception => handleError(formatException(e))
case e: SubmitRestProtocolException =>
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError(formatException(e))
}
requestMessage match {
case submitRequest: CreateSubmissionRequest =>
val driverDescription = buildDriverDescription(submitRequest)
val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
DeployMessages.RequestSubmitDriver(driverDescription), master.self, askTimeout)
val submitResponse = new CreateSubmissionResponse
submitResponse.serverSparkVersion = sparkVersion
submitResponse.message = response.message
submitResponse.success = response.success.toString
submitResponse.submissionId = response.driverId.orNull
submitResponse
case unexpected =>
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError(
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
}
}

Expand Down Expand Up @@ -331,7 +333,7 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
// Translate all fields to the relevant Spark properties
val conf = new SparkConf(false)
.setAll(sparkProperties)
.set("spark.master", masterUrl)
.set("spark.master", master.masterUrl)
.set("spark.app.name", appName)
jars.foreach { j => conf.set("spark.jars", j) }
files.foreach { f => conf.set("spark.files", f) }
Expand Down Expand Up @@ -362,19 +364,25 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
/**
* A default servlet that handles error cases that are not captured by other servlets.
*/
private[spark] class ErrorServlet extends HttpServlet {
private[spark] class ErrorServlet extends StandaloneRestServlet {
private val expectedVersion = StandaloneRestServer.PROTOCOL_VERSION
override def service(request: HttpServletRequest, response: HttpServletResponse): Unit = {
protected override def service(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val path = request.getPathInfo
val parts = path.stripPrefix("/").split("/")
if (parts.nonEmpty) {
val version = parts.head
if (version != expectedVersion) {
response.sendError(800, s"Incompatible protocol version $version")
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
val error = handleError(s"Incompatible protocol version $version")
sendResponse(error, response)
return
}
}
response.sendError(801,
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
val error = handleError(
s"Unexpected path $path: Please submit requests through /$expectedVersion/submissions/")
sendResponse(error, response)
}
}

0 comments on commit 792e112

Please sign in to comment.