Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4188] [Core] Perform network-level retry of shuffle file fetches #3101

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
.start(OpenBlocks(blockIds.map(BlockId.apply)))
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
.start(OpenBlocks(blockIds.map(BlockId.apply)))
}
}

val maxRetries = transportConf.maxIORetries()
if (maxRetries > 0) {
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
// a bug in this code. We should remove the if statement once we're sure of the stability.
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
} else {
blockFetchStarter.createAndStart(blockIds, listener)
}
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.network.client;

import java.io.Closeable;
import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import com.google.common.base.Objects;
Expand Down Expand Up @@ -116,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
channel.close();
try {
callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
}
});
Expand Down Expand Up @@ -147,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
callback.onFailure(new RuntimeException(errorMsg, future.cause()));
channel.close();
try {
callback.onFailure(new IOException(errorMsg, future.cause()));
} catch (Exception e) {
logger.error("Uncaught exception in RPC response callback handler!", e);
}
}
}
});
Expand All @@ -175,6 +185,8 @@ public void onFailure(Throwable e) {

try {
return result.get(timeoutMs, TimeUnit.MILLISECONDS);
} catch (ExecutionException e) {
throw Throwables.propagate(e.getCause());
} catch (Exception e) {
throw Throwables.propagate(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
package org.apache.spark.network.client;

import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;

import com.google.common.base.Preconditions;
Expand All @@ -44,7 +44,6 @@
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

Expand Down Expand Up @@ -93,15 +92,17 @@ public TransportClientFactory(
*
* Concurrency: This method is safe to call from multiple threads.
*/
public TransportClient createClient(String remoteHost, int remotePort) {
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
if (cachedClient != null) {
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
logger.info("Found inactive connection to {}, closing it.", address);
connectionPool.remove(address, cachedClient); // Remove inactive clients.
}
}
Expand Down Expand Up @@ -133,10 +134,10 @@ public void initChannel(SocketChannel ch) {
long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new RuntimeException(
throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}

TransportClient client = clientRef.get();
Expand Down Expand Up @@ -198,7 +199,7 @@ public void close() {
*/
private PooledByteBufAllocator createPooledByteBufAllocator() {
return new PooledByteBufAllocator(
PlatformDependent.directBufferPreferred(),
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
getPrivateStaticField("DEFAULT_PAGE_SIZE"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network.client;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -94,7 +95,7 @@ public void channelUnregistered() {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
logger.error("Still have {} requests outstanding when connection from {} is closed",
numOutstandingRequests(), remoteAddress);
failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
// All messages have the frame length, message type, and message itself.
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
long frameLength = headerLength + bodyLength;
ByteBuf header = ctx.alloc().buffer(headerLength);
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why heap explicitly here? wouldn't your spark.shuffle.io.preferDirectBufs flag already controls it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The header is a very small buffer, I thought it was overkill to try to get direct byte bufs for it.

header.writeLong(frameLength);
msgType.encode(header);
in.encode(header);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.util.internal.PlatformDependent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -71,11 +72,14 @@ private void init(int portToBind) {
NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
EventLoopGroup workerGroup = bossGroup;

PooledByteBufAllocator allocator = new PooledByteBufAllocator(
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred());

bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NettyUtils.getServerChannelClass(ioMode))
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
.option(ChannelOption.ALLOCATOR, allocator)
.childOption(ChannelOption.ALLOCATOR, allocator);

if (conf.backLog() > 0) {
bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@
* Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
*/
public class NettyUtils {
/** Creates a Netty EventLoopGroup based on the IOMode. */
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {

ThreadFactory threadFactory = new ThreadFactoryBuilder()
/** Creates a new ThreadFactory which prefixes each thread with the given name. */
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
return new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat(threadPrefix + "-%d")
.setNameFormat(threadPoolPrefix + "-%d")
.build();
}

/** Creates a Netty EventLoopGroup based on the IOMode. */
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
ThreadFactory threadFactory = createThreadFactory(threadPrefix);

switch (mode) {
case NIO:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) {
/** IO mode: nio or epoll */
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }

/** If true, we will prefer allocating off-heap byte buffers within Netty. */
public boolean preferDirectBufs() {
return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true);
}

/** Connect timeout in secs. Default 120 secs. */
public int connectionTimeoutMs() {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
Expand Down Expand Up @@ -58,4 +63,16 @@ public int connectionTimeoutMs() {

/** Timeout for a single round trip of SASL token exchange, in milliseconds. */
public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }

/**
* Max number of times we will try IO exceptions (such as connection timeouts) per request.
* If set to 0, we will not do any retries.
*/
public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); }

/**
* Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
* Only relevant if maxIORetries > 0.
*/
public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network;

import java.io.IOException;
import java.util.concurrent.TimeoutException;

import org.junit.After;
Expand Down Expand Up @@ -57,7 +58,7 @@ public void tearDown() {
}

@Test
public void createAndReuseBlockClients() throws TimeoutException {
public void createAndReuseBlockClients() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
Expand All @@ -70,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException {
}

@Test
public void neverReturnInactiveClients() throws Exception {
public void neverReturnInactiveClients() throws IOException, InterruptedException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
c1.close();
Expand All @@ -88,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception {
}

@Test
public void closeBlockClientsWithFactory() throws TimeoutException {
public void closeBlockClientsWithFactory() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network.shuffle;

import java.io.IOException;
import java.util.List;

import com.google.common.collect.Lists;
Expand Down Expand Up @@ -76,17 +77,33 @@ public void init(String appId) {

@Override
public void fetchBlocks(
String host,
int port,
String execId,
final String host,
final int port,
final String execId,
String[] blockIds,
BlockFetchingListener listener) {
assert appId != null : "Called before init()";
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, blockIds, listener)
.start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
new RetryingBlockFetcher.BlockFetchStarter() {
@Override
public void createAndStart(String[] blockIds, BlockFetchingListener listener)
throws IOException {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, blockIds, listener)
.start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
}
};

int maxRetries = conf.maxIORetries();
if (maxRetries > 0) {
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
// a bug in this code. We should remove the if statement once we're sure of the stability.
new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
} else {
blockFetchStarter.createAndStart(blockIds, listener);
}
} catch (Exception e) {
logger.error("Exception while beginning fetchBlocks", e);
for (String blockId : blockIds) {
Expand All @@ -108,7 +125,7 @@ public void registerWithShuffleServer(
String host,
int port,
String execId,
ExecutorShuffleInfo executorInfo) {
ExecutorShuffleInfo executorInfo) throws IOException {
assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
byte[] registerExecutorMessage =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ public OneForOneBlockFetcher(
TransportClient client,
String[] blockIds,
BlockFetchingListener listener) {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
this.client = client;
this.blockIds = blockIds;
this.listener = listener;
Expand Down Expand Up @@ -82,6 +79,10 @@ public void onFailure(int chunkIndex, Throwable e) {
* {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
*/
public void start(Object openBlocksMessage) {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}

client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
Expand All @@ -95,7 +96,7 @@ public void onSuccess(byte[] response) {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
} catch (Exception e) {
logger.error("Failed while starting block fetches", e);
logger.error("Failed while starting block fetches after success", e);
failRemainingBlocks(blockIds, e);
}
}
Expand Down
Loading