Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-41415/SPARK-42090 Backport to 3.3 #39634

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeoutException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;

Expand Down Expand Up @@ -65,9 +66,18 @@ public void doBootstrap(TransportClient client, Channel channel) {
SaslMessage msg = new SaslMessage(appId, payload);
ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
msg.encode(buf);
ByteBuffer response;
buf.writeBytes(msg.body().nioByteBuffer());

ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
try {
response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
} catch (RuntimeException ex) {
// We know it is a Sasl timeout here if it is a TimeoutException.
if (ex.getCause() instanceof TimeoutException) {
throw new SaslTimeoutException(ex.getCause());
} else {
throw ex;
}
}
payload = saslClient.response(JavaUtils.bufferToArray(response));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.sasl;

/**
* An exception thrown if there is a SASL timeout.
*/
public class SaslTimeoutException extends RuntimeException {
public SaslTimeoutException(Throwable cause) {
super(cause);
}

public SaslTimeoutException(String message) {
super(message);
}

public SaslTimeoutException(String message, Throwable cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,13 @@ public boolean useOldFetchProtocol() {
return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false);
}

/** Whether to enable sasl retries or not. The number of retries is dictated by the config
* `spark.shuffle.io.maxRetries`.
*/
public boolean enableSaslRetries() {
return conf.getBoolean("spark.shuffle.sasl.enableRetries", false);
}

/**
* Class name of the implementation of MergedShuffleFileManager that merges the blocks
* pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.sasl.SaslTimeoutException;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

Expand Down Expand Up @@ -85,6 +88,17 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
/** Number of times we've attempted to retry so far. */
private int retryCount = 0;

// Number of times SASL timeout has been retried without success.
// If we see maxRetries consecutive failures, the request is failed.
// On the other hand, if sasl succeeds and we are able to send other requests subsequently,
// we reduce the SASL failures from retryCount (since SASL failures were part of
// connection bootstrap - which ended up being successful).
// spark.network.auth.rpcTimeout is much lower than spark.network.timeout and others -
// and so sasl is more susceptible to failures when remote service
// (like external shuffle service) is under load: but once it succeeds, we do not want to
// include it as part of request retries.
private int saslRetryCount = 0;

/**
* Set of all block ids which have not been transferred successfully or with a non-IO Exception.
* A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
Expand All @@ -99,6 +113,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
*/
private RetryingBlockTransferListener currentListener;

/** Whether sasl retries are enabled. */
private final boolean enableSaslRetries;

private final ErrorHandler errorHandler;

public RetryingBlockTransferor(
Expand All @@ -115,6 +132,8 @@ public RetryingBlockTransferor(
Collections.addAll(outstandingBlocksIds, blockIds);
this.currentListener = new RetryingBlockTransferListener();
this.errorHandler = errorHandler;
this.enableSaslRetries = conf.enableSaslRetries();
this.saslRetryCount = 0;
}

public RetryingBlockTransferor(
Expand Down Expand Up @@ -158,7 +177,7 @@ private void transferAllOutstanding() {
numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e);

if (shouldRetry(e)) {
initiateRetry();
initiateRetry(e);
} else {
for (String bid : blockIdsToTransfer) {
listener.onBlockTransferFailure(bid, e);
Expand All @@ -171,7 +190,10 @@ private void transferAllOutstanding() {
* Lightweight method which initiates a retry in a different thread. The retry will involve
* calling transferAllOutstanding() after a configured wait time.
*/
private synchronized void initiateRetry() {
private synchronized void initiateRetry(Throwable e) {
if (enableSaslRetries && e instanceof SaslTimeoutException) {
saslRetryCount += 1;
}
retryCount += 1;
currentListener = new RetryingBlockTransferListener();

Expand All @@ -187,13 +209,30 @@ private synchronized void initiateRetry() {

/**
* Returns true if we should retry due a block transfer failure. We will retry if and only if
* the exception was an IOException and we haven't retried 'maxRetries' times already.
* the exception was an IOException or SaslTimeoutException and we haven't retried
* 'maxRetries' times already.
*/
private synchronized boolean shouldRetry(Throwable e) {
boolean isIOException = e instanceof IOException
|| e.getCause() instanceof IOException;
boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException;
// If this is a non SASL request failure, reduce earlier SASL failures from retryCount
// since some subsequent SASL attempt was successful
if (!isSaslTimeout && saslRetryCount > 0) {
Preconditions.checkState(retryCount >= saslRetryCount,
"retryCount must be greater than or equal to saslRetryCount");
retryCount -= saslRetryCount;
saslRetryCount = 0;
}
boolean hasRemainingRetries = retryCount < maxRetries;
return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e);
boolean shouldRetry = (isSaslTimeout || isIOException) &&
hasRemainingRetries && errorHandler.shouldRetryError(e);
return shouldRetry;
}

@VisibleForTesting
public int getRetryCount() {
return retryCount;
}

/**
Expand All @@ -211,6 +250,14 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) {
if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
outstandingBlocksIds.remove(blockId);
shouldForwardSuccess = true;
// If there were SASL failures earlier, remove them from retryCount, as there was
// a SASL success (and some other request post bootstrap was also successful).
if (saslRetryCount > 0) {
Preconditions.checkState(retryCount >= saslRetryCount,
"retryCount must be greater than or equal to saslRetryCount");
retryCount -= saslRetryCount;
saslRetryCount = 0;
}
}
}

Expand All @@ -227,7 +274,7 @@ private void handleBlockTransferFailure(String blockId, Throwable exception) {
synchronized (RetryingBlockTransferor.this) {
if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
if (shouldRetry(exception)) {
initiateRetry();
initiateRetry(exception);
} else {
if (errorHandler.shouldLogError(exception)) {
logger.error(
Expand Down
Loading