diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 848b62f9de71b..6db95a45189b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -102,9 +102,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) + val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout.duration) .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + val statusResponse = timeout.awaitResult(statusFuture) statusResponse.found match { case false => println(s"ERROR: Cluster master did not recognize $driverId") diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 43c8a934c311a..773ffdb778e03 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -194,8 +194,9 @@ private[spark] class AppClient( if (actor != null) { try { val timeout = RpcUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + val future = actor.ask(StopAppClient)(timeout.duration) + // TODO(bryanc) - RpcTimeout use awaitResult ??? + Await.result(future, timeout.duration) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fccceb3ea528b..01015af3e8dda 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -23,7 +23,6 @@ import java.text.SimpleDateFormat import java.util.Date import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random @@ -940,8 +939,8 @@ private[deploy] object Master extends Logging { val actor = actorSystem.actorOf( Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) val timeout = RpcUtils.askTimeout(conf) - val portsRequest = actor.ask(BoundPortsRequest)(timeout) - val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] + val portsRequest = actor.ask(BoundPortsRequest)(timeout.duration) + val portsResponse = timeout.awaitResult(portsRequest).asInstanceOf[BoundPortsResponse] (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 06e265f99e231..f3a8af32550b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node import akka.pattern.ask @@ -38,8 +37,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val stateFuture = (master ? RequestMasterState)(timeout.duration).mapTo[MasterStateResponse] + val state = timeout.awaitResult(stateFuture) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 756927682cd24..f0b270d799d23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node import akka.pattern.ask @@ -36,8 +35,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private val timeout = parent.timeout def getMasterState: MasterStateResponse = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - Await.result(stateFuture, timeout) + val stateFuture = (master ? RequestMasterState)(timeout.duration).mapTo[MasterStateResponse] + timeout.awaitResult(stateFuture) } override def renderJson(request: HttpServletRequest): JValue = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 9f9f27d71e1ae..14e46d95e143a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node import akka.pattern.ask @@ -36,14 +35,14 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { private val timeout = parent.timeout override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val stateFuture = (workerActor ? RequestWorkerState)(timeout.duration).mapTo[WorkerStateResponse] + val workerState = timeout.awaitResult(stateFuture) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val stateFuture = (workerActor ? RequestWorkerState)(timeout.duration).mapTo[WorkerStateResponse] + val workerState = timeout.awaitResult(stateFuture) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 69181edb9ad44..4f5df21499adc 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -17,8 +17,7 @@ package org.apache.spark.rpc -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.spark.util.RpcUtils @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to @@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts var attempts = 0 var lastException: Exception = null @@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) attempts += 1 try { val future = ask[T](message, timeout) - val result = Await.result(future, timeout) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -110,7 +109,10 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) lastException = e logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } - Thread.sleep(retryWaitMs) + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } } throw new SparkException( diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7ec..ca1a31cef8cfd 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,7 +18,10 @@ package org.apache.spark.rpc import java.net.URI +import java.util.concurrent.TimeoutException +import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration._ import scala.concurrent.{Await, Future} import scala.language.postfixOps @@ -94,7 +97,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ def setupEndpointRefByURI(uri: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout.duration) } /** @@ -182,3 +185,68 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + + +/** + * Associates a timeout with a configuration property so that a TimeoutException can be + * traced back to the controlling property. + * @param timeout timeout duration in seconds + * @param description description to be displayed in a timeout exception + */ +private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) { + + /** Get the timeout duration */ + def duration: FiniteDuration = timeout + + /** Get the message associated with this timeout */ + def message: String = description + + /** Amends the standard message of TimeoutException to include the description */ + def amend(te: TimeoutException): TimeoutException = { + new TimeoutException(te.getMessage() + " " + description) + } + + /** Wait on a future result to catch and amend a TimeoutException */ + def awaitResult[T](future: Future[T]): T = { + try { + Await.result(future, duration) + } + catch { + case te: TimeoutException => + throw amend(te) + } + } + + // TODO(bryanc) wrap Await.ready also +} + +object RpcTimeout { + + private[this] val messagePrefix = "This timeout is controlled by " + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, messagePrefix + timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, messagePrefix + timeoutProp) + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index ba0d468f111ef..34ea6103e4abb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -212,7 +211,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) } @@ -293,9 +292,9 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { import scala.concurrent.ExecutionContext.Implicits.global - actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { case msg @ AkkaMessage(message, reply) => if (reply) { logError(s"Receive $msg but the sender cannot reply") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index a85e1c7632973..12be1beccde1b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.rpc.RpcEndpointRef @@ -105,7 +105,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") } if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -117,7 +117,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") } if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -131,7 +131,7 @@ class BlockManagerMaster( s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}") } if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -169,7 +169,7 @@ class BlockManagerMaster( val response = driverEndpoint. askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip - val result = Await.result(Future.sequence(futures), timeout) + val result = timeout.awaitResult(Future.sequence(futures)) if (result == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } @@ -192,7 +192,7 @@ class BlockManagerMaster( askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + timeout.awaitResult(future) } /** Stop the driver endpoint, called only on the Spark driver node */ diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index de3316d083a22..8b53146b0f969 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,9 +17,9 @@ package org.apache.spark.util +import org.apache.spark.rpc.RpcTimeout + import scala.collection.JavaConversions.mapAsJavaMap -import scala.concurrent.Await -import scala.concurrent.duration.FiniteDuration import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -147,7 +147,7 @@ private[spark] object AkkaUtils extends Logging { def askWithReply[T]( message: Any, actor: ActorRef, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) } @@ -160,7 +160,7 @@ private[spark] object AkkaUtils extends Logging { actor: ActorRef, maxAttempts: Int, retryInterval: Long, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts if (actor == null) { throw new SparkException(s"Error sending message [message = $message]" + @@ -171,8 +171,8 @@ private[spark] object AkkaUtils extends Logging { while (attempts < maxAttempts) { attempts += 1 try { - val future = actor.ask(message)(timeout) - val result = Await.result(future, timeout) + val future = actor.ask(message)(timeout.duration) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -200,7 +200,7 @@ private[spark] object AkkaUtils extends Logging { val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def makeExecutorRef( @@ -214,7 +214,7 @@ private[spark] object AkkaUtils extends Logging { val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def protocol(actorSystem: ActorSystem): String = { diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index f16cc8e7e42c6..a853f39d9a493 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,11 +17,10 @@ package org.apache.spark.util -import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} object RpcUtils { @@ -47,14 +46,24 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ - def askTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.askTimeout", - conf.get("spark.network.timeout", "120s")) seconds + def askTimeout(conf: SparkConf): RpcTimeout = { + try { + RpcTimeout(conf, "spark.rpc.askTimeout") + } + catch { + case _: Throwable => + RpcTimeout(conf, "spark.network.timeout", "120s") + } } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ - def lookupTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.lookupTimeout", - conf.get("spark.network.timeout", "120s")) seconds + def lookupTimeout(conf: SparkConf): RpcTimeout = { + try { + RpcTimeout(conf, "spark.rpc.lookupTimeout") + } + catch { + case _: Throwable => + RpcTimeout(conf, "spark.network.timeout", "120s") + } } }