Skip to content

Commit

Permalink
[SPARK-44275][CONNECT] Add configurable retry mechanism to Scala Spar…
Browse files Browse the repository at this point in the history
…k Connect

### What changes were proposed in this pull request?
This PR introduces a configurable retry mechanism for the Scala `SparkConnectClient`.
Parameters for the exponential backoff and a filter for exceptions to retry are passed to the client via the existing (extended) `Configuration` class. By default, only RuntimeStatusException: Unavailable triggers a retry - this mirrors the default retry behavior in the Python client.

One might want to move the retry logic into the GRPC stub that will potentially be introduced [here](apache#41743).

### Why are the changes needed?
There are a few existing exceptions that one might want to handle with a retry.
For example, this would allow one to not exit with an exception when executing a command while the cluster is still starting.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tests included.

Closes apache#41829 from dillitz/spark-44275-retries.

Authored-by: Robert Dillitz <robert.dillitz@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
dillitz authored and HyukjinKwon committed Jul 7, 2023
1 parent 17fac56 commit e3c2bf5
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import scala.util.control.NonFatal

import Artifact._
import com.google.protobuf.ByteString
import io.grpc.ManagedChannel
import io.grpc.stub.StreamObserver
import org.apache.commons.codec.digest.DigestUtils.sha256Hex

Expand All @@ -44,20 +43,23 @@ import org.apache.spark.util.{SparkFileUtils, SparkThreadUtils}
* The Artifact Manager is responsible for handling and transferring artifacts from the local
* client to the server (local/remote).
* @param userContext
* The user context the artifact manager operates in.
* @param sessionId
* An unique identifier of the session which the artifact manager belongs to.
* @param channel
* @param bstub
* A blocking stub to the server.
* @param stub
* An async stub to the server.
*/
class ArtifactManager(
userContext: proto.UserContext,
sessionId: String,
channel: ManagedChannel) {
bstub: CustomSparkConnectBlockingStub,
stub: CustomSparkConnectStub) {
// Using the midpoint recommendation of 32KiB for chunk size as specified in
// https://github.com/grpc/grpc.github.io/issues/371.
private val CHUNK_SIZE: Int = 32 * 1024

private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
private[this] val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,51 @@ package org.apache.spark.sql.connect.client

import io.grpc.ManagedChannel

import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ConfigRequest, ConfigResponse, ExecutePlanRequest, ExecutePlanResponse, InterruptRequest, InterruptResponse}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto._

private[client] class CustomSparkConnectBlockingStub(channel: ManagedChannel) {
private[client] class CustomSparkConnectBlockingStub(
channel: ManagedChannel,
retryPolicy: GrpcRetryHandler.RetryPolicy) {

private val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
private val stub = SparkConnectServiceGrpc.newBlockingStub(channel)
private val retryHandler = new GrpcRetryHandler(retryPolicy)

def executePlan(request: ExecutePlanRequest): java.util.Iterator[ExecutePlanResponse] = {
GrpcExceptionConverter.convert {
GrpcExceptionConverter.convertIterator[ExecutePlanResponse](stub.executePlan(request))
GrpcExceptionConverter.convertIterator[ExecutePlanResponse](
retryHandler.RetryIterator(request, stub.executePlan))
}
}

def analyzePlan(request: AnalyzePlanRequest): AnalyzePlanResponse = {
GrpcExceptionConverter.convert {
stub.analyzePlan(request)
retryHandler.retry {
stub.analyzePlan(request)
}
}
}

def config(request: ConfigRequest): ConfigResponse = {
GrpcExceptionConverter.convert {
stub.config(request)
retryHandler.retry {
stub.config(request)
}
}
}

def interrupt(request: InterruptRequest): InterruptResponse = {
GrpcExceptionConverter.convert {
stub.interrupt(request)
retryHandler.retry {
stub.interrupt(request)
}
}
}

def artifactStatus(request: ArtifactStatusesRequest): ArtifactStatusesResponse = {
GrpcExceptionConverter.convert {
retryHandler.retry {
stub.artifactStatus(request)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.client

import io.grpc.ManagedChannel
import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc}

private[client] class CustomSparkConnectStub(
channel: ManagedChannel,
retryPolicy: GrpcRetryHandler.RetryPolicy) {

private val stub = SparkConnectServiceGrpc.newStub(channel)
private val retryHandler = new GrpcRetryHandler(retryPolicy)

def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse])
: StreamObserver[AddArtifactsRequest] = {
retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.client

import scala.annotation.tailrec
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal

import io.grpc.{Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver

private[client] class GrpcRetryHandler(private val retryPolicy: GrpcRetryHandler.RetryPolicy) {

/**
* Retries the given function with exponential backoff according to the client's retryPolicy.
* @param fn
* The function to retry.
* @param currentRetryNum
* Current number of retries.
* @tparam T
* The return type of the function.
* @return
* The result of the function.
*/
@tailrec final def retry[T](fn: => T, currentRetryNum: Int = 0): T = {
if (currentRetryNum > retryPolicy.maxRetries) {
throw new IllegalArgumentException(
s"The number of retries ($currentRetryNum) must not exceed " +
s"the maximum number of retires (${retryPolicy.maxRetries}).")
}
try {
return fn
} catch {
case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < retryPolicy.maxRetries =>
Thread.sleep(
(retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
.pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
}
retry(fn, currentRetryNum + 1)
}

/**
* Generalizes the retry logic for RPC calls that return an iterator.
* @param request
* The request to send to the server.
* @param call
* The function that calls the RPC.
* @tparam T
* The type of the request.
* @tparam U
* The type of the response.
*/
class RetryIterator[T, U](request: T, call: T => java.util.Iterator[U])
extends java.util.Iterator[U] {

private var opened = false // we only retry if it fails on first call when using the iterator
private var iterator = call(request)

private def retryIter[V](f: java.util.Iterator[U] => V) = {
if (!opened) {
opened = true
var firstTry = true
retry {
if (firstTry) {
// on first try, we use the initial iterator.
firstTry = false
} else {
// on retry, we need to call the RPC again.
iterator = call(request)
}
f(iterator)
}
} else {
f(iterator)
}
}

override def next: U = {
retryIter(_.next())
}

override def hasNext: Boolean = {
retryIter(_.hasNext())
}
}

object RetryIterator {
def apply[T, U](request: T, call: T => java.util.Iterator[U]): RetryIterator[T, U] =
new RetryIterator(request, call)
}

/**
* Generalizes the retry logic for RPC calls that return a StreamObserver.
* @param request
* The request to send to the server.
* @param call
* The function that calls the RPC.
* @tparam T
* The type of the request.
* @tparam U
* The type of the response.
*/
class RetryStreamObserver[T, U](request: T, call: T => StreamObserver[U])
extends StreamObserver[U] {

private var opened = false // only retries on first call
private var streamObserver = call(request)

override def onNext(v: U): Unit = {
if (!opened) {
opened = true
var firstTry = true
retry {
if (firstTry) {
// on first try, we use the initial streamObserver.
firstTry = false
} else {
// on retry, we need to call the RPC again.
streamObserver = call(request)
}
streamObserver.onNext(v)
}
} else {
streamObserver.onNext(v)
}
}

override def onError(throwable: Throwable): Unit = {
opened = true
streamObserver.onError(throwable)
}

override def onCompleted(): Unit = {
opened = true
streamObserver.onCompleted()
}
}

object RetryStreamObserver {
def apply[T, U](request: T, call: T => StreamObserver[U]): RetryStreamObserver[T, U] =
new RetryStreamObserver(request, call)
}
}

private[client] object GrpcRetryHandler {

/**
* Default canRetry in [[RetryPolicy]].
* @param e
* The exception to check.
* @return
* true if the exception is a [[StatusRuntimeException]] with code UNAVAILABLE.
*/
private[client] def retryException(e: Throwable): Boolean = {
e match {
case e: StatusRuntimeException => e.getStatus.getCode == Status.Code.UNAVAILABLE
case _ => false
}
}

/**
* [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]]
*
* @param maxRetries
* Maximum number of retries.
* @param initialBackoff
* Start value of the exponential backoff (ms).
* @param maxBackoff
* Maximal value of the exponential backoff (ms).
* @param backoffMultiplier
* Multiplicative base of the exponential backoff.
* @param canRetry
* Function that determines whether a retry is to be performed in the event of an error.
*/
case class RetryPolicy(
maxRetries: Int = 15,
initialBackoff: FiniteDuration = FiniteDuration(50, "ms"),
maxBackoff: FiniteDuration = FiniteDuration(1, "min"),
backoffMultiplier: Double = 4.0,
canRetry: Throwable => Boolean = retryException) {}
}
Loading

0 comments on commit e3c2bf5

Please sign in to comment.