Skip to content

Commit

Permalink
Address Reynold's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Oct 28, 2014
1 parent 8dfcceb commit 14e37f7
Show file tree
Hide file tree
Showing 35 changed files with 203 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class NettyBlockFetcher(
}

override def onFailure(e: Throwable): Unit = {
logError("Failed while starting block fetches")
logError("Failed while starting block fetches", e)
blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConversions._

import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}

/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ import java.util.concurrent.LinkedBlockingQueue

import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}

import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
import org.apache.spark.serializer.Serializer
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.{CompletionIterator, Utils}
import org.apache.spark.{Logging, TaskContext}


/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
Expand Down Expand Up @@ -285,7 +284,7 @@ final class ShuffleBlockFetcherIterator(
val iteratorOpt: Option[Iterator[Any]] = if (result.failed) {
None
} else {
val is = blockManager.wrapForCompression(result.blockId, result.buf.inputStream())
val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream())
val iter = serializer.newInstance().deserializeStream(is).asIterator
Some(CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.response.MessageDecoder;
import org.apache.spark.network.protocol.response.MessageEncoder;
import org.apache.spark.network.protocol.MessageDecoder;
import org.apache.spark.network.protocol.MessageEncoder;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportClientHandler;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.server.TransportRequestHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.StreamManager;
Expand All @@ -37,7 +37,12 @@

/**
* Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
* setup Netty Channel pipelines with a {@link TransportClientHandler}.
* setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}.
*
* There are two communication protocols that the TransportClient provides, control-plane RPCs and
* data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the
* TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams
* which can be streamed through the data plane in chunks using zero-copy IO.
*
* The TransportServer and TransportClientFactory both create a TransportChannelHandler for each
* channel. As each TransportChannelHandler contains a TransportClient, this enables server
Expand Down Expand Up @@ -71,16 +76,16 @@ public TransportServer createServer() {

/**
* Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
* has a {@link org.apache.spark.network.server.TransportClientHandler} to handle request or
* has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
* response messages.
*
* @return Returns the created TransportChannelHandler, which includes a TransportClient that can
* be used to communicate on this channel. The TransportClient is directly associated with a
* ChannelHandler to ensure all users of the same channel get the same TransportClient object.
*/
public TransportClientHandler initializePipeline(SocketChannel channel) {
public TransportChannelHandler initializePipeline(SocketChannel channel) {
try {
TransportClientHandler channelHandler = createChannelHandler(channel);
TransportChannelHandler channelHandler = createChannelHandler(channel);
channel.pipeline()
.addLast("encoder", encoder)
.addLast("frameDecoder", NettyUtils.createFrameDecoder())
Expand All @@ -100,12 +105,12 @@ public TransportClientHandler initializePipeline(SocketChannel channel) {
* ResponseMessages. The channel is expected to have been successfully created, though certain
* properties (such as the remoteAddress()) may not be available yet.
*/
private TransportClientHandler createChannelHandler(Channel channel) {
private TransportChannelHandler createChannelHandler(Channel channel) {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, streamManager,
rpcHandler);
return new TransportClientHandler(client, responseHandler, requestHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
streamManager, rpcHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler);
}

public TransportConf getConf() { return conf; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer {
* Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
* Avoid unless there's a good reason not to.
*/
// TODO: Make this configurable
private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;

private final File file;
Expand Down Expand Up @@ -88,7 +89,7 @@ public ByteBuffer nioByteBuffer() throws IOException {
}

@Override
public InputStream inputStream() throws IOException {
public InputStream createInputStream() throws IOException {
FileInputStream is = null;
try {
is = new FileInputStream(file);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ public abstract class ManagedBuffer {
* Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
* returned ByteBuffer should not affect the content of this buffer.
*/
// TODO: Deprecate this, usage may require expensive memory mapping or allocation.
public abstract ByteBuffer nioByteBuffer() throws IOException;

/**
* Exposes this buffer's data as an InputStream. The underlying implementation does not
* necessarily check for the length of bytes read, so the caller is responsible for making sure
* it does not go over the limit.
*/
public abstract InputStream inputStream() throws IOException;
public abstract InputStream createInputStream() throws IOException;

/**
* Increment the reference count by one if applicable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public ByteBuffer nioByteBuffer() throws IOException {
}

@Override
public InputStream inputStream() throws IOException {
public InputStream createInputStream() throws IOException {
return new ByteBufInputStream(buf);
}

Expand All @@ -64,7 +64,7 @@ public ManagedBuffer release() {

@Override
public Object convertToNetty() throws IOException {
return buf;
return buf.duplicate();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public ByteBuffer nioByteBuffer() throws IOException {
}

@Override
public InputStream inputStream() throws IOException {
public InputStream createInputStream() throws IOException {
return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,11 @@
* General exception caused by a remote exception while fetching a chunk.
*/
public class ChunkFetchFailureException extends RuntimeException {
private final int chunkIndex;

public ChunkFetchFailureException(int chunkIndex, String errorMsg, Throwable cause) {
public ChunkFetchFailureException(String errorMsg, Throwable cause) {
super(errorMsg, cause);
this.chunkIndex = chunkIndex;
}

public ChunkFetchFailureException(int chunkIndex, String errorMsg) {
public ChunkFetchFailureException(String errorMsg) {
super(errorMsg);
this.chunkIndex = chunkIndex;
}

public int getChunkIndex() { return chunkIndex; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.request.ChunkFetchRequest;
import org.apache.spark.network.protocol.request.RpcRequest;
import org.apache.spark.network.util.NettyUtils;

/**
Expand Down Expand Up @@ -106,14 +106,15 @@ public void fetchChunk(
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
timeTaken);
} else {
String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
channel.close();
}
}
});
Expand All @@ -126,24 +127,25 @@ public void operationComplete(ChannelFuture future) throws Exception {
public void sendRpc(byte[] message, final RpcResponseCallback callback) {
final String serverAddr = NettyUtils.getRemoteAddress(channel);
final long startTime = System.currentTimeMillis();
logger.debug("Sending RPC to {}", serverAddr);
logger.trace("Sending RPC to {}", serverAddr);

final long tag = UUID.randomUUID().getLeastSignificantBits();
handler.addRpcRequest(tag, callback);
final long requestId = UUID.randomUUID().getLeastSignificantBits();
handler.addRpcRequest(requestId, callback);

channel.writeAndFlush(new RpcRequest(tag, message)).addListener(
channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
long timeTaken = System.currentTimeMillis() - startTime;
logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken);
logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken);
} else {
String errorMsg = String.format("Failed to send RPC %s to %s: %s", tag,
String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(tag);
handler.removeRpcRequest(requestId);
callback.onFailure(new RuntimeException(errorMsg, future.cause()));
channel.close();
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.net.SocketAddress;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
Expand All @@ -37,7 +38,7 @@
import org.slf4j.LoggerFactory;

import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportClientHandler;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
Expand Down Expand Up @@ -66,6 +67,7 @@ public TransportClientFactory(TransportContext context) {

IOMode ioMode = IOMode.valueOf(conf.ioMode());
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
// TODO: Make thread pool name configurable.
this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
}

Expand Down Expand Up @@ -100,17 +102,13 @@ public TransportClient createClient(String remoteHost, int remotePort) throws Ti
// Use pooled buffers to reduce temporary buffer allocation
bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());

final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>();

bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportClientHandler channelHandler = context.initializePipeline(ch);
TransportClient oldClient = connectionPool.putIfAbsent(address, channelHandler.getClient());
if (oldClient != null) {
logger.debug("Two clients were created concurrently, second one will be disposed.");
ch.close();
// Note: this type of failure is still considered a success by Netty, and thus the
// ChannelFuture will complete successfully.
}
TransportChannelHandler clientHandler = context.initializePipeline(ch);
client.set(clientHandler.getClient());
}
});

Expand All @@ -119,23 +117,31 @@ public void initChannel(SocketChannel ch) {
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new TimeoutException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
}

TransportClient client = connectionPool.get(address);
if (client == null) {
// The only way we should be able to reach here is if the client we created started out
// in the "inactive" state, and someone else simultaneously tried to create another client to
// the same server. This is an error condition, as the first client failed to connect.
throw new IllegalStateException("Client was unset! Must have been immediately inactive.");
// Successful connection
assert client.get() != null : "Channel future completed successfully with null client";
TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
if (oldClient == null) {
return client.get();
} else {
logger.debug("Two clients were created concurrently, second one will be disposed.");
client.get().close();
return oldClient;
}
return client;
}

/** Close all connections in the connection pool, and shutdown the worker thread pool. */
@Override
public void close() {
for (TransportClient client : connectionPool.values()) {
client.close();
try {
client.close();
} catch (RuntimeException e) {
logger.warn("Ignoring exception during close", e);
}
}
connectionPool.clear();

Expand Down
Loading

0 comments on commit 14e37f7

Please sign in to comment.