Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2521] Broadcast RDD object (instead of sending it along with every task) #1498

Closed
wants to merge 12 commits into from
28 changes: 18 additions & 10 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
* Base class for dependencies.
*/
@DeveloperApi
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class Dependency[T] extends Serializable {
def rdd: RDD[T]
}


/**
Expand All @@ -36,41 +38,47 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
@DeveloperApi
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just FYI - this is an API breaking change... probably not a huge deal, but FYI

/**
* Get the parent partitions for a child partition.
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(partitionId: Int): Seq[Int]

override def rdd: RDD[T] = _rdd
}


/**
* :: DeveloperApi ::
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
* the RDD is transient since we don't need it on the executor side.
*
* @param _rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
@transient _rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
extends Dependency[Product2[K, V]] {

override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]

val shuffleId: Int = rdd.context.newShuffleId()
val shuffleId: Int = _rdd.context.newShuffleId()

val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, _rdd.partitions.size, this)

rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}


Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
ResultTask.clearCache()
listenerBus.stop()
eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
Expand Down
11 changes: 4 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}

Expand Down Expand Up @@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed: Boolean = {
checkpointData.map(_.isCheckpointed).getOrElse(false)
}
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)

/**
* Gets the name of the file to which this RDD was checkpointed
*/
def getCheckpointFile: Option[String] = {
checkpointData.flatMap(_.getCheckpointFile)
}
def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)

// =======================================================================
// Other internal methods and fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}
Expand All @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

private[spark] object RDDCheckpointData {
def clearTaskCaches() {
ShuffleMapTask.clearCache()
ResultTask.clearCache()
}
}
// Used for synchronization
private[spark] object RDDCheckpointData
87 changes: 63 additions & 24 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.scheduler

import java.io.{NotSerializableException, PrintWriter, StringWriter}
import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

Expand All @@ -35,6 +35,7 @@ import akka.pattern.ask
import akka.util.Timeout

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -114,6 +115,10 @@ class DAGScheduler(
private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))

// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()

private[scheduler] var eventProcessActor: ActorRef = _

private def initializeEventProcessActor() {
Expand Down Expand Up @@ -361,9 +366,6 @@ class DAGScheduler(
// data structures based on StageId
stageIdToStage -= stageId

ShuffleMapTask.removeStage(stageId)
ResultTask.removeStage(stageId)

logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
Expand Down Expand Up @@ -691,49 +693,83 @@ class DAGScheduler(
}
}


/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
var tasks = ArrayBuffer[Task[_]]()

val properties = if (jobIdToActiveJob.contains(jobId)) {
jobIdToActiveJob(stage.jobId).properties
} else {
// this stage will be assigned to "default" pool
null
}

runningStages += stage
// SparkListenerStageSubmitted should be posted before testing whether tasks are
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))

// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
// the serialized copy of the RDD and for each task we will deserialize it, which means each
// task gets a different copy of the RDD. This provides stronger isolation between tasks that
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] =
if (stage.isShuffleMap) {
closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array()
} else {
closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array()
}
taskBinary = sc.broadcast(taskBinaryBytes)
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString)
runningStages -= stage
return
case NonFatal(e) =>
abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
runningStages -= stage
return
}

if (stage.isShuffleMap) {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
val part = stage.rdd.partitions(p)
tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
}
} else {
// This is a final stage; figure out its job's missing partitions
val job = stage.resultOfJob.get
for (id <- 0 until job.numPartitions if !job.finished(id)) {
val partition = job.partitions(id)
val locs = getPreferredLocs(stage.rdd, partition)
tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
}
}

val properties = if (jobIdToActiveJob.contains(jobId)) {
jobIdToActiveJob(stage.jobId).properties
} else {
// this stage will be assigned to "default" pool
null
}

if (tasks.size > 0) {
runningStages += stage
// SparkListenerStageSubmitted should be posted before testing whether tasks are
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))

// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
// down the road, where we have several different implementations for local scheduler and
// cluster schedulers.
//
// We've already serialized RDDs and closures in taskBinary, but here we check for all other
// objects such as Partition.
try {
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
closureSerializer.serialize(tasks.head)
} catch {
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString)
Expand All @@ -752,6 +788,9 @@ class DAGScheduler(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stage.info.submissionTime = Some(clock.getTime())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
listenerBus.post(SparkListenerStageCompleted(stage.info))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
runningStages -= stage
Expand Down
Loading