Skip to content

Commit

Permalink
Merge branch 'bc-unpersist-merge' of github.com:ignatich/incubator-sp…
Browse files Browse the repository at this point in the history
…ark into cleanup

Conflicts:
	core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
	core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
	core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
  • Loading branch information
andrewor14 committed Mar 26, 2014
2 parents 6c9dcf6 + 80dd977 commit c7ccef1
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 43 deletions.
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,13 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* If `registerBlocks` is true, workers will notify driver about blocks they create
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
}

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ import org.apache.spark._
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T

/**
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
*
* @param removeSource Whether to remove data from disk as well.
* Will cause errors if broadcast is accessed on workers afterwards
* (e.g. in case of RDD re-computation due to executor failure).
*/
def unpersist(removeSource: Boolean = false)

// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.

Expand Down Expand Up @@ -92,8 +101,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager:

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)

def isDriver = _isDriver
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
def stop(): Unit
}
49 changes: 37 additions & 12 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,24 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}

private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(blockId)
SparkEnv.get.blockManager.removeBlock(blockId)
}

if (removeSource) {
HttpBroadcast.synchronized {
HttpBroadcast.cleanupById(id)
}
}
}

def blockId = BroadcastBlockId(id)

HttpBroadcast.synchronized {
Expand All @@ -54,7 +67,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
Expand All @@ -71,8 +84,8 @@ class HttpBroadcastFactory extends BroadcastFactory {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { HttpBroadcast.stop() }
}
Expand Down Expand Up @@ -136,8 +149,10 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}

def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)

def write(id: Long, value: Any) {
val file = new File(broadcastDir, BroadcastBlockId(id).name)
val file = getFile(id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
Expand Down Expand Up @@ -183,20 +198,30 @@ private object HttpBroadcast extends Logging {
obj
}

def deleteFile(fileName: String) {
try {
new File(fileName).delete()
logInfo("Deleted broadcast file '" + fileName + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
}
}

def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
try {
iterator.remove()
new File(file.toString).delete()
logInfo("Deleted broadcast file '" + file + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
}
iterator.remove()
deleteFile(file)
}
}
}

def cleanupById(id: Long) {
val file = getFile(id).getAbsolutePath
files.internalMap.remove(file)
deleteFile(file)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,68 @@ import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.util.Utils

private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(broadcastId)
SparkEnv.get.blockManager.removeBlock(broadcastId)
}

if (!removeSource) {
//We can't tell BlockManager master to remove blocks from all nodes except driver,
//so we need to save them here in order to store them on disk later.
//This may be inefficient if blocks were already dropped to disk,
//but since unpersist is supposed to be called right after working with
//a broadcast this should not happen (and getting them from memory is cheap).
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)

for (pid <- 0 until totalBlocks) {
val pieceId = pieceBlockId(pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
}
}

for (pid <- 0 until totalBlocks) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid))
}
}

if (removeSource) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.removeBlock(metaId)
}
} else {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.dropFromMemory(metaId)
}

for (i <- 0 until totalBlocks) {
val pieceId = pieceBlockId(i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true)
}
}
arrayOfBlocks = null
}
}

def broadcastId = BroadcastBlockId(id)
private def metaId = BroadcastHelperBlockId(broadcastId, "meta")
private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid)
private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList

TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
Expand All @@ -54,7 +110,6 @@ extends Broadcast[T](id) with Logging with Serializable {
hasBlocks = tInfo.totalBlocks

// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
Expand All @@ -63,7 +118,7 @@ extends Broadcast[T](id) with Logging with Serializable {

// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
val pieceId = pieceBlockId(i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
Expand Down Expand Up @@ -93,7 +148,7 @@ extends Broadcast[T](id) with Logging with Serializable {
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)

// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
Expand All @@ -116,7 +171,6 @@ extends Broadcast[T](id) with Logging with Serializable {

def receiveBroadcast(variableID: Long): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
Expand All @@ -139,9 +193,9 @@ extends Broadcast[T](id) with Logging with Serializable {
}

// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
val recvOrder = new Random().shuffle(pieceIds)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
val pieceId = pieceBlockId(pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
Expand Down Expand Up @@ -245,8 +299,8 @@ class TorrentBroadcastFactory extends BroadcastFactory {
TorrentBroadcast.initialize(isDriver, conf)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new TorrentBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { TorrentBroadcast.stop() }
}
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ private[spark] class BlockManager(
}
}

/**
* For testing. Returns number of blocks BlockManager knows about that are in memory.
*/
def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_))

/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
Expand Down Expand Up @@ -812,6 +817,13 @@ private[spark] class BlockManager(
}

/**
* Drop a block from memory, possibly putting it on disk if applicable.
*/
def dropFromMemory(blockId: BlockId) {
memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId)
}

/**
* Remove all blocks belonging to the given RDD.
* @return The number of blocks removed.
*/
Expand Down
38 changes: 22 additions & 16 deletions core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,27 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}

/**
* Try to free up a given amount of space to store a particular block, but can fail if
* either the block is bigger than our memory or it would require replacing another block
* from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
* Drop a block from memory, possibly putting it on disk if applicable.
*/
def dropFromMemory(blockId: BlockId) {
val entry = entries.synchronized { entries.get(blockId) }
// This should never be null if called from ensureFreeSpace as only one
// thread should be dropping blocks and removing entries.
// However the check is required in other cases.
if (entry != null) {
val data = if (entry.deserialized) {
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
} else {
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
}
blockManager.dropFromMemory(blockId, data)
}
}

/**
* Tries to free up a given amount of space to store a particular block, but can fail and return
* false if either the block is bigger than our memory or it would require replacing another
* block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
* don't fit into memory that we want to avoid).
*
* Assume that a lock is held by the caller to ensure only one thread is dropping blocks.
Expand Down Expand Up @@ -254,19 +272,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - (currentMemory - selectedMemory) >= space) {
logInfo(selectedBlocks.size + " blocks selected for dropping")
for (blockId <- selectedBlocks) {
val entry = entries.synchronized { entries.get(blockId) }
// This should never be null as only one thread should be dropping
// blocks and removing entries. However the check is still here for
// future safety.
if (entry != null) {
val data = if (entry.deserialized) {
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
} else {
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
}
val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
}
dropFromMemory(blockId)
}
return ResultWithDroppedBlocks(success = true, droppedBlocks)
} else {
Expand Down
Loading

0 comments on commit c7ccef1

Please sign in to comment.