From 14e37f7ce36f42e07fbe4f5382c0674754f523c9 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 28 Oct 2014 15:56:40 -0700 Subject: [PATCH] Address Reynold's comments --- .../network/netty/NettyBlockFetcher.scala | 2 +- .../shuffle/FileShuffleBlockManager.scala | 4 +- .../storage/ShuffleBlockFetcherIterator.scala | 7 ++-- .../spark/network/TransportContext.java | 27 +++++++----- .../buffer/FileSegmentManagedBuffer.java | 3 +- .../spark/network/buffer/ManagedBuffer.java | 3 +- .../network/buffer/NettyManagedBuffer.java | 4 +- .../network/buffer/NioManagedBuffer.java | 2 +- .../client/ChunkFetchFailureException.java | 10 +---- .../spark/network/client/TransportClient.java | 22 +++++----- .../client/TransportClientFactory.java | 40 ++++++++++-------- .../client/TransportResponseHandler.java | 42 +++++++++---------- .../{response => }/ChunkFetchFailure.java | 7 +--- .../{request => }/ChunkFetchRequest.java | 6 +-- .../{response => }/ChunkFetchSuccess.java | 8 ++-- .../spark/network/protocol/Encodable.java | 6 +++ .../spark/network/protocol/Message.java | 2 +- .../{response => }/MessageDecoder.java | 8 +--- .../{response => }/MessageEncoder.java | 4 +- .../{request => }/RequestMessage.java | 2 +- .../{response => }/ResponseMessage.java | 2 +- .../protocol/{response => }/RpcFailure.java | 20 ++++----- .../protocol/{request => }/RpcRequest.java | 22 +++++----- .../protocol/{response => }/RpcResponse.java | 20 ++++----- .../network/server/DefaultStreamManager.java | 15 ++++--- ...dler.java => TransportChannelHandler.java} | 27 +++++++----- .../server/TransportRequestHandler.java | 30 ++++++------- .../spark/network/server/TransportServer.java | 9 +++- .../network/ChunkFetchIntegrationSuite.java | 1 - .../apache/spark/network/ProtocolSuite.java | 16 +++---- .../spark/network/RpcIntegrationSuite.java | 1 - .../SystemPropertyConfigProvider.java | 4 +- .../spark/network/TestManagedBuffer.java | 4 +- .../network/TransportClientFactorySuite.java | 1 - .../TransportResponseHandlerSuite.java | 8 ++-- 35 files changed, 203 insertions(+), 186 deletions(-) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/ChunkFetchFailure.java (91%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request => }/ChunkFetchRequest.java (90%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/ChunkFetchSuccess.java (88%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/MessageDecoder.java (88%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/MessageEncoder.java (96%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request => }/RequestMessage.java (95%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/ResponseMessage.java (94%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/RpcFailure.java (79%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request => }/RpcRequest.java (78%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/RpcResponse.java (79%) rename network/common/src/main/java/org/apache/spark/network/server/{TransportClientHandler.java => TransportChannelHandler.java} (79%) rename network/common/src/{main/java/org/apache/spark/network/util => test/java/org/apache/spark/network}/SystemPropertyConfigProvider.java (92%) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala index 344d17e7bf661..8c5ffd8da6bbb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -87,7 +87,7 @@ class NettyBlockFetcher( } override def onFailure(e: Throwable): Unit = { - logError("Failed while starting block fetches") + logError("Failed while starting block fetches", e) blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) } }) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index c35aa2481ad03..1fb5b2c4546bd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.{Logging, SparkConf, SparkEnv} +import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 23313fe9271fd..0d6f3bf003a9d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,12 +21,11 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} -import org.apache.spark.serializer.Serializer import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.serializer.Serializer import org.apache.spark.util.{CompletionIterator, Utils} -import org.apache.spark.{Logging, TaskContext} - /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -285,7 +284,7 @@ final class ShuffleBlockFetcherIterator( val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { None } else { - val is = blockManager.wrapForCompression(result.blockId, result.buf.inputStream()) + val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream()) val iter = serializer.newInstance().deserializeStream(is).asIterator Some(CompletionIterator[Any, Iterator[Any]](iter, { // Once the iterator is exhausted, release the buffer and set currentResult to null diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index da0decac7e064..854aa6685f85f 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -25,10 +25,10 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.response.MessageDecoder; -import org.apache.spark.network.protocol.response.MessageEncoder; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.TransportClientHandler; +import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; @@ -37,7 +37,12 @@ /** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to - * setup Netty Channel pipelines with a {@link TransportClientHandler}. + * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}. + * + * There are two communication protocols that the TransportClient provides, control-plane RPCs and + * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the + * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams + * which can be streamed through the data plane in chunks using zero-copy IO. * * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each * channel. As each TransportChannelHandler contains a TransportClient, this enables server @@ -71,16 +76,16 @@ public TransportServer createServer() { /** * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and - * has a {@link org.apache.spark.network.server.TransportClientHandler} to handle request or + * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or * response messages. * * @return Returns the created TransportChannelHandler, which includes a TransportClient that can * be used to communicate on this channel. The TransportClient is directly associated with a * ChannelHandler to ensure all users of the same channel get the same TransportClient object. */ - public TransportClientHandler initializePipeline(SocketChannel channel) { + public TransportChannelHandler initializePipeline(SocketChannel channel) { try { - TransportClientHandler channelHandler = createChannelHandler(channel); + TransportChannelHandler channelHandler = createChannelHandler(channel); channel.pipeline() .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) @@ -100,12 +105,12 @@ public TransportClientHandler initializePipeline(SocketChannel channel) { * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet. */ - private TransportClientHandler createChannelHandler(Channel channel) { + private TransportChannelHandler createChannelHandler(Channel channel) { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); - TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, streamManager, - rpcHandler); - return new TransportClientHandler(client, responseHandler, requestHandler); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, + streamManager, rpcHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 224f1e6c515ea..a02f692a674b2 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -40,6 +40,7 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer { * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). * Avoid unless there's a good reason not to. */ + // TODO: Make this configurable private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; private final File file; @@ -88,7 +89,7 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { + public InputStream createInputStream() throws IOException { FileInputStream is = null; try { is = new FileInputStream(file); diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index 1735f5540c61b..a415db593a788 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -43,6 +43,7 @@ public abstract class ManagedBuffer { * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the * returned ByteBuffer should not affect the content of this buffer. */ + // TODO: Deprecate this, usage may require expensive memory mapping or allocation. public abstract ByteBuffer nioByteBuffer() throws IOException; /** @@ -50,7 +51,7 @@ public abstract class ManagedBuffer { * necessarily check for the length of bytes read, so the caller is responsible for making sure * it does not go over the limit. */ - public abstract InputStream inputStream() throws IOException; + public abstract InputStream createInputStream() throws IOException; /** * Increment the reference count by one if applicable. diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index d928980423f1f..c806bfa45bef3 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -46,7 +46,7 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { + public InputStream createInputStream() throws IOException { return new ByteBufInputStream(buf); } @@ -64,7 +64,7 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - return buf; + return buf.duplicate(); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index 3953ef89fbf88..f55b884bc45ce 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -46,7 +46,7 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { + public InputStream createInputStream() throws IOException { return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java index 40a1fe67b1c5b..1fbdcd6780785 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -21,17 +21,11 @@ * General exception caused by a remote exception while fetching a chunk. */ public class ChunkFetchFailureException extends RuntimeException { - private final int chunkIndex; - - public ChunkFetchFailureException(int chunkIndex, String errorMsg, Throwable cause) { + public ChunkFetchFailureException(String errorMsg, Throwable cause) { super(errorMsg, cause); - this.chunkIndex = chunkIndex; } - public ChunkFetchFailureException(int chunkIndex, String errorMsg) { + public ChunkFetchFailureException(String errorMsg) { super(errorMsg); - this.chunkIndex = chunkIndex; } - - public int getChunkIndex() { return chunkIndex; } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 75e26cb7e60c1..b1732fcde21f1 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -28,9 +28,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; import org.apache.spark.network.util.NettyUtils; /** @@ -106,7 +106,7 @@ public void fetchChunk( public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, timeTaken); } else { String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, @@ -114,6 +114,7 @@ public void operationComplete(ChannelFuture future) throws Exception { logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); + channel.close(); } } }); @@ -126,24 +127,25 @@ public void operationComplete(ChannelFuture future) throws Exception { public void sendRpc(byte[] message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.debug("Sending RPC to {}", serverAddr); + logger.trace("Sending RPC to {}", serverAddr); - final long tag = UUID.randomUUID().getLeastSignificantBits(); - handler.addRpcRequest(tag, callback); + final long requestId = UUID.randomUUID().getLeastSignificantBits(); + handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(tag, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken); + logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", tag, + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, serverAddr, future.cause()); logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(tag); + handler.removeRpcRequest(requestId); callback.onFailure(new RuntimeException(errorMsg, future.cause())); + channel.close(); } } }); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index c351858bfe30d..10eb9ef7a025f 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -23,6 +23,7 @@ import java.net.SocketAddress; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; @@ -37,7 +38,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.server.TransportClientHandler; +import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -66,6 +67,7 @@ public TransportClientFactory(TransportContext context) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + // TODO: Make thread pool name configurable. this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); } @@ -100,17 +102,13 @@ public TransportClient createClient(String remoteHost, int remotePort) throws Ti // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + final AtomicReference client = new AtomicReference(); + bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { - TransportClientHandler channelHandler = context.initializePipeline(ch); - TransportClient oldClient = connectionPool.putIfAbsent(address, channelHandler.getClient()); - if (oldClient != null) { - logger.debug("Two clients were created concurrently, second one will be disposed."); - ch.close(); - // Note: this type of failure is still considered a success by Netty, and thus the - // ChannelFuture will complete successfully. - } + TransportChannelHandler clientHandler = context.initializePipeline(ch); + client.set(clientHandler.getClient()); } }); @@ -119,23 +117,31 @@ public void initChannel(SocketChannel ch) { if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new TimeoutException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } else if (cf.cause() != null) { + throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - TransportClient client = connectionPool.get(address); - if (client == null) { - // The only way we should be able to reach here is if the client we created started out - // in the "inactive" state, and someone else simultaneously tried to create another client to - // the same server. This is an error condition, as the first client failed to connect. - throw new IllegalStateException("Client was unset! Must have been immediately inactive."); + // Successful connection + assert client.get() != null : "Channel future completed successfully with null client"; + TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + if (oldClient == null) { + return client.get(); + } else { + logger.debug("Two clients were created concurrently, second one will be disposed."); + client.get().close(); + return oldClient; } - return client; } /** Close all connections in the connection pool, and shutdown the worker thread pool. */ @Override public void close() { for (TransportClient client : connectionPool.values()) { - client.close(); + try { + client.close(); + } catch (RuntimeException e) { + logger.warn("Ignoring exception during close", e); + } } connectionPool.clear(); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 187b20d27656b..d8965590b34da 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -25,12 +25,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.response.ResponseMessage; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.NettyUtils; @@ -63,12 +63,12 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); } - public void addRpcRequest(long tag, RpcResponseCallback callback) { - outstandingRpcs.put(tag, callback); + public void addRpcRequest(long requestId, RpcResponseCallback callback) { + outstandingRpcs.put(requestId, callback); } - public void removeRpcRequest(long tag) { - outstandingRpcs.remove(tag); + public void removeRpcRequest(long requestId) { + outstandingRpcs.remove(requestId); } /** @@ -115,7 +115,7 @@ public void handle(ResponseMessage message) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { - logger.warn("Got a response for block {} from {} but it is not outstanding", + logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); resp.buffer.release(); } else { @@ -127,31 +127,31 @@ public void handle(ResponseMessage message) { ChunkFetchFailure resp = (ChunkFetchFailure) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { - logger.warn("Got a response for block {} from {} ({}) but it is not outstanding", + logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", resp.streamChunkId, remoteAddress, resp.errorString); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onFailure(resp.streamChunkId.chunkIndex, - new ChunkFetchFailureException(resp.streamChunkId.chunkIndex, resp.errorString)); + listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( + "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); } } else if (message instanceof RpcResponse) { RpcResponse resp = (RpcResponse) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { - logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding", - resp.tag, remoteAddress, resp.response.length); + logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", + resp.requestId, remoteAddress, resp.response.length); } else { - outstandingRpcs.remove(resp.tag); + outstandingRpcs.remove(resp.requestId); listener.onSuccess(resp.response); } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { - logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding", - resp.tag, remoteAddress, resp.errorString); + logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", + resp.requestId, remoteAddress, resp.errorString); } else { - outstandingRpcs.remove(resp.tag); + outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } } else { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java similarity index 91% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index cb3cbcd0a53ca..152af98ced7ce 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -15,17 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.StreamChunkId; - /** - * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when there is an - * error fetching the chunk. + * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ public final class ChunkFetchFailure implements ResponseMessage { public final StreamChunkId streamChunkId; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java similarity index 90% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 99cbb8777a873..980947cf13f6b 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -15,16 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.StreamChunkId; - /** * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single - * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ public final class ChunkFetchRequest implements RequestMessage { public final StreamChunkId streamChunkId; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java similarity index 88% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 6bc26a64b9945..ff4936470c697 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -15,18 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; -import org.apache.spark.network.protocol.StreamChunkId; /** - * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when a chunk - * exists and has been successfully fetched. + * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. * * Note that the server-side encoding of this messages does NOT include the buffer itself, as this * may be written by Netty in a more efficient manner (i.e., zero-copy write). @@ -49,7 +47,7 @@ public int encodedLength() { return streamChunkId.encodedLength(); } - /** Encoding does NOT include buffer itself. See {@link MessageEncoder}. */ + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java index 363ea5ecfa936..b4e299471b41a 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -22,6 +22,12 @@ /** * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + * + * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by + * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than + * just copying data from it), then you must retain() the ByteBuf. + * + * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}. */ public interface Encodable { /** Number of bytes of the encoded form of this object. */ diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index 6731b3f53ae82..d568370125fd4 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,7 +19,7 @@ import io.netty.buffer.ByteBuf; -/** Messages from the client to the server. */ +/** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java similarity index 88% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 3ae80305803eb..81f8d7f96350f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import java.util.List; @@ -26,10 +26,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; - /** * Decoder used by the client side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. @@ -43,7 +39,7 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); Message decoded = decode(msgType, in); assert decoded.type() == msgType; - logger.debug("Received message " + msgType + ": " + decoded); + logger.trace("Received message " + msgType + ": " + decoded); out.add(decoded); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java similarity index 96% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 5ca8de42a6429..4cb8becc3ed22 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import java.util.List; @@ -26,8 +26,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.Message; - /** * Encoder used by the server side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java similarity index 95% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java index 58abce25d9a2a..31b15bb17a327 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import org.apache.spark.network.protocol.Message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java similarity index 94% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java index 8f545e91d1d8e..6edffd11cf1e2 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import org.apache.spark.network.protocol.Message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java similarity index 79% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 1f161f7957543..e239d4ffbd29c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -15,19 +15,19 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a failed RPC. */ +/** Response to {@link RpcRequest} for a failed RPC. */ public final class RpcFailure implements ResponseMessage { - public final long tag; + public final long requestId; public final String errorString; - public RpcFailure(long tag, String errorString) { - this.tag = tag; + public RpcFailure(long requestId, String errorString) { + this.requestId = requestId; this.errorString = errorString; } @@ -41,25 +41,25 @@ public int encodedLength() { @Override public void encode(ByteBuf buf) { - buf.writeLong(tag); + buf.writeLong(requestId); byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); buf.writeInt(errorBytes.length); buf.writeBytes(errorBytes); } public static RpcFailure decode(ByteBuf buf) { - long tag = buf.readLong(); + long requestId = buf.readLong(); int numErrorStringBytes = buf.readInt(); byte[] errorBytes = new byte[numErrorStringBytes]; buf.readBytes(errorBytes); - return new RpcFailure(tag, new String(errorBytes, Charsets.UTF_8)); + return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); } @Override public boolean equals(Object other) { if (other instanceof RpcFailure) { RpcFailure o = (RpcFailure) other; - return tag == o.tag && errorString.equals(o.errorString); + return requestId == o.requestId && errorString.equals(o.errorString); } return false; } @@ -67,7 +67,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("tag", tag) + .add("requestId", requestId) .add("errorString", errorString) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java similarity index 78% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 810da7a689c13..099e934ae018c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import java.util.Arrays; @@ -25,17 +25,17 @@ /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. * This will correspond to a single - * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ public final class RpcRequest implements RequestMessage { - /** Tag is used to link an RPC request with its response. */ - public final long tag; + /** Used to link an RPC request with its response. */ + public final long requestId; /** Serialized message to send to remote RpcHandler. */ public final byte[] message; - public RpcRequest(long tag, byte[] message) { - this.tag = tag; + public RpcRequest(long requestId, byte[] message) { + this.requestId = requestId; this.message = message; } @@ -49,24 +49,24 @@ public int encodedLength() { @Override public void encode(ByteBuf buf) { - buf.writeLong(tag); + buf.writeLong(requestId); buf.writeInt(message.length); buf.writeBytes(message); } public static RpcRequest decode(ByteBuf buf) { - long tag = buf.readLong(); + long requestId = buf.readLong(); int messageLen = buf.readInt(); byte[] message = new byte[messageLen]; buf.readBytes(message); - return new RpcRequest(tag, message); + return new RpcRequest(requestId, message); } @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { RpcRequest o = (RpcRequest) other; - return tag == o.tag && Arrays.equals(message, o.message); + return requestId == o.requestId && Arrays.equals(message, o.message); } return false; } @@ -74,7 +74,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("tag", tag) + .add("requestId", requestId) .add("message", message) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java similarity index 79% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 40623ce31c666..ed479478325b6 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -15,20 +15,20 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import java.util.Arrays; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a successful RPC. */ +/** Response to {@link RpcRequest} for a successful RPC. */ public final class RpcResponse implements ResponseMessage { - public final long tag; + public final long requestId; public final byte[] response; - public RpcResponse(long tag, byte[] response) { - this.tag = tag; + public RpcResponse(long requestId, byte[] response) { + this.requestId = requestId; this.response = response; } @@ -40,24 +40,24 @@ public RpcResponse(long tag, byte[] response) { @Override public void encode(ByteBuf buf) { - buf.writeLong(tag); + buf.writeLong(requestId); buf.writeInt(response.length); buf.writeBytes(response); } public static RpcResponse decode(ByteBuf buf) { - long tag = buf.readLong(); + long requestId = buf.readLong(); int responseLen = buf.readInt(); byte[] response = new byte[responseLen]; buf.readBytes(response); - return new RpcResponse(tag, response); + return new RpcResponse(requestId, response); } @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { RpcResponse o = (RpcResponse) other; - return tag == o.tag && Arrays.equals(response, o.response); + return requestId == o.requestId && Arrays.equals(response, o.response); } return false; } @@ -65,7 +65,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("tag", tag) + .add("requestId", requestId) .add("response", response) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java index d93607a7c31ea..9688705569634 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -42,6 +42,8 @@ public class DefaultStreamManager extends StreamManager { private static class StreamState { final Iterator buffers; + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure + // that the caller only requests each chunk one at a time, in order. int curChunk = 0; StreamState(Iterator buffers) { @@ -50,7 +52,8 @@ private static class StreamState { } public DefaultStreamManager() { - // Start with a random stream id to help identifying different streams. + // For debugging purposes, start with a random stream id to help identifying different streams. + // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); streams = new ConcurrentHashMap(); } @@ -87,13 +90,15 @@ public void connectionTerminated(long streamId) { } } + /** + * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to + * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a + * client connection is closed before the iterator is fully drained, then the remaining buffers + * will all be release()'d. + */ public long registerStream(Iterator buffers) { long myStreamId = nextStreamId.getAndIncrement(); streams.put(myStreamId, new StreamState(buffers)); return myStreamId; } - - public void unregisterStream(long streamId) { - streams.remove(streamId); - } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java similarity index 79% rename from network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java rename to network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 08cc1b1f95de6..e491367fa4528 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -25,14 +25,13 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.request.RequestMessage; -import org.apache.spark.network.protocol.response.ResponseMessage; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.util.NettyUtils; /** - * A handler which is used for delegating requests to the - * {@link TransportRequestHandler} and responses to the - * {@link org.apache.spark.network.client.TransportResponseHandler}. + * The single Transport-level Channel handler which is used for delegating requests to the + * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. * * All channels created in the transport layer are bidirectional. When the Client initiates a Netty * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server @@ -42,14 +41,14 @@ * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. */ -public class TransportClientHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(TransportClientHandler.class); +public class TransportChannelHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; private final TransportResponseHandler responseHandler; private final TransportRequestHandler requestHandler; - public TransportClientHandler( + public TransportChannelHandler( TransportClient client, TransportResponseHandler responseHandler, TransportRequestHandler requestHandler) { @@ -73,8 +72,16 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { - requestHandler.channelUnregistered(); - responseHandler.channelUnregistered(); + try { + requestHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while unregistering channel", e); + } + try { + responseHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while unregistering channel", e); + } super.channelUnregistered(ctx); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 08a2a3ec52f8b..352f865935b11 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -31,13 +31,13 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.request.RequestMessage; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.util.NettyUtils; /** @@ -66,10 +66,10 @@ public class TransportRequestHandler extends MessageHandler { private final Set streamIds; public TransportRequestHandler( - Channel channel, - TransportClient reverseClient, - StreamManager streamManager, - RpcHandler rpcHandler) { + Channel channel, + TransportClient reverseClient, + StreamManager streamManager, + RpcHandler rpcHandler) { this.channel = channel; this.reverseClient = reverseClient; this.streamManager = streamManager; @@ -124,17 +124,17 @@ private void processRpcRequest(final RpcRequest req) { rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { - respond(new RpcResponse(req.tag, response)); + respond(new RpcResponse(req.requestId, response)); } @Override public void onFailure(Throwable e) { - respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } }); } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on RPC tag " + req.tag, e); - respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 973fb05f57944..243070750d6e7 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -47,7 +47,7 @@ public class TransportServer implements Closeable { private ServerBootstrap bootstrap; private ChannelFuture channelFuture; - private int port; + private int port = -1; public TransportServer(TransportContext context) { this.context = context; @@ -56,7 +56,12 @@ public TransportServer(TransportContext context) { init(); } - public int getPort() { return port; } + public int getPort() { + if (port == -1) { + throw new IllegalStateException("Server not initialized"); + } + return port; + } private void init() { diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 00ed7b527abd5..738dca9b6a9ee 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -45,7 +45,6 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 6932760c44fee..43dc0cf8c7194 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -24,14 +24,14 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.MessageDecoder; -import org.apache.spark.network.protocol.response.MessageEncoder; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 19ce9c6a8d826..9f216dd2d722d 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -38,7 +38,6 @@ import org.apache.spark.network.server.DefaultStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java similarity index 92% rename from network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java rename to network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java index f15ec8d294258..f4e0a2426a3d2 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java +++ b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java @@ -15,10 +15,12 @@ * limitations under the License. */ -package org.apache.spark.network.util; +package org.apache.spark.network; import java.util.NoSuchElementException; +import org.apache.spark.network.util.ConfigProvider; + /** Uses System properties to obtain config values. */ public class SystemPropertyConfigProvider extends ConfigProvider { @Override diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index 7e7554af70f42..38113a918f795 100644 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -59,8 +59,8 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { - return underlying.inputStream(); + public InputStream createInputStream() throws IOException { + return underlying.createInputStream(); } @Override diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index f76b4bc55182d..3ef964616f0c5 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -32,7 +32,6 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 6e360e96099f4..17a03ebe88a93 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -29,11 +29,11 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; public class TransportResponseHandlerSuite { @Test