diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 09a60571238ea..3935c8772252e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle * Base class for dependencies. */ @DeveloperApi -abstract class Dependency[T](val rdd: RDD[T]) extends Serializable +abstract class Dependency[T] extends Serializable { + def rdd: RDD[T] +} /** @@ -36,20 +38,24 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable * partition of the child RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { +abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { /** * Get the parent partitions for a child partition. * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ def getParents(partitionId: Int): Seq[Int] + + override def rdd: RDD[T] = _rdd } /** * :: DeveloperApi :: - * Represents a dependency on the output of a shuffle stage. - * @param rdd the parent RDD + * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, + * the RDD is transient since we don't need it on the executor side. + * + * @param _rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will @@ -57,20 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { */ @DeveloperApi class ShuffleDependency[K, V, C]( - @transient rdd: RDD[_ <: Product2[K, V]], + @transient _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { + extends Dependency[Product2[K, V]] { + + override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] - val shuffleId: Int = rdd.context.newShuffleId() + val shuffleId: Int = _rdd.context.newShuffleId() - val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( - shuffleId, rdd.partitions.size, this) + val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( + shuffleId, _rdd.partitions.size, this) - rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e6addeaf04a8..fb4c86716bb8d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging { // TODO: Cache.stop()? env.stop() SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() listenerBus.stop() eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a6abc49c5359e..726b3f2bbeea7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } + def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } + def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) // ======================================================================= // Other internal methods and fields diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index c3b2a33fb54d0..f67e5f1857979 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() } logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} +// Used for synchronization +private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dc6142ab79d03..50186d097a632 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{NotSerializableException, PrintWriter, StringWriter} +import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger @@ -35,6 +35,7 @@ import akka.pattern.ask import akka.util.Timeout import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD @@ -114,6 +115,10 @@ class DAGScheduler( private val dagSchedulerActorSupervisor = env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) + // A closure serializer that we reuse. + // This is only safe because DAGScheduler runs in a single thread. + private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + private[scheduler] var eventProcessActor: ActorRef = _ private def initializeEventProcessActor() { @@ -361,9 +366,6 @@ class DAGScheduler( // data structures based on StageId stageIdToStage -= stageId - ShuffleMapTask.removeStage(stageId) - ResultTask.removeStage(stageId) - logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -691,49 +693,83 @@ class DAGScheduler( } } - /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() + + val properties = if (jobIdToActiveJob.contains(jobId)) { + jobIdToActiveJob(stage.jobId).properties + } else { + // this stage will be assigned to "default" pool + null + } + + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + + // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. + // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // the serialized copy of the RDD and for each task we will deserialize it, which means each + // task gets a different copy of the RDD. This provides stronger isolation between tasks that + // might modify state of objects referenced in their closures. This is necessary in Hadoop + // where the JobConf/Configuration object is not thread-safe. + var taskBinary: Broadcast[Array[Byte]] = null + try { + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ResultTask, serialize and broadcast (rdd, func). + val taskBinaryBytes: Array[Byte] = + if (stage.isShuffleMap) { + closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() + } else { + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) + } catch { + // In the case of a failure during serialization, abort the stage. + case e: NotSerializableException => + abortStage(stage, "Task not serializable: " + e.toString) + runningStages -= stage + return + case NonFatal(e) => + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + val part = stage.rdd.partitions(p) + tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) } } else { // This is a final stage; figure out its job's missing partitions val job = stage.resultOfJob.get for (id <- 0 until job.numPartitions if !job.finished(id)) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ResultTask(stage.id, taskBinary, part, locs, id) } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } - if (tasks.size > 0) { - runningStages += stage - // SparkListenerStageSubmitted should be posted before testing whether tasks are - // serializable. If tasks are not serializable, a SparkListenerStageCompleted event - // will be posted, which should always come after a corresponding SparkListenerStageSubmitted - // event. - listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) - // Preemptively serialize a task to make sure it can be serialized. We are catching this // exception here because it would be fairly hard to catch the non-serializable exception // down the road, where we have several different implementations for local scheduler and // cluster schedulers. + // + // We've already serialized RDDs and closures in taskBinary, but here we check for all other + // objects such as Partition. try { - SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) + closureSerializer.serialize(tasks.head) } catch { case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) @@ -752,6 +788,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.info.submissionTime = Some(clock.getTime()) } else { + // Because we posted SparkListenerStageSubmitted earlier, we should post + // SparkListenerStageCompleted here in case there are no tasks to run. + listenerBus.post(SparkListenerStageCompleted(stage.info)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) runningStages -= stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index bbf9f7388b074..d09fd7aa57642 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,134 +17,56 @@ package org.apache.spark.scheduler -import scala.language.existentials +import java.nio.ByteBuffer import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} - -private[spark] object ResultTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = - { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(func) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = - { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - (rdd, func) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD /** * A task that sends back the output to the driver application. * - * See [[org.apache.spark.scheduler.Task]] for more information. + * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rdd input to func - * @param func a function to apply on a partition of the RDD - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each + * partition of the given RDD. Once deserialized, the type should be + * (RDD[T], (TaskContext, Iterator[T]) => U). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). */ private[spark] class ResultTask[T, U]( stageId: Int, - var rdd: RDD[T], - var func: (TaskContext, Iterator[T]) => U, - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient locs: Seq[TaskLocation], - var outputId: Int) - extends Task[U](stageId, _partitionId) with Externalizable { + val outputId: Int) + extends Task[U](stageId, partition.index) with Serializable { - def this() = this(0, null, null, 0, null, 0) - - var split = if (rdd == null) null else rdd.partitions(partitionId) - - @transient private val preferredLocs: Seq[TaskLocation] = { + @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } override def runTask(context: TaskContext): U = { + // Deserialize the RDD and the func using the broadcast variables. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) try { - func(context, rdd.iterator(split, context)) + func(context, rdd.iterator(partition, context)) } finally { context.executeOnCompleteCallbacks() } } + // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ResultTask.serializeInfo( - stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeInt(outputId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) - rdd = rdd_.asInstanceOf[RDD[T]] - func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partitionId = in.readInt() - outputId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fdaf1de83f051..11255c07469d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,134 +17,55 @@ package org.apache.spark.scheduler -import scala.language.existentials - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import java.nio.ByteBuffer -import scala.collection.mutable.HashMap +import scala.language.existentials import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter -private[spark] object ShuffleMapTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] - (rdd, dep) - } - - // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) - val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - HashMap(set.toSeq: _*) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - /** - * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner - * specified in the ShuffleDependency). - * - * See [[org.apache.spark.scheduler.Task]] for more information. - * +* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner +* specified in the ShuffleDependency). +* +* See [[org.apache.spark.scheduler.Task]] for more information. +* * @param stageId id of the stage this task belongs to - * @param rdd the final RDD in this stage - * @param dep the ShuffleDependency - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * the type should be (RDD[_], ShuffleDependency[_, _, _]). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_, _, _], - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, _partitionId) - with Externalizable - with Logging { + extends Task[MapStatus](stageId, partition.index) with Logging { - protected def this() = this(0, null, null, 0, null) + /** A constructor used only in test suites. This does not require passing in an RDD. */ + def this(partitionId: Int) { + this(0, null, new Partition { override def index = 0 }, null) + } @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partitionId) - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partitionId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } - override def runTask(context: TaskContext): MapStatus = { + // Deserialize the RDD using the broadcast variable. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) return writer.stop(success = true).get } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8cbb9050f393b..912fcb1b9cf62 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -38,7 +38,7 @@ import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 13b415cccb647..ad20f9b937ac1 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark import java.lang.ref.WeakReference +import org.apache.spark.broadcast.Broadcast + +import scala.collection.mutable import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.language.postfixOps @@ -52,9 +55,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } } - test("cleanup RDD") { - val rdd = newRDD.persist() + val rdd = newRDD().persist() val collected = rdd.collect().toList val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) @@ -67,7 +69,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup shuffle") { - val (rdd, shuffleDeps) = newRDDWithShuffleDependencies + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() val collected = rdd.collect().toList val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) @@ -80,7 +82,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup broadcast") { - val broadcast = newBroadcast + val broadcast = newBroadcast() val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) // Explicit cleanup @@ -89,7 +91,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup RDD") { - var rdd = newRDD.persist() + var rdd = newRDD().persist() rdd.count() // Test that GC does not cause RDD cleanup due to a strong reference @@ -107,7 +109,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup shuffle") { - var rdd = newShuffleRDD + var rdd = newShuffleRDD() rdd.count() // Test that GC does not cause shuffle cleanup due to a strong reference @@ -125,7 +127,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup broadcast") { - var broadcast = newBroadcast + var broadcast = newBroadcast() // Test that GC does not cause broadcast cleanup due to a strong reference val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) @@ -144,11 +146,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() @@ -162,6 +164,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { @@ -175,11 +184,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val numRdds = 10 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() @@ -193,21 +202,29 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } //------ Helper functions ------ - def newRDD = sc.makeRDD(1 to 10) - def newPairRDD = newRDD.map(_ -> 1) - def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { + private def newRDD() = sc.makeRDD(1 to 10) + private def newPairRDD() = newRDD().map(_ -> 1) + private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _) + private def newBroadcast() = sc.broadcast(1 to 100) + + private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) } } - val rdd = newShuffleRDD + val rdd = newShuffleRDD() // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) @@ -216,34 +233,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo (rdd, shuffleDeps) } - def randomRdd = { + private def randomRdd() = { val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD - case 1 => newShuffleRDD - case 2 => newPairRDD.join(newPairRDD) + case 0 => newRDD() + case 1 => newShuffleRDD() + case 2 => newPairRDD.join(newPairRDD()) } if (Random.nextBoolean()) rdd.persist() rdd.count() rdd } - def randomBroadcast = { + private def randomBroadcast() = { sc.broadcast(Random.nextInt(Int.MaxValue)) } /** Run GC and make sure it actually has run */ - def runGC() { + private def runGC() { val weakRef = new WeakReference(new Object()) val startTime = System.currentTimeMillis System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. // Wait until a weak reference object has been GCed - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { System.gc() Thread.sleep(200) } } - def cleaner = sc.cleaner.get + private def cleaner = sc.cleaner.get } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6654ec2d7c656..c52080378c970 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -155,19 +155,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { override def getPartitions: Array[Partition] = Array(onlySplit) override val getDependencies = List[Dependency[_]]() override def compute(split: Partition, context: TaskContext): Iterator[Int] = { - if (shouldFail) { - throw new Exception("injected failure") - } else { - Array(1, 2, 3, 4).iterator - } + throw new Exception("injected failure") } }.cache() val thrown = intercept[Exception]{ rdd.collect() } assert(thrown.getMessage.contains("injected failure")) - shouldFail = false - assert(rdd.collect().toList === List(1, 2, 3, 4)) } test("empty RDD") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8bb5317cd2875..270f7e661045a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,31 +20,35 @@ package org.apache.spark.scheduler import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter -import org.apache.spark.LocalSparkContext -import org.apache.spark.Partition -import org.apache.spark.SparkContext -import org.apache.spark.TaskContext +import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { test("Calls executeOnCompleteCallbacks after failure") { - var completed = false + TaskContextSuite.completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => completed = true) + context.addOnCompleteCallback(() => TaskContextSuite.completed = true) sys.error("failed") } } - val func = (c: TaskContext, i: Iterator[String]) => i.next - val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val task = new ResultTask[String, String]( + 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0) } - assert(completed === true) + assert(TaskContextSuite.completed === true) } +} - case class StubPartition(val index: Int) extends Partition +private object TaskContextSuite { + @volatile var completed = false } + +private case class StubPartition(index: Int) extends Partition diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index b52f81877d557..86a271eb67000 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { + test("test LRU eviction of stages") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) @@ -66,7 +67,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - var task = new ShuffleMapTask(0, null, null, 0, null) + var task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) @@ -76,14 +77,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo = new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) .shuffleRead === 2000) @@ -91,7 +92,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated duration taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) .shuffleRead === 1000) @@ -103,7 +104,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val metrics = new TaskMetrics() val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - val task = new ShuffleMapTask(0, null, null, 0, null) + val task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) // Go through all the failure cases to make sure we are counting them as failures.