diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 89361efec44a4..4985d4202ed6b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -186,7 +186,7 @@ private[spark] object HttpBroadcast extends Logging { * and delete the associated broadcast file. */ def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) if (removeFromDriver) { val file = new File(broadcastDir, BroadcastBlockId(id).name) files.remove(file.toString) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 07ef54bb120b9..51f1592cef752 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -232,7 +232,7 @@ private[spark] object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ca23513c4dc64..3c0941e195724 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -820,10 +820,22 @@ private[spark] class BlockManager( // from RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long) { + logInfo("Removing broadcast " + broadcastId) + val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { + case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid + case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid + } + blocksToRemove.foreach { blockId => removeBlock(blockId) } + } + /** * Remove a block from both memory and disk. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ff3f22b3b092a..4579c0d959553 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -126,6 +126,13 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveShuffle(shuffleId)) } + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { + askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 646ccb7fa74f6..4cc4227fd87e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -100,6 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus removeShuffle(shuffleId) sender ! true + case RemoveBroadcast(broadcastId, removeFromDriver) => + removeBroadcast(broadcastId, removeFromDriver) + sender ! true + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -151,9 +155,15 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def removeShuffle(shuffleId: Int) { // Nothing to do in the BlockManagerMasterActor data structures val removeMsg = RemoveShuffle(shuffleId) - blockManagerInfo.values.foreach { bm => - bm.slaveActor ! removeMsg - } + blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } + } + + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + // TODO(aor): Consolidate usages of + val removeMsg = RemoveBroadcast(broadcastId) + blockManagerInfo.values + .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } + .foreach { bm => bm.slaveActor ! removeMsg } } private def removeBlockManager(blockManagerId: BlockManagerId) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 4c5b31d0abe44..3ea710ebc786e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -22,9 +22,11 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef private[storage] object BlockManagerMessages { + ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerSlave // Remove a block from the slaves that have it. This can only be used to remove @@ -37,10 +39,15 @@ private[storage] object BlockManagerMessages { // Remove all blocks belonging to a specific shuffle. case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + // Remove all blocks belonging to a specific broadcast. + case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) + extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerMaster case class RegisterBlockManager( @@ -57,8 +64,7 @@ private[storage] object BlockManagerMessages { var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { + extends ToBlockManagerMaster with Externalizable { def this() = this(null, null, null, 0, 0) // For deserialization only @@ -80,7 +86,8 @@ private[storage] object BlockManagerMessages { } object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, + def apply( + blockManagerId: BlockManagerId, blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 9a12481b7f6d5..8c2ccbe6a7e66 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -46,5 +46,8 @@ class BlockManagerSlaveActor( if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } + + case RemoveBroadcast(broadcastId, _) => + blockManager.removeBroadcast(broadcastId) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ad87fda140476..e541591ee7582 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -461,10 +461,10 @@ private[spark] object Utils extends Logging { private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - val cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached + // Check cache first. + val cached = hostPortParseResults.get(hostPort) + if (cached != null) { + return cached } val indx: Int = hostPort.lastIndexOf(':') diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 11e22145ebb88..77d9825434706 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -28,8 +28,8 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ -import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {