diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 647813772294e..69baaca8a2614 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.TimeoutException; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -65,9 +66,18 @@ public void doBootstrap(TransportClient client, Channel channel) { SaslMessage msg = new SaslMessage(appId, payload); ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + ByteBuffer response; buf.writeBytes(msg.body().nioByteBuffer()); - - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + try { + response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); + } catch (RuntimeException ex) { + // We know it is a Sasl timeout here if it is a TimeoutException. + if (ex.getCause() instanceof TimeoutException) { + throw new SaslTimeoutException(ex.getCause()); + } else { + throw ex; + } + } payload = saslClient.response(JavaUtils.bufferToArray(response)); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java new file mode 100644 index 0000000000000..2533ae93f8de8 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * An exception thrown if there is a SASL timeout. + */ +public class SaslTimeoutException extends RuntimeException { + public SaslTimeoutException(Throwable cause) { + super(cause); + } + + public SaslTimeoutException(String message) { + super(message); + } + + public SaslTimeoutException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index f73e3ce2e0aa4..9dedd5d9849c4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -374,6 +374,13 @@ public boolean useOldFetchProtocol() { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); } + /** Whether to enable sasl retries or not. The number of retries is dictated by the config + * `spark.shuffle.io.maxRetries`. + */ + public boolean enableSaslRetries() { + return conf.getBoolean("spark.shuffle.sasl.enableRetries", false); + } + /** * Class name of the implementation of MergedShuffleFileManager that merges the blocks * pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 512e4a52c8628..892de9916124f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -24,12 +24,15 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -85,6 +88,17 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) /** Number of times we've attempted to retry so far. */ private int retryCount = 0; + // Number of times SASL timeout has been retried without success. + // If we see maxRetries consecutive failures, the request is failed. + // On the other hand, if sasl succeeds and we are able to send other requests subsequently, + // we reduce the SASL failures from retryCount (since SASL failures were part of + // connection bootstrap - which ended up being successful). + // spark.network.auth.rpcTimeout is much lower than spark.network.timeout and others - + // and so sasl is more susceptible to failures when remote service + // (like external shuffle service) is under load: but once it succeeds, we do not want to + // include it as part of request retries. + private int saslRetryCount = 0; + /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, @@ -99,6 +113,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) */ private RetryingBlockTransferListener currentListener; + /** Whether sasl retries are enabled. */ + private final boolean enableSaslRetries; + private final ErrorHandler errorHandler; public RetryingBlockTransferor( @@ -115,6 +132,8 @@ public RetryingBlockTransferor( Collections.addAll(outstandingBlocksIds, blockIds); this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; + this.enableSaslRetries = conf.enableSaslRetries(); + this.saslRetryCount = 0; } public RetryingBlockTransferor( @@ -158,7 +177,7 @@ private void transferAllOutstanding() { numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); if (shouldRetry(e)) { - initiateRetry(); + initiateRetry(e); } else { for (String bid : blockIdsToTransfer) { listener.onBlockTransferFailure(bid, e); @@ -171,7 +190,10 @@ private void transferAllOutstanding() { * Lightweight method which initiates a retry in a different thread. The retry will involve * calling transferAllOutstanding() after a configured wait time. */ - private synchronized void initiateRetry() { + private synchronized void initiateRetry(Throwable e) { + if (enableSaslRetries && e instanceof SaslTimeoutException) { + saslRetryCount += 1; + } retryCount += 1; currentListener = new RetryingBlockTransferListener(); @@ -187,13 +209,30 @@ private synchronized void initiateRetry() { /** * Returns true if we should retry due a block transfer failure. We will retry if and only if - * the exception was an IOException and we haven't retried 'maxRetries' times already. + * the exception was an IOException or SaslTimeoutException and we haven't retried + * 'maxRetries' times already. */ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException - || (e.getCause() != null && e.getCause() instanceof IOException); + || e.getCause() instanceof IOException; + boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; + // If this is a non SASL request failure, reduce earlier SASL failures from retryCount + // since some subsequent SASL attempt was successful + if (!isSaslTimeout && saslRetryCount > 0) { + Preconditions.checkState(retryCount >= saslRetryCount, + "retryCount must be greater than or equal to saslRetryCount"); + retryCount -= saslRetryCount; + saslRetryCount = 0; + } boolean hasRemainingRetries = retryCount < maxRetries; - return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e); + boolean shouldRetry = (isSaslTimeout || isIOException) && + hasRemainingRetries && errorHandler.shouldRetryError(e); + return shouldRetry; + } + + @VisibleForTesting + public int getRetryCount() { + return retryCount; } /** @@ -211,6 +250,14 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; + // If there were SASL failures earlier, remove them from retryCount, as there was + // a SASL success (and some other request post bootstrap was also successful). + if (saslRetryCount > 0) { + Preconditions.checkState(retryCount >= saslRetryCount, + "retryCount must be greater than or equal to saslRetryCount"); + retryCount -= saslRetryCount; + saslRetryCount = 0; + } } } @@ -227,7 +274,7 @@ private void handleBlockTransferFailure(String blockId, Throwable exception) { synchronized (RetryingBlockTransferor.this) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { if (shouldRetry(exception)) { - initiateRetry(); + initiateRetry(exception); } else { if (errorHandler.shouldLogError(exception)) { logger.error( diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 1b44b061f3d81..31fe61841669c 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -20,13 +20,18 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeoutException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; + +import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Stubber; @@ -38,6 +43,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.sasl.SaslTimeoutException; import static org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter; /** @@ -49,6 +55,18 @@ public class RetryingBlockTransferorSuite { private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + private static Map configMap; + private static RetryingBlockTransferor _retryingBlockTransferor; + + private static final int MAX_RETRIES = 2; + + @Before + public void initMap() { + configMap = new HashMap() {{ + put("spark.shuffle.io.maxRetries", Integer.toString(MAX_RETRIES)); + put("spark.shuffle.io.retryWait", "0"); + }}; + } @Test public void testNoFailures() throws IOException, InterruptedException { @@ -230,6 +248,130 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException verifyNoMoreInteractions(listener); } + @Test + public void testSaslTimeoutFailure() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verify(listener).getTransferType(); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryOnSaslTimeout() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // SaslTimeout will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new SaslTimeoutException(new TimeoutException())) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .build() + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == 0); + } + + @Test + public void testRepeatedSaslRetryFailures() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + TimeoutException timeoutException = new TimeoutException(); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); + List> interactions = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + interactions.add( + ImmutableMap.builder() + .put("b0", saslTimeoutException) + .build() + ); + } + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); + verify(listener, times(3)).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES); + } + + @Test + public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + ImmutableMap.builder() + .put("b0", new SaslTimeoutException(new TimeoutException())) + .put("b1", new IOException()) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build() + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1); + verify(listener, atLeastOnce()).getTransferType(); + verifyNoMoreInteractions(listener); + // This should be equal to 1 because after the SASL exception is retried, + // retryCount should be set back to 0. Then after that b1 encounters an + // exception that is retried. + assert(_retryingBlockTransferor.getRetryCount() == 1); + } + + @Test + public void testIOExceptionFailsConnectionEvenWithSaslException() + throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + SaslTimeoutException saslExceptionInitial = new SaslTimeoutException("initial", + new TimeoutException()); + SaslTimeoutException saslExceptionFinal = new SaslTimeoutException("final", + new TimeoutException()); + IOException ioException = new IOException(); + List> interactions = Arrays.asList( + ImmutableMap.of("b0", saslExceptionInitial), + ImmutableMap.of("b0", ioException), + ImmutableMap.of("b0", saslExceptionInitial), + ImmutableMap.of("b0", ioException), + ImmutableMap.of("b0", saslExceptionFinal), + // will not get invoked because the connection fails + ImmutableMap.of("b0", ioException), + // will not get invoked + ImmutableMap.of("b0", block0) + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslExceptionFinal); + verify(listener, atLeastOnce()).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES); + } + /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction @@ -245,9 +387,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException, InterruptedException { - MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of( - "spark.shuffle.io.maxRetries", "2", - "spark.shuffle.io.retryWait", "0")); + MapConfigProvider provider = new MapConfigProvider(configMap); TransportConf conf = new TransportConf("shuffle", provider); BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class); @@ -299,6 +439,8 @@ private static void performInteractions(List> inte assertNotNull(stub); stub.when(fetchStarter).createAndStart(any(), any()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); - new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener).start(); + _retryingBlockTransferor = + new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener); + _retryingBlockTransferor.start(); } }