Skip to content

Commit

Permalink
Replace SubmitRestProtocolAction with class name
Browse files Browse the repository at this point in the history
This makes it easier for users to define their own messages.
Rather than forcing the users to introduce an action class that
extends our inflexible enum, they can now implement custom logic
in their REST servers depending on the action of the request.
  • Loading branch information
Andrew Or committed Jan 29, 2015
1 parent df90e8b commit 8d43486
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
package org.apache.spark.deploy.rest

class DriverStatusRequest extends SubmitRestProtocolRequest {
protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_REQUEST
private val driverId = new SubmitRestProtocolField[String]

def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)

override def validate(): Unit = {
super.validate()
assertFieldIsSet(driverId, "driver_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.deploy.rest

class DriverStatusResponse extends SubmitRestProtocolResponse {
protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE
private val driverId = new SubmitRestProtocolField[String]
private val success = new SubmitRestProtocolField[Boolean]
private val driverState = new SubmitRestProtocolField[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.deploy.rest

class ErrorResponse extends SubmitRestProtocolResponse {
protected override val action = SubmitRestProtocolAction.ERROR
override def validate(): Unit = {
super.validate()
assertFieldIsSet(message, "message")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
package org.apache.spark.deploy.rest

class KillDriverRequest extends SubmitRestProtocolRequest {
protected override val action = SubmitRestProtocolAction.KILL_DRIVER_REQUEST
private val driverId = new SubmitRestProtocolField[String]

def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)

override def validate(): Unit = {
super.validate()
assertFieldIsSet(driverId, "driver_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.deploy.rest

class KillDriverResponse extends SubmitRestProtocolResponse {
protected override val action = SubmitRestProtocolAction.KILL_DRIVER_RESPONSE
private val driverId = new SubmitRestProtocolField[String]
private val success = new SubmitRestProtocolField[Boolean]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.util.JsonProtocol

class SubmitDriverRequest extends SubmitRestProtocolRequest {
protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST
private val appName = new SubmitRestProtocolField[String]
private val appResource = new SubmitRestProtocolField[String]
private val mainClass = new SubmitRestProtocolField[String]
Expand Down Expand Up @@ -62,7 +61,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {
def getExecutorMemory: String = executorMemory.toString
def getTotalExecutorCores: String = totalExecutorCores.toString

// Special getters required for JSON de/serialization
// Special getters required for JSON serialization
@JsonProperty("appArgs")
private def getAppArgsJson: String = arrayToJson(getAppArgs)
@JsonProperty("sparkProperties")
Expand All @@ -85,7 +84,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {
def setExecutorMemory(s: String): this.type = setField(executorMemory, s)
def setTotalExecutorCores(s: String): this.type = setNumericField(totalExecutorCores, s)

// Special setters required for JSON de/serialization
// Special setters required for JSON deserialization
@JsonProperty("appArgs")
private def setAppArgsJson(s: String): Unit = {
appArgs.clear()
Expand Down Expand Up @@ -116,11 +115,11 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {
def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this }

private def arrayToJson(arr: Array[String]): String = {
if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else { null }
if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else null
}

private def mapToJson(map: Map[String, String]): String = {
if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else { null }
if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null
}

override def validate(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.deploy.rest

class SubmitDriverResponse extends SubmitRestProtocolResponse {
protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE
private val success = new SubmitRestProtocolField[Boolean]
private val driverId = new SubmitRestProtocolField[String]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,6 @@

package org.apache.spark.deploy.rest

/**
* All possible values of the ACTION field in a SubmitRestProtocolMessage.
*/
abstract class SubmitRestProtocolAction
object SubmitRestProtocolAction {
case object SUBMIT_DRIVER_REQUEST extends SubmitRestProtocolAction
case object SUBMIT_DRIVER_RESPONSE extends SubmitRestProtocolAction
case object KILL_DRIVER_REQUEST extends SubmitRestProtocolAction
case object KILL_DRIVER_RESPONSE extends SubmitRestProtocolAction
case object DRIVER_STATUS_REQUEST extends SubmitRestProtocolAction
case object DRIVER_STATUS_RESPONSE extends SubmitRestProtocolAction
case object ERROR extends SubmitRestProtocolAction
private val allActions =
Seq(SUBMIT_DRIVER_REQUEST, SUBMIT_DRIVER_RESPONSE, KILL_DRIVER_REQUEST,
KILL_DRIVER_RESPONSE, DRIVER_STATUS_REQUEST, DRIVER_STATUS_RESPONSE, ERROR)
private val allActionsMap = allActions.map { a => (a.toString, a) }.toMap

def fromString(action: String): SubmitRestProtocolAction = {
allActionsMap.get(action).getOrElse {
throw new IllegalArgumentException(s"Unknown action $action")
}
}
}

class SubmitRestProtocolField[T] {
protected var value: Option[T] = None
def isSet: Boolean = value.isDefined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.util.Utils
import org.apache.spark.deploy.rest.SubmitRestProtocolAction._

@JsonInclude(Include.NON_NULL)
@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY)
Expand All @@ -34,12 +33,12 @@ abstract class SubmitRestProtocolMessage {
import SubmitRestProtocolMessage._

private val messageType = Utils.getFormattedClassName(this)
protected val action: SubmitRestProtocolAction
protected val action: String = camelCaseToUnderscores(decapitalize(messageType))
protected val sparkVersion = new SubmitRestProtocolField[String]
protected val message = new SubmitRestProtocolField[String]

// Required for JSON de/serialization and not explicitly used
private def getAction: String = action.toString
private def getAction: String = action
private def setAction(s: String): this.type = this

// Spark version implementation depends on whether this is a request or a response
Expand Down Expand Up @@ -124,24 +123,22 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {

object SubmitRestProtocolMessage {
private val mapper = new ObjectMapper
private val packagePrefix = this.getClass.getPackage.getName

def fromJson(json: String): SubmitRestProtocolMessage = {
val fields = parse(json).asInstanceOf[JObject].obj
val action = fields
def parseAction(json: String): String = {
parse(json).asInstanceOf[JObject].obj
.find { case (f, _) => f == "action" }
.map { case (_, v) => v.asInstanceOf[JString].s }
.getOrElse {
throw new IllegalArgumentException(s"Could not find action field in message:\n$json")
}
val clazz = SubmitRestProtocolAction.fromString(action) match {
case SUBMIT_DRIVER_REQUEST => classOf[SubmitDriverRequest]
case SUBMIT_DRIVER_RESPONSE => classOf[SubmitDriverResponse]
case KILL_DRIVER_REQUEST => classOf[KillDriverRequest]
case KILL_DRIVER_RESPONSE => classOf[KillDriverResponse]
case DRIVER_STATUS_REQUEST => classOf[DriverStatusRequest]
case DRIVER_STATUS_RESPONSE => classOf[DriverStatusResponse]
case ERROR => classOf[ErrorResponse]
throw new IllegalArgumentException(s"Could not find action field in message:\n$json")
}
}

def fromJson(json: String): SubmitRestProtocolMessage = {
val action = parseAction(json)
val className = underscoresToCamelCase(action).capitalize
val clazz = Class.forName(packagePrefix + "." + className)
.asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
fromJson(json, clazz)
}

Expand Down Expand Up @@ -178,6 +175,14 @@ object SubmitRestProtocolMessage {
}
newString.toString()
}

private def decapitalize(s: String): String = {
if (s != null && s.nonEmpty) {
s(0).toLower + s.substring(1)
} else {
s
}
}
}

object SubmitRestProtocolRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ package org.apache.spark.deploy.rest
import org.json4s.jackson.JsonMethods._
import org.scalatest.FunSuite

case object DUMMY_REQUEST extends SubmitRestProtocolAction
case object DUMMY_RESPONSE extends SubmitRestProtocolAction

class DummyRequest extends SubmitRestProtocolRequest {
protected override val action = DUMMY_REQUEST
private val active = new SubmitRestProtocolField[Boolean]
private val age = new SubmitRestProtocolField[Int]
private val name = new SubmitRestProtocolField[String]
Expand All @@ -45,9 +41,7 @@ class DummyRequest extends SubmitRestProtocolRequest {
}
}

class DummyResponse extends SubmitRestProtocolResponse {
protected override val action = DUMMY_RESPONSE
}
class DummyResponse extends SubmitRestProtocolResponse

/**
* Tests for the stable application submission REST protocol.
Expand Down Expand Up @@ -325,7 +319,7 @@ class SubmitRestProtocolSuite extends FunSuite {
private val dummyRequestJson =
"""
|{
| "action" : "DUMMY_REQUEST",
| "action" : "dummy_request",
| "active" : "true",
| "age" : "25",
| "client_spark_version" : "1.2.3",
Expand All @@ -336,15 +330,15 @@ class SubmitRestProtocolSuite extends FunSuite {
private val dummyResponseJson =
"""
|{
| "action" : "DUMMY_RESPONSE",
| "action" : "dummy_response",
| "server_spark_version" : "3.3.4"
|}
""".stripMargin

private val submitDriverRequestJson =
"""
|{
| "action" : "SUBMIT_DRIVER_REQUEST",
| "action" : "submit_driver_request",
| "app_args" : "[\"two slices\",\"a hint of cinnamon\"]",
| "app_name" : "SparkPie",
| "app_resource" : "honey-walnut-cherry.jar",
Expand All @@ -369,7 +363,7 @@ class SubmitRestProtocolSuite extends FunSuite {
private val submitDriverResponseJson =
"""
|{
| "action" : "SUBMIT_DRIVER_RESPONSE",
| "action" : "submit_driver_response",
| "driver_id" : "driver_123",
| "server_spark_version" : "1.2.3",
| "success" : "true"
Expand All @@ -379,7 +373,7 @@ class SubmitRestProtocolSuite extends FunSuite {
private val killDriverRequestJson =
"""
|{
| "action" : "KILL_DRIVER_REQUEST",
| "action" : "kill_driver_request",
| "client_spark_version" : "1.2.3",
| "driver_id" : "driver_123"
|}
Expand All @@ -388,7 +382,7 @@ class SubmitRestProtocolSuite extends FunSuite {
private val killDriverResponseJson =
"""
|{
| "action" : "KILL_DRIVER_RESPONSE",
| "action" : "kill_driver_response",
| "driver_id" : "driver_123",
| "server_spark_version" : "1.2.3",
| "success" : "true"
Expand All @@ -398,7 +392,7 @@ class SubmitRestProtocolSuite extends FunSuite {
private val driverStatusRequestJson =
"""
|{
| "action" : "DRIVER_STATUS_REQUEST",
| "action" : "driver_status_request",
| "client_spark_version" : "1.2.3",
| "driver_id" : "driver_123"
|}
Expand All @@ -407,7 +401,7 @@ class SubmitRestProtocolSuite extends FunSuite {
private val driverStatusResponseJson =
"""
|{
| "action" : "DRIVER_STATUS_RESPONSE",
| "action" : "driver_status_response",
| "driver_id" : "driver_123",
| "driver_state" : "RUNNING",
| "server_spark_version" : "1.2.3",
Expand All @@ -420,7 +414,7 @@ class SubmitRestProtocolSuite extends FunSuite {
private val errorJson =
"""
|{
| "action" : "ERROR",
| "action" : "error_response",
| "message" : "Field not found in submit request: X",
| "server_spark_version" : "1.2.3"
|}
Expand Down

0 comments on commit 8d43486

Please sign in to comment.