diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index b136b62ed9296..6433a07de8293 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -27,6 +27,8 @@ import com.google.common.base.Charsets import org.eclipse.jetty.server.Server import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.{AkkaUtils, Utils} @@ -131,6 +133,22 @@ private[spark] abstract class StandaloneRestServlet extends HttpServlet with Log out.close() } + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown.asInstanceOf[JObject].obj.map { case (k, _) => k }.toArray + } + /** Return a human readable String representation of the exception. */ protected def formatException(e: Exception): String = { val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") @@ -259,6 +277,11 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) val responseMessage = handleSubmit(requestMessage, responseServlet) + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + responseMessage.unknownFields = unknownFields + } responseServlet.setContentType("application/json") responseServlet.setCharacterEncoding("utf-8") responseServlet.setStatus(HttpServletResponse.SC_OK) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index cdf9c0a82624d..3a050b7d709e4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -23,6 +23,7 @@ package org.apache.spark.deploy.rest private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { var serverSparkVersion: String = null var success: String = null + var unknownFields: Array[String] = null protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(serverSparkVersion, "serverSparkVersion")