diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72716567ca99b..294a58fafc365 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -31,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.network.ConnectionManager +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.cm.CMBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} @@ -59,8 +60,8 @@ class SparkEnv ( val mapOutputTracker: MapOutputTracker, val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, + val blockTransferService: BlockTransferService, val blockManager: BlockManager, - val connectionManager: ConnectionManager, val securityManager: SecurityManager, val httpFileServer: HttpFileServer, val sparkFilesDir: String, @@ -79,6 +80,7 @@ class SparkEnv ( Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() + blockTransferService.stop() broadcastManager.stop() blockManager.stop() blockManager.master.stop() @@ -223,14 +225,14 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val blockTransferService = new CMBlockTransferService(conf, securityManager) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker, shuffleManager) - - val connectionManager = blockManager.connectionManager + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) @@ -278,8 +280,8 @@ object SparkEnv extends Logging { mapOutputTracker, shuffleManager, broadcastManager, + blockTransferService, blockManager, - connectionManager, securityManager, httpFileServer, sparkFilesDir, diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala similarity index 56% rename from core/src/main/scala/org/apache/spark/network/ReceiverTest.scala rename to core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 53a6038a9b59e..e0e91724271c8 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,21 +17,20 @@ package org.apache.spark.network -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.storage.StorageLevel -private[spark] object ReceiverTest { - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */ - val buffer = ByteBuffer.wrap("response".getBytes("utf-8")) - Some(Message.createBufferMessage(buffer, msg.id)) - }) - Thread.currentThread.join() - } -} +trait BlockDataManager { + + /** + * Interface to get local block data. + * + * @return Some(buffer) if the block exists locally, and None if it doesn't. + */ + def getBlockData(blockId: String): Option[ManagedBuffer] + /** + * Put the block locally, using the given storage level. + */ + def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala new file mode 100644 index 0000000000000..c1dfcf1c12d39 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.util.EventListener + + +/** + * Listener callback interface for [[BlockTransferService.fetchBlocks]]. + */ +trait BlockFetchingListener extends EventListener { + + /** + * Called once per successfully fetched block. + */ + def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit + + /** + * Called upon failures. + */ + def onBlockFetchFailure(exception: Exception): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala new file mode 100644 index 0000000000000..0aa4a85531fa6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import org.apache.spark.storage.StorageLevel + + +abstract class BlockTransferService { + + /** + * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch + * local blocks or put local blocks. + */ + def init(blockDataManager: BlockDataManager) + + /** + * Tear down the transfer service. + */ + def stop(): Unit + + /** + * Port number the service is listening on, available only after [[init]] is invoked. + */ + def port: Int + + /** + * Host name the service is listening on, available only after [[init]] is invoked. + */ + def hostName: String + + /** + * Fetch a sequence of blocks from a remote node, available only after [[init]] is invoked. + * + * This takes a sequence so the implementation can batch requests. + */ + def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit + + /** + * Fetch a single block from a remote node, available only after [[init]] is invoked. + * + * This is functionally equivalent to + * {{{ + * fetchBlocks(hostName, port, Seq(blockId)).iterator().next()._2 + * }}} + */ + def fetchBlock(hostName: String, port: Int, blockId: String): ManagedBuffer = { + //fetchBlocks(hostName, port, Seq(blockId)).iterator().next()._2 + null + } + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + * + * This call blocks until the upload completes, or throws an exception upon failures. + */ + def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala deleted file mode 100644 index 4894ecd41f6eb..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer - -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.io.Source - -import org.apache.spark._ - -private[spark] object ConnectionManagerTest extends Logging{ - def main(args: Array[String]) { - // - the master URL - a list slaves to run connectionTest on - // [num of tasks] - the number of parallel tasks to be initiated default is number of slave - // hosts [size of msg in MB (integer)] - the size of messages to be sent in each task, - // default is 10 [count] - how many times to run, default is 3 [await time in seconds] : - // await time (in seconds), default is 600 - if (args.length < 2) { - println("Usage: ConnectionManagerTest [num of tasks] " + - "[size of msg in MB (integer)] [count] [await time in seconds)] ") - System.exit(1) - } - - if (args(0).startsWith("local")) { - println("This runs only on a mesos cluster") - } - - val sc = new SparkContext(args(0), "ConnectionManagerTest") - val slavesFile = Source.fromFile(args(1)) - val slaves = slavesFile.mkString.split("\n") - slavesFile.close() - - /* println("Slaves") */ - /* slaves.foreach(println) */ - val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 - val count = if (args.length > 4) args(4).toInt else 3 - val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second - println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " + - "msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) - val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( - i => SparkEnv.get.connectionManager.id).collect() - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - logInfo("Received [" + msg + "] from [" + id + "]") - None - }) - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId => - { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - } - } - val results = futures.map(f => Await.result(f, awaitTime)) - val finishTime = System.currentTimeMillis - Thread.sleep(5000) - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * - 1000.0) + " MB/s" - logInfo(resultStr) - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } -} - diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala new file mode 100644 index 0000000000000..f51724593a9b6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.io.{File, FileInputStream, InputStream} +import java.nio.ByteBuffer + +import io.netty.buffer.{ByteBufInputStream, ByteBuf, Unpooled} +import io.netty.channel.DefaultFileRegion + +import org.apache.spark.storage.FileSegment +import org.apache.spark.util.ByteBufferInputStream + + +/** + * Provides a buffer abstraction that allows pooling and reuse. + */ +abstract class ManagedBuffer { + // Note that all the methods are defined with parenthesis because their implementations can + // have side effects (io operations). + + def byteBuffer(): ByteBuffer = throw new UnsupportedOperationException + + def fileSegment(): Option[FileSegment] = None + + def inputStream(): InputStream = throw new UnsupportedOperationException + + def release(): Unit = throw new UnsupportedOperationException + + def size: Long + + private[network] def toNetty(): AnyRef +} + + +/** + * A ManagedBuffer backed by a segment in a file. + */ +final class FileSegmentManagedBuffer(file: File, offset: Long, length: Long) + extends ManagedBuffer { + + override def size: Long = length + + override private[network] def toNetty(): AnyRef = { + val fileChannel = new FileInputStream(file).getChannel + new DefaultFileRegion(fileChannel, offset, length) + } +} + + +/** + * A ManagedBuffer backed by [[java.nio.ByteBuffer]]. + */ +final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { + + override def byteBuffer() = buf + + override def inputStream() = new ByteBufferInputStream(buf) + + override def size: Long = buf.remaining() + + override private[network] def toNetty(): AnyRef = Unpooled.wrappedBuffer(buf) +} + + +/** + * A ManagedBuffer backed by a Netty [[ByteBuf]]. + */ +final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { + + override def byteBuffer() = buf.nioBuffer() + + override def inputStream() = new ByteBufInputStream(buf) + + override def release(): Unit = buf.release() + + override def size: Long = buf.readableBytes() + + override private[network] def toNetty(): AnyRef = buf +} diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala deleted file mode 100644 index ea2ad104ecae1..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.Try - -private[spark] object SenderTest { - def main(args: Array[String]) { - - if (args.length < 2) { - println("Usage: SenderTest ") - System.exit(1) - } - - val targetHost = args(0) - val targetPort = args(1).toInt - val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - val conf = new SparkConf - val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val targetServer = args(0) - - val count = 100 - (0 until count).foreach(i => { - val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis - /* println("Started timer at " + startTime) */ - val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) - val responseStr: String = Try(Await.result(promise, Duration.Inf)) - .map { response => - val buffer = response.asInstanceOf[BufferMessage].buffers(0) - new String(buffer.array, "utf-8") - }.getOrElse("none") - - val finishTime = System.currentTimeMillis - val mb = size / 1024.0 / 1024.0 - val ms = finishTime - startTime - // val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms - // * 1000.0) + " MB/s" - val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + - (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr - println(resultStr) - }) - } -} - diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/storage/BlockMessage.scala rename to core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala index a2bfce7b4a0fa..107d0131efd74 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.StringBuilder +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} -import org.apache.spark.network._ +import scala.collection.mutable.{ArrayBuffer, StringBuilder} private[spark] case class GetBlock(id: BlockId) private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala rename to core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala index 973d85c0a9b3a..b0f770261c199 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - import org.apache.spark._ -import org.apache.spark.network._ +import org.apache.spark.storage.{StorageLevel, TestBlockId} + +import scala.collection.mutable.ArrayBuffer private[spark] class BlockMessageArray(var blockMessages: Seq[BlockMessage]) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/BufferMessage.scala rename to core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala index af35f1fc3e459..5f7761838ab33 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.storage.BlockManager +import scala.collection.mutable.ArrayBuffer + private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) extends Message(Message.BUFFER_MESSAGE, id_) { diff --git a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala new file mode 100644 index 0000000000000..3b61c0ee852c5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.cm + +import java.nio.ByteBuffer + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} +import org.apache.spark.network._ +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.Utils + + +/** + * A [[BlockTransferService]] implementation based on our [[ConnectionManager]]. + */ +final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService with Logging { + + private var cm: ConnectionManager = _ + + private var blockDataManager: BlockDataManager = _ + + /** + * Port number the service is listening on, available only after [[init]] is invoked. + */ + override def port: Int = cm.id.port + + /** + * Host name the service is listening on, available only after [[init]] is invoked. + */ + override def hostName: String = cm.id.host + + /** + * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch + * local blocks or put local blocks. + */ + override def init(blockDataManager: BlockDataManager): Unit = { + this.blockDataManager = blockDataManager + cm = new ConnectionManager( + conf.getInt("spark.blockManager.port", 0), + conf, + securityManager, + "Connection manager for block manager") + cm.onReceiveMessage(onBlockMessageReceive) + } + + /** + * Tear down the transfer service. + */ + override def stop(): Unit = { + if (cm != null) { + cm.stop() + } + } + + override def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit = { + + val cmId = new ConnectionManagerId(hostName, port) + val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => + BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) + }) + + val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + + // If succeeds in getting blocks from a remote connection manager, put the block in results. + future.onSuccess { case message => + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + listener.onBlockFetchFailure( + new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + } else { + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + listener.onBlockFetchSuccess( + blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + } + } + }(cm.futureExecContext) + } + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + * + * This call blocks until the upload completes, or throws an exception upon failures. + */ + override def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel) { + val msg = PutBlock(BlockId(blockId), blockData.byteBuffer(), level) + val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) + val remoteCmId = new ConnectionManagerId(hostName, port) + + // TODO: Not wait infinitely. + Await.result(cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage), + Duration.Inf) + } + + private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { + logDebug("Handling message " + msg) + msg match { + case bufferMessage: BufferMessage => + try { + logDebug("Handling as a buffer message " + bufferMessage) + val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) + logDebug("Parsed as a block message array") + val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) + Some(new BlockMessageArray(responseMessages).toBufferMessage) + } catch { + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + case otherMessage: Any => + logError("Unknown type message received: " + otherMessage) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { + blockMessage.getType match { + case BlockMessage.TYPE_PUT_BLOCK => + val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) + logDebug("Received [" + msg + "]") + putBlock(msg.id.toString, msg.data, msg.level) + None + + case BlockMessage.TYPE_GET_BLOCK => + val msg = new GetBlock(blockMessage.getId) + logDebug("Received [" + msg + "]") + val buffer = getBlock(msg.id.toString) + if (buffer == null) { + return None + } + Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) + + case _ => None + } + } + + private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + val startTimeMs = System.currentTimeMillis() + logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) + blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " with data size: " + bytes.limit) + } + + private def getBlock(blockId: String): ByteBuffer = { + val startTimeMs = System.currentTimeMillis() + logDebug("GetBlock " + blockId + " started from " + startTimeMs) + val buffer = blockDataManager.getBlockData(blockId).orNull + logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " and got buffer " + buffer) + buffer.byteBuffer() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/cm/Connection.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/Connection.scala rename to core/src/main/scala/org/apache/spark/network/cm/Connection.scala index 5285ec82c1b64..080c3e7dd42a8 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/Connection.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.net._ import java.nio._ import java.nio.channels._ -import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} - import org.apache.spark._ +import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} + private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/network/ConnectionId.scala rename to core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala index d579c165a1917..7b358a4d25988 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/ConnectionManager.scala rename to core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala index 578d806263006..f9e35fb793faa 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala @@ -15,31 +15,25 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.io.IOException +import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ -import java.net._ -import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} +import java.util.{Timer, TimerTask} -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.SynchronizedMap -import scala.collection.mutable.SynchronizedQueue +import org.apache.spark._ +import org.apache.spark.util.{SystemClock, Utils} -import scala.concurrent.{Await, ExecutionContext, Future, Promise} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps -import org.apache.spark._ -import org.apache.spark.util.{SystemClock, Utils} - private[spark] class ConnectionManager( port: Int, conf: SparkConf, @@ -904,7 +898,7 @@ private[spark] class ConnectionManager( private[spark] object ConnectionManager { - import ExecutionContext.Implicits.global + import scala.concurrent.ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala rename to core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala index 57f7586883af1..b6b2cb0db4291 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.net.InetSocketAddress diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/cm/Message.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/network/Message.scala rename to core/src/main/scala/org/apache/spark/network/cm/Message.scala index 04ea50f62918c..5b5bcc2d966e8 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/Message.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.net.InetSocketAddress import java.nio.ByteBuffer diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/network/MessageChunk.scala rename to core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala index d0f986a12bfe0..95b46cd11f6b9 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -private[network] +private[cm] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { val size = if (buffer == null) 0 else buffer.remaining diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala rename to core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala index f3ecca5f992e0..7087c7ad6c50b 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm -import java.net.InetAddress -import java.net.InetSocketAddress +import java.net.{InetAddress, InetSocketAddress} import java.nio.ByteBuffer private[spark] class MessageChunkHeader( diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/network/SecurityMessage.scala rename to core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala index 9af9e2e8e9e59..f59df06fb3d96 100644 --- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala @@ -15,15 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.StringBuilder - import org.apache.spark._ -import org.apache.spark.network._ + +import scala.collection.mutable.{ArrayBuffer, StringBuilder} /** * SecurityMessage is class that contains the connectionId and sasl token diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 87ef9bb0b43c6..dd0421a5c15ac 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,9 +27,9 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.network.cm.{PutBlock, GotBlock, GetBlock} import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ -import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 12b475658e29d..6cf9305977a3c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -21,10 +21,9 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { @@ -32,8 +31,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer, - shuffleMetrics: ShuffleReadMetrics) + serializer: Serializer) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -74,7 +72,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockTransferService, + blockManager, + blocksByAddress, + serializer, + SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 7bed97a63f0f6..88a5f1e5ddf58 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,10 +36,8 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, - readMetrics) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala deleted file mode 100644 index ca60ec78b62ee..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.LinkedBlockingQueue -import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue -import scala.util.{Failure, Success} - -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.executor.ShuffleReadMetrics -import org.apache.spark.network.BufferMessage -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils - -/** - * A block fetcher iterator interface. There are two implementations: - * - * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. - * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. - * - * Eventually we would like the two to converge and use a single NIO-based communication layer, - * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), - * NIO would perform poorly and thus the need for the Netty OIO one. - */ - -private[storage] -trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { - def initialize() -} - - -private[storage] -object BlockFetcherIterator { - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. - */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { - val size = blocks.map(_._2).sum - } - - /** - * Result of a fetch from a remote block. A failure is represented as size == -1. - * @param blockId block id - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. - */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - // TODO: Refactor this whole thing to make code more reusable. - class BasicBlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics) - extends BlockFetcherIterator { - - import blockManager._ - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - - // Total number blocks fetched (local + remote). Also number of FetchResults expected - protected var _numBlocksToFetch = 0 - - protected var startTime = System.currentTimeMillis - - // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - - // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[BlockId]() - - // A queue to hold our results. - protected val results = new LinkedBlockingQueue[FetchResult] - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - protected val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - protected var bytesInFlight = 0L - - protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onComplete { - case Success(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can - // be incrementing bytes read at the same time (SPARK-2625). - readMetrics.remoteBytesRead += networkSize - readMetrics.remoteBlocksFetched += 1 - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case Failure(exception) => { - logError("Could not get block(s) from " + cmId, exception) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) - - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - var totalBlocks = 0 - for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size - if (address == blockManagerId) { - // Filter out zero-sized blocks - localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) - _numBlocksToFetch += localBlocksToFetch.size - } else { - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(BlockId, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { - curBlocks += ((blockId, size)) - remoteBlocksToFetch += blockId - _numBlocksToFetch += 1 - curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= targetRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") - curRequestSize = 0 - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - totalBlocks + " blocks") - remoteRequests - } - - protected def getLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocksToFetch) { - try { - // getLocalFromDisk never return None but throws BlockException - val iter = getLocalFromDisk(id, serializer).get - // Pass 0 as size since it's not in flight - readMetrics.localBlocksFetched += 1 - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } catch { - case e: Exception => { - logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) - return - } - } - } - } - - override def initialize() { - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numFetches = remoteRequests.size - fetchRequests.size - logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue - // as they arrive. - @volatile protected var resultsGotten = 0 - - override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - - override def next(): (BlockId, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - } - // End of BasicBlockFetcherIterator - - class NettyBlockFetcherIterator( - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { - - override protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - - // This could throw a TimeoutException. In that case we will just retry the task. - val client = blockManager.nettyBlockClientFactory.createClient( - cmId.host, req.address.nettyPort) - val blocks = req.blocks.map(_._1.toString) - - client.fetchBlocks( - blocks, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = { - logError(s"Could not get block(s) from $cmId with error: $errorMsg") - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - - override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { - // Increment the reference count so the buffer won't be recycled. - // TODO: This could result in memory leaks when the task is stopped due to exception - // before the iterator is exhausted. - data.retain() - val buf = data.byteBuffer() - val blockSize = buf.remaining() - val bid = BlockId(blockId) - - // TODO: remove code duplication between here and BlockManager.dataDeserialization. - results.put(new FetchResult(bid, sizeMap(bid), () => { - def createIterator: Iterator[Any] = { - val stream = blockManager.wrapForCompression(bid, data.inputStream()) - serializer.newInstance().deserializeStream(stream).asIterator - } - new LazyInitIterator(createIterator) { - // Release the buffer when we are done traversing it. - override def close(): Unit = data.release() - } - })) - - readMetrics.synchronized { - readMetrics.remoteBytesRead += blockSize - readMetrics.remoteBlocksFetched += 1 - } - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - ) - } - } - // End of NettyBlockFetcherIterator -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 12a92d44f4c36..cc5d505303fc2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.storage import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} +import scala.concurrent.ExecutionContext.Implicits.global + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ @@ -32,8 +34,6 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ -import org.apache.spark.network.netty.client.BlockFetchingClientFactory -import org.apache.spark.network.netty.server.BlockServer import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ @@ -60,18 +60,16 @@ private[spark] class BlockManager( defaultSerializer: Serializer, maxMemory: Long, val conf: SparkConf, - securityManager: SecurityManager, mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager) + shuffleManager: ShuffleManager, + blockTransferService: BlockTransferService) extends BlockDataProvider with Logging { + //blockTransferService.init(this) + private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) - val connectionManager = - new ConnectionManager(port, conf, securityManager, "Connection manager for block manager") - - implicit val futureExecContext = connectionManager.futureExecContext private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -90,31 +88,8 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } - private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) - - // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = { - if (useNetty) new BlockFetchingClientFactory(conf) else null - } - - private val nettyBlockServer: BlockServer = { - if (useNetty) { - val server = new BlockServer(conf, this) - logInfo(s"Created NettyBlockServer binding to port: ${server.port}") - server - } else { - null - } - } - - private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0 - val blockManagerId = BlockManagerId( - executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) - - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory - // for receiving shuffle outputs) - val maxBytesInFlight = conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 + executorId, blockTransferService.hostName, blockTransferService.port) // Whether to compress broadcast variables that are stored private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) @@ -157,11 +132,11 @@ private[spark] class BlockManager( master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, - securityManager: SecurityManager, mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager) = { + shuffleManager: ShuffleManager, + blockTransferService: BlockTransferService) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker, shuffleManager) + conf, mapOutputTracker, shuffleManager, blockTransferService) } /** @@ -170,7 +145,6 @@ private[spark] class BlockManager( */ private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) - BlockManagerWorker.startBlockManagerWorker(this) } /** @@ -527,8 +501,8 @@ private[spark] class BlockManager( val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) + val data = blockTransferService.fetchBlock(loc.host, loc.port, blockId.toString).byteBuffer() + if (data != null) { if (asBlockResult) { return Some(new BlockResult( @@ -562,28 +536,6 @@ private[spark] class BlockManager( None } - /** - * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns - * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined - * fashion as they're received. Expects a size in bytes to be provided for each block fetched, - * so that we can control the maxMegabytesInFlight for the fetch. - */ - def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics): BlockFetcherIterator = { - val iter = - if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer, - readMetrics) - } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer, - readMetrics) - } - iter.initialize() - iter - } - def putIterator( blockId: BlockId, values: Iterator[Any], @@ -836,12 +788,15 @@ private[spark] class BlockManager( data.rewind() logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " + s"To node: $peer") - val putBlock = PutBlock(blockId, data, tLevel) - val cmId = new ConnectionManagerId(peer.host, peer.port) - val syncPutBlockSuccess = BlockManagerWorker.syncPutBlock(putBlock, cmId) - if (!syncPutBlockSuccess) { - logError(s"Failed to call syncPutBlock to $peer") + + try { + blockTransferService.uploadBlock( + peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + } catch { + case e: Exception => + logError(s"Failed to replicate block to $peer", e) } + logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes." .format(blockId, (System.nanoTime - start) / 1e6, data.limit())) } @@ -1066,40 +1021,13 @@ private[spark] class BlockManager( bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - - def getIterator: Iterator[Any] = { - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator - } - - if (blockId.isShuffle) { - /* Reducer may need to read many local shuffle blocks and will wrap them into Iterators - * at the beginning. The wrapping will cost some memory (compression instance - * initialization, etc.). Reducer reads shuffle blocks one by one so we could do the - * wrapping lazily to save memory. */ - class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] { - lazy val proxy = f - override def hasNext: Boolean = proxy.hasNext - override def next(): Any = proxy.next() - } - new LazyProxyIterator(getIterator) - } else { - getIterator - } + val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) + serializer.newInstance().deserializeStream(stream).asIterator } def stop(): Unit = { - connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() - - if (nettyBlockClientFactory != null) { - nettyBlockClientFactory.stop() - } - if (nettyBlockServer != null) { - nettyBlockServer.stop() - } - actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b1585bd8199d1..f39510160e637 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -36,11 +36,10 @@ import org.apache.spark.util.Utils class BlockManagerId private ( private var executorId_ : String, private var host_ : String, - private var port_ : Int, - private var nettyPort_ : Int - ) extends Externalizable { + private var port_ : Int) + extends Externalizable { - private def this() = this(null, null, 0, 0) // For deserialization only + private def this() = this(null, null, 0) // For deserialization only def executorId: String = executorId_ @@ -60,32 +59,29 @@ class BlockManagerId private ( def port: Int = port_ - def nettyPort: Int = nettyPort_ override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) - out.writeInt(nettyPort_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() - nettyPort_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port) - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort + executorId == id.executorId && port == id.port && host == id.host case _ => false } @@ -100,11 +96,10 @@ private[spark] object BlockManagerId { * @param execId ID of the executor. * @param host Host name of the block manager. * @param port Port of the block manager. - * @param nettyPort Optional port for the Netty-based shuffle sender. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int, nettyPort: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) + def apply(execId: String, host: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala deleted file mode 100644 index bf002a42d5dc5..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -import org.apache.spark.Logging -import org.apache.spark.network._ -import org.apache.spark.util.Utils - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.{Try, Failure, Success} - -/** - * A network interface for BlockManager. Each slave should have one - * BlockManagerWorker. - * - * TODO: Use event model. - */ -private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { - - blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) - - def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => { - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => { - logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) - } - } - } - case otherMessage: Any => { - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) - } - } - } - - def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + pB + "]") - putBlock(pB.id, pB.data, pB.level) - None - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - logDebug("Received [" + gB + "]") - val buffer = getBlock(gB.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) - } - case _ => None - } - } - - private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) - blockManager.putBytes(id, bytes, level) - logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(id: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + id + " started from " + startTimeMs) - val buffer = blockManager.getLocalBytes(id) match { - case Some(bytes) => bytes - case None => null - } - logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer - } -} - -private[spark] object BlockManagerWorker extends Logging { - private var blockManagerWorker: BlockManagerWorker = null - - def startBlockManagerWorker(manager: BlockManager) { - blockManagerWorker = new BlockManagerWorker(manager) - } - - def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromPutBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = Try(Await.result(connectionManager.sendMessageReliably( - toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) - resultMessage.isSuccess - } - - def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromGetBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = Try(Await.result(connectionManager.sendMessageReliably( - toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) - responseMessage match { - case Success(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - logDebug("Response message received " + bufferMessage) - BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { - logDebug("Found " + blockMessage) - return blockMessage.getData - }) - } - case Failure(exception) => logDebug("No response message received") - } - null - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala new file mode 100644 index 0000000000000..d4ed33eb506cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet +import scala.collection.mutable.Queue + +import org.apache.spark.{TaskContext, Logging, SparkException} +import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils + + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * @param blockManager + * @param blocksByAddress + * @param serializer + * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + */ +private[spark] +final class ShuffleBlockFetcherIterator( + context: TaskContext, + blockTransferService: BlockTransferService, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + serializer: Serializer, + maxBytesInFlight: Long) + extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + + import ShuffleBlockFetcherIterator._ + + /** + * Total number of blocks to fetch. This can be smaller than the total number of blocks + * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * + * This should equal localBlocks.size + remoteBlocks.size. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks proccessed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTime = System.currentTimeMillis + + /** Local blocks to fetch, excluding zero-sized blocks. */ + private[this] val localBlocks = new ArrayBuffer[BlockId]() + + /** Remote blocks to fetch, excluding zero-sized blocks. */ + private[this] val remoteBlocks = new HashSet[BlockId]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + private[this] val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + private[this] var bytesInFlight = 0L + + private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + + initialize() + + private[this] def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + + // so we can look up the size of each blockID + val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap + val blockIds = req.blocks.map(_._1.toString) + + blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, + new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), + () => blockManager.dataDeserialize(BlockId(blockId), data.byteBuffer(), serializer) + )) + shuffleMetrics.remoteBytesRead += data.size + shuffleMetrics.remoteBlocksFetched += 1 + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } + + override def onBlockFetchFailure(exception: Exception): Unit = { + + } + } + ) + // case Failure(exception) => { + // logError("Could not get block(s) from " + cmId, exception) + // for ((blockId, size) <- req.blocks) { + // results.put(new FetchResult(blockId, -1, null)) + // } + // } + // } + } + + private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + + // Tracks total number of blocks (including zero sized blocks) + var totalBlocks = 0 + for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size + if (address == blockManager.blockManagerId) { + // Filter out zero-sized blocks + localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + numBlocksToFetch += localBlocks.size + } else { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(BlockId, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + remoteBlocks += blockId + numBlocksToFetch += 1 + curRequestSize += size + } else if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } + if (curRequestSize >= targetRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curBlocks = new ArrayBuffer[(BlockId, Long)] + logDebug(s"Creating fetch request of $curRequestSize at $address") + curRequestSize = 0 + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + remoteRequests + } + + private[this] def fetchLocalBlocks() { + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlocks) { + try { + shuffleMetrics.localBlocksFetched += 1 + results.put(new FetchResult(id, 0, () => blockManager.getLocalFromDisk(id, serializer).get)) + logDebug("Got local block " + id) + } catch { + case e: Exception => + logError(s"Error occurred while fetching local blocks", e) + results.put(new FetchResult(id, -1, null)) + return + } + } + } + + private[this] def initialize(): Unit = { + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (fetchRequests.nonEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numFetches = remoteRequests.size - fetchRequests.size + logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + fetchLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + override def next(): (BlockId, Option[Iterator[Any]]) = { + numBlocksProcessed += 1 + val startFetchWait = System.currentTimeMillis() + val result = results.take() + val stopFetchWait = System.currentTimeMillis() + shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) + if (!result.failed) { + bytesInFlight -= result.size + } + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } +} + + +private[storage] +object ShuffleBlockFetcherIterator { + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + val size = blocks.map(_._2).sum + } + + /** + * Result of a fetch from a remote block. A failure is represented as size == -1. + * @param blockId block id + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param deserialize closure to return the result in the form of an Iterator. + */ + class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index aa83ea90ee9ee..8a836bbba274c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -102,7 +102,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf)) + new MapOutputTrackerMaster(conf), new HashShuffleManager(conf), null) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index db7384705fc1b..a7543454eca1f 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -295,8 +295,7 @@ private[spark] object JsonProtocol { def blockManagerIdToJson(blockManagerId: BlockManagerId): JValue = { ("Executor ID" -> blockManagerId.executorId) ~ ("Host" -> blockManagerId.host) ~ - ("Port" -> blockManagerId.port) ~ - ("Netty Port" -> blockManagerId.nettyPort) + ("Port" -> blockManagerId.port) } def jobResultToJson(jobResult: JobResult): JValue = { @@ -644,8 +643,7 @@ private[spark] object JsonProtocol { val executorId = (json \ "Executor ID").extract[String] val host = (json \ "Host").extract[String] val port = (json \ "Port").extract[Int] - val nettyPort = (json \ "Netty Port").extract[Int] - BlockManagerId(executorId, host, port, nettyPort) + BlockManagerId(executorId, host, port) } def jobResultFromJson(json: JValue): JobResult = { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 41c294f727b3c..5406fcc2ac839 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import org.apache.spark.network.cm.{GetBlock, BlockManagerWorker, ConnectionManagerId} import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ @@ -24,8 +25,7 @@ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} import org.apache.spark.SparkContext._ -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala similarity index 97% rename from core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala rename to core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala index e2f4d4c57cdb5..258492051173a 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala @@ -15,23 +15,17 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.io.IOException import java.nio._ -import java.util.concurrent.TimeoutException import org.apache.spark.{SecurityManager, SparkConf} import org.scalatest.FunSuite -import org.mockito.Mockito._ -import org.mockito.Matchers._ - -import scala.concurrent.TimeoutException -import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ +import scala.concurrent.{Await, TimeoutException} import scala.language.postfixOps -import scala.util.{Failure, Success, Try} /** * Test the ConnectionManager with various security settings. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index bcbfe8baf36ad..56d5907d4f2c2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -24,16 +24,17 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global -import org.scalatest.{FunSuite, Matchers} - import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.stubbing.Answer import org.mockito.invocation.InvocationOnMock -import org.apache.spark.storage.BlockFetcherIterator._ -import org.apache.spark.network.{ConnectionManager, Message} +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.network.cm._ import org.apache.spark.executor.ShuffleReadMetrics +import org.apache.spark.storage.BlockFetcherIterator._ + class BlockFetcherIteratorSuite extends FunSuite with Matchers { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f32ce6f9fcc7f..8c458b99e6bc3 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,10 +21,15 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import java.util.concurrent.TimeUnit +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import akka.actor._ import akka.pattern.ask import akka.util.Timeout -import org.apache.spark.shuffle.hash.HashShuffleManager import org.mockito.invocation.InvocationOnMock import org.mockito.Matchers.any @@ -38,17 +43,13 @@ import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.{Message, ConnectionManagerId} +import org.apache.spark.network.cm._ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.implicitConversions -import scala.language.postfixOps class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter with PrivateMethodTester {