Skip to content

Commit

Permalink
[SPARK-4188] [Core] Perform network-level retry of shuffle file fetches
Browse files Browse the repository at this point in the history
This adds a RetryingBlockFetcher to the NettyBlockTransferService which is wrapped around our typical OneForOneBlockFetcher, adding retry logic in the event of an IOException.

This sort of retry allows us to avoid marking an entire executor as failed due to garbage collection or high network load.

TODO:
- [x] unit tests
- [x] put in ExternalShuffleClient too

Author: Aaron Davidson <aaron@databricks.com>

Closes #3101 from aarondav/retry and squashes the following commits:

72a2a32 [Aaron Davidson] Add that we should remove the condition around the retry thingy
c7fd107 [Aaron Davidson] Fix unit tests
e80e4c2 [Aaron Davidson] Address initial comments
6f594cd [Aaron Davidson] Fix unit test
05ff43c [Aaron Davidson] Add to external shuffle client and add unit test
66e5a24 [Aaron Davidson] [SPARK-4238] [Core] Perform network-level retry of shuffle file fetches

(cherry picked from commit f165b2b)
Signed-off-by: Reynold Xin <rxin@databricks.com>
  • Loading branch information
aarondav authored and rxin committed Nov 7, 2014
1 parent cbe9a6c commit c1ea5c5
Show file tree
Hide file tree
Showing 16 changed files with 668 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
.start(OpenBlocks(blockIds.map(BlockId.apply)))
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
.start(OpenBlocks(blockIds.map(BlockId.apply)))
}
}

val maxRetries = transportConf.maxIORetries()
if (maxRetries > 0) {
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
// a bug in this code. We should remove the if statement once we're sure of the stability.
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
} else {
blockFetchStarter.createAndStart(blockIds, listener)
}
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.network.client;

import java.io.Closeable;
import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import com.google.common.base.Objects;
Expand Down Expand Up @@ -116,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
channel.close();
try {
callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
}
});
Expand Down Expand Up @@ -147,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
callback.onFailure(new RuntimeException(errorMsg, future.cause()));
channel.close();
try {
callback.onFailure(new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
}
});
Expand All @@ -175,6 +185,8 @@ public void onFailure(Throwable e) {

try {
return result.get(timeoutMs, TimeUnit.MILLISECONDS);
} catch (ExecutionException e) {
throw Throwables.propagate(e.getCause());
} catch (Exception e) {
throw Throwables.propagate(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
package org.apache.spark.network.client;

import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;

import com.google.common.base.Preconditions;
Expand All @@ -44,7 +44,6 @@
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

Expand Down Expand Up @@ -93,15 +92,17 @@ public TransportClientFactory(
*
* Concurrency: This method is safe to call from multiple threads.
*/
public TransportClient createClient(String remoteHost, int remotePort) {
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
if (cachedClient != null) {
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
logger.info("Found inactive connection to {}, closing it.", address);
connectionPool.remove(address, cachedClient); // Remove inactive clients.
}
}
Expand Down Expand Up @@ -133,10 +134,10 @@ public void initChannel(SocketChannel ch) {
long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new RuntimeException(
throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}

TransportClient client = clientRef.get();
Expand Down Expand Up @@ -198,7 +199,7 @@ public void close() {
*/
private PooledByteBufAllocator createPooledByteBufAllocator() {
return new PooledByteBufAllocator(
PlatformDependent.directBufferPreferred(),
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
getPrivateStaticField("DEFAULT_PAGE_SIZE"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network.client;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -94,7 +95,7 @@ public void channelUnregistered() {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
logger.error("Still have {} requests outstanding when connection from {} is closed",
numOutstandingRequests(), remoteAddress);
failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
// All messages have the frame length, message type, and message itself.
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
long frameLength = headerLength + bodyLength;
ByteBuf header = ctx.alloc().buffer(headerLength);
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
header.writeLong(frameLength);
msgType.encode(header);
in.encode(header);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.util.internal.PlatformDependent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -71,11 +72,14 @@ private void init(int portToBind) {
NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
EventLoopGroup workerGroup = bossGroup;

PooledByteBufAllocator allocator = new PooledByteBufAllocator(
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred());

bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NettyUtils.getServerChannelClass(ioMode))
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
.option(ChannelOption.ALLOCATOR, allocator)
.childOption(ChannelOption.ALLOCATOR, allocator);

if (conf.backLog() > 0) {
bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@
* Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
*/
public class NettyUtils {
/** Creates a Netty EventLoopGroup based on the IOMode. */
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {

ThreadFactory threadFactory = new ThreadFactoryBuilder()
/** Creates a new ThreadFactory which prefixes each thread with the given name. */
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
return new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat(threadPrefix + "-%d")
.setNameFormat(threadPoolPrefix + "-%d")
.build();
}

/** Creates a Netty EventLoopGroup based on the IOMode. */
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
ThreadFactory threadFactory = createThreadFactory(threadPrefix);

switch (mode) {
case NIO:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) {
/** IO mode: nio or epoll */
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }

/** If true, we will prefer allocating off-heap byte buffers within Netty. */
public boolean preferDirectBufs() {
return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true);
}

/** Connect timeout in secs. Default 120 secs. */
public int connectionTimeoutMs() {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
Expand Down Expand Up @@ -58,4 +63,16 @@ public int connectionTimeoutMs() {

/** Timeout for a single round trip of SASL token exchange, in milliseconds. */
public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }

/**
* Max number of times we will try IO exceptions (such as connection timeouts) per request.
* If set to 0, we will not do any retries.
*/
public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); }

/**
* Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
* Only relevant if maxIORetries > 0.
*/
public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network;

import java.io.IOException;
import java.util.concurrent.TimeoutException;

import org.junit.After;
Expand Down Expand Up @@ -57,7 +58,7 @@ public void tearDown() {
}

@Test
public void createAndReuseBlockClients() throws TimeoutException {
public void createAndReuseBlockClients() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
Expand All @@ -70,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException {
}

@Test
public void neverReturnInactiveClients() throws Exception {
public void neverReturnInactiveClients() throws IOException, InterruptedException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
c1.close();
Expand All @@ -88,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception {
}

@Test
public void closeBlockClientsWithFactory() throws TimeoutException {
public void closeBlockClientsWithFactory() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network.shuffle;

import java.io.IOException;
import java.util.List;

import com.google.common.collect.Lists;
Expand Down Expand Up @@ -76,17 +77,33 @@ public void init(String appId) {

@Override
public void fetchBlocks(
String host,
int port,
String execId,
final String host,
final int port,
final String execId,
String[] blockIds,
BlockFetchingListener listener) {
assert appId != null : "Called before init()";
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, blockIds, listener)
.start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
new RetryingBlockFetcher.BlockFetchStarter() {
@Override
public void createAndStart(String[] blockIds, BlockFetchingListener listener)
throws IOException {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, blockIds, listener)
.start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
}
};

int maxRetries = conf.maxIORetries();
if (maxRetries > 0) {
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
// a bug in this code. We should remove the if statement once we're sure of the stability.
new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
} else {
blockFetchStarter.createAndStart(blockIds, listener);
}
} catch (Exception e) {
logger.error("Exception while beginning fetchBlocks", e);
for (String blockId : blockIds) {
Expand All @@ -108,7 +125,7 @@ public void registerWithShuffleServer(
String host,
int port,
String execId,
ExecutorShuffleInfo executorInfo) {
ExecutorShuffleInfo executorInfo) throws IOException {
assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
byte[] registerExecutorMessage =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ public OneForOneBlockFetcher(
TransportClient client,
String[] blockIds,
BlockFetchingListener listener) {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
this.client = client;
this.blockIds = blockIds;
this.listener = listener;
Expand Down Expand Up @@ -82,6 +79,10 @@ public void onFailure(int chunkIndex, Throwable e) {
* {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
*/
public void start(Object openBlocksMessage) {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}

client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
Expand All @@ -95,7 +96,7 @@ public void onSuccess(byte[] response) {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
} catch (Exception e) {
logger.error("Failed while starting block fetches", e);
logger.error("Failed while starting block fetches after success", e);
failRemainingBlocks(blockIds, e);
}
}
Expand Down
Loading

0 comments on commit c1ea5c5

Please sign in to comment.