diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e4b40a7f7b4d..5cd2caed10297 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -641,13 +641,8 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. - * - * If `registerBlocks` is true, workers will notify driver about blocks they create - * and these blocks will be dropped when `unpersist` method of the broadcast variable is called. */ - def broadcast[T](value: T, registerBlocks: Boolean = false) = { - env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks) - } + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 516e6ba4005c8..e3e1e4f29b107 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -18,9 +18,6 @@ package org.apache.spark.broadcast import java.io.Serializable -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark._ /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable @@ -53,56 +50,8 @@ import org.apache.spark._ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T - /** - * Removes all blocks of this broadcast from memory (and disk if removeSource is true). - * - * @param removeSource Whether to remove data from disk as well. - * Will cause errors if broadcast is accessed on workers afterwards - * (e.g. in case of RDD re-computation due to executor failure). - */ - def unpersist(removeSource: Boolean = false) - // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. override def toString = "Broadcast(" + id + ")" } - -private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - initialize() - - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = conf.get( - "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf, securityManager) - - initialized = true - } - } - } - - def stop() { - broadcastFactory.stop() - } - - private val nextBroadcastId = new AtomicLong(0) - - def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks) - - def isDriver = _isDriver -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 7aff8d7bb670b..0a0bb6cca336c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -28,6 +28,6 @@ import org.apache.spark.SparkConf */ trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T] + def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala new file mode 100644 index 0000000000000..746e23e81931a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark._ + +private[spark] class BroadcastManager( + val isDriver: Boolean, + conf: SparkConf, + securityManager: SecurityManager) + extends Logging with Serializable { + + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + initialize() + + // Called by SparkContext or Executor before using Broadcast + private def initialize() { + synchronized { + if (!initialized) { + val broadcastFactoryClass = + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isDriver, conf, securityManager) + + initialized = true + } + } + } + + def stop() { + broadcastFactory.stop() + } + + private val nextBroadcastId = new AtomicLong(0) + + def newBroadcast[T](value_ : T, isLocal: Boolean) = { + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 6c2413cea526a..374180e472805 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -29,24 +29,11 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def unpersist(removeSource: Boolean) { - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(blockId) - SparkEnv.get.blockManager.removeBlock(blockId) - } - - if (removeSource) { - HttpBroadcast.synchronized { - HttpBroadcast.cleanupById(id) - } - } - } - def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { @@ -67,7 +54,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -76,20 +63,6 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -/** - * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. - */ -class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = - new HttpBroadcast[T](value_, isLocal, id, registerBlocks) - - def stop() { HttpBroadcast.stop() } -} - private object HttpBroadcast extends Logging { private var initialized = false @@ -149,10 +122,8 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) - def write(id: Long, value: Any) { - val file = getFile(id) + val file = new File(broadcastDir, BroadcastBlockId(id).name) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -198,30 +169,20 @@ private object HttpBroadcast extends Logging { obj } - def deleteFile(fileName: String) { - try { - new File(fileName).delete() - logInfo("Deleted broadcast file '" + fileName + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e) - } - } - def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - iterator.remove() - deleteFile(file) + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } } } } - - def cleanupById(id: Long) { - val file = getFile(id).getAbsolutePath - files.internalMap.remove(file) - deleteFile(file) - } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala new file mode 100644 index 0000000000000..c4f0f149534a5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. + */ +class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 206765679e9ed..0828035c5d217 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -26,68 +26,12 @@ import org.apache.spark._ import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def unpersist(removeSource: Boolean) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(broadcastId) - SparkEnv.get.blockManager.removeBlock(broadcastId) - } - - if (!removeSource) { - //We can't tell BlockManager master to remove blocks from all nodes except driver, - //so we need to save them here in order to store them on disk later. - //This may be inefficient if blocks were already dropped to disk, - //but since unpersist is supposed to be called right after working with - //a broadcast this should not happen (and getting them from memory is cheap). - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - - for (pid <- 0 until totalBlocks) { - val pieceId = pieceBlockId(pid) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } - } - } - } - - for (pid <- 0 until totalBlocks) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid)) - } - } - - if (removeSource) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.removeBlock(metaId) - } - } else { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.dropFromMemory(metaId) - } - - for (i <- 0 until totalBlocks) { - val pieceId = pieceBlockId(i) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true) - } - } - arrayOfBlocks = null - } - } - def broadcastId = BroadcastBlockId(id) - private def metaId = BroadcastHelperBlockId(broadcastId, "meta") - private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid) - private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -110,6 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = tInfo.totalBlocks // Store meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -118,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = pieceBlockId(i) + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) @@ -148,7 +93,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() @@ -171,6 +116,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo def receiveBroadcast(variableID: Long): Boolean = { // Receive meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -193,9 +139,9 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } // Receive actual blocks - val recvOrder = new Random().shuffle(pieceIds) + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = pieceBlockId(pid) + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => @@ -215,8 +161,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } -private object TorrentBroadcast -extends Logging { +private object TorrentBroadcast extends Logging { private var initialized = false private var conf: SparkConf = null @@ -289,18 +234,3 @@ private[spark] case class TorrentInfo( @transient var hasBlocks = 0 } - -/** - * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. - */ -class TorrentBroadcastFactory extends BroadcastFactory { - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = - new TorrentBroadcast[T](value_, isLocal, id, registerBlocks) - - def stop() { TorrentBroadcast.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala new file mode 100644 index 0000000000000..a51c438c57717 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. + */ +class TorrentBroadcastFactory extends BroadcastFactory { + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 84c87949adae4..ca23513c4dc64 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,11 +209,6 @@ private[spark] class BlockManager( } } - /** - * For testing. Returns number of blocks BlockManager knows about that are in memory. - */ - def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_)) - /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -817,13 +812,6 @@ private[spark] class BlockManager( } /** - * Drop a block from memory, possibly putting it on disk if applicable. - */ - def dropFromMemory(blockId: BlockId) { - memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId) - } - - /** * Remove all blocks belonging to the given RDD. * @return The number of blocks removed. */ diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 7d614aa4726b2..488f1ea9628f5 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -210,27 +210,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Drop a block from memory, possibly putting it on disk if applicable. - */ - def dropFromMemory(blockId: BlockId) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null if called from ensureFreeSpace as only one - // thread should be dropping blocks and removing entries. - // However the check is required in other cases. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } - } - - /** - * Tries to free up a given amount of space to store a particular block, but can fail and return - * false if either the block is bigger than our memory or it would require replacing another - * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that + * Try to free up a given amount of space to store a particular block, but can fail if + * either the block is bigger than our memory or it would require replacing another block + * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * * Assume that a lock is held by the caller to ensure only one thread is dropping blocks. @@ -272,7 +254,19 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - dropFromMemory(blockId) + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } + } } return ResultWithDroppedBlocks(success = true, droppedBlocks) } else { diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index dad330d6513da..e022accee6d08 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -18,15 +18,9 @@ package org.apache.spark import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Millis, Span} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ -import org.scalatest.matchers.ShouldMatchers._ class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { super.afterEach() System.clearProperty("spark.broadcast.factory") @@ -88,47 +82,4 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } - def blocksExist(sc: SparkContext, numSlaves: Int) = { - val rdd = sc.parallelize(1 to numSlaves, numSlaves) - val workerBlocks = rdd.mapPartitions(_ => { - val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory() - Seq(blocks).iterator - }) - val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory() - - totalKnown > 0 - } - - def testUnpersist(bcFactory: String, removeSource: Boolean) { - test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) { - val numSlaves = 2 - System.setProperty("spark.broadcast.factory", bcFactory) - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - - assert(!blocksExist(sc, numSlaves)) - - val listBroadcast = sc.broadcast(list, true) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - - assert(blocksExist(sc, numSlaves)) - - listBroadcast.unpersist(removeSource) - - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - blocksExist(sc, numSlaves) should be (false) - } - - if (!removeSource) { - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - } - } - } - - for (removeSource <- Seq(true, false)) { - testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource) - testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource) - } }