From e427a9eeb8d6b5def3a5ff1b766458588d8b05a9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 13 Feb 2014 19:14:31 -0800 Subject: [PATCH] Added ContextCleaner to automatically clean RDDs and shuffles when they fall out of scope. Also replaced TimeStampedHashMap to BoundedHashMaps and TimeStampedWeakValueHashMap for the necessary hashmap behavior. --- .../org/apache/spark/ContextCleaner.scala | 126 +++++++++++ .../scala/org/apache/spark/Dependency.scala | 6 + .../org/apache/spark/MapOutputTracker.scala | 64 ++++-- .../scala/org/apache/spark/SparkContext.scala | 12 +- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 10 + .../apache/spark/scheduler/ResultTask.scala | 12 +- .../spark/scheduler/ShuffleMapTask.scala | 17 +- .../spark/storage/BlockManagerMaster.scala | 15 ++ .../storage/BlockManagerMasterActor.scala | 12 + .../spark/storage/BlockManagerMessages.scala | 3 + .../storage/BlockManagerSlaveActor.scala | 3 + .../spark/storage/DiskBlockManager.scala | 5 + .../spark/storage/ShuffleBlockManager.scala | 33 ++- .../apache/spark/util/BoundedHashMap.scala | 45 ++++ .../apache/spark/util/MetadataCleaner.scala | 4 +- .../util/TimeStampedWeakValueHashMap.scala | 84 +++++++ .../spark/util/WrappedJavaHashMap.scala | 126 +++++++++++ .../apache/spark/ContextCleanerSuite.scala | 210 ++++++++++++++++++ .../apache/spark/MapOutputTrackerSuite.scala | 25 ++- .../spark/storage/DiskBlockManagerSuite.scala | 3 +- .../spark/util/WrappedJavaHashMapSuite.scala | 189 ++++++++++++++++ 22 files changed, 946 insertions(+), 60 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/ContextCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala create mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala create mode 100644 core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala create mode 100644 core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala new file mode 100644 index 0000000000000..1cc4271f8cf33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} + +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} + +import org.apache.spark.rdd.RDD + +/** Listener class used for testing when any item has been cleaned by the Cleaner class */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) +} + +/** + * Cleans RDDs and shuffle data. This should be instantiated only on the driver. + */ +private[spark] class ContextCleaner(env: SparkEnv) extends Logging { + + /** Classes to represent cleaning tasks */ + private sealed trait CleaningTask + private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask + private case class CleanShuffle(id: Int) extends CleaningTask + // TODO: add CleanBroadcast + + private val QUEUE_CAPACITY = 1000 + private val queue = new ArrayBlockingQueue[CleaningTask](QUEUE_CAPACITY) + + protected val listeners = new ArrayBuffer[CleanerListener] + with SynchronizedBuffer[CleanerListener] + + private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + + private var stopped = false + + /** Start the cleaner */ + def start() { + cleaningThread.setDaemon(true) + cleaningThread.start() + } + + /** Stop the cleaner */ + def stop() { + synchronized { stopped = true } + cleaningThread.interrupt() + } + + /** Clean all data and metadata related to a RDD, including shuffle files and metadata */ + def cleanRDD(rdd: RDD[_]) { + enqueue(CleanRDD(rdd.sparkContext, rdd.id)) + logDebug("Enqueued RDD " + rdd + " for cleaning up") + } + + def cleanShuffle(shuffleId: Int) { + enqueue(CleanShuffle(shuffleId)) + logDebug("Enqueued shuffle " + shuffleId + " for cleaning up") + } + + def attachListener(listener: CleanerListener) { + listeners += listener + } + /** Enqueue a cleaning task */ + private def enqueue(task: CleaningTask) { + queue.put(task) + } + + /** Keep cleaning RDDs and shuffle data */ + private def keepCleaning() { + try { + while (!isStopped) { + val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS)) + if (taskOpt.isDefined) { + logDebug("Got cleaning task " + taskOpt.get) + taskOpt.get match { + case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId) + case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId) + } + } + } + } catch { + case ie: java.lang.InterruptedException => + if (!isStopped) logWarning("Cleaning thread interrupted") + } + } + + /** Perform RDD cleaning */ + private def doCleanRDD(sc: SparkContext, rddId: Int) { + logDebug("Cleaning rdd "+ rddId) + sc.env.blockManager.master.removeRdd(rddId, false) + sc.persistentRdds.remove(rddId) + listeners.foreach(_.rddCleaned(rddId)) + logInfo("Cleaned rdd "+ rddId) + } + + /** Perform shuffle cleaning */ + private def doCleanShuffle(shuffleId: Int) { + logDebug("Cleaning shuffle "+ shuffleId) + mapOutputTrackerMaster.unregisterShuffle(shuffleId) + blockManager.master.removeShuffle(shuffleId) + listeners.foreach(_.shuffleCleaned(shuffleId)) + logInfo("Cleaned shuffle " + shuffleId) + } + + private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + private def blockManager = env.blockManager + + private def isStopped = synchronized { stopped } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index cc30105940d1a..dba0604ab4866 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -52,6 +52,12 @@ class ShuffleDependency[K, V]( extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + + override def finalize() { + if (rdd != null) { + rdd.sparkContext.cleaner.cleanShuffle(shuffleId) + } + } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 30d182b008930..bf291bf71bb61 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -17,19 +17,19 @@ package org.apache.spark +import scala.Some +import scala.collection.mutable.{HashSet, Map} +import scala.concurrent.Await + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashSet -import scala.concurrent.Await -import scala.concurrent.duration._ - import akka.actor._ import akka.pattern.ask import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) @@ -51,23 +51,21 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } } -private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { +private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { private val timeout = AkkaUtils.askTimeout(conf) // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + /** This HashMap needs to have different storage behavior for driver and worker */ + protected val mapStatuses: Map[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. protected var epoch: Long = 0 protected val epochLock = new java.lang.Object - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) - // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. private def askTracker(message: Any): Any = { @@ -138,8 +136,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } - } - else { + } else { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) } @@ -151,13 +148,12 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } protected def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) + mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime) } def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() - metadataCleaner.cancel() trackerActor = null } @@ -182,15 +178,42 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } +private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { + + /** + * Bounded HashMap for storing serialized statuses in the worker. This allows + * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be + * automatically repopulated by fetching them again from the driver. + */ + protected val MAX_MAP_STATUSES = 100 + protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true) +} + + private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { // Cache a serialized version of the output statuses for each shuffle to send them out faster private var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + + /** + * Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped + * only by explicit deregistering or by ttl-based cleaning (if set). Other than these two + * scenarios, nothing should be dropped from this HashMap. + */ + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() + + /** + * Bounded HashMap for storing serialized statuses in the master. This allows + * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be + * automatically repopulated by serializing the lost statuses again . + */ + protected val MAX_SERIALIZED_STATUSES = 100 + private val cachedSerializedStatuses = + new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true) def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } @@ -224,6 +247,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 @@ -260,9 +287,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) bytes } - protected override def cleanup(cleanupTime: Long) { - super.cleanup(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) + def contains(shuffleId: Int): Boolean = { + mapStatuses.contains(shuffleId) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 566472e597958..9fab2a7e0c707 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -48,8 +48,10 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType, - ClosureCleaner} +import org.apache.spark.util._ +import scala.Some +import org.apache.spark.storage.RDDInfo +import org.apache.spark.storage.StorageStatus /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -150,7 +152,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] + private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) @@ -202,6 +204,9 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() + private[spark] val cleaner = new ContextCleaner(env) + cleaner.start() + ui.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -784,6 +789,7 @@ class SparkContext( dagScheduler = null if (dagSchedulerCopy != null) { metadataCleaner.cancel() + cleaner.stop() dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ed788560e79f1..23dbe18fd2576 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -181,7 +181,7 @@ object SparkEnv extends Logging { val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { - new MapOutputTracker(conf) + new MapOutputTrackerWorker(conf) } mapOutputTracker.trackerActor = registerOrLookup( "MapOutputTracker", diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8010bb68e31dd..37168d7cd5969 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1012,6 +1012,13 @@ abstract class RDD[T: ClassTag]( checkpointData.flatMap(_.getCheckpointFile) } + def cleanup() { + sc.cleaner.cleanRDD(this) + dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId) + .foreach(sc.cleaner.cleanShuffle) + } + // ======================================================================= // Other internal methods and fields // ======================================================================= @@ -1091,4 +1098,7 @@ abstract class RDD[T: ClassTag]( new JavaRDD(this)(elementClassTag) } + override def finalize() { + cleanup() + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 28f3ba53b8425..671faf42a9278 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -21,20 +21,16 @@ import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData -import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.util.BoundedHashMap private[spark] object ResultTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf) + val MAX_CACHE_SIZE = 100 + val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index a37ead563271a..df3a7b9ee37ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,29 +17,24 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashMap + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleanerType, TimeStampedHashMap, MetadataCleaner} -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData - +import org.apache.spark.util.BoundedHashMap private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf) + val MAX_CACHE_SIZE = 100 + val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index c54e4f2664753..55d8349ea9d2c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -82,6 +82,14 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } + /** + * Check if block manager master has a block. Note that this can be used to check for only + * those blocks that are expected to be reported to block manager master. + */ + def contains(blockId: BlockId) = { + !getLocations(blockId).isEmpty + } + /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) @@ -113,6 +121,13 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo } } + /** + * Remove all blocks belonging to the given shuffle. + */ + def removeShuffle(shuffleId: Int) { + askDriverWithReply(RemoveShuffle(shuffleId)) + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 2c1a4e2f5d3a1..8b972672c8117 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -95,6 +95,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act case RemoveRdd(rddId) => sender ! removeRdd(rddId) + case RemoveShuffle(shuffleId) => + removeShuffle(shuffleId) + sender ! true + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -143,6 +147,14 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act }.toSeq) } + private def removeShuffle(shuffleId: Int) { + // Nothing to do in the BlockManagerMasterActor data structures + val removeMsg = RemoveShuffle(shuffleId) + blockManagerInfo.values.map { bm => + bm.slaveActor ! removeMsg + } + } + private def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 45f51da288548..98a3b68748ada 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -35,6 +35,9 @@ private[storage] object BlockManagerMessages { // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + // Remove all blocks belonging to a specific shuffle. + case class RemoveShuffle(shuffleId: Int) + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 3a65e55733834..eeeee07ebb722 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -36,5 +36,8 @@ class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { case RemoveRdd(rddId) => val numBlocksRemoved = blockManager.removeRdd(rddId) sender ! numBlocksRemoved + + case RemoveShuffle(shuffleId) => + blockManager.shuffleBlockManager.removeShuffle(shuffleId) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f3e1c38744d78..cdee285a1cbd4 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -90,6 +90,11 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) + /** Check if disk block manager has a block */ + def contains(blockId: BlockId): Boolean = { + getBlockLocation(blockId).file.exists() + } + /** Produces a unique block id and File suitable for intermediate results. */ def createTempBlock(): (TempBlockId, File) = { var blockId = new TempBlockId(UUID.randomUUID()) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index bb07c8cb134cc..ed03f189fb4ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,23 +169,32 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } + /** Remove all the blocks / files related to a particular shuffle */ + def removeShuffle(shuffleId: ShuffleId) { + shuffleStates.get(shuffleId) match { + case Some(state) => + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + logInfo("Deleted all files for shuffle " + shuffleId) + case None => + logInfo("Could not find files for shuffle " + shuffleId + " for deleting") + } + } + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => { - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } - } - }) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffle(shuffleId)) } } diff --git a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala new file mode 100644 index 0000000000000..0095b8a38d7b6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala @@ -0,0 +1,45 @@ +package org.apache.spark.util + +import scala.collection.mutable.{ArrayBuffer, SynchronizedMap} + +import java.util.{Collections, LinkedHashMap} +import java.util.Map.{Entry => JMapEntry} +import scala.reflect.ClassTag + +/** + * A map that bounds the number of key-value pairs present in it. It can be configured to + * drop least recently inserted or used pair. It exposes a scala.collection.mutable.Map interface + * to allow it to be a drop-in replacement of Scala HashMaps. Internally, a Java LinkedHashMap is + * used to get insert-order or access-order behavior. Note that the LinkedHashMap is not + * thread-safe and hence, it is wrapped in a Collections.synchronizedMap. + * However, getting the Java HashMap's iterator and using it can still lead to + * ConcurrentModificationExceptions. Hence, the iterator() function is overridden to copy the + * all pairs into an ArrayBuffer and then return the iterator to the ArrayBuffer. Also, + * the class apply the trait SynchronizedMap which ensures that all calls to the Scala Map API + * are synchronized. This together ensures that ConcurrentModificationException is never thrown. + * @param bound max number of key-value pairs + * @param useLRU true = least recently used/accessed will be dropped when bound is reached, + * false = earliest inserted will be dropped + */ +private[spark] class BoundedHashMap[A, B](bound: Int, useLRU: Boolean) + extends WrappedJavaHashMap[A, B, A, B] with SynchronizedMap[A, B] { + + protected[util] val internalJavaMap = Collections.synchronizedMap(new LinkedHashMap[A, B]( + bound / 8, (0.75).toFloat, useLRU) { + override protected def removeEldestEntry(eldest: JMapEntry[A, B]): Boolean = { + size() > bound + } + }) + + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + new BoundedHashMap[K1, V1](bound, useLRU) + } + + /** + * Overriding iterator to make sure that the internal Java HashMap's iterator + * is not concurrently modified. + */ + override def iterator: Iterator[(A, B)] = { + (new ArrayBuffer[(A, B)] ++= super.iterator).iterator + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index b0febe906ade3..1953e4cd2b59e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -62,8 +62,8 @@ private[spark] class MetadataCleaner( private[spark] object MetadataCleanerType extends Enumeration { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, - SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, BLOCK_MANAGER, + SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS, CLEANER = Value type MetadataCleanerType = Value diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala new file mode 100644 index 0000000000000..43848def0ffe6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -0,0 +1,84 @@ +package org.apache.spark.util + +import scala.collection.{JavaConversions, immutable} + +import java.util +import java.lang.ref.WeakReference +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.Logging + +private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { + def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) +} + + +private[spark] class TimeStampedWeakValueHashMap[A, B] + extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { + + protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { + new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() + } + + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + new TimeStampedWeakValueHashMap[K1, V1]() + } + + override def get(key: A): Option[B] = { + Option(internalJavaMap.get(key)) match { + case Some(weakValue) => + val value = weakValue.weakValue.get + if (value == null) cleanupKey(key) + Option(value) + case None => + None + } + } + + @inline override protected def externalValueToInternalValue(v: B): TimeStampedWeakValue[B] = { + new TimeStampedWeakValue(currentTime, v) + } + + @inline override protected def internalValueToExternalValue(iv: TimeStampedWeakValue[B]): B = { + iv.weakValue.get + } + + override def iterator: Iterator[(A, B)] = { + val jIterator = internalJavaMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).flatMap(kv => { + val key = kv.getKey + val value = kv.getValue.weakValue.get + if (value == null) { + cleanupKey(key) + Seq.empty + } else { + Seq((key, value)) + } + }) + } + + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime`, + * calling the supplied function on each such entry before removing. + */ + def clearOldValues(threshTime: Long, f: (A, B) => Unit = null) { + val iterator = internalJavaMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue.timestamp < threshTime) { + val value = entry.getValue.weakValue.get + if (f != null && value != null) { + f(entry.getKey, value) + } + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } + } + + private def cleanupKey(key: A) { + // TODO: Consider cleaning up keys to empty weak ref values automatically in future. + } + + private def currentTime = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala new file mode 100644 index 0000000000000..e7c66e494678b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala @@ -0,0 +1,126 @@ +package org.apache.spark.util + +import scala.collection.mutable.Map +import java.util.{Map => JMap} +import java.util.Map.{Entry => JMapEntry} +import scala.collection.{immutable, JavaConversions} +import scala.reflect.ClassTag + +/** + * Convenient wrapper class for exposing Java HashMaps as Scala Maps even if the + * exposed key-value type is different from the internal type. This allows Scala HashMaps to be + * hot replaceable with these Java HashMaps. + * + * While Java <-> Scala conversion methods exists, its hard to understand the performance + * implications and thread safety of the Scala wrapper. This class allows you to convert + * between types and applying the necessary overridden methods to take care of performance. + * + * Note that the threading behavior of an implementation of WrappedJavaHashMap is tied to that of + * the internal Java HashMap used in the implementation. Each implementation must use + * necessary traits (e.g, scala.collection.mutable.SynchronizedMap), etc. to achieve the + * desired thread safety. + * + * @tparam K External key type + * @tparam V External value type + * @tparam IK Internal key type + * @tparam IV Internal value type + */ +private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] { + + /* Methods that must be defined. */ + + /** Internal Java HashMap that is being wrapped. */ + protected[util] val internalJavaMap: JMap[IK, IV] + + /** Method to get a new instance of the internal Java HashMap. */ + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] + + /* + Methods that convert between internal and external types. These implementations + optimistically assume that the internal types are same as external types. These must + be overridden if the internal and external types are different. Otherwise there will be + runtime exceptions. + */ + + @inline protected def externalKeyToInternalKey(k: K): IK = { + k.asInstanceOf[IK] // works only if K is same or subclass of K + } + + @inline protected def externalValueToInternalValue(v: V): IV = { + v.asInstanceOf[IV] // works only if V is same or subclass of + } + + @inline protected def internalKeyToExternalKey(ik: IK): K = { + ik.asInstanceOf[K] + } + + @inline protected def internalValueToExternalValue(iv: IV): V = { + iv.asInstanceOf[V] + } + + @inline protected def internalPairToExternalPair(ip: JMapEntry[IK, IV]): (K, V) = { + (internalKeyToExternalKey(ip.getKey), internalValueToExternalValue(ip.getValue) ) + } + + /* Implicit functions to convert the types. */ + + @inline implicit private def convExtKeyToIntKey(k: K) = externalKeyToInternalKey(k) + + @inline implicit private def convExtValueToIntValue(v: V) = externalValueToInternalValue(v) + + @inline implicit private def convIntKeyToExtKey(ia: IK) = internalKeyToExternalKey(ia) + + @inline implicit private def convIntValueToExtValue(ib: IV) = internalValueToExternalValue(ib) + + @inline implicit private def convIntPairToExtPair(ip: JMapEntry[IK, IV]) = { + internalPairToExternalPair(ip) + } + + def get(key: K): Option[V] = { + Option(internalJavaMap.get(key)) + } + + def iterator: Iterator[(K, V)] = { + val jIterator = internalJavaMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => convIntPairToExtPair(kv)) + } + + def +=(kv: (K, V)): this.type = { + internalJavaMap.put(kv._1, kv._2) + this + } + + def -=(key: K): this.type = { + internalJavaMap.remove(key) + this + } + + override def + [V1 >: V](kv: (K, V1)): Map[K, V1] = { + val newMap = newInstance[K, V1]() + newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) + newMap += kv + newMap + } + + override def - (key: K): Map[K, V] = { + val newMap = newInstance[K, V]() + newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) + newMap -= key + } + + override def foreach[U](f: ((K, V)) => U) { + while(iterator.hasNext) { + f(iterator.next()) + } + } + + override def empty: Map[K, V] = newInstance[K, V]() + + override def size: Int = internalJavaMap.size + + override def filter(p: ((K, V)) => Boolean): Map[K, V] = { + newInstance[K, V]() ++= iterator.filter(p) + } + + def toMap: immutable.Map[K, V] = iterator.toMap +} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala new file mode 100644 index 0000000000000..2ec314aa632f3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -0,0 +1,210 @@ +package org.apache.spark + +import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext._ +import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} +import org.apache.spark.rdd.RDD +import scala.util.Random +import java.lang.ref.WeakReference + +class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + implicit val defaultTimeout = timeout(10000 millis) + + before { + sc = new SparkContext("local[2]", "CleanerSuite") + } + + test("cleanup RDD") { + val rdd = newRDD.persist() + rdd.count() + val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + cleaner.cleanRDD(rdd) + tester.assertCleanup + } + + test("cleanup shuffle") { + val rdd = newShuffleRDD + rdd.count() + val tester = new CleanerTester(sc, shuffleIds = Seq(0)) + cleaner.cleanShuffle(0) + tester.assertCleanup + } + + test("automatically cleanup RDD") { + var rdd = newRDD.persist() + rdd.count() + + // test that GC does not cause RDD cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + doGC() + intercept[Exception] { + preGCTester.assertCleanup(timeout(1000 millis)) + } + + // test that GC causes RDD cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + rdd = null // make RDD out of scope + doGC() + postGCTester.assertCleanup + } + + test("automatically cleanup shuffle") { + var rdd = newShuffleRDD + rdd.count() + + // test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + doGC() + intercept[Exception] { + preGCTester.assertCleanup(timeout(1000 millis)) + } + + // test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + rdd = null // make RDD out of scope, so that corresponding shuffle goes out of scope + doGC() + postGCTester.assertCleanup + } + + test("automatically cleanup RDD + shuffle") { + + def randomRDD: RDD[_] = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + val buffer = new ArrayBuffer[RDD[_]] + for (i <- 1 to 1000) { + buffer += randomRDD + } + + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds) + intercept[Exception] { + preGCTester.assertCleanup(timeout(1000 millis)) + } + + // test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds) + buffer.clear() + doGC() + postGCTester.assertCleanup + } + + def newRDD = sc.makeRDD(1 to 10) + + def newPairRDD = newRDD.map(_ -> 1) + + def newShuffleRDD = newPairRDD.reduceByKey(_ + _) + + def doGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + } + + def cleaner = sc.cleaner +} + + +/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ +class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[Int] = Nil) + extends Logging { + + val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds + val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + + val cleanerListener = new CleanerListener { + def rddCleaned(rddId: Int): Unit = { + toBeCleanedRDDIds -= rddId + logInfo("RDD "+ rddId + " cleaned") + } + + def shuffleCleaned(shuffleId: Int): Unit = { + toBeCleanedShuffleIds -= shuffleId + logInfo("Shuffle " + shuffleId + " cleaned") + } + } + + logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString) + preCleanupValidate() + sc.cleaner.attachListener(cleanerListener) + + def assertCleanup(implicit waitTimeout: Eventually.Timeout) { + try { + eventually(waitTimeout, interval(10 millis)) { + assert(isAllCleanedUp) + } + Thread.sleep(100) // to allow async cleanup actions to be completed + postCleanupValidate() + } finally { + logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString) + } + } + + private def preCleanupValidate() { + assert(rddIds.nonEmpty || shuffleIds.nonEmpty, "Nothing to cleanup") + + // Verify the RDDs have been persisted and blocks are present + assert(rddIds.forall(sc.persistentRdds.contains), + "One or more RDDs have not been persisted, cannot start cleaner test") + assert(rddIds.forall(rddId => blockManager.master.contains(rddBlockId(rddId))), + "One or more RDDs' blocks cannot be found in block manager, cannot start cleaner test") + + // Verify the shuffle ids are registered and blocks are present + assert(shuffleIds.forall(mapOutputTrackerMaster.contains), + "One or more shuffles have not been registered cannot start cleaner test") + assert(shuffleIds.forall(shuffleId => diskBlockManager.contains(shuffleBlockId(shuffleId))), + "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") + } + + private def postCleanupValidate() { + // Verify all the RDDs have been persisted + assert(rddIds.forall(!sc.persistentRdds.contains(_))) + assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) + + // Verify all the shuffle have been deregistered and cleaned up + assert(shuffleIds.forall(!mapOutputTrackerMaster.contains(_))) + assert(shuffleIds.forall(shuffleId => !diskBlockManager.contains(shuffleBlockId(shuffleId)))) + } + + private def uncleanedResourcesToString = { + s""" + |\tRDDs = ${toBeCleanedRDDIds.mkString("[", ", ", "]")} + |\tShuffles = ${toBeCleanedShuffleIds.mkString("[", ", ", "]")} + """.stripMargin + } + + private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty + + private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + + private def rddBlockId(rddId: Int) = RDDBlockId(rddId, 0) + + private def blockManager = sc.env.blockManager + + private def diskBlockManager = blockManager.diskBlockManager + + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 930c2523caf8c..7675a47552ba4 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -54,11 +54,12 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and fetch") { + test("master register shuffle and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) tracker.registerShuffle(10, 2) + assert(tracker.contains(10)) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) @@ -73,7 +74,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and unregister and fetch") { + test("master register and unregister shuffle") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTrackerMaster(conf) + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize10000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000))) + assert(tracker.contains(10)) + assert(tracker.getServerStatuses(10, 0).nonEmpty) + tracker.unregisterShuffle(10) + assert(!tracker.contains(10)) + assert(tracker.getServerStatuses(10, 0).isEmpty) + } + + test("master register shuffle and unregister mapoutput and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) @@ -105,7 +124,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 829f389460f3b..d3d22bc1d6c0e 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -58,8 +58,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assertSegmentEquals(blockId, blockId.name, 0, 10) - + assert(diskBlockManager.contains(blockId)) newFile.delete() + assert(!diskBlockManager.contains(blockId)) } test("block appending") { diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala new file mode 100644 index 0000000000000..9fc4681b524e9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -0,0 +1,189 @@ +package org.apache.spark.util + +import scala.collection.mutable.{HashMap, Map} + +import java.util + +import org.scalatest.FunSuite +import scala.util.Random +import java.lang.ref.WeakReference + +class WrappedJavaHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new HashMap[String, String]()) + + // Test a simple WrappedJavaHashMap + testMap(new TestMap[String, String]()) + + // Test BoundedHashMap + testMap(new BoundedHashMap[String, String](100, true)) + + testMapThreadSafety(new BoundedHashMap[String, String](100, true)) + + // Test TimeStampedHashMap + testMap(new TimeStampedHashMap[String, String]) + + testMapThreadSafety(new TimeStampedHashMap[String, String]) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis() + assert(map.internalMap.get("k1")._2 < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis() + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.internalMap.get("k1")._2 < threshTime1) + assert(map1.internalMap.get("k2")._2 >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + // Test TimeStampedHashMap + testMap(new TimeStampedWeakValueHashMap[String, String]) + + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]) + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String]() + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis() + assert(map.internalJavaMap.get("k1").timestamp < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + } + + + test("TimeStampedWeakValueHashMap - get not returning null when weak reference is cleared") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + + map("k1") = strongRef + assert(map("k1") === strongRef) + + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.internalJavaMap.get("k1").weakValue.get == null) + assert(map.get("k1") === None) + } + + def testMap(hashMapConstructor: => Map[String, String]) { + def newMap() = hashMapConstructor + + val name = newMap().getClass.getSimpleName + + test(name + " - basic test") { + val testMap1 = newMap() + + // put and get + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[Exception] { + testMap1("k2") // Map.apply() causes exception + } + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + val testMap2 = newMap() + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + testMap2("k1") = "v1" + + // multi remove + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // new instance + } + } + + def testMapThreadSafety(hashMapConstructor: => Map[String, String]) { + def newMap() = hashMapConstructor + + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 100).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble.toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t : Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +} + +class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { + protected[util] val internalJavaMap: util.Map[A, B] = new util.HashMap[A, B]() + + protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { + new TestMap[K1, V1] + } +} \ No newline at end of file