forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-44275][CONNECT] Add configurable retry mechanism to Scala Spar…
…k Connect ### What changes were proposed in this pull request? This PR introduces a configurable retry mechanism for the Scala `SparkConnectClient`. Parameters for the exponential backoff and a filter for exceptions to retry are passed to the client via the existing (extended) `Configuration` class. By default, only RuntimeStatusException: Unavailable triggers a retry - this mirrors the default retry behavior in the Python client. One might want to move the retry logic into the GRPC stub that will potentially be introduced [here](apache#41743). ### Why are the changes needed? There are a few existing exceptions that one might want to handle with a retry. For example, this would allow one to not exit with an exception when executing a command while the cluster is still starting. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tests included. Closes apache#41829 from dillitz/spark-44275-retries. Authored-by: Robert Dillitz <robert.dillitz@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
- Loading branch information
1 parent
17fac56
commit e3c2bf5
Showing
7 changed files
with
341 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
...lient/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.sql.connect.client | ||
|
||
import io.grpc.ManagedChannel | ||
import io.grpc.stub.StreamObserver | ||
|
||
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} | ||
|
||
private[client] class CustomSparkConnectStub( | ||
channel: ManagedChannel, | ||
retryPolicy: GrpcRetryHandler.RetryPolicy) { | ||
|
||
private val stub = SparkConnectServiceGrpc.newStub(channel) | ||
private val retryHandler = new GrpcRetryHandler(retryPolicy) | ||
|
||
def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) | ||
: StreamObserver[AddArtifactsRequest] = { | ||
retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts) | ||
} | ||
} |
196 changes: 196 additions & 0 deletions
196
...nect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.connect.client | ||
|
||
import scala.annotation.tailrec | ||
import scala.concurrent.duration.FiniteDuration | ||
import scala.util.control.NonFatal | ||
|
||
import io.grpc.{Status, StatusRuntimeException} | ||
import io.grpc.stub.StreamObserver | ||
|
||
private[client] class GrpcRetryHandler(private val retryPolicy: GrpcRetryHandler.RetryPolicy) { | ||
|
||
/** | ||
* Retries the given function with exponential backoff according to the client's retryPolicy. | ||
* @param fn | ||
* The function to retry. | ||
* @param currentRetryNum | ||
* Current number of retries. | ||
* @tparam T | ||
* The return type of the function. | ||
* @return | ||
* The result of the function. | ||
*/ | ||
@tailrec final def retry[T](fn: => T, currentRetryNum: Int = 0): T = { | ||
if (currentRetryNum > retryPolicy.maxRetries) { | ||
throw new IllegalArgumentException( | ||
s"The number of retries ($currentRetryNum) must not exceed " + | ||
s"the maximum number of retires (${retryPolicy.maxRetries}).") | ||
} | ||
try { | ||
return fn | ||
} catch { | ||
case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < retryPolicy.maxRetries => | ||
Thread.sleep( | ||
(retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math | ||
.pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis) | ||
} | ||
retry(fn, currentRetryNum + 1) | ||
} | ||
|
||
/** | ||
* Generalizes the retry logic for RPC calls that return an iterator. | ||
* @param request | ||
* The request to send to the server. | ||
* @param call | ||
* The function that calls the RPC. | ||
* @tparam T | ||
* The type of the request. | ||
* @tparam U | ||
* The type of the response. | ||
*/ | ||
class RetryIterator[T, U](request: T, call: T => java.util.Iterator[U]) | ||
extends java.util.Iterator[U] { | ||
|
||
private var opened = false // we only retry if it fails on first call when using the iterator | ||
private var iterator = call(request) | ||
|
||
private def retryIter[V](f: java.util.Iterator[U] => V) = { | ||
if (!opened) { | ||
opened = true | ||
var firstTry = true | ||
retry { | ||
if (firstTry) { | ||
// on first try, we use the initial iterator. | ||
firstTry = false | ||
} else { | ||
// on retry, we need to call the RPC again. | ||
iterator = call(request) | ||
} | ||
f(iterator) | ||
} | ||
} else { | ||
f(iterator) | ||
} | ||
} | ||
|
||
override def next: U = { | ||
retryIter(_.next()) | ||
} | ||
|
||
override def hasNext: Boolean = { | ||
retryIter(_.hasNext()) | ||
} | ||
} | ||
|
||
object RetryIterator { | ||
def apply[T, U](request: T, call: T => java.util.Iterator[U]): RetryIterator[T, U] = | ||
new RetryIterator(request, call) | ||
} | ||
|
||
/** | ||
* Generalizes the retry logic for RPC calls that return a StreamObserver. | ||
* @param request | ||
* The request to send to the server. | ||
* @param call | ||
* The function that calls the RPC. | ||
* @tparam T | ||
* The type of the request. | ||
* @tparam U | ||
* The type of the response. | ||
*/ | ||
class RetryStreamObserver[T, U](request: T, call: T => StreamObserver[U]) | ||
extends StreamObserver[U] { | ||
|
||
private var opened = false // only retries on first call | ||
private var streamObserver = call(request) | ||
|
||
override def onNext(v: U): Unit = { | ||
if (!opened) { | ||
opened = true | ||
var firstTry = true | ||
retry { | ||
if (firstTry) { | ||
// on first try, we use the initial streamObserver. | ||
firstTry = false | ||
} else { | ||
// on retry, we need to call the RPC again. | ||
streamObserver = call(request) | ||
} | ||
streamObserver.onNext(v) | ||
} | ||
} else { | ||
streamObserver.onNext(v) | ||
} | ||
} | ||
|
||
override def onError(throwable: Throwable): Unit = { | ||
opened = true | ||
streamObserver.onError(throwable) | ||
} | ||
|
||
override def onCompleted(): Unit = { | ||
opened = true | ||
streamObserver.onCompleted() | ||
} | ||
} | ||
|
||
object RetryStreamObserver { | ||
def apply[T, U](request: T, call: T => StreamObserver[U]): RetryStreamObserver[T, U] = | ||
new RetryStreamObserver(request, call) | ||
} | ||
} | ||
|
||
private[client] object GrpcRetryHandler { | ||
|
||
/** | ||
* Default canRetry in [[RetryPolicy]]. | ||
* @param e | ||
* The exception to check. | ||
* @return | ||
* true if the exception is a [[StatusRuntimeException]] with code UNAVAILABLE. | ||
*/ | ||
private[client] def retryException(e: Throwable): Boolean = { | ||
e match { | ||
case e: StatusRuntimeException => e.getStatus.getCode == Status.Code.UNAVAILABLE | ||
case _ => false | ||
} | ||
} | ||
|
||
/** | ||
* [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]] | ||
* | ||
* @param maxRetries | ||
* Maximum number of retries. | ||
* @param initialBackoff | ||
* Start value of the exponential backoff (ms). | ||
* @param maxBackoff | ||
* Maximal value of the exponential backoff (ms). | ||
* @param backoffMultiplier | ||
* Multiplicative base of the exponential backoff. | ||
* @param canRetry | ||
* Function that determines whether a retry is to be performed in the event of an error. | ||
*/ | ||
case class RetryPolicy( | ||
maxRetries: Int = 15, | ||
initialBackoff: FiniteDuration = FiniteDuration(50, "ms"), | ||
maxBackoff: FiniteDuration = FiniteDuration(1, "min"), | ||
backoffMultiplier: Double = 4.0, | ||
canRetry: Throwable => Boolean = retryException) {} | ||
} |
Oops, something went wrong.