diff --git a/core/pom.xml b/core/pom.xml index a5a178079bc57..aff0d989d01bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -44,6 +44,11 @@ + + org.apache.spark + network + ${project.version} + net.java.dev.jets3t jets3t diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 373ce795a309e..867173e04714e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.netty.{NettyBlockTransferService} import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -40,7 +40,6 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} - /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -233,12 +232,14 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec. - val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new NettyBlockTransferService(conf) - } else { - new NioBlockTransferService(conf, securityManager) - } + // TODO: This is only netty by default for initial testing -- it should not be merged as such!!! + val blockTransferService = + conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { + case "netty" => + new NettyBlockTransferService(conf) + case "nio" => + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 0eeffe0e7c5e6..1745d52c81923 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,8 +17,8 @@ package org.apache.spark.network -import org.apache.spark.storage.StorageLevel - +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] trait BlockDataManager { @@ -27,10 +27,10 @@ trait BlockDataManager { * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - def getBlockData(blockId: String): ManagedBuffer + def getBlockData(blockId: BlockId): ManagedBuffer /** * Put the block locally, using the given storage level. */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit + def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index dd70e26647939..e35fdb4e95899 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -19,6 +19,8 @@ package org.apache.spark.network import java.util.EventListener +import org.apache.spark.network.buffer.ManagedBuffer + /** * Listener callback interface for [[BlockTransferService.fetchBlocks]]. diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index d3ed683c7e880..8287a0fc81cfe 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -18,16 +18,18 @@ package org.apache.spark.network import java.io.Closeable -import java.nio.ByteBuffer + +import org.apache.spark.network.buffer.ManagedBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration +import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.Utils private[spark] -abstract class BlockTransferService extends Closeable { +abstract class BlockTransferService extends Closeable with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -92,10 +94,7 @@ abstract class BlockTransferService extends Closeable { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result = Left(new NioManagedBuffer(ret)) + result = Left(data) lock.notify() } } diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala deleted file mode 100644 index dd808d2500fbc..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.io._ -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.channels.FileChannel.MapMode - -import scala.util.Try - -import com.google.common.io.ByteStreams -import io.netty.buffer.{Unpooled, ByteBufInputStream, ByteBuf} -import io.netty.channel.DefaultFileRegion - -import org.apache.spark.util.{ByteBufferInputStream, Utils} - - -/** - * This interface provides an immutable view for data in the form of bytes. The implementation - * should specify how the data is provided: - * - * - [[FileSegmentManagedBuffer]]: data backed by part of a file - * - [[NioManagedBuffer]]: data backed by a NIO ByteBuffer - * - [[NettyManagedBuffer]]: data backed by a Netty ByteBuf - * - * The concrete buffer implementation might be managed outside the JVM garbage collector. - * For example, in the case of [[NettyManagedBuffer]], the buffers are reference counted. - * In that case, if the buffer is going to be passed around to a different thread, retain/release - * should be called. - */ -private[spark] -abstract class ManagedBuffer { - // Note that all the methods are defined with parenthesis because their implementations can - // have side effects (io operations). - - /** Number of bytes of the data. */ - def size: Long - - /** - * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the - * returned ByteBuffer should not affect the content of this buffer. - */ - def nioByteBuffer(): ByteBuffer - - /** - * Exposes this buffer's data as an InputStream. The underlying implementation does not - * necessarily check for the length of bytes read, so the caller is responsible for making sure - * it does not go over the limit. - */ - def inputStream(): InputStream - - /** - * Increment the reference count by one if applicable. - */ - def retain(): this.type - - /** - * If applicable, decrement the reference count by one and deallocates the buffer if the - * reference count reaches zero. - */ - def release(): this.type - - /** - * Convert the buffer into an Netty object, used to write the data out. - */ - private[network] def convertToNetty(): AnyRef -} - - -/** - * A [[ManagedBuffer]] backed by a segment in a file - */ -private[spark] -final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) - extends ManagedBuffer { - - override def size: Long = length - - override def nioByteBuffer(): ByteBuffer = { - var channel: FileChannel = null - try { - channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) - } catch { - case e: IOException => - Try(channel.size).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - } finally { - if (channel != null) { - Utils.tryLog(channel.close()) - } - } - } - - override def inputStream(): InputStream = { - var is: FileInputStream = null - try { - is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) - } catch { - case e: IOException => - if (is != null) { - Utils.tryLog(is.close()) - } - Try(file.length).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - case e: Throwable => - if (is != null) { - Utils.tryLog(is.close()) - } - throw e - } - } - - override def toString: String = s"${getClass.getName}($file, $offset, $length)" -} - - -/** - * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. - */ -private[spark] -final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { - - override def size: Long = buf.remaining() - - override def nioByteBuffer() = buf.duplicate() - - override def inputStream() = new ByteBufferInputStream(buf) - - private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) - - // [[ByteBuffer]] is managed by the JVM garbage collector itself. - override def retain(): this.type = this - override def release(): this.type = this -} - - -/** - * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. - */ -private[spark] -final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { - - override def size: Long = buf.readableBytes() - - override def nioByteBuffer() = buf.nioBuffer() - - override def inputStream() = new ByteBufInputStream(buf) - - private[network] override def convertToNetty(): AnyRef = buf - - override def retain(): this.type = { - buf.retain() - this - } - - override def release(): this.type = { - buf.release() - this - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala deleted file mode 100644 index 6bdbf88d337ce..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.util.concurrent.TimeoutException - -import scala.concurrent.{Future, promise} - -import io.netty.channel.{ChannelFuture, ChannelFutureListener} - -import org.apache.spark.Logging -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener} -import org.apache.spark.storage.StorageLevel - - -/** - * Client for [[NettyBlockTransferService]]. The connection to server must have been established - * using [[BlockClientFactory]] before instantiating this. - * - * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible - * for handling responses from the server. - * - * Concurrency: thread safe and can be called from multiple threads. - * - * @param cf the ChannelFuture for the connection. - * @param handler [[BlockClientHandler]] for handling outstanding requests. - */ -@throws[TimeoutException] -private[netty] -class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closeable with Logging { - - private[this] val serverAddr = cf.channel().remoteAddress().toString - - def isActive: Boolean = cf.channel().isActive - - /** - * Ask the remote server for a sequence of blocks, and execute the callback. - * - * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory due to large outstanding fetches. - * - * @param blockIds sequence of block ids to fetch. - * @param listener callback to fire on fetch success / failure. - */ - def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - var startTime: Long = 0 - logTrace { - startTime = System.currentTimeMillis() - s"Sending request $blockIds to $serverAddr" - } - - blockIds.foreach { blockId => - handler.addFetchRequest(blockId, listener) - } - - cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = System.currentTimeMillis() - startTime - s"Sending request $blockIds to $serverAddr took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - blockIds.foreach { blockId => - handler.removeFetchRequest(blockId) - listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) - } - } - } - }) - } - - def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = - { - var startTime: Long = 0 - logTrace { - startTime = System.currentTimeMillis() - s"Uploading block ($blockId) to $serverAddr" - } - val f = cf.channel().writeAndFlush(new BlockUploadRequest(blockId, data, storageLevel)) - - val p = promise[Unit]() - handler.addUploadRequest(blockId, p) - f.addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = System.currentTimeMillis() - startTime - s"Uploading block ($blockId) to $serverAddr took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to upload block $blockId to $serverAddr: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - } - } - }) - - p.future - } - - /** Close the connection. This does NOT block till the connection is closed. */ - def close(): Unit = cf.channel().close() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala deleted file mode 100644 index 8021cfdf42d1a..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.util.concurrent.{ConcurrentHashMap, TimeoutException} - -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel._ -import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.util.internal.PlatformDependent - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.util.Utils - - -/** - * Factory for creating [[BlockClient]] by using createClient. - * - * The factory maintains a connection pool to other hosts and should return the same [[BlockClient]] - * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. - */ -private[netty] -class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { - - def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) - - /** A thread factory so the threads are named (for debugging). */ - private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client") - - /** Socket channel type, initialized by [[init]] depending ioMode. */ - private[this] var socketChannelClass: Class[_ <: Channel] = _ - - /** Thread pool shared by all clients. */ - private[this] var workerGroup: EventLoopGroup = _ - - private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient] - - // The encoders are stateless and can be shared among multiple clients. - private[this] val encoder = new ClientRequestEncoder - private[this] val decoder = new ServerResponseDecoder - - init() - - /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ - private def init(): Unit = { - def initNio(): Unit = { - socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(conf.clientThreads, threadFactory) - } - def initEpoll(): Unit = { - socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(conf.clientThreads, threadFactory) - } - - // For auto mode, first try epoll (only available on Linux), then nio. - conf.ioMode match { - case "nio" => initNio() - case "epoll" => initEpoll() - case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() - } - } - - /** - * Create a new BlockFetchingClient connecting to the given remote host / port. - * - * This blocks until a connection is successfully established. - * - * Concurrency: This method is safe to call from multiple threads. - */ - def createClient(remoteHost: String, remotePort: Int): BlockClient = { - // Get connection from the connection pool first. - // If it is not found or not active, create a new one. - val cachedClient = connectionPool.get((remoteHost, remotePort)) - if (cachedClient != null && cachedClient.isActive) { - return cachedClient - } - - logDebug(s"Creating new connection to $remoteHost:$remotePort") - - // There is a chance two threads are creating two different clients connecting to the same host. - // But that's probably ok ... - - val handler = new BlockClientHandler - - val bootstrap = new Bootstrap - bootstrap.group(workerGroup) - .channel(socketChannelClass) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()) - - bootstrap.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler) - } - }) - - // Connect to the remote server - val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort) - if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") - } - - val client = new BlockClient(cf, handler) - connectionPool.put((remoteHost, remotePort), client) - client - } - - /** Close all connections in the connection pool, and shutdown the worker thread pool. */ - override def close(): Unit = { - val iter = connectionPool.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.close() - connectionPool.remove(entry.getKey) - } - - if (workerGroup != null) { - workerGroup.shutdownGracefully() - } - } - - /** - * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches - * are disabled because the ByteBufs are allocated by the event loop thread, but released by the - * executor thread rather than the event loop thread. Those thread-local caches actually delay - * the recycling of buffers, leading to larger memory usage. - */ - private def createPooledByteBufAllocator(): PooledByteBufAllocator = { - def getPrivateStaticField(name: String): Int = { - val f = PooledByteBufAllocator.DEFAULT.getClass.getDeclaredField(name) - f.setAccessible(true) - f.getInt(null) - } - new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), - getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), - getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - 0, // tinyCacheSize - 0, // smallCacheSize - 0 // normalCacheSize - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala deleted file mode 100644 index 5e28a07a461fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.ConcurrentHashMap - -import scala.concurrent.Promise - -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging -import org.apache.spark.network.{BlockFetchFailureException, BlockUploadFailureException, BlockFetchingListener} - - -/** - * Handler that processes server responses, in response to requests issued from [[BlockClient]]. - * It works by tracking the list of outstanding requests (and their callbacks). - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[netty] -class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private[this] val outstandingFetches: java.util.Map[String, BlockFetchingListener] = - new ConcurrentHashMap[String, BlockFetchingListener] - - private[this] val outstandingUploads: java.util.Map[String, Promise[Unit]] = - new ConcurrentHashMap[String, Promise[Unit]] - - def addFetchRequest(blockId: String, listener: BlockFetchingListener): Unit = { - outstandingFetches.put(blockId, listener) - } - - def removeFetchRequest(blockId: String): Unit = { - outstandingFetches.remove(blockId) - } - - def addUploadRequest(blockId: String, promise: Promise[Unit]): Unit = { - outstandingUploads.put(blockId, promise) - } - - /** - * Fire the failure callback for all outstanding requests. This is called when we have an - * uncaught exception or pre-mature connection termination. - */ - private def failOutstandingRequests(cause: Throwable): Unit = { - val iter1 = outstandingFetches.entrySet().iterator() - while (iter1.hasNext) { - val entry = iter1.next() - entry.getValue.onBlockFetchFailure(entry.getKey, cause) - } - // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests - // as well. But I guess that is ok given the caller will fail as soon as any requests fail. - outstandingFetches.clear() - - val iter2 = outstandingUploads.entrySet().iterator() - while (iter2.hasNext) { - val entry = iter2.next() - entry.getValue.failure(new RuntimeException(s"Failed to upload block ${entry.getKey}")) - } - outstandingUploads.clear() - } - - override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { - if (outstandingFetches.size() > 0) { - logError("Still have " + outstandingFetches.size() + " requests outstanding " + - s"when connection from ${ctx.channel.remoteAddress} is closed") - failOutstandingRequests(new RuntimeException( - s"Connection from ${ctx.channel.remoteAddress} closed")) - } - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - if (outstandingFetches.size() > 0) { - logError( - s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) - failOutstandingRequests(cause) - } - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, response: ServerResponse) { - val server = ctx.channel.remoteAddress.toString - response match { - case BlockFetchSuccess(blockId, buf) => - val listener = outstandingFetches.get(blockId) - if (listener == null) { - logWarning(s"Got a response for block $blockId from $server but it is not outstanding") - buf.release() - } else { - outstandingFetches.remove(blockId) - listener.onBlockFetchSuccess(blockId, buf) - buf.release() - } - case BlockFetchFailure(blockId, errorMsg) => - val listener = outstandingFetches.get(blockId) - if (listener == null) { - logWarning( - s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") - } else { - outstandingFetches.remove(blockId) - listener.onBlockFetchFailure(blockId, new BlockFetchFailureException(blockId, errorMsg)) - } - case BlockUploadSuccess(blockId) => - val p = outstandingUploads.get(blockId) - if (p == null) { - logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") - } else { - outstandingUploads.remove(blockId) - p.success(Unit) - } - case BlockUploadFailure(blockId, error) => - val p = outstandingUploads.get(blockId) - if (p == null) { - logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") - } else { - outstandingUploads.remove(blockId) - p.failure(new BlockUploadFailureException(blockId)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala deleted file mode 100644 index e2eb7c379f14d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} - -import org.apache.spark.Logging -import org.apache.spark.network.BlockDataManager -import org.apache.spark.util.Utils - - -/** - * Server for the [[NettyBlockTransferService]]. - */ -private[netty] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) - extends Closeable with Logging { - - def port: Int = _port - - def hostName: String = _hostName - - private var _port: Int = conf.serverPort - private var _hostName: String = "" - private var bootstrap: ServerBootstrap = _ - private var channelFuture: ChannelFuture = _ - - init() - - /** Initialize the server. */ - private def init(): Unit = { - bootstrap = new ServerBootstrap - val threadFactory = Utils.namedThreadFactory("spark-netty-server") - - // Use only one thread to accept connections, and 2 * num_cores for worker. - def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(conf.serverThreads, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) - } - def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(conf.serverThreads, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) - } - - conf.ioMode match { - case "nio" => initNio() - case "epoll" => initEpoll() - case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() - } - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - - // Various (advanced) user-configured settings. - conf.backLog.foreach { backLog => - bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) - } - conf.receiveBuf.foreach { receiveBuf => - bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) - } - conf.sendBuf.foreach { sendBuf => - bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) - } - - bootstrap.childHandler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("clientRequestDecoder", new ClientRequestDecoder) - .addLast("serverResponseEncoder", new ServerResponseEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } - }) - - channelFuture = bootstrap.bind(new InetSocketAddress(_port)) - channelFuture.sync() - - val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] - _port = addr.getPort - // _hostName = addr.getHostName - _hostName = Utils.localHostName() - - logInfo(s"Server started ${_hostName}:${_port}") - } - - /** Shutdown the server. */ - def close(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() - } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() - } - bootstrap = null - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala deleted file mode 100644 index 44687f0b770e9..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.network.{ManagedBuffer, BlockDataManager} -import org.apache.spark.storage.StorageLevel - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by the pipeline setup by BlockServerChannelInitializer. - */ -private[netty] class BlockServerHandler(dataProvider: BlockDataManager) - extends SimpleChannelInboundHandler[ClientRequest] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { - request match { - case BlockFetchRequest(blockIds) => - blockIds.foreach(processFetchRequest(ctx, _)) - case BlockUploadRequest(blockId, data, level) => - processUploadRequest(ctx, blockId, data, level) - } - } // end of channelRead0 - - private def processFetchRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { - // A helper function to send error message back to the client. - def client = ctx.channel.remoteAddress.toString - - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockFetchFailure(blockId, error)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - // First make sure we can find the block. If not, send error back to the user. - var buf: ManagedBuffer = null - try { - buf = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.size} B) back to $client") - } else { - logError( - s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - } - ) - } // end of processBlockRequest - - private def processUploadRequest( - ctx: ChannelHandlerContext, - blockId: String, - data: ManagedBuffer, - level: StorageLevel): Unit = { - // A helper function to send error message back to the client. - def client = ctx.channel.remoteAddress.toString - - try { - dataProvider.putBlockData(blockId, data, level) - ctx.writeAndFlush(BlockUploadSuccess(blockId)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (!future.isSuccess) { - logError(s"Error sending an ACK back to client $client") - } - } - }) - } catch { - case e: Throwable => - logError(s"Error processing uploaded block $blockId", e) - ctx.writeAndFlush(BlockUploadFailure(blockId, e.getMessage)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (!future.isSuccess) { - logError(s"Error sending an ACK back to client $client") - } - } - }) - } - } // end of processUploadRequest -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala new file mode 100644 index 0000000000000..aefd8a6335b2a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer +import java.util + +import org.apache.spark.Logging +import org.apache.spark.network.BlockFetchingListener +import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.Utils + +/** + * Responsible for holding the state for a request for a single set of blocks. This assumes that + * the chunks will be returned in the same order as requested, and that there will be exactly + * one chunk per block. + * + * Upon receipt of any block, the listener will be called back. Upon failure part way through, + * the listener will receive a failure callback for each outstanding block. + */ +class NettyBlockFetcher( + serializer: Serializer, + client: SluiceClient, + blockIds: Seq[String], + listener: BlockFetchingListener) + extends Logging { + + require(blockIds.nonEmpty) + + val ser = serializer.newInstance() + + var streamHandle: ShuffleStreamHandle = _ + + val chunkCallback = new ChunkReceivedCallback { + // On receipt of a chunk, pass it upwards as a block. + def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { + buffer.retain() + listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) + } + + // On receipt of a failure, fail every block from chunkIndex onwards. + def onFailure(chunkIndex: Int, e: Throwable): Unit = { + blockIds.drop(chunkIndex).foreach { blockId => + listener.onBlockFetchFailure(blockId, e); + } + } + } + + // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle. + client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(), + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + try { + streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response)) + logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.") + + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (i <- 0 until streamHandle.numChunks) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback) + } + } catch { + case e: Exception => + logError("Failed while starting block fetches", e) + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + } + + override def onFailure(e: Throwable): Unit = { + logError("Failed while starting block fetches") + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + }) +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala new file mode 100644 index 0000000000000..c8658ec98b82c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import org.apache.spark.Logging +import org.apache.spark.network.BlockDataManager +import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} +import org.apache.spark.storage.BlockId + +import scala.collection.JavaConversions._ + +/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ +case class OpenBlocks(blockIds: Seq[BlockId]) + +/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ +case class ShuffleStreamHandle(streamId: Long, numChunks: Int) + +/** + * Serves requests to open blocks by simply registering one chunk per block requested. + */ +class NettyBlockRpcServer( + serializer: Serializer, + streamManager: DefaultStreamManager, + blockManager: BlockDataManager) + extends RpcHandler with Logging { + + override def receive(messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { + val ser = serializer.newInstance() + val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + logTrace(s"Received request: $message") + message match { + case OpenBlocks(blockIds) => + val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + val streamId = streamManager.registerStream(blocks.iterator) + responseContext.onSuccess( + ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7f979dccd0f5..7576d51e22175 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,38 +17,39 @@ package org.apache.spark.network.netty -import scala.concurrent.Future - import org.apache.spark.SparkConf import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{SluiceClient, SluiceClientFactory} +import org.apache.spark.network.server.{DefaultStreamManager, SluiceServer} +import org.apache.spark.network.util.{ConfigProvider, SluiceConfig} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import scala.concurrent.Future /** - * A [[BlockTransferService]] implementation based on Netty. - * - * See protocol.scala for the communication protocol between server and client + * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -private[spark] -final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { + var client: SluiceClient = _ - private[this] val nettyConf: NettyConfig = new NettyConfig(conf) + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. + val serializer = new JavaSerializer(conf) - private[this] var server: BlockServer = _ - private[this] var clientFactory: BlockClientFactory = _ + // Create a SluiceConfig using SparkConf. + private[this] val sluiceConf = new SluiceConfig( + new ConfigProvider { override def get(name: String) = conf.get(name) }) - override def init(blockDataManager: BlockDataManager): Unit = { - server = new BlockServer(nettyConf, blockDataManager) - clientFactory = new BlockClientFactory(nettyConf) - } + private[this] var server: SluiceServer = _ + private[this] var clientFactory: SluiceClientFactory = _ - override def close(): Unit = { - if (server != null) { - server.close() - } - if (clientFactory != null) { - clientFactory.close() - } + override def init(blockDataManager: BlockDataManager): Unit = { + val streamManager = new DefaultStreamManager + val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) + server = new SluiceServer(sluiceConf, streamManager, rpcHandler) + clientFactory = new SluiceClientFactory(sluiceConf) } override def fetchBlocks( @@ -56,29 +57,21 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ port: Int, blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - clientFactory.createClient(hostName, port).fetchBlocks(blockIds, listener) + val client = clientFactory.createClient(hostName, port) + new NettyBlockFetcher(serializer, client, blockIds, listener) } + override def hostName: String = Utils.localHostName() + + override def port: Int = server.getPort + + // TODO: Implement override def uploadBlock( hostname: String, port: Int, blockId: String, blockData: ManagedBuffer, - level: StorageLevel): Future[Unit] = { - clientFactory.createClient(hostName, port).uploadBlock(blockId, blockData, level) - } + level: StorageLevel): Future[Unit] = ??? - override def hostName: String = { - if (server == null) { - throw new IllegalStateException("Server has not been started") - } - server.hostName - } - - override def port: Int = { - if (server == null) { - throw new IllegalStateException("Server has not been started") - } - server.port - } + override def close(): Unit = server.close() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala deleted file mode 100644 index 13942f3d0adcd..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer -import java.util.{List => JList} - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.channel.ChannelHandler.Sharable -import io.netty.handler.codec._ - -import org.apache.spark.Logging -import org.apache.spark.network.{NioManagedBuffer, NettyManagedBuffer, ManagedBuffer} -import org.apache.spark.storage.StorageLevel - - -/** Messages from the client to the server. */ -private[netty] -sealed trait ClientRequest { - def id: Byte -} - -/** - * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can - * correspond to multiple [[ServerResponse]]s. - */ -private[netty] -final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { - override def id = 0 -} - -/** - * Request to upload a block to the server. Currently the server does not ack the upload request. - */ -private[netty] -final case class BlockUploadRequest( - blockId: String, - data: ManagedBuffer, - level: StorageLevel) - extends ClientRequest { - require(blockId.length <= Byte.MaxValue) - override def id = 1 -} - - -/** Messages from server to client (usually in response to some [[ClientRequest]]. */ -private[netty] -sealed trait ServerResponse { - def id: Byte -} - -/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ -private[netty] -final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 0 -} - -/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ -private[netty] -final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 1 -} - -/** Response to [[BlockUploadRequest]] when a block is successfully uploaded. */ -private[netty] -final case class BlockUploadSuccess(blockId: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 2 -} - -/** Response to [[BlockUploadRequest]] when there is an error uploading the block. */ -private[netty] -final case class BlockUploadFailure(blockId: String, error: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 3 -} - - -/** - * Encoder for [[ClientRequest]] used in client side. - * - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@Sharable -private[netty] -final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { - override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { - in match { - case BlockFetchRequest(blocks) => - // 8 bytes: frame size - // 1 byte: BlockFetchRequest vs BlockUploadRequest - // 4 byte: num blocks - // then for each block id write 1 byte for blockId.length and then blockId itself - val frameLength = 8 + 1 + 4 + blocks.size + blocks.map(_.size).fold(0)(_ + _) - val buf = ctx.alloc().buffer(frameLength) - - buf.writeLong(frameLength) - buf.writeByte(in.id) - buf.writeInt(blocks.size) - blocks.foreach { blockId => - ProtocolUtils.writeBlockId(buf, blockId) - } - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadRequest(blockId, data, level) => - // 8 bytes: frame size - // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) - // 1 byte: blockId.length - // data itself (length can be derived from: frame size - 1 - blockId.length) - val headerLength = 8 + 1 + 1 + blockId.length + 5 - val frameLength = headerLength + data.size - val header = ctx.alloc().buffer(headerLength) - - // Call this before we add header to out so in case of exceptions - // we don't send anything at all. - val body = data.convertToNetty() - - header.writeLong(frameLength) - header.writeByte(in.id) - ProtocolUtils.writeBlockId(header, blockId) - header.writeInt(level.toInt) - header.writeByte(level.replication) - - assert(header.writableBytes() == 0) - out.add(header) - out.add(body) - } - } -} - - -/** - * Decoder in the server side to decode client requests. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * [[ProtocolUtils.createFrameDecoder()]]. - */ -@Sharable -private[netty] -final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { - override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = - { - val msgTypeId = in.readByte() - val decoded = msgTypeId match { - case 0 => // BlockFetchRequest - val numBlocks = in.readInt() - val blockIds = Seq.fill(numBlocks) { ProtocolUtils.readBlockId(in) } - BlockFetchRequest(blockIds) - - case 1 => // BlockUploadRequest - val blockId = ProtocolUtils.readBlockId(in) - val level = new StorageLevel(in.readInt(), in.readByte()) - - val ret = ByteBuffer.allocate(in.readableBytes()) - ret.put(in.nioBuffer()) - ret.flip() - BlockUploadRequest(blockId, new NioManagedBuffer(ret), level) - } - - assert(decoded.id == msgTypeId) - out.add(decoded) - } -} - - -/** - * Encoder used by the server side to encode server-to-client responses. - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@Sharable -private[netty] -final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { - override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { - in match { - case BlockFetchSuccess(blockId, data) => - // Handle the body first so if we encounter an error getting the body, we can respond - // with an error instead. - var body: AnyRef = null - try { - body = data.convertToNetty() - } catch { - case e: Exception => - // Re-encode this message as BlockFetchFailure. - logError(s"Error opening block $blockId for client ${ctx.channel.remoteAddress}", e) - encode(ctx, new BlockFetchFailure(blockId, e.getMessage), out) - return - } - - // If we got here, body cannot be null - // 8 bytes = long for frame length - // 1 byte = message id (type) - // 1 byte = block id length - // followed by block id itself - val headerLength = 8 + 1 + 1 + blockId.length - val frameLength = headerLength + data.size - val header = ctx.alloc().buffer(headerLength) - header.writeLong(frameLength) - header.writeByte(in.id) - ProtocolUtils.writeBlockId(header, blockId) - - assert(header.writableBytes() == 0) - out.add(header) - out.add(body) - - case BlockFetchFailure(blockId, error) => - val frameLength = 8 + 1 + 1 + blockId.length + error.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - buf.writeBytes(error.getBytes) - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadSuccess(blockId) => - val frameLength = 8 + 1 + 1 + blockId.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadFailure(blockId, error) => - val frameLength = 8 + 1 + 1 + blockId.length + + error.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - buf.writeBytes(error.getBytes) - - assert(buf.writableBytes() == 0) - out.add(buf) - } - } -} - - -/** - * Decoder in the client side to decode server responses. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * [[ProtocolUtils.createFrameDecoder()]]. - */ -@Sharable -private[netty] -final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { - override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { - val msgId = in.readByte() - val decoded = msgId match { - case 0 => // BlockFetchSuccess - val blockId = ProtocolUtils.readBlockId(in) - in.retain() - BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) - - case 1 => // BlockFetchFailure - val blockId = ProtocolUtils.readBlockId(in) - val errorBytes = new Array[Byte](in.readableBytes()) - in.readBytes(errorBytes) - BlockFetchFailure(blockId, new String(errorBytes)) - - case 2 => // BlockUploadSuccess - BlockUploadSuccess(ProtocolUtils.readBlockId(in)) - - case 3 => // BlockUploadFailure - val blockId = ProtocolUtils.readBlockId(in) - val errorBytes = new Array[Byte](in.readableBytes()) - in.readBytes(errorBytes) - BlockUploadFailure(blockId, new String(errorBytes)) - } - - assert(decoded.id == msgId) - out.add(decoded) - } -} - - -private[netty] object ProtocolUtils { - - /** LengthFieldBasedFrameDecoder used before all decoders. */ - def createFrameDecoder(): ByteToMessageDecoder = { - // maxFrameLength = 2G - // lengthFieldOffset = 0 - // lengthFieldLength = 8 - // lengthAdjustment = -8, i.e. exclude the 8 byte length itself - // initialBytesToStrip = 8, i.e. strip out the length field itself - new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) - } - - // TODO(rxin): Make sure these work for all charsets. - def readBlockId(in: ByteBuf): String = { - val numBytesToRead = in.readByte().toInt - val bytes = new Array[Byte](numBytesToRead) - in.readBytes(bytes) - new String(bytes) - } - - def writeBlockId(out: ByteBuf, blockId: String): Unit = { - out.writeByte(blockId.length) - out.writeBytes(blockId.getBytes) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index e942b43d9cc4a..bce1069548437 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -19,12 +19,13 @@ package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.concurrent.Future - -import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} + +import scala.concurrent.Future /** @@ -153,12 +154,11 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => { + case e: Exception => logError("Exception handling buffer message", e) val errorMessage = Message.createBufferMessage(msg.id) errorMessage.hasError = true Some(errorMessage) - } } case otherMessage: Any => @@ -174,13 +174,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa case BlockMessage.TYPE_PUT_BLOCK => val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logDebug("Received [" + msg + "]") - putBlock(msg.id.toString, msg.data, msg.level) + putBlock(msg.id, msg.data, msg.level) None case BlockMessage.TYPE_GET_BLOCK => val msg = new GetBlock(blockMessage.getId) logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id.toString) + val buffer = getBlock(msg.id) if (buffer == null) { return None } @@ -190,7 +190,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } } - private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) @@ -198,7 +198,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa + " with data size: " + bytes.limit) } - private def getBlock(blockId: String): ByteBuffer = { + private def getBlock(blockId: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) val buffer = blockDataManager.getBlockData(blockId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 439981d232349..c35aa2481ad03 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ -import org.apache.spark.{SparkEnv, SparkConf, Logging} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.{Logging, SparkConf, SparkEnv} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 4ab34336d3f01..6a9fa4ec65d5d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -21,7 +21,7 @@ import java.io._ import java.nio.ByteBuffer import org.apache.spark.SparkEnv -import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.storage._ /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala index 63863cc0250a3..b521f0c7fc77e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer - -import org.apache.spark.network.ManagedBuffer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ac0599f30ef22..4d8b5c1e1b084 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,15 +17,13 @@ package org.apache.spark.storage -import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} +import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.concurrent.ExecutionContext.Implicits.global - -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} import scala.util.Random import akka.actor.{ActorSystem, Props} @@ -35,11 +33,11 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ - private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -215,17 +213,17 @@ private[spark] class BlockManager( * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - override def getBlockData(blockId: String): ManagedBuffer = { - val bid = BlockId(blockId) - if (bid.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]) + override def getBlockData(blockId: BlockId): ManagedBuffer = { + if (blockId.isShuffle) { + shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { - val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) + .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get new NioManagedBuffer(buffer) } else { - throw new BlockNotFoundException(blockId) + throw new BlockNotFoundException(blockId.toString) } } } @@ -233,8 +231,8 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(BlockId(blockId), data.nioByteBuffer(), level) + override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { + putBytes(blockId, data.nioByteBuffer(), level) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala index 9ef453605f4f1..81f5f2d31dbd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -17,5 +17,4 @@ package org.apache.spark.storage - class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d095452a261db..23313fe9271fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,14 +19,13 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue +import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.{Logging, TaskContext} /** @@ -228,7 +227,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val blockId = iter.next() try { - val buf = blockManager.getBlockData(blockId.toString) + val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 buf.retain() results.put(new FetchResult(blockId, 0, buf)) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 2fc7c7d9b8312..1e35abaab5353 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -42,7 +42,7 @@ class StorageLevel private( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - private[spark] def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -98,7 +98,6 @@ class StorageLevel private( } override def writeExternal(out: ObjectOutput) { - /* If the wire protocol changes, please also update [[ClientRequestEncoder]] */ out.writeByte(toInt) out.writeByte(_replication) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala deleted file mode 100644 index 2d4baafcf03d0..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import scala.concurrent.{Await, future} -import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global - -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.SparkConf - - -class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { - - private val conf = new SparkConf - private var server1: BlockServer = _ - private var server2: BlockServer = _ - - override def beforeAll() { - server1 = new BlockServer(new NettyConfig(conf), null) - server2 = new BlockServer(new NettyConfig(conf), null) - } - - override def afterAll() { - if (server1 != null) { - server1.close() - } - if (server2 != null) { - server2.close() - } - } - - test("BlockClients created are active and reused") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - val c2 = factory.createClient(server1.hostName, server1.port) - val c3 = factory.createClient(server2.hostName, server2.port) - assert(c1.isActive) - assert(c3.isActive) - assert(c1 === c2) - assert(c1 !== c3) - factory.close() - } - - test("never return inactive clients") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - c1.close() - - // Block until c1 is no longer active - val f = future { - while (c1.isActive) { - Thread.sleep(10) - } - } - Await.result(f, 3.seconds) - assert(!c1.isActive) - - // Create c2, which should be different from c1 - val c2 = factory.createClient(server1.hostName, server1.port) - assert(c1 !== c2) - factory.close() - } - - test("BlockClients are close when BlockClientFactory is stopped") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - val c2 = factory.createClient(server2.hostName, server2.port) - assert(c1.isActive) - assert(c2.isActive) - factory.close() - assert(!c1.isActive) - assert(!c2.isActive) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala deleted file mode 100644 index 4c3a649081574..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.mockito.Mockito._ -import org.mockito.Matchers.{any, eq => meq} - -import org.scalatest.{FunSuite, PrivateMethodTester} - -import org.apache.spark.network._ - - -class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { - - /** Helper method to get num. outstanding requests from a private field using reflection. */ - private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { - val f = handler.getClass.getDeclaredField( - "org$apache$spark$network$netty$BlockClientHandler$$outstandingFetches") - f.setAccessible(true) - f.get(handler).asInstanceOf[java.util.Map[_, _]].size - } - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest(blockId, listener) - assert(sizeOfOutstandingRequests(handler) === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) - verify(listener, times(1)).onBlockFetchSuccess(meq(blockId), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest(blockId, listener) - assert(sizeOfOutstandingRequests(handler) === 1) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchFailure(blockId, "some error msg")) - verify(listener, times(0)).onBlockFetchSuccess(any(), any()) - verify(listener, times(1)).onBlockFetchFailure(meq(blockId), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("clear all outstanding request upon uncaught exception") { - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest("b1", listener) - handler.addFetchRequest("b2", listener) - handler.addFetchRequest("b3", listener) - assert(sizeOfOutstandingRequests(handler) === 3) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) - channel.pipeline().fireExceptionCaught(new Exception("duh duh duh")) - - // should fail both b2 and b3 - verify(listener, times(1)).onBlockFetchSuccess(any(), any()) - verify(listener, times(2)).onBlockFetchFailure(any(), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("clear all outstanding request upon connection close") { - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest("c1", listener) - handler.addFetchRequest("c2", listener) - handler.addFetchRequest("c3", listener) - assert(sizeOfOutstandingRequests(handler) === 3) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("c1", new NettyManagedBuffer(Unpooled.buffer()))) - channel.finish() - - // should fail both b2 and b3 - verify(listener, times(1)).onBlockFetchSuccess(any(), any()) - verify(listener, times(2)).onBlockFetchFailure(any(), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala deleted file mode 100644 index 8d1b7276f4082..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.api.java.StorageLevels - - -/** - * Test client/server encoder/decoder protocol. - */ -class ProtocolSuite extends FunSuite { - - /** - * Helper to test server to client message protocol by encoding a message and decoding it. - */ - private def testServerToClient(msg: ServerResponse) { - val serverChannel = new EmbeddedChannel(new ServerResponseEncoder) - serverChannel.writeOutbound(msg) - - val clientChannel = new EmbeddedChannel( - ProtocolUtils.createFrameDecoder(), - new ServerResponseDecoder) - - // Drain all server outbound messages and write them to the client's server decoder. - while (!serverChannel.outboundMessages().isEmpty) { - clientChannel.writeInbound(serverChannel.readOutbound()) - } - - assert(clientChannel.inboundMessages().size === 1) - // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is - // overridden. - assert(msg === clientChannel.readInbound()) - } - - /** - * Helper to test client to server message protocol by encoding a message and decoding it. - */ - private def testClientToServer(msg: ClientRequest) { - val clientChannel = new EmbeddedChannel(new ClientRequestEncoder) - clientChannel.writeOutbound(msg) - - val serverChannel = new EmbeddedChannel( - ProtocolUtils.createFrameDecoder(), - new ClientRequestDecoder) - - // Drain all client outbound messages and write them to the server's decoder. - while (!clientChannel.outboundMessages().isEmpty) { - serverChannel.writeInbound(clientChannel.readOutbound()) - } - - assert(serverChannel.inboundMessages().size === 1) - // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is - // overridden. - assert(msg === serverChannel.readInbound()) - } - - test("server to client protocol - BlockFetchSuccess(\"a1234\", new TestManagedBuffer(10))") { - testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) - } - - test("server to client protocol - BlockFetchSuccess(\"\", new TestManagedBuffer(0))") { - testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) - } - - test("server to client protocol - BlockFetchFailure(\"abcd\", \"this is an error\")") { - testServerToClient(BlockFetchFailure("abcd", "this is an error")) - } - - test("server to client protocol - BlockFetchFailure(\"\", \"\")") { - testServerToClient(BlockFetchFailure("", "")) - } - - test("client to server protocol - BlockFetchRequest(Seq.empty[String])") { - testClientToServer(BlockFetchRequest(Seq.empty[String])) - } - - test("client to server protocol - BlockFetchRequest(Seq(\"b1\"))") { - testClientToServer(BlockFetchRequest(Seq("b1"))) - } - - test("client to server protocol - BlockFetchRequest(Seq(\"b1\", \"b2\", \"b3\"))") { - testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) - } - - test("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { - testClientToServer( - BlockUploadRequest("", new TestManagedBuffer(0), StorageLevels.MEMORY_AND_DISK)) - } - - test("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { - testClientToServer( - BlockUploadRequest("b_upload", new TestManagedBuffer(10), StorageLevels.MEMORY_AND_DISK_2)) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala deleted file mode 100644 index 35ff90a2dabc5..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer -import java.util.{Collections, HashSet} -import java.util.concurrent.{TimeUnit, Semaphore} - -import scala.collection.JavaConversions._ - -import io.netty.buffer.Unpooled - -import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.Span -import org.scalatest.time.Seconds - -import org.apache.spark.SparkConf -import org.apache.spark.network._ -import org.apache.spark.storage.{BlockNotFoundException, StorageLevel} - - -/** -* Test cases that create real clients and servers and connect. -*/ -class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { - - val bufSize = 100000 - var buf: ByteBuffer = _ - var testFile: File = _ - var server: BlockServer = _ - var clientFactory: BlockClientFactory = _ - - val bufferBlockId = "buffer_block" - val fileBlockId = "file_block" - - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - - override def beforeAll() = { - buf = ByteBuffer.allocate(bufSize) - for (i <- 1 to bufSize) { - buf.put(i.toByte) - } - buf.flip() - - testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { - override def getBlockData(blockId: String): ManagedBuffer = { - if (blockId == bufferBlockId) { - new NioManagedBuffer(buf) - } else if (blockId == fileBlockId) { - new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) - } else { - throw new BlockNotFoundException(blockId) - } - } - - /** - * Put the block locally, using the given storage level. - */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = ??? - }) - - clientFactory = new BlockClientFactory(new SparkConf) - } - - override def afterAll() = { - server.close() - clientFactory.close() - } - - /** A ByteBuf for buffer_block */ - lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) - - /** A ByteBuf for file_block */ - lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { - val client = clientFactory.createClient(server.hostName, server.port) - val sem = new Semaphore(0) - val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) - val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ManagedBuffer]) - - client.fetchBlocks( - blockIds, - new BlockFetchingListener { - override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { - errorBlockIds.add(blockId) - sem.release() - } - - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - data.retain() - receivedBlockIds.add(blockId) - receivedBuffers.add(data) - sem.release() - } - } - ) - if (!sem.tryAcquire(blockIds.size, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server") - } - client.close() - (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) - } - - test("fetch a ByteBuffer block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a FileSegment block via zero-copy send") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) - assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) - assert(blockIds.isEmpty) - assert(buffers.isEmpty) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and FileSegment block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) - assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } - - test("shutting down server should also close client") { - val client = clientFactory.createClient(server.hostName, server.port) - server.close() - eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala deleted file mode 100644 index e47e4d03fa898..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled - -import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} - - -/** - * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). - * - * Used for testing. - */ -class TestManagedBuffer(len: Int) extends ManagedBuffer { - - require(len <= Byte.MaxValue) - - private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) - - private val underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)) - - override def size: Long = underlying.size - - override private[network] def convertToNetty(): AnyRef = underlying.convertToNetty() - - override def nioByteBuffer(): ByteBuffer = underlying.nioByteBuffer() - - override def inputStream(): InputStream = underlying.inputStream() - - override def toString: String = s"${getClass.getName}($len)" - - override def equals(other: Any): Boolean = other match { - case otherBuf: ManagedBuffer => - val nioBuf = otherBuf.nioByteBuffer() - if (nioBuf.remaining() != len) { - return false - } else { - var i = 0 - while (i < len) { - if (nioBuf.get() != i) { - return false - } - i += 1 - } - return true - } - case _ => false - } - - override def retain(): this.type = this - - override def release(): this.type = this -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index ba47fe5e25b9b..6790388f96603 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockManager import org.apache.spark.storage.{ShuffleBlockId, FileSegment} @@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath) - assert(expected.offset === segment.offset) - assert(expected.length === segment.length) + assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) + assert(expected.offset === segment.getOffset) + assert(expected.length === segment.getLength) } test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 7d4086313fcc1..3beb503b206f2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.TestSerializer @@ -71,7 +72,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } // Make sure remote blocks would return diff --git a/network/common/pom.xml b/network/common/pom.xml new file mode 100644 index 0000000000000..e3b7e328701b4 --- /dev/null +++ b/network/common/pom.xml @@ -0,0 +1,94 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + network + jar + Shuffle Streaming Service + http://spark.apache.org/ + + network + + + + + + io.netty + netty-all + + + org.slf4j + slf4j-api + + + + + com.google.guava + guava + provided + + + + + junit + junit + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + + + + target/java/classes + target/java/test-classes + + + org.apache.maven.plugins + maven-surefire-plugin + 2.17 + + false + + **/Test*.java + **/*Test.java + **/*Suite.java + + + + + + diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java new file mode 100644 index 0000000000000..224f1e6c515ea --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import com.google.common.base.Objects; +import com.google.common.io.ByteStreams; +import io.netty.channel.DefaultFileRegion; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A {@link ManagedBuffer} backed by a segment in a file. + */ +public final class FileSegmentManagedBuffer extends ManagedBuffer { + + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + + private final File file; + private final long offset; + private final long length; + + public FileSegmentManagedBuffer(File file, long offset, long length) { + this.file = file; + this.offset = offset; + this.length = length; + } + + @Override + public long size() { + return length; + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + FileChannel channel = null; + try { + channel = new RandomAccessFile(file, "r").getChannel(); + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + ByteBuffer buf = ByteBuffer.allocate((int) length); + channel.read(buf, offset); + buf.flip(); + return buf; + } else { + return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); + } + } catch (IOException e) { + try { + if (channel != null) { + long size = channel.size(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } + throw new IOException("Error in opening " + this, e); + } finally { + JavaUtils.closeQuietly(channel); + } + } + + @Override + public InputStream inputStream() throws IOException { + FileInputStream is = null; + try { + is = new FileInputStream(file); + is.skip(offset); + return ByteStreams.limit(is, length); + } catch (IOException e) { + try { + if (is != null) { + long size = file.length(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } finally { + JavaUtils.closeQuietly(is); + } + throw new IOException("Error in opening " + this, e); + } catch (RuntimeException e) { + JavaUtils.closeQuietly(is); + throw e; + } + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } + + public File getFile() { return file; } + + public long getOffset() { return offset; } + + public long getLength() { return length; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("offset", offset) + .add("length", length) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java new file mode 100644 index 0000000000000..1735f5540c61b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +/** + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - {@link FileSegmentManagedBuffer}: data backed by part of a file + * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer + * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. + */ +public abstract class ManagedBuffer { + + /** Number of bytes of the data. */ + public abstract long size(); + + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + public abstract ByteBuffer nioByteBuffer() throws IOException; + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + public abstract InputStream inputStream() throws IOException; + + /** + * Increment the reference count by one if applicable. + */ + public abstract ManagedBuffer retain(); + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + public abstract ManagedBuffer release(); + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + public abstract Object convertToNetty() throws IOException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java new file mode 100644 index 0000000000000..d928980423f1f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; + +/** + * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. + */ +public final class NettyManagedBuffer extends ManagedBuffer { + private final ByteBuf buf; + + public NettyManagedBuffer(ByteBuf buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.readableBytes(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.nioBuffer(); + } + + @Override + public InputStream inputStream() throws IOException { + return new ByteBufInputStream(buf); + } + + @Override + public ManagedBuffer retain() { + buf.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + buf.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return buf; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java new file mode 100644 index 0000000000000..3953ef89fbf88 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; + +/** + * A {@link ManagedBuffer} backed by {@link ByteBuffer}. + */ +public final class NioManagedBuffer extends ManagedBuffer { + private final ByteBuffer buf; + + public NioManagedBuffer(ByteBuffer buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.remaining(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.duplicate(); + } + + @Override + public InputStream inputStream() throws IOException { + return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} + diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java new file mode 100644 index 0000000000000..40a1fe67b1c5b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * General exception caused by a remote exception while fetching a chunk. + */ +public class ChunkFetchFailureException extends RuntimeException { + private final int chunkIndex; + + public ChunkFetchFailureException(int chunkIndex, String errorMsg, Throwable cause) { + super(errorMsg, cause); + this.chunkIndex = chunkIndex; + } + + public ChunkFetchFailureException(int chunkIndex, String errorMsg) { + super(errorMsg); + this.chunkIndex = chunkIndex; + } + + public int getChunkIndex() { return chunkIndex; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java new file mode 100644 index 0000000000000..519e6cb470d0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single chunk result. For a single stream, the callbacks are + * guaranteed to be called by the same thread in the same order as the requests for chunks were + * made. + * + * Note that if a general stream failure occurs, all outstanding chunk requests may be failed. + */ +public interface ChunkReceivedCallback { + /** + * Called upon receipt of a particular chunk. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + */ + void onSuccess(int chunkIndex, ManagedBuffer buffer); + + /** + * Called upon failure to fetch a particular chunk. Note that this may actually be called due + * to failure to fetch a prior chunk in this stream. + * + * After receiving a failure, the stream may or may not be valid. The client should not assume + * that the server's side of the stream has been closed. + */ + void onFailure(int chunkIndex, Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java new file mode 100644 index 0000000000000..6ec960d795420 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * Callback for the result of a single RPC. This will be invoked once with either success or + * failure. + */ +public interface RpcResponseCallback { + /** Successful serialized result from server. */ + void onSuccess(byte[] response); + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java new file mode 100644 index 0000000000000..1f7d3b0234e38 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.util.UUID; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.RpcRequest; + +/** + * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow + * efficient transfer of a large amount of data, broken up into chunks with size ranging from + * hundreds of KB to a few MB. + * + * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), + * the actual setup of the streams is done outside the scope of Sluice. The convenience method + * "sendRPC" is provided to enable control plane communication between the client and server to + * perform this setup. + * + * For example, a typical workflow might be: + * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 + * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) + * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) + * ... + * client.sendRPC(new CloseStream(100)) + * + * Construct an instance of SluiceClient using {@link SluiceClientFactory}. A single SluiceClient + * may be used for multiple streams, but any given stream must be restricted to a single client, + * in order to avoid out-of-order responses. + * + * NB: This class is used to make requests to the server, while {@link SluiceClientHandler} is + * responsible for handling responses from the server. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class SluiceClient implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceClient.class); + + private final ChannelFuture cf; + private final SluiceClientHandler handler; + + private final String serverAddr; + + SluiceClient(ChannelFuture cf, SluiceClientHandler handler) { + this.cf = cf; + this.handler = handler; + + if (cf != null && cf.channel() != null && cf.channel().remoteAddress() != null) { + serverAddr = cf.channel().remoteAddress().toString(); + } else { + serverAddr = ""; + } + } + + public boolean isActive() { + return cf.channel().isActive(); + } + + /** + * Requests a single chunk from the remote side, from the pre-negotiated streamId. + * + * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though + * some streams may not support this. + * + * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed + * to be returned in the same order that they were requested, assuming only a single SluiceClient + * is used to fetch the chunks. + * + * @param streamId Identifier that refers to a stream in the remote StreamManager. This should + * be agreed upon by client and server beforehand. + * @param chunkIndex 0-based index of the chunk to fetch + * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. + */ + public void fetchChunk( + long streamId, + final int chunkIndex, + final ChunkReceivedCallback callback) { + final long startTime = System.currentTimeMillis(); + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + + final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + handler.addFetchRequest(streamChunkId, callback); + + cf.channel().writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + timeTaken); + } else { + // Fail all blocks. + String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, + serverAddr, future.cause().getMessage()); + logger.error(errorMsg, future.cause()); + future.cause().printStackTrace(); + handler.removeFetchRequest(streamChunkId); + callback.onFailure(chunkIndex, new RuntimeException(errorMsg)); + } + } + }); + } + + /** + * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked + * with the server's response or upon any failure. + */ + public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final long startTime = System.currentTimeMillis(); + logger.debug("Sending RPC to {}", serverAddr); + + final long tag = UUID.randomUUID().getLeastSignificantBits(); + handler.addRpcRequest(tag, callback); + + cf.channel().writeAndFlush(new RpcRequest(tag, message)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken); + } else { + // Fail all blocks. + String errorMsg = String.format("Failed to send request %s to %s: %s", tag, + serverAddr, future.cause().getMessage()); + logger.error(errorMsg, future.cause()); + handler.removeRpcRequest(tag); + callback.onFailure(new RuntimeException(errorMsg)); + } + } + }); + } + + @Override + public void close() { + cf.channel().close(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java new file mode 100644 index 0000000000000..17491dc3f8720 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.request.ClientRequestEncoder; +import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Factory for creating {@link SluiceClient}s by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same + * {@link SluiceClient} for the same remote host. It also shares a single worker thread pool for + * all {@link SluiceClient}s. + */ +public class SluiceClientFactory implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceClientFactory.class); + + private final SluiceConfig conf; + private final Map connectionPool; + private final ClientRequestEncoder encoder; + private final ServerResponseDecoder decoder; + + private final Class socketChannelClass; + private final EventLoopGroup workerGroup; + + public SluiceClientFactory(SluiceConfig conf) { + this.conf = conf; + this.connectionPool = new ConcurrentHashMap(); + this.encoder = new ClientRequestEncoder(); + this.decoder = new ServerResponseDecoder(); + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + public SluiceClient createClient(String remoteHost, int remotePort) throws TimeoutException { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + SluiceClient cachedClient = connectionPool.get(address); + if (cachedClient != null && cachedClient.isActive()) { + return cachedClient; + } + + logger.debug("Creating new connection to " + address); + + // There is a chance two threads are creating two different clients connecting to the same host. + // But that's probably ok, as long as the caller hangs on to their client for a single stream. + final SluiceClientHandler handler = new SluiceClientHandler(); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + + bootstrap.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + ch.pipeline() + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) + .addLast("handler", handler); + } + }); + + // Connect to the remote server + ChannelFuture cf = bootstrap.connect(address); + if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + throw new TimeoutException( + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } + + SluiceClient client = new SluiceClient(cf, handler); + connectionPool.put(address, client); + return client; + } + + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ + @Override + public void close() { + for (SluiceClient client : connectionPool.values()) { + client.close(); + } + connectionPool.clear(); + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + } + } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private PooledByteBufAllocator createPooledByteBufAllocator() { + return new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java new file mode 100644 index 0000000000000..ed20b032931c3 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.protocol.response.ServerResponse; + +/** + * Handler that processes server responses, in response to requests issued from [[SluiceClient]]. + * It works by tracking the list of outstanding requests (and their callbacks). + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class SluiceClientHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceClientHandler.class); + + private final Map outstandingFetches = + new ConcurrentHashMap(); + + private final Map outstandingRpcs = + new ConcurrentHashMap(); + + public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + outstandingFetches.put(streamChunkId, callback); + } + + public void removeFetchRequest(StreamChunkId streamChunkId) { + outstandingFetches.remove(streamChunkId); + } + + public void addRpcRequest(long tag, RpcResponseCallback callback) { + outstandingRpcs.put(tag, callback); + } + + public void removeRpcRequest(long tag) { + outstandingRpcs.remove(tag); + } + + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private void failOutstandingRequests(Throwable cause) { + for (Map.Entry entry : outstandingFetches.entrySet()) { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } + // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests + // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + outstandingFetches.clear(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + if (outstandingFetches.size() > 0) { + SocketAddress remoteAddress = ctx.channel().remoteAddress(); + logger.error("Still have {} requests outstanding when contention from {} is closed", + outstandingFetches.size(), remoteAddress); + failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + } + super.channelUnregistered(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (outstandingFetches.size() > 0) { + logger.error(String.format("Exception in connection from %s: %s", + ctx.channel().remoteAddress(), cause.getMessage()), cause); + failOutstandingRequests(cause); + } + ctx.close(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { + String server = ctx.channel().remoteAddress().toString(); + if (message instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Got a response for block {} from {} but it is not outstanding", + resp.streamChunkId, server); + resp.buffer.release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); + resp.buffer.release(); + } + } else if (message instanceof ChunkFetchFailure) { + ChunkFetchFailure resp = (ChunkFetchFailure) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Got a response for block {} from {} ({}) but it is not outstanding", + resp.streamChunkId, server, resp.errorString); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onFailure(resp.streamChunkId.chunkIndex, + new ChunkFetchFailureException(resp.streamChunkId.chunkIndex, resp.errorString)); + } + } else if (message instanceof RpcResponse) { + RpcResponse resp = (RpcResponse) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + if (listener == null) { + logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding", + resp.tag, server, resp.response.length); + } else { + outstandingRpcs.remove(resp.tag); + listener.onSuccess(resp.response); + } + } else if (message instanceof RpcFailure) { + RpcFailure resp = (RpcFailure) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + if (listener == null) { + logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding", + resp.tag, server, resp.errorString); + } else { + outstandingRpcs.remove(resp.tag); + listener.onFailure(new RuntimeException(resp.errorString)); + } + } + } + + @VisibleForTesting + public int numOutstandingRequests() { + return outstandingFetches.size(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java new file mode 100644 index 0000000000000..363ea5ecfa936 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** + * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are + * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + */ +public interface Encodable { + /** Number of bytes of the encoded form of this object. */ + int encodedLength(); + + /** + * Serializes this object by writing into the given ByteBuf. + * This method must write exactly encodedLength() bytes. + */ + void encode(ByteBuf buf); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java new file mode 100644 index 0000000000000..d46a263884807 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** +* Encapsulates a request for a particular chunk of a stream. +*/ +public final class StreamChunkId implements Encodable { + public final long streamId; + public final int chunkIndex; + + public StreamChunkId(long streamId, int chunkIndex) { + this.streamId = streamId; + this.chunkIndex = chunkIndex; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + public void encode(ByteBuf buffer) { + buffer.writeLong(streamId); + buffer.writeInt(chunkIndex); + } + + public static StreamChunkId decode(ByteBuf buffer) { + assert buffer.readableBytes() >= 8 + 4; + long streamId = buffer.readLong(); + int chunkIndex = buffer.readInt(); + return new StreamChunkId(streamId, chunkIndex); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, chunkIndex); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamChunkId) { + StreamChunkId o = (StreamChunkId) other; + return streamId == o.streamId && chunkIndex == o.chunkIndex; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("chunkIndex", chunkIndex) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java new file mode 100644 index 0000000000000..a79eb363cf58c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single + * {@link org.apache.spark.network.protocol.response.ServerResponse} (either success or failure). + */ +public final class ChunkFetchRequest implements ClientRequest { + public final StreamChunkId streamChunkId; + + public ChunkFetchRequest(StreamChunkId streamChunkId) { + this.streamChunkId = streamChunkId; + } + + @Override + public Type type() { return Type.ChunkFetchRequest; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + public static ChunkFetchRequest decode(ByteBuf buf) { + return new ChunkFetchRequest(StreamChunkId.decode(buf)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchRequest) { + ChunkFetchRequest o = (ChunkFetchRequest) other; + return streamChunkId.equals(o.streamChunkId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java new file mode 100644 index 0000000000000..db075c44b4cda --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** Messages from the client to the server. */ +public interface ClientRequest extends Encodable { + /** Used to identify this request type. */ + Type type(); + + /** + * Preceding every serialized ClientRequest is the type, which allows us to deserialize + * the request. + */ + public static enum Type implements Encodable { + ChunkFetchRequest(0), RpcRequest(1); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 request types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch(id) { + case 0: return ChunkFetchRequest; + case 1: return RpcRequest; + default: throw new IllegalArgumentException("Unknown request type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java new file mode 100644 index 0000000000000..a937da4cecae0 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; + +/** + * Decoder in the server side to decode client requests. + * This decoder is stateless so it is safe to be shared by multiple threads. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * {@link org.apache.spark.network.util.NettyUtils#createFrameDecoder()}. + */ +@ChannelHandler.Sharable +public final class ClientRequestDecoder extends MessageToMessageDecoder { + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ClientRequest.Type msgType = ClientRequest.Type.decode(in); + ClientRequest decoded = decode(msgType, in); + assert decoded.type() == msgType; + assert in.readableBytes() == 0; + out.add(decoded); + } + + private ClientRequest decode(ClientRequest.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + + case RpcRequest: + return RpcRequest.decode(in); + + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java new file mode 100644 index 0000000000000..bcff4a0a25568 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; + +/** + * Encoder for {@link ClientRequest} used in client side. + * + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ClientRequestEncoder extends MessageToMessageEncoder { + @Override + public void encode(ChannelHandlerContext ctx, ClientRequest in, List out) { + ClientRequest.Type msgType = in.type(); + // Write 8 bytes for the frame's length, followed by the request type and request itself. + int frameLength = 8 + msgType.encodedLength() + in.encodedLength(); + ByteBuf buf = ctx.alloc().buffer(frameLength); + buf.writeLong(frameLength); + msgType.encode(buf); + in.encode(buf); + assert buf.writableBytes() == 0; + out.add(buf); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java new file mode 100644 index 0000000000000..126370330f723 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. + * This will correspond to a single {@link org.apache.spark.network.protocol.response.ServerResponse} + * (either success or failure). + */ +public final class RpcRequest implements ClientRequest { + /** Tag is used to link an RPC request with its response. */ + public final long tag; + + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public RpcRequest(long tag, byte[] message) { + this.tag = tag; + this.message = message; + } + + @Override + public Type type() { return Type.RpcRequest; } + + @Override + public int encodedLength() { + return 8 + 4 + message.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + buf.writeInt(message.length); + buf.writeBytes(message); + } + + public static RpcRequest decode(ByteBuf buf) { + long tag = buf.readLong(); + int messageLen = buf.readInt(); + byte[] message = new byte[messageLen]; + buf.readBytes(message); + return new RpcRequest(tag, message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcRequest) { + RpcRequest o = (RpcRequest) other; + return tag == o.tag && Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java new file mode 100644 index 0000000000000..3a57d71b4f3ea --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when there is an + * error fetching the chunk. + */ +public final class ChunkFetchFailure implements ServerResponse { + public final StreamChunkId streamChunkId; + public final String errorString; + + public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { + this.streamChunkId = streamChunkId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.ChunkFetchFailure; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 4 + errorString.getBytes().length; + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + byte[] errorBytes = errorString.getBytes(); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static ChunkFetchFailure decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new ChunkFetchFailure(streamChunkId, new String(errorBytes)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchFailure) { + ChunkFetchFailure o = (ChunkFetchFailure) other; + return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java new file mode 100644 index 0000000000000..874dc4f5940cf --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when a chunk + * exists and has been successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class ChunkFetchSuccess implements ServerResponse { + public final StreamChunkId streamChunkId; + public final ManagedBuffer buffer; + + public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + this.streamChunkId = streamChunkId; + this.buffer = buffer; + } + + @Override + public Type type() { return Type.ChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + /** Encoding does NOT include buffer itself. See {@link ServerResponseEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static ChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new ChunkFetchSuccess(streamChunkId, managedBuf); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchSuccess) { + ChunkFetchSuccess o = (ChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("buffer", buffer) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java new file mode 100644 index 0000000000000..274920b28bced --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a failed RPC. */ +public final class RpcFailure implements ServerResponse { + public final long tag; + public final String errorString; + + public RpcFailure(long tag, String errorString) { + this.tag = tag; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.RpcFailure; } + + @Override + public int encodedLength() { + return 8 + 4 + errorString.getBytes().length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + byte[] errorBytes = errorString.getBytes(); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static RpcFailure decode(ByteBuf buf) { + long tag = buf.readLong(); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new RpcFailure(tag, new String(errorBytes)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcFailure) { + RpcFailure o = (RpcFailure) other; + return tag == o.tag && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java new file mode 100644 index 0000000000000..0c6f8acdcdc4b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a successful RPC. */ +public final class RpcResponse implements ServerResponse { + public final long tag; + public final byte[] response; + + public RpcResponse(long tag, byte[] response) { + this.tag = tag; + this.response = response; + } + + @Override + public Type type() { return Type.RpcResponse; } + + @Override + public int encodedLength() { return 8 + 4 + response.length; } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + buf.writeInt(response.length); + buf.writeBytes(response); + } + + public static RpcResponse decode(ByteBuf buf) { + long tag = buf.readLong(); + int responseLen = buf.readInt(); + byte[] response = new byte[responseLen]; + buf.readBytes(response); + return new RpcResponse(tag, response); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcResponse) { + RpcResponse o = (RpcResponse) other; + return tag == o.tag && Arrays.equals(response, o.response); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("response", response) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java new file mode 100644 index 0000000000000..335f9e8ea69f9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages from server to client (usually in response to some + * {@link org.apache.spark.network.protocol.request.ClientRequest}. + */ +public interface ServerResponse extends Encodable { + /** Used to identify this response type. */ + Type type(); + + /** + * Preceding every serialized ServerResponse is the type, which allows us to deserialize + * the response. + */ + public static enum Type implements Encodable { + ChunkFetchSuccess(0), ChunkFetchFailure(1), RpcResponse(2), RpcFailure(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 response types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch(id) { + case 0: return ChunkFetchSuccess; + case 1: return ChunkFetchFailure; + case 2: return RpcResponse; + case 3: return RpcFailure; + default: throw new IllegalArgumentException("Unknown response type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java new file mode 100644 index 0000000000000..e06198284e620 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; + +/** + * Decoder used by the client side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ServerResponseDecoder extends MessageToMessageDecoder { + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ServerResponse.Type msgType = ServerResponse.Type.decode(in); + ServerResponse decoded = decode(msgType, in); + assert decoded.type() == msgType; + out.add(decoded); + } + + private ServerResponse decode(ServerResponse.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchSuccess: + return ChunkFetchSuccess.decode(in); + + case ChunkFetchFailure: + return ChunkFetchFailure.decode(in); + + case RpcResponse: + return RpcResponse.decode(in); + + case RpcFailure: + return RpcFailure.decode(in); + + default: + throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java new file mode 100644 index 0000000000000..069f42463a8fe --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ServerResponseEncoder extends MessageToMessageEncoder { + + private final Logger logger = LoggerFactory.getLogger(ServerResponseEncoder.class); + + @Override + public void encode(ChannelHandlerContext ctx, ServerResponse in, List out) { + Object body = null; + long bodyLength = 0; + + // Only ChunkFetchSuccesses have data besides the header. + // The body is used in order to enable zero-copy transfer for the payload. + if (in instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + try { + bodyLength = resp.buffer.size(); + body = resp.buffer.convertToNetty(); + } catch (Exception e) { + // Re-encode this message as BlockFetchFailure. + logger.error(String.format("Error opening block %s for client %s", + resp.streamChunkId, ctx.channel().remoteAddress()), e); + encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + return; + } + } + + ServerResponse.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + bodyLength; + ByteBuf header = ctx.alloc().buffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + out.add(header); + if (body != null && bodyLength > 0) { + out.add(body); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java new file mode 100644 index 0000000000000..04814d9a88c4a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Iterator; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * StreamManager which allows registration of an Iterator, which are individually + * fetched as chunks by the client. + */ +public class DefaultStreamManager extends StreamManager { + private final AtomicLong nextStreamId; + private final Map streams; + + /** State of a single stream. */ + private static class StreamState { + final Iterator buffers; + + int curChunk = 0; + + StreamState(Iterator buffers) { + this.buffers = buffers; + } + } + + public DefaultStreamManager() { + // Start with a random stream id to help identifying different streams. + nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); + streams = new ConcurrentHashMap(); + } + + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + StreamState state = streams.get(streamId); + if (chunkIndex != state.curChunk) { + throw new IllegalStateException(String.format( + "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk)); + } else if (!state.buffers.hasNext()) { + throw new IllegalStateException(String.format( + "Requested chunk index beyond end %s", chunkIndex)); + } + state.curChunk += 1; + return state.buffers.next(); + } + + @Override + public void connectionTerminated(long streamId) { + // Release all remaining buffers. + StreamState state = streams.remove(streamId); + if (state != null && state.buffers != null) { + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } + } + } + + public long registerStream(Iterator buffers) { + long myStreamId = nextStreamId.getAndIncrement(); + streams.put(myStreamId, new StreamState(buffers)); + return myStreamId; + } + + public void unregisterStream(long streamId) { + streams.remove(streamId); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java new file mode 100644 index 0000000000000..abfbe66d875e8 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.client.RpcResponseCallback; + +/** + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.SluiceClient}s. + */ +public interface RpcHandler { + /** + * Receive a single RPC message. Any exception thrown while in this method will be sent back to + * the client in string form as a standard RPC failure. + */ + void receive(byte[] message, RpcResponseCallback callback); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java new file mode 100644 index 0000000000000..aa81271024156 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.io.Closeable; +import java.net.InetSocketAddress; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.request.ClientRequestDecoder; +import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Server for the efficient, low-level streaming service. + */ +public class SluiceServer implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceServer.class); + + private final SluiceConfig conf; + private final StreamManager streamManager; + private final RpcHandler rpcHandler; + + private ServerBootstrap bootstrap; + private ChannelFuture channelFuture; + private int port; + + public SluiceServer(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { + this.conf = conf; + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + + init(); + } + + public int getPort() { return port; } + + private void init() { + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + EventLoopGroup bossGroup = + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + EventLoopGroup workerGroup = bossGroup; + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NettyUtils.getServerChannelClass(ioMode)) + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + + if (conf.backLog() > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); + } + + if (conf.receiveBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf()); + } + + if (conf.sendBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf()); + } + + bootstrap.childHandler(new ChannelInitializer() { + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("clientRequestDecoder", new ClientRequestDecoder()) + .addLast("serverResponseEncoder", new ServerResponseEncoder()) + // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this + // would require more logic to guarantee if this were not part of the same event loop. + .addLast("handler", new SluiceServerHandler(streamManager, rpcHandler)); + } + }); + + channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort())); + channelFuture.syncUninterruptibly(); + + port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + logger.debug("Shuffle server started on port :" + port); + } + + @Override + public void close() { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly(); + channelFuture = null; + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully(); + } + bootstrap = null; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java new file mode 100644 index 0000000000000..fad72fbfc711b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Set; + +import com.google.common.base.Throwables; +import com.google.common.collect.Sets; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.ClientRequest; +import org.apache.spark.network.protocol.request.RpcRequest; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; + +/** + * A handler that processes requests from clients and writes chunk data back. Each handler keeps + * track of which streams have been fetched via this channel, in order to clean them up if the + * channel is terminated (see #channelUnregistered). + * + * The messages should have been processed by the pipeline setup by {@link SluiceServer}. + */ +public class SluiceServerHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceServerHandler.class); + + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + + /** Handles all RPC messages. */ + private final RpcHandler rpcHandler; + + /** List of all stream ids that have been read on this handler, used for cleanup. */ + private final Set streamIds; + + public SluiceServerHandler(StreamManager streamManager, RpcHandler rpcHandler) { + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + this.streamIds = Sets.newHashSet(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.error("Exception in connection from " + ctx.channel().remoteAddress(), cause); + ctx.close(); + super.exceptionCaught(ctx, cause); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + super.channelUnregistered(ctx); + // Inform the StreamManager that these streams will no longer be read from. + for (long streamId : streamIds) { + streamManager.connectionTerminated(streamId); + } + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, ClientRequest request) { + if (request instanceof ChunkFetchRequest) { + processFetchRequest(ctx, (ChunkFetchRequest) request); + } else if (request instanceof RpcRequest) { + processRpcRequest(ctx, (RpcRequest) request); + } else { + throw new IllegalArgumentException("Unknown request type: " + request); + } + } + + private void processFetchRequest(final ChannelHandlerContext ctx, final ChunkFetchRequest req) { + final String client = ctx.channel().remoteAddress().toString(); + streamIds.add(req.streamChunkId.streamId); + + logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + + ManagedBuffer buf; + try { + buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format( + "Error opening block %s for request from %s", req.streamChunkId, client), e); + respond(ctx, new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(ctx, new ChunkFetchSuccess(req.streamChunkId, buf)); + } + + private void processRpcRequest(final ChannelHandlerContext ctx, final RpcRequest req) { + try { + rpcHandler.receive(req.message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + respond(ctx, new RpcResponse(req.tag, response)); + } + + @Override + public void onFailure(Throwable e) { + respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC tag " + req.tag, e); + respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + } + } + + /** + * Responds to a single message with some Encodable object. If a failure occurs while sending, + * it will be logged and the channel closed. + */ + private void respond(final ChannelHandlerContext ctx, final Encodable result) { + final String remoteAddress = ctx.channel().remoteAddress().toString(); + ctx.writeAndFlush(result).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + ctx.close(); + } + } + } + ); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java new file mode 100644 index 0000000000000..2e07f5a270cb9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * The StreamManager is used to fetch individual chunks from a stream. This is used in + * {@link SluiceServerHandler} in order to respond to fetchChunk() requests. Creation of the + * stream is outside the scope of Sluice, but a given stream is guaranteed to be read by only one + * client connection, meaning that getChunk() for a particular stream will be called serially and + * that once the connection associated with the stream is closed, that stream will never be used + * again. + */ +public abstract class StreamManager { + /** + * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the + * client. A single stream will be associated with a single TCP connection, so this method + * will not be called in parallel for a particular stream. + * + * Chunks may be requested in any order, and requests may be repeated, but it is not required + * that implementations support this behavior. + * + * The returned ManagedBuffer will be release()'d after being written to the network. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param chunkIndex 0-indexed chunk of the stream that's requested + */ + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + + /** + * Indicates that the TCP connection that was tied to the given stream has been terminated. After + * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned + * up. + */ + public void connectionTerminated(long streamId) { } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java new file mode 100644 index 0000000000000..2dc0e248ae835 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** + * Provides a mechanism for constructing a {@link SluiceConfig} using some sort of configuration. + */ +public abstract class ConfigProvider { + /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ + public abstract String get(String name); + + public String get(String name, String defaultValue) { + try { + return get(name); + } catch (NoSuchElementException e) { + return defaultValue; + } + } + + public int getInt(String name, int defaultValue) { + return Integer.parseInt(get(name, Integer.toString(defaultValue))); + } + + public long getLong(String name, long defaultValue) { + return Long.parseLong(get(name, Long.toString(defaultValue))); + } + + public double getDouble(String name, double defaultValue) { + return Double.parseDouble(get(name, Double.toString(defaultValue))); + } + + public boolean getBoolean(String name, boolean defaultValue) { + return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java new file mode 100644 index 0000000000000..cef88c0091eff --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** Uses System properties to obtain config values. */ +public class DefaultConfigProvider extends ConfigProvider { + @Override + public String get(String name) { + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java new file mode 100644 index 0000000000000..91cb3e0e6f8f6 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * Selector for which form of low-level IO we should use. + * NIO is always available, while EPOLL is only available on certain machines. + * AUTO is used to select EPOLL if it's available, or NIO otherwise. + */ +public enum IOMode { + NIO, EPOLL, AUTO +} diff --git a/core/src/main/scala/org/apache/spark/network/exceptions.scala b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 65% rename from core/src/main/scala/org/apache/spark/network/exceptions.scala rename to network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index d918d358c4adb..fafdcad04aeb6 100644 --- a/core/src/main/scala/org/apache/spark/network/exceptions.scala +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.util; -class BlockFetchFailureException(blockId: String, errorMsg: String, cause: Throwable) - extends Exception(errorMsg, cause) { +import java.io.Closeable; - def this(blockId: String, errorMsg: String) = this(blockId, errorMsg, null) -} - - -class BlockUploadFailureException(blockId: String, cause: Throwable) - extends Exception(s"Failed to fetch block $blockId", cause) { +import com.google.common.io.Closeables; - def this(blockId: String) = this(blockId, null) +public class JavaUtils { + /** Closes the given object, ignoring IOExceptions. */ + @SuppressWarnings("deprecation") + public static void closeQuietly(Closeable closable) { + Closeables.closeQuietly(closable); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java new file mode 100644 index 0000000000000..3d20dc9e1c1cd --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.concurrent.ThreadFactory; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.channel.Channel; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +/** + * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. + */ +public class NettyUtils { + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(threadPrefix + "-%d") + .build(); + + switch(mode) { + case NIO: + return new NioEventLoopGroup(numThreads, threadFactory); + case EPOLL: + return new EpollEventLoopGroup(numThreads, threadFactory); + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct (client) SocketChannel class based on IOMode. */ + public static Class getClientChannelClass(IOMode mode) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + switch(mode) { + case NIO: + return NioSocketChannel.class; + case EPOLL: + return EpollSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct ServerSocketChannel class based on IOMode. */ + public static Class getServerChannelClass(IOMode mode) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + switch(mode) { + case NIO: + return NioServerSocketChannel.class; + case EPOLL: + return EpollServerSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** + * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. + * This is used before all decoders. + */ + public static ByteToMessageDecoder createFrameDecoder() { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + } + + /** Returns EPOLL if it's available on this system, NIO otherwise. */ + private static IOMode autoselectMode() { + return Epoll.isAvailable() ? IOMode.EPOLL : IOMode.NIO; + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java similarity index 58% rename from core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala rename to network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java index 7c3074e939794..26fa3229c4721 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java @@ -15,35 +15,37 @@ * limitations under the License. */ -package org.apache.spark.network.netty - -import org.apache.spark.SparkConf +package org.apache.spark.network.util; /** - * A central location that tracks all the settings we exposed to users. + * A central location that tracks all the settings we expose to users. */ -private[spark] -class NettyConfig(conf: SparkConf) { +public class SluiceConfig { + private final ConfigProvider conf; + + public SluiceConfig(ConfigProvider conf) { + this.conf = conf; + } /** Port the server listens on. Default to a random port. */ - private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) + public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } - /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ - private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase + /** IO mode: nio, epoll, or auto (try epoll first and then nio). */ + public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } /** Connect timeout in secs. Default 120 secs. */ - private[netty] val connectTimeoutMs = { - conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000 + public int connectionTimeoutMs() { + return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; } - /** Requested maximum length of the queue of incoming connections. */ - private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ + public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - private[netty] val serverThreads: Int = conf.getInt("spark.shuffle.io.serverThreads", 0) + public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - private[netty] val clientThreads: Int = conf.getInt("spark.shuffle.io.clientThreads", 0) + public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } /** * Receive buffer size (SO_RCVBUF). @@ -52,10 +54,8 @@ class NettyConfig(conf: SparkConf) { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.receiveBuffer").map(_.toInt) + public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } /** Send buffer size (SO_SNDBUF). */ - private[netty] val sendBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } } diff --git a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java new file mode 100644 index 0000000000000..d20528558cae1 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.SluiceConfig; + +public class IntegrationSuite { + static final long STREAM_ID = 1; + static final int BUFFER_CHUNK_INDEX = 0; + static final int FILE_CHUNK_INDEX = 1; + + static SluiceServer server; + static SluiceClientFactory clientFactory; + static StreamManager streamManager; + static File testFile; + + static ManagedBuffer bufferChunk; + static ManagedBuffer fileChunk; + + @BeforeClass + public static void setUp() throws Exception { + int bufSize = 100000; + final ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + bufferChunk = new NioManagedBuffer(buf); + + testFile = File.createTempFile("shuffle-test-file", "txt"); + testFile.deleteOnExit(); + RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + fp.close(); + fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + + SluiceConfig conf = new SluiceConfig(new DefaultConfigProvider()); + streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + assertEquals(STREAM_ID, streamId); + if (chunkIndex == BUFFER_CHUNK_INDEX) { + return new NioManagedBuffer(buf); + } else if (chunkIndex == FILE_CHUNK_INDEX) { + return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + } else { + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + } + } + }; + server = new SluiceServer(conf, streamManager, new NoOpRpcHandler()); + clientFactory = new SluiceClientFactory(conf); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + testFile.delete(); + } + + class FetchResult { + public Set successChunks; + public Set failedChunks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + private FetchResult fetchChunks(List chunkIndices) throws Exception { + SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final FetchResult res = new FetchResult(); + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void fetchBufferChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchFileChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchBothChunks() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + + @Test + public void fetchChunkAndNonExistent() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), list1.get(i)); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java new file mode 100644 index 0000000000000..af35709319957 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -0,0 +1,26 @@ +package org.apache.spark.network;/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.server.RpcHandler; + +public class NoOpRpcHandler implements RpcHandler { + @Override + public void receive(byte[] message, RpcResponseCallback callback) { + callback.onSuccess(new byte[0]); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java new file mode 100644 index 0000000000000..cf74a9d8993fe --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.ClientRequest; +import org.apache.spark.network.protocol.request.ClientRequestDecoder; +import org.apache.spark.network.protocol.request.ClientRequestEncoder; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.ServerResponse; +import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.util.NettyUtils; + +public class ProtocolSuite { + private void testServerToClient(ServerResponse msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new ServerResponseEncoder()); + serverChannel.writeOutbound(msg); + + EmbeddedChannel clientChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new ServerResponseDecoder()); + + while (!serverChannel.outboundMessages().isEmpty()) { + clientChannel.writeInbound(serverChannel.readOutbound()); + } + + assertEquals(1, clientChannel.inboundMessages().size()); + assertEquals(msg, clientChannel.readInbound()); + } + + private void testClientToServer(ClientRequest msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new ClientRequestEncoder()); + clientChannel.writeOutbound(msg); + + EmbeddedChannel serverChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new ClientRequestDecoder()); + + while (!clientChannel.outboundMessages().isEmpty()) { + serverChannel.writeInbound(clientChannel.readOutbound()); + } + + assertEquals(1, serverChannel.inboundMessages().size()); + assertEquals(msg, serverChannel.readInbound()); + } + + @Test + public void s2cChunkFetchSuccess() { + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + } + + @Test + public void s2cBlockFetchFailure() { + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); + } + + @Test + public void c2sChunkFetchRequest() { + testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java new file mode 100644 index 0000000000000..e6b59b9ad8e5c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.concurrent.TimeoutException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SluiceConfig; + +public class SluiceClientFactorySuite { + private SluiceConfig conf; + private SluiceServer server1; + private SluiceServer server2; + + @Before + public void setUp() { + conf = new SluiceConfig(new DefaultConfigProvider()); + StreamManager streamManager = new DefaultStreamManager(); + RpcHandler rpcHandler = new NoOpRpcHandler(); + server1 = new SluiceServer(conf, streamManager, rpcHandler); + server2 = new SluiceServer(conf, streamManager, rpcHandler); + } + + @After + public void tearDown() { + JavaUtils.closeQuietly(server1); + JavaUtils.closeQuietly(server2); + } + + @Test + public void createAndReuseBlockClients() throws TimeoutException { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c3.isActive()); + assertTrue(c1 == c2); + assertTrue(c1 != c3); + factory.close(); + } + + @Test + public void neverReturnInactiveClients() throws Exception { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + c1.close(); + + long start = System.currentTimeMillis(); + while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertFalse(c1 == c2); + assertTrue(c2.isActive()); + factory.close(); + } + + @Test + public void closeBlockClientsWithFactory() throws TimeoutException { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c2.isActive()); + factory.close(); + assertFalse(c1.isActive()); + assertFalse(c2.isActive()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java new file mode 100644 index 0000000000000..cab0597fb948a --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.SluiceClientHandler; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; + +public class SluiceClientHandlerSuite { + @Test + public void handleSuccessfulFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } + + @Test + public void handleFailedFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + channel.writeInbound(new ChunkFetchFailure(streamChunkId, "some error msg")); + verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } + + @Test + public void clearAllOutstandingRequests() { + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(new StreamChunkId(1, 0), callback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback); + assertEquals(3, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + channel.pipeline().fireExceptionCaught(new Exception("duh duh duhhhh")); + + // should fail both b2 and b3 + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java new file mode 100644 index 0000000000000..7e7554af70f42 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Preconditions; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +public class TestManagedBuffer extends ManagedBuffer { + + private final int len; + private NettyManagedBuffer underlying; + + public TestManagedBuffer(int len) { + Preconditions.checkArgument(len <= Byte.MAX_VALUE); + this.len = len; + byte[] byteArray = new byte[len]; + for (int i = 0; i < len; i ++) { + byteArray[i] = (byte) i; + } + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); + } + + + @Override + public long size() { + return underlying.size(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return underlying.nioByteBuffer(); + } + + @Override + public InputStream inputStream() throws IOException { + return underlying.inputStream(); + } + + @Override + public ManagedBuffer retain() { + underlying.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + underlying.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return underlying.convertToNetty(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ManagedBuffer) { + try { + ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer(); + if (nioBuf.remaining() != len) { + return false; + } else { + for (int i = 0; i < len; i ++) { + if (nioBuf.get() != i) { + return false; + } + } + return true; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return false; + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java new file mode 100644 index 0000000000000..56a2b805f154c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.net.InetAddress; + +public class TestUtils { + public static String getLocalHost() { + try { + return InetAddress.getLocalHost().getHostAddress(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/pom.xml b/pom.xml index 7756c89b00cad..b0d39cfec1e8d 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ graphx mllib tools + network/common streaming sql/catalyst sql/core diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8a1b2d3b91327..71041e7fe1a14 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -51,8 +51,6 @@ object MimaExcludes { // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.scheduler.MapStatus"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.network.netty.PathResolver"), ProblemFilters.exclude[MissingClassProblem]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 7149dbc12a365..190373e0cb5f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -122,7 +122,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { sender: ActorRef ) { if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + throw new Exception("Register received for unexpected type " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host)