Skip to content

Commit

Permalink
Retry the submit-application request to multiple nodes (apache#69)
Browse files Browse the repository at this point in the history
* Retry the submit-application request to multiple nodes.

* Fix doc style comment

* Check node unschedulable, log retry failures
  • Loading branch information
mccheah authored and ash211 committed Mar 8, 2017
1 parent b2e6877 commit 6ee3be5
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,13 @@ private[spark] class Client(
DEFAULT_BLOCKMANAGER_PORT.toString)
val driverSubmitter = buildDriverSubmissionClient(kubernetesClient, service,
driverSubmitSslOptions)
val ping = Retry.retry(5, 5.seconds) {
val ping = Retry.retry(5, 5.seconds,
Some("Failed to contact the driver server")) {
driverSubmitter.ping()
}
ping onFailure {
case t: Throwable =>
logError("Ping failed to the driver server", t)
submitCompletedFuture.setException(t)
kubernetesClient.services().delete(service)
}
Expand Down Expand Up @@ -532,17 +534,6 @@ private[spark] class Client(
kubernetesClient: KubernetesClient,
service: Service,
driverSubmitSslOptions: SSLOptions): KubernetesSparkRestApi = {
val servicePort = service
.getSpec
.getPorts
.asScala
.filter(_.getName == SUBMISSION_SERVER_PORT_NAME)
.head
.getNodePort
// NodePort is exposed on every node, so just pick one of them.
// TODO be resilient to node failures and try all of them
val node = kubernetesClient.nodes.list.getItems.asScala.head
val nodeAddress = node.getStatus.getAddresses.asScala.head.getAddress
val urlScheme = if (driverSubmitSslOptions.enabled) {
"https"
} else {
Expand All @@ -551,15 +542,23 @@ private[spark] class Client(
" to secure this step.")
"http"
}
val servicePort = service.getSpec.getPorts.asScala
.filter(_.getName == SUBMISSION_SERVER_PORT_NAME)
.head.getNodePort
val nodeUrls = kubernetesClient.nodes.list.getItems.asScala
.filterNot(_.getSpec.getUnschedulable)
.flatMap(_.getStatus.getAddresses.asScala.map(address => {
s"$urlScheme://${address.getAddress}:$servicePort"
})).toArray
require(nodeUrls.nonEmpty, "No nodes found to contact the driver!")
val (trustManager, sslContext): (X509TrustManager, SSLContext) =
if (driverSubmitSslOptions.enabled) {
buildSslConnectionConfiguration(driverSubmitSslOptions)
} else {
(null, SSLContext.getDefault)
}
val url = s"$urlScheme://$nodeAddress:$servicePort"
HttpClientUtil.createClient[KubernetesSparkRestApi](
url,
uris = nodeUrls,
sslSocketFactory = sslContext.getSocketFactory,
trustContext = trustManager)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,36 @@ package org.apache.spark.deploy.kubernetes
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

private[spark] object Retry {
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging

private[spark] object Retry extends Logging {

private def retryableFuture[T]
(times: Int, interval: Duration)
(attempt: Int, maxAttempts: Int, interval: Duration, retryMessage: Option[String])
(f: => Future[T])
(implicit executionContext: ExecutionContext): Future[T] = {
f recoverWith {
case _ if times > 0 => {
Thread.sleep(interval.toMillis)
retryableFuture(times - 1, interval)(f)
}
case error: Throwable =>
if (attempt <= maxAttempts) {
retryMessage.foreach { message =>
logWarning(s"$message - attempt $attempt of $maxAttempts", error)
}
Thread.sleep(interval.toMillis)
retryableFuture(attempt + 1, maxAttempts, interval, retryMessage)(f)
} else {
Future.failed(retryMessage.map(message =>
new SparkException(s"$message - reached $maxAttempts attempts," +
s" and aborting task.", error)
).getOrElse(error))
}
}
}

def retry[T]
(times: Int, interval: Duration)
(times: Int, interval: Duration, retryMessage: Option[String] = None)
(f: => T)
(implicit executionContext: ExecutionContext): Future[T] = {
retryableFuture(times, interval)(Future[T] { f })
retryableFuture(1, times, interval, retryMessage)(Future[T] { f })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import javax.net.ssl.{SSLContext, SSLSocketFactory, X509TrustManager}

import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import feign.Feign
import feign.{Client, Feign, Request, Response}
import feign.Request.Options
import feign.jackson.{JacksonDecoder, JacksonEncoder}
import feign.jaxrs.JAXRSContract
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.status.api.v1.JacksonMessageWriter
private[spark] object HttpClientUtil {

def createClient[T: ClassTag](
uri: String,
uris: Array[String],
sslSocketFactory: SSLSocketFactory = SSLContext.getDefault.getSocketFactory,
trustContext: X509TrustManager = null,
readTimeoutMillis: Int = 20000,
Expand All @@ -45,13 +45,24 @@ private[spark] object HttpClientUtil {
.registerModule(new DefaultScalaModule)
.setDateFormat(JacksonMessageWriter.makeISODateFormat)
objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
val target = new MultiServerFeignTarget[T](uris)
val baseHttpClient = new feign.okhttp.OkHttpClient(httpClientBuilder.build())
val resetTargetHttpClient = new Client {
override def execute(request: Request, options: Options): Response = {
val response = baseHttpClient.execute(request, options)
if (response.status() >= 200 && response.status() < 300) {
target.reset()
}
response
}
}
Feign.builder()
.client(new feign.okhttp.OkHttpClient(httpClientBuilder.build()))
.client(resetTargetHttpClient)
.contract(new JAXRSContract)
.encoder(new JacksonEncoder(objectMapper))
.decoder(new JacksonDecoder(objectMapper))
.options(new Options(connectTimeoutMillis, readTimeoutMillis))
.target(clazz, uri)
.retryer(target)
.target(target)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.rest.kubernetes

import feign.{Request, RequestTemplate, RetryableException, Retryer, Target}
import scala.reflect.ClassTag
import scala.util.Random

private[kubernetes] class MultiServerFeignTarget[T : ClassTag](
private val servers: Seq[String]) extends Target[T] with Retryer {
require(servers.nonEmpty, "Must provide at least one server URI.")

private val threadLocalShuffledServers = new ThreadLocal[Seq[String]] {
override def initialValue(): Seq[String] = Random.shuffle(servers)
}

override def `type`(): Class[T] = {
implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
}

override def url(): String = threadLocalShuffledServers.get.head

/**
* Cloning the target is done on every request, for use on the current
* thread - thus it's important that clone returns a "fresh" target.
*/
override def clone(): Retryer = {
reset()
this
}

override def name(): String = {
s"${getClass.getSimpleName} with servers [${servers.mkString(",")}]"
}

override def apply(requestTemplate: RequestTemplate): Request = {
if (!requestTemplate.url().startsWith("http")) {
requestTemplate.insert(0, url())
}
requestTemplate.request()
}

override def continueOrPropagate(e: RetryableException): Unit = {
threadLocalShuffledServers.set(threadLocalShuffledServers.get.drop(1))
if (threadLocalShuffledServers.get.isEmpty) {
throw e
}
}

def reset(): Unit = {
threadLocalShuffledServers.set(Random.shuffle(servers))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private[spark] object Minikube extends Logging {
.build()
val sslContext = SSLUtils.sslContext(kubernetesConf)
val trustManager = SSLUtils.trustManagers(kubernetesConf)(0).asInstanceOf[X509TrustManager]
HttpClientUtil.createClient[T](url, sslContext.getSocketFactory, trustManager)
HttpClientUtil.createClient[T](Array(url), sslContext.getSocketFactory, trustManager)
}

def executeMinikubeSsh(command: String): Unit = {
Expand Down

0 comments on commit 6ee3be5

Please sign in to comment.