Skip to content

Commit

Permalink
Include unknown fields, if any, in server response
Browse files Browse the repository at this point in the history
... in case the client wants to propagate this to the user in
the future.
  • Loading branch information
Andrew Or committed Feb 5, 2015
1 parent 9fee16f commit cbd670b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import com.google.common.base.Charsets
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
import org.apache.spark.util.{AkkaUtils, Utils}
Expand Down Expand Up @@ -131,6 +133,22 @@ private[spark] abstract class StandaloneRestServlet extends HttpServlet with Log
out.close()
}

/**
* Return any fields in the client request message that the server does not know about.
*
* The mechanism for this is to reconstruct the JSON on the server side and compare the
* diff between this JSON and the one generated on the client side. Any fields that are
* only in the client JSON are treated as unexpected.
*/
protected def findUnknownFields(
requestJson: String,
requestMessage: SubmitRestProtocolMessage): Array[String] = {
val clientSideJson = parse(requestJson)
val serverSideJson = parse(requestMessage.toJson)
val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson)
unknown.asInstanceOf[JObject].obj.map { case (k, _) => k }.toArray
}

/** Return a human readable String representation of the exception. */
protected def formatException(e: Exception): String = {
val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n")
Expand Down Expand Up @@ -259,6 +277,11 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString
val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
val responseMessage = handleSubmit(requestMessage, responseServlet)
val unknownFields = findUnknownFields(requestMessageJson, requestMessage)
if (unknownFields.nonEmpty) {
// If there are fields that the server does not know about, warn the client
responseMessage.unknownFields = unknownFields
}
responseServlet.setContentType("application/json")
responseServlet.setCharacterEncoding("utf-8")
responseServlet.setStatus(HttpServletResponse.SC_OK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package org.apache.spark.deploy.rest
private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
var serverSparkVersion: String = null
var success: String = null
var unknownFields: Array[String] = null
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(serverSparkVersion, "serverSparkVersion")
Expand Down

0 comments on commit cbd670b

Please sign in to comment.