Skip to content

Commit

Permalink
Merge pull request #1 from andrewor14/cleanup
Browse files Browse the repository at this point in the history
I am merging this. I will take one more detailed look in the context of my original changes in the main PR.
  • Loading branch information
tdas committed Apr 2, 2014
2 parents 7edbc98 + f0aabb1 commit 762a4d8
Show file tree
Hide file tree
Showing 38 changed files with 1,460 additions and 928 deletions.
133 changes: 85 additions & 48 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,105 +21,106 @@ import java.lang.ref.{ReferenceQueue, WeakReference}

import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD

/** Listener class used for testing when any item has been cleaned by the Cleaner class */
private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
}
/**
* Classes that represent cleaning tasks.
*/
private sealed trait CleanupTask
private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask

/**
* Cleans RDDs and shuffle data.
* A WeakReference associated with a CleanupTask.
*
* When the referent object becomes only weakly reachable, the corresponding
* CleanupTaskWeakReference is automatically added to the given reference queue.
*/
private class CleanupTaskWeakReference(
val task: CleanupTask,
referent: AnyRef,
referenceQueue: ReferenceQueue[AnyRef])
extends WeakReference(referent, referenceQueue)

/**
* An asynchronous cleaner for RDD, shuffle, and broadcast state.
*
* This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest,
* to be processed when the associated object goes out of scope of the application. Actual
* cleanup is performed in a separate daemon thread.
*/
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Classes to represent cleaning tasks */
private sealed trait CleanupTask
private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
// TODO: add CleanBroadcast
private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
with SynchronizedBuffer[CleanupTaskWeakReference]

private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask]
with SynchronizedBuffer[WeakReferenceWithCleanupTask]
private val referenceQueue = new ReferenceQueue[AnyRef]

private val listeners = new ArrayBuffer[CleanerListener]
with SynchronizedBuffer[CleanerListener]

private val cleaningThread = new Thread() { override def run() { keepCleaning() }}

private val REF_QUEUE_POLL_TIMEOUT = 100

@volatile private var stopped = false

private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask)
extends WeakReference(referent, referenceQueue)
/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) {
listeners += listener
}

/** Start the cleaner */
/** Start the cleaner. */
def start() {
cleaningThread.setDaemon(true)
cleaningThread.setName("ContextCleaner")
cleaningThread.start()
}

/** Stop the cleaner */
/** Stop the cleaner. */
def stop() {
stopped = true
cleaningThread.interrupt()
}

/**
* Register a RDD for cleanup when it is garbage collected.
*/
/** Register a RDD for cleanup when it is garbage collected. */
def registerRDDForCleanup(rdd: RDD[_]) {
registerForCleanup(rdd, CleanRDD(rdd.id))
}

/**
* Register a shuffle dependency for cleanup when it is garbage collected.
*/
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}

/** Cleanup RDD. */
def cleanupRDD(rdd: RDD[_]) {
doCleanupRDD(rdd.id)
}

/** Cleanup shuffle. */
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
doCleanupShuffle(shuffleDependency.shuffleId)
}

/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) {
listeners += listener
/** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task)
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
}

/** Keep cleaning RDDs and shuffle data */
/** Keep cleaning RDD, shuffle, and broadcast state. */
private def keepCleaning() {
while (!isStopped) {
while (!stopped) {
try {
val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT))
.map(_.asInstanceOf[WeakReferenceWithCleanupTask])
val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
.map(_.asInstanceOf[CleanupTaskWeakReference])
reference.map(_.task).foreach { task =>
logDebug("Got cleaning task " + task)
referenceBuffer -= reference.get
task match {
case CleanRDD(rddId) => doCleanupRDD(rddId)
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId)
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId)
}
}
} catch {
case ie: InterruptedException =>
if (!isStopped) logWarning("Cleaning thread interrupted")
if (!stopped) logWarning("Cleaning thread interrupted")
case t: Throwable => logError("Error in cleaning thread", t)
}
}
Expand All @@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private def doCleanupRDD(rddId: Int) {
try {
logDebug("Cleaning RDD " + rddId)
sc.unpersistRDD(rddId, false)
sc.unpersistRDD(rddId, blocking = false)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned RDD " + rddId)
} catch {
Expand All @@ -150,10 +151,46 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}

private def mapOutputTrackerMaster =
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
/** Perform broadcast cleanup. */
private def doCleanupBroadcast(broadcastId: Long) {
try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, removeFromDriver = true)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t)
}
}

private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]

// Used for testing

def cleanupRDD(rdd: RDD[_]) {
doCleanupRDD(rdd.id)
}

def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
doCleanupShuffle(shuffleDependency.shuffleId)
}

private def isStopped = stopped
def cleanupBroadcast[T](broadcast: Broadcast[T]) {
doCleanupBroadcast(broadcast.id)
}
}

private object ContextCleaner {
private val REF_QUEUE_POLL_TIMEOUT = 100
}

/**
* Listener class used for testing when any item has been cleaned by the Cleaner class.
*/
private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
def broadcastCleaned(broadcastId: Long)
}
11 changes: 5 additions & 6 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}

/**
* Called from executors to get the server URIs and
* output sizes of the map outputs of a given shuffle
* Called from executors to get the server URIs and output sizes of the map outputs of
* a given shuffle.
*/
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
Expand Down Expand Up @@ -218,10 +218,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
private var cacheEpoch = epoch

/**
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses
* in the master, so that statuses are dropped only by explicit deregistering or
* by TTL-based cleaning (if set). Other than these two
* scenarios, nothing should be dropped from this HashMap.
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
* so that statuses are dropped only by explicit deregistering or by TTL-based cleaning (if set).
* Other than these two scenarios, nothing should be dropped from this HashMap.
*/
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()
Expand Down
8 changes: 6 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
Expand Down Expand Up @@ -230,6 +229,7 @@ class SparkContext(

private[spark] val cleaner = new ContextCleaner(this)
cleaner.start()

postEnvironmentUpdate()

/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
Expand Down Expand Up @@ -643,7 +643,11 @@ class SparkContext(
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal)
def broadcast[T](value: T) = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
cleaner.registerBroadcastForCleanup(bc)
bc
}

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ object SparkEnv extends Logging {
} else {
new MapOutputTrackerWorker(conf)
}

// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
mapOutputTracker.trackerActor = registerOrLookup(
Expand Down
67 changes: 27 additions & 40 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
package org.apache.spark.broadcast

import java.io.Serializable
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark._
import org.apache.spark.SparkException

/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
Expand Down Expand Up @@ -51,49 +50,37 @@ import org.apache.spark._
* @tparam T Type of the data contained in the broadcast variable.
*/
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T

// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.

override def toString = "Broadcast(" + id + ")"
}

private[spark]
class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
extends Logging with Serializable {

private var initialized = false
private var broadcastFactory: BroadcastFactory = null

initialize()
protected var _isValid: Boolean = true

// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass = conf.get(
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
/**
* Whether this Broadcast is actually usable. This should be false once persisted state is
* removed from the driver.
*/
def isValid: Boolean = _isValid

broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)
def value: T

initialized = true
}
/**
* Remove all persisted state associated with this broadcast on the executors. The next use
* of this broadcast on the executors will trigger a remote fetch.
*/
def unpersist()

/**
* Remove all persisted state associated with this broadcast on both the executors and the
* driver. Overriding implementations should set isValid to false.
*/
private[spark] def destroy()

/**
* If this broadcast is no longer valid, throw an exception.
*/
protected def assertValid() {
if (!_isValid) {
throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
}
}

def stop() {
broadcastFactory.stop()
}

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())

def isDriver = _isDriver
override def toString = "Broadcast(" + id + ")"
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager)
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
def unbroadcast(id: Long, removeFromDriver: Boolean)
def stop()
}
Loading

0 comments on commit 762a4d8

Please sign in to comment.