Skip to content

Commit

Permalink
Address TD's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewor14 committed Apr 1, 2014
1 parent 7ed72fb commit 5016375
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 84 deletions.
7 changes: 3 additions & 4 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,17 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

// Used for testing

private[spark] def cleanupRDD(rdd: RDD[_]) {
def cleanupRDD(rdd: RDD[_]) {
doCleanupRDD(rdd.id)
}

private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
doCleanupShuffle(shuffleDependency.shuffleId)
}

private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) {
def cleanupBroadcast[T](broadcast: Broadcast[T]) {
doCleanupBroadcast(broadcast.id)
}

}

private object ContextCleaner {
Expand Down
31 changes: 22 additions & 9 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.broadcast

import java.io.Serializable

import org.apache.spark.SparkException

/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
* cached on each machine rather than shipping a copy of it with tasks. They can be used, for
Expand Down Expand Up @@ -49,25 +51,36 @@ import java.io.Serializable
*/
abstract class Broadcast[T](val id: Long) extends Serializable {

protected var _isValid: Boolean = true

/**
* Whether this Broadcast is actually usable. This should be false once persisted state is
* removed from the driver.
*/
protected var isValid: Boolean = true
def isValid: Boolean = _isValid

def value: T

/**
* Remove all persisted state associated with this broadcast. Overriding implementations
* should set isValid to false if persisted state is also removed from the driver.
*
* @param removeFromDriver Whether to remove state from the driver.
* If true, the resulting broadcast should no longer be valid.
* Remove all persisted state associated with this broadcast on the executors. The next use
* of this broadcast on the executors will trigger a remote fetch.
*/
def unpersist(removeFromDriver: Boolean)
def unpersist()

// We cannot define abstract readObject and writeObject here due to some weird issues
// with these methods having to be 'private' in sub-classes.
/**
* Remove all persisted state associated with this broadcast on both the executors and the
* driver. Overriding implementations should set isValid to false.
*/
private[spark] def destroy()

/**
* If this broadcast is no longer valid, throw an exception.
*/
protected def assertValid() {
if (!_isValid) {
throw new SparkException("Attempted to use %s when is no longer valid!".format(toString))
}
}

override def toString = "Broadcast(" + id + ")"
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private[spark] class BroadcastManager(
val isDriver: Boolean,
conf: SparkConf,
securityManager: SecurityManager)
extends Logging with Serializable {
extends Logging {

private var initialized = false
private var broadcastFactory: BroadcastFactory = null
Expand Down Expand Up @@ -63,5 +63,4 @@ private[spark] class BroadcastManager(
def unbroadcast(id: Long, removeFromDriver: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver)
}

}
25 changes: 17 additions & 8 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

override def value = value_
def value: T = {
assertValid()
value_
}

val blockId = BroadcastBlockId(id)

Expand All @@ -45,17 +48,24 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}

/**
* Remove all persisted state associated with this HTTP broadcast.
* @param removeFromDriver Whether to remove state from the driver.
* Remove all persisted state associated with this HTTP broadcast on the executors.
*/
def unpersist() {
HttpBroadcast.unpersist(id, removeFromDriver = false)
}

/**
* Remove all persisted state associated with this HTTP Broadcast on both the executors
* and the driver.
*/
override def unpersist(removeFromDriver: Boolean) {
isValid = !removeFromDriver
HttpBroadcast.unpersist(id, removeFromDriver)
private[spark] def destroy() {
_isValid = false
HttpBroadcast.unpersist(id, removeFromDriver = true)
}

// Used by the JVM when serializing this object
private def writeObject(out: ObjectOutputStream) {
assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!")
assertValid()
out.defaultWriteObject()
}

Expand Down Expand Up @@ -231,5 +241,4 @@ private[spark] object HttpBroadcast extends Logging {
logError("Exception while deleting broadcast file: %s".format(file), e)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ import org.apache.spark.util.Utils
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

override def value = value_
def value = {
assertValid()
value_
}

val broadcastId = BroadcastBlockId(id)

Expand All @@ -47,7 +50,23 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
sendBroadcast()
}

def sendBroadcast() {
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
def unpersist() {
TorrentBroadcast.unpersist(id, removeFromDriver = false)
}

/**
* Remove all persisted state associated with this Torrent broadcast on both the executors
* and the driver.
*/
private[spark] def destroy() {
_isValid = false
TorrentBroadcast.unpersist(id, removeFromDriver = true)
}

private def sendBroadcast() {
val tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
Expand All @@ -71,18 +90,9 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
}
}

/**
* Remove all persisted state associated with this Torrent broadcast.
* @param removeFromDriver Whether to remove state from the driver.
*/
override def unpersist(removeFromDriver: Boolean) {
isValid = !removeFromDriver
TorrentBroadcast.unpersist(id, removeFromDriver)
}

// Used by the JVM when serializing this object
private def writeObject(out: ObjectOutputStream) {
assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!")
assertValid()
out.defaultWriteObject()
}

Expand Down Expand Up @@ -240,7 +250,6 @@ private[spark] object TorrentBroadcast extends Logging {
def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
}

}

private[spark] case class TorrentBlock(
Expand Down
1 change: 0 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1128,5 +1128,4 @@ abstract class RDD[T: ClassTag](
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
}

}
21 changes: 4 additions & 17 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: I
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}

// Leave field as an instance variable to avoid matching on it
private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
var field = ""
private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
}

Expand All @@ -77,19 +75,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
def name = "test_" + id
}

private[spark] object BroadcastBlockId {
def apply(broadcastId: Long, field: String) = {
val blockId = new BroadcastBlockId(broadcastId)
blockId.field = field
blockId
}
}

private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
Expand All @@ -100,10 +89,8 @@ private[spark] object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
case BROADCAST_FIELD(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ private[spark] class BlockManager(
def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
logInfo("Removing broadcast " + broadcastId)
val blocksToRemove = blockInfo.keys.collect {
case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
}
blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) }
}
Expand Down Expand Up @@ -897,7 +897,7 @@ private[spark] class BlockManager(

def shouldCompress(blockId: BlockId): Boolean = blockId match {
case ShuffleBlockId(_, _, _) => compressShuffle
case BroadcastBlockId(_) => compressBroadcast
case BroadcastBlockId(_, _) => compressBroadcast
case RDDBlockId(_, _) => compressRdds
case TempBlockId(_) => compressShuffleSpill
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply(RemoveBlock(blockId))
}

/**
* Remove all blocks belonging to the given RDD.
*/
/** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
future onFailure {
Expand All @@ -119,16 +117,12 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
}
}

/**
* Remove all blocks belonging to the given shuffle.
*/
/** Remove all blocks belonging to the given shuffle. */
def removeShuffle(shuffleId: Int) {
askDriverWithReply(RemoveShuffle(shuffleId))
}

/**
* Remove all blocks belonging to the given broadcast.
*/
/** Remove all blocks belonging to the given broadcast. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) {
askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster))
}
Expand All @@ -148,20 +142,21 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
}

/**
* Return the block's local status on all block managers, if any.
* Return the block's status on all block managers, if any.
*
* If askSlaves is true, this invokes the master to query each block manager for the most
* updated block statuses. This is useful when the master is not informed of the given block
* by all block managers.
*
* To avoid potential deadlocks, the use of Futures is necessary, because the master actor
* should not block on waiting for a block manager, which can in turn be waiting for the
* master actor for a response to a prior message.
*/
def getBlockStatus(
blockId: BlockId,
askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
val msg = GetBlockStatus(blockId, askSlaves)
/*
* To avoid potential deadlocks, the use of Futures is necessary, because the master actor
* should not block on waiting for a block manager, which can in turn be waiting for the
* master actor for a response to a prior message.
*/
val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
val result = Await.result(Future.sequence(futures), timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,21 +255,22 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}

/**
* Return the block's local status for all block managers, if any.
* Return the block's status for all block managers, if any.
*
* If askSlaves is true, the master queries each block manager for the most updated block
* statuses. This is useful when the master is not informed of the given block by all block
* managers.
*
* Rather than blocking on the block status query, master actor should simply return a
* Future to avoid potential deadlocks. This can arise if there exists a block manager
* that is also waiting for this master actor's response to a previous message.
*/
private def blockStatus(
blockId: BlockId,
askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
import context.dispatcher
val getBlockStatus = GetBlockStatus(blockId)
/*
* Rather than blocking on the block status query, master actor should simply return
* Futures to avoid potential deadlocks. This can arise if there exists a block manager
* that is also waiting for this master actor's response to a previous message.
*/
blockManagerInfo.values.map { info =>
val blockStatusFuture =
if (askSlaves) {
Expand Down
Loading

0 comments on commit 5016375

Please sign in to comment.