Skip to content

Commit

Permalink
Reduce duplicate naming in REST field
Browse files Browse the repository at this point in the history
This commit also fixes a the standalone REST protocol test, which
would fail with ClassCastException if the server returns error for
the same reason explained in the previous commit.
  • Loading branch information
Andrew Or committed Feb 1, 2015
1 parent ade28fd commit 9229433
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest
* A request to query the status of a driver in the REST application submission protocol.
*/
class DriverStatusRequest extends SubmitRestProtocolRequest {
private val driverId = new SubmitRestProtocolField[String]("driverId")
private val driverId = new SubmitRestProtocolField[String]
def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(driverId, "driverId")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest
* A response to the [[DriverStatusRequest]] in the REST application submission protocol.
*/
class DriverStatusResponse extends SubmitRestProtocolResponse {
private val driverId = new SubmitRestProtocolField[String]("driverId")
private val driverId = new SubmitRestProtocolField[String]
// standalone cluster mode only
private val driverState = new SubmitRestProtocolField[String]("driverState")
private val workerId = new SubmitRestProtocolField[String]("workerId")
private val workerHostPort = new SubmitRestProtocolField[String]("workerHostPort")
private val driverState = new SubmitRestProtocolField[String]
private val workerId = new SubmitRestProtocolField[String]
private val workerHostPort = new SubmitRestProtocolField[String]

def getDriverId: String = driverId.toString
def getDriverState: String = driverState.toString
Expand All @@ -39,6 +39,6 @@ class DriverStatusResponse extends SubmitRestProtocolResponse {

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(driverId, "driverId")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.deploy.rest

import com.fasterxml.jackson.annotation.JsonIgnore

/**
* An error response message used in the REST application submission protocol.
*/
class ErrorResponse extends SubmitRestProtocolResponse {
// request was unsuccessful
setSuccess("false")

// Don't bother logging success = false in the JSON
@JsonIgnore
override def getSuccess: String = super.getSuccess

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(message)
assert(!getSuccess.toBoolean, "The 'success' field cannot be true in an error response.")
assertFieldIsSet(message, "message")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest
* A request to kill a driver in the REST application submission protocol.
*/
class KillDriverRequest extends SubmitRestProtocolRequest {
private val driverId = new SubmitRestProtocolField[String]("driverId")
private val driverId = new SubmitRestProtocolField[String]
def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(driverId, "driverId")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest
* A response to the [[KillDriverRequest]] in the REST application submission protocol.
*/
class KillDriverResponse extends SubmitRestProtocolResponse {
private val driverId = new SubmitRestProtocolField[String]("driverId")
private val driverId = new SubmitRestProtocolField[String]
def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(driverId, "driverId")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@ import org.apache.spark.util.JsonProtocol
* A request to submit a driver in the REST application submission protocol.
*/
class SubmitDriverRequest extends SubmitRestProtocolRequest {
private val appName = new SubmitRestProtocolField[String]("appName")
private val appResource = new SubmitRestProtocolField[String]("appResource")
private val mainClass = new SubmitRestProtocolField[String]("mainClass")
private val jars = new SubmitRestProtocolField[String]("jars")
private val files = new SubmitRestProtocolField[String]("files")
private val pyFiles = new SubmitRestProtocolField[String]("pyFiles")
private val driverMemory = new SubmitRestProtocolField[String]("driverMemory")
private val driverCores = new SubmitRestProtocolField[Int]("driverCores")
private val driverExtraJavaOptions = new SubmitRestProtocolField[String]("driverExtraJavaOptions")
private val driverExtraClassPath = new SubmitRestProtocolField[String]("driverExtraClassPath")
private val driverExtraLibraryPath = new SubmitRestProtocolField[String]("driverExtraLibraryPath")
private val superviseDriver = new SubmitRestProtocolField[Boolean]("superviseDriver")
private val executorMemory = new SubmitRestProtocolField[String]("executorMemory")
private val totalExecutorCores = new SubmitRestProtocolField[Int]("totalExecutorCores")
private val appName = new SubmitRestProtocolField[String]
private val appResource = new SubmitRestProtocolField[String]
private val mainClass = new SubmitRestProtocolField[String]
private val jars = new SubmitRestProtocolField[String]
private val files = new SubmitRestProtocolField[String]
private val pyFiles = new SubmitRestProtocolField[String]
private val driverMemory = new SubmitRestProtocolField[String]
private val driverCores = new SubmitRestProtocolField[Int]
private val driverExtraJavaOptions = new SubmitRestProtocolField[String]
private val driverExtraClassPath = new SubmitRestProtocolField[String]
private val driverExtraLibraryPath = new SubmitRestProtocolField[String]
private val superviseDriver = new SubmitRestProtocolField[Boolean]
private val executorMemory = new SubmitRestProtocolField[String]
private val totalExecutorCores = new SubmitRestProtocolField[Int]

// Special fields
private val appArgs = new ArrayBuffer[String]
Expand Down Expand Up @@ -140,7 +140,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(appName)
assertFieldIsSet(appResource)
assertFieldIsSet(appName, "appName")
assertFieldIsSet(appResource, "appResource")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.spark.deploy.rest
* A response to the [[SubmitDriverRequest]] in the REST application submission protocol.
*/
class SubmitDriverResponse extends SubmitRestProtocolResponse {
private val driverId = new SubmitRestProtocolField[String]("driverId")
private val driverId = new SubmitRestProtocolField[String]
def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.deploy.rest
/**
* A field used in [[SubmitRestProtocolMessage]]s.
*/
class SubmitRestProtocolField[T](val name: String) {
class SubmitRestProtocolField[T] {
protected var value: Option[T] = None
def isSet: Boolean = value.isDefined
def getValue: Option[T] = value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.deploy.rest
import com.fasterxml.jackson.annotation._
import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility
import com.fasterxml.jackson.annotation.JsonInclude.Include
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature}
import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods._

Expand All @@ -42,7 +42,7 @@ abstract class SubmitRestProtocolMessage {
val messageType = Utils.getFormattedClassName(this)
protected val action: String = messageType
protected val sparkVersion: SubmitRestProtocolField[String]
protected val message = new SubmitRestProtocolField[String]("message")
protected val message = new SubmitRestProtocolField[String]

// Required for JSON de/serialization and not explicitly used
private def getAction: String = action
Expand All @@ -64,7 +64,8 @@ abstract class SubmitRestProtocolMessage {
def toJson: String = {
validate()
val mapper = new ObjectMapper
pretty(parse(mapper.writeValueAsString(this)))
mapper.enable(SerializationFeature.INDENT_OUTPUT)
mapper.writeValueAsString(this)
}

/**
Expand All @@ -85,14 +86,13 @@ abstract class SubmitRestProtocolMessage {
if (action == null) {
throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType")
}
assertFieldIsSet(sparkVersion)
}

/** Assert that the specified field is set in this message. */
protected def assertFieldIsSet(field: SubmitRestProtocolField[_]): Unit = {
protected def assertFieldIsSet(field: SubmitRestProtocolField[_], name: String): Unit = {
if (!field.isSet) {
throw new SubmitRestMissingFieldException(
s"Field '${field.name}' is missing in message $messageType.")
s"Field '$name' is missing in message $messageType.")
}
}

Expand Down Expand Up @@ -143,19 +143,23 @@ abstract class SubmitRestProtocolMessage {
* An abstract request sent from the client in the REST application submission protocol.
*/
abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage {
protected override val sparkVersion = new SubmitRestProtocolField[String]("client_spark_version")
protected override val sparkVersion = new SubmitRestProtocolField[String]
def getClientSparkVersion: String = sparkVersion.toString
def setClientSparkVersion(s: String): this.type = setField(sparkVersion, s)
override def getSparkVersion: String = getClientSparkVersion
override def setSparkVersion(s: String) = setClientSparkVersion(s)
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(sparkVersion, "clientSparkVersion")
}
}

/**
* An abstract response sent from the server in the REST application submission protocol.
*/
abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
protected override val sparkVersion = new SubmitRestProtocolField[String]("server_spark_version")
private val success = new SubmitRestProtocolField[Boolean]("success")
protected override val sparkVersion = new SubmitRestProtocolField[String]
private val success = new SubmitRestProtocolField[Boolean]

override def getSparkVersion: String = getServerSparkVersion
def getServerSparkVersion: String = sparkVersion.toString
Expand All @@ -167,7 +171,8 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(success)
assertFieldIsSet(sparkVersion, "serverSparkVersion")
assertFieldIsSet(success, "success")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B

test("kill empty driver") {
val response = client.killDriver(masterRestUrl, "driver-that-does-not-exist")
val killResponse = getResponse[KillDriverResponse](response, client)
val killResponse = getKillResponse(response)
val killSuccess = killResponse.getSuccess
assert(killSuccess === "false")
}
Expand All @@ -71,11 +71,11 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B
val size = 500
val driverId = submitApplication(resultsFile, numbers, size)
val response = client.killDriver(masterRestUrl, driverId)
val killResponse = getResponse[KillDriverResponse](response, client)
val killResponse = getKillResponse(response)
val killSuccess = killResponse.getSuccess
waitUntilFinished(driverId)
val response2 = client.requestDriverStatus(masterRestUrl, driverId)
val statusResponse = getResponse[DriverStatusResponse](response2, client)
val statusResponse = getStatusResponse(response2)
val statusSuccess = statusResponse.getSuccess
val driverState = statusResponse.getDriverState
assert(killSuccess === "true")
Expand All @@ -87,7 +87,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B

test("request status for empty driver") {
val response = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist")
val statusResponse = getResponse[DriverStatusResponse](response, client)
val statusResponse = getStatusResponse(response)
val statusSuccess = statusResponse.getSuccess
assert(statusSuccess === "false")
}
Expand Down Expand Up @@ -130,7 +130,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B
val args = new SparkSubmitArguments(commandLineArgs)
SparkSubmit.prepareSubmitEnvironment(args)
val response = client.submitDriver(args)
val submitResponse = getResponse[SubmitDriverResponse](response, client)
val submitResponse = getSubmitResponse(response)
submitResponse.getDriverId
}

Expand All @@ -140,7 +140,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B
val expireTime = System.currentTimeMillis + maxSeconds * 1000
while (!finished) {
val response = client.requestDriverStatus(masterRestUrl, driverId)
val statusResponse = getResponse[DriverStatusResponse](response, client)
val statusResponse = getStatusResponse(response)
val driverState = statusResponse.getDriverState
finished =
driverState != DriverState.SUBMITTED.toString &&
Expand All @@ -151,17 +151,30 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B
}
}

/** Return the response as the expected type, or fail with an informative error message. */
private def getResponse[T <: SubmitRestProtocolResponse](
response: SubmitRestProtocolResponse,
client: StandaloneRestClient): T = {
/** Return the response as a submit driver response, or fail with error otherwise. */
private def getSubmitResponse(response: SubmitRestProtocolResponse): SubmitDriverResponse = {
response match {
case error: ErrorResponse =>
fail(s"Error from the server:\n${error.getMessage}")
case _ =>
client.getResponse[T](response).getOrElse {
fail(s"Response type was unexpected: ${response.toJson}")
}
case s: SubmitDriverResponse => s
case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}")
case r => fail(s"Expected submit response. Actual: ${r.toJson}")
}
}

/** Return the response as a kill driver response, or fail with error otherwise. */
private def getKillResponse(response: SubmitRestProtocolResponse): KillDriverResponse = {
response match {
case k: KillDriverResponse => k
case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}")
case r => fail(s"Expected kill response. Actual: ${r.toJson}")
}
}

/** Return the response as a driver status response, or fail with error otherwise. */
private def getStatusResponse(response: SubmitRestProtocolResponse): DriverStatusResponse = {
response match {
case s: DriverStatusResponse => s
case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}")
case r => fail(s"Expected status response. Actual: ${r.toJson}")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,15 @@ class SubmitRestProtocolSuite extends FunSuite {
}

test("response to and from JSON") {
val response = new DummyResponse().setSparkVersion("3.3.4")
val response = new DummyResponse()
.setSparkVersion("3.3.4")
.setSuccess("true")
val json = response.toJson
assertJsonEquals(json, dummyResponseJson)
val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse])
assert(newResponse.getSparkVersion === "3.3.4")
assert(newResponse.getServerSparkVersion === "3.3.4")
assert(newResponse.getSuccess === "true")
assert(newResponse.getMessage === null)
}

Expand Down Expand Up @@ -300,7 +303,8 @@ class SubmitRestProtocolSuite extends FunSuite {
"""
|{
| "action" : "DummyResponse",
| "serverSparkVersion" : "3.3.4"
| "serverSparkVersion" : "3.3.4",
| "success": "true"
|}
""".stripMargin

Expand Down Expand Up @@ -403,9 +407,9 @@ class SubmitRestProtocolSuite extends FunSuite {

private class DummyResponse extends SubmitRestProtocolResponse
private class DummyRequest extends SubmitRestProtocolRequest {
private val active = new SubmitRestProtocolField[Boolean]("active")
private val age = new SubmitRestProtocolField[Int]("age")
private val name = new SubmitRestProtocolField[String]("name")
private val active = new SubmitRestProtocolField[Boolean]
private val age = new SubmitRestProtocolField[Int]
private val name = new SubmitRestProtocolField[String]

def getActive: String = active.toString
def getAge: String = age.toString
Expand All @@ -417,8 +421,8 @@ private class DummyRequest extends SubmitRestProtocolRequest {

protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(name)
assertFieldIsSet(age)
assertFieldIsSet(name, "name")
assertFieldIsSet(age, "age")
assert(age.getValue.get > 5, "Not old enough!")
}
}

0 comments on commit 9229433

Please sign in to comment.