Skip to content

Commit

Permalink
try best to not submit tasks when the partitions are already completed
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Apr 15, 2019
1 parent 0bb716b commit 9a7b053
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,14 @@ private[spark] class DAGScheduler(

event.reason match {
case Success =>
// An earlier attempt of a stage (which is zombie) may still have running tasks. If these
// tasks complete, they still count and we can mark the corresponding partitions as
// finished. Here we notify the task scheduler to skip running tasks for the same partition,
// to save resource.
if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
}

task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
}

def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
scheduler.markPartitionCompleted(stageId, partitionId)
})
}

def stop() {
getTaskResultExecutor.shutdownNow()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ private[spark] trait TaskScheduler {
// Throw UnsupportedOperationException if the backend doesn't support kill tasks.
def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit

// Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
// and they can skip running tasks for it.
def notifyPartitionCompletion(stageId: Int, partitionId: Int)

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ private[spark] class TaskSchedulerImpl(
}
}

override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId)
}

/**
* Called to indicate that all task attempts (including speculated tasks) associated with the
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
Expand Down Expand Up @@ -870,22 +874,21 @@ private[spark] class TaskSchedulerImpl(
}

/**
* Marks the task has completed in all TaskSetManagers for the given stage.
* Marks the task has completed in the active TaskSetManager for the given stage.
*
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
* If an earlier attempt of a stage completes a task, we should ensure that the later attempts
* do not also submit those same tasks. That also means that a task completion from an earlier
* attempt can lead to the entire stage getting marked as successful.
* If an earlier zombie attempt of a stage completes a task, we can ask the later active attempt
* to skip submitting and running the task for the same partition, to save resource. That also
* means that a task completion from an earlier zombie attempt can lead to the entire stage
* getting marked as successful.
*/
private[scheduler] def markPartitionCompletedInAllTaskSets(
private[scheduler] def markPartitionCompleted(
stageId: Int,
partitionId: Int,
taskInfo: TaskInfo) = {
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
tsm.markPartitionCompleted(partitionId, taskInfo)
}
partitionId: Int) = {
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
tsm.markPartitionCompleted(partitionId)
})
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,6 @@ private[spark] class TaskSetManager(
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
" because task " + index + " has already completed successfully")
}
// There may be multiple tasksets for this stage -- we let all of them know that the partition
// was completed. This may result in some of the tasksets getting completed.
sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
Expand All @@ -819,12 +816,9 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}

private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = {
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
partitionToIndex.get(partitionId).foreach { index =>
if (!successful(index)) {
if (speculationEnabled && !isZombie) {
successfulTaskDurations.insert(taskInfo.duration)
}
tasksSuccessful += 1
successful(index) = true
if (tasksSuccessful == numTasks) {
Expand Down Expand Up @@ -1041,7 +1035,11 @@ private[spark] class TaskSetManager(
val minFinishedForSpeculation = (speculationQuantile * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)

if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
// It's possible that a task is marked as completed by the scheduler, then the size of
// `successfulTaskDurations` may not equal to `tasksSuccessful`. Here we should only count the
// tasks that are submitted by this `TaskSetManager` and are completed successfully.
val numSuccessfulTasks = successfulTaskDurations.size()
if (numSuccessfulTasks >= minFinishedForSpeculation && numSuccessfulTasks > 0) {
val time = clock.getTimeMillis()
val medianDuration = successfulTaskDurations.median
val threshold = max(speculationMultiplier * medianDuration, minTimeToSpeculation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
/** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */
val cancelledStages = new HashSet[Int]()

val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]()

val taskScheduler = new TaskScheduler() {
override def schedulingMode: SchedulingMode = SchedulingMode.FIFO
override def rootPool: Pool = new Pool("", schedulingMode, 0, 0)
Expand All @@ -156,6 +158,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
val tasks = ts.tasks.filter(_.partitionId == partitionId)
assert(tasks.length == 1)
tasksMarkedAsCompleted += tasks.head
}
}
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
Expand Down Expand Up @@ -246,6 +255,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
failure = null
sc.addSparkListener(sparkListener)
taskSets.clear()
tasksMarkedAsCompleted.clear()
cancelledStages.clear()
cacheLocations.clear()
results.clear()
Expand Down Expand Up @@ -658,6 +668,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
throw new UnsupportedOperationException
}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
throw new UnsupportedOperationException
}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorHeartbeatReceived(
Expand Down Expand Up @@ -2862,6 +2875,57 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(latch.await(10, TimeUnit.SECONDS))
}

test("Completions in zombie tasksets update status of non-zombie taskset") {
val parts = 4
val shuffleMapRdd = new MyRDD(sc, parts, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts))
val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker)
submit(reduceRdd, (0 until parts).toArray)
assert(taskSets.length == 1)

// Finish the first task of the shuffle map stage.
runEvent(makeCompletionEvent(
taskSets(0).tasks(0), Success, makeMapStatus("hostA", 4),
Seq.empty, createFakeTaskInfoWithId(0)))

// The second task of the shuffle map stage failed with FetchFailed.
runEvent(makeCompletionEvent(
taskSets(0).tasks(1),
FetchFailed(makeBlockManagerId("hostB"), shuffleDep.shuffleId, 0, 0, "ignored"),
null))

scheduler.resubmitFailedStages()
assert(taskSets.length == 2)
// The first partition has completed already, so the new attempt only need to run 3 tasks.
assert(taskSets(1).tasks.length == 3)

// Finish the first task of the second attempt of the shuffle map stage.
runEvent(makeCompletionEvent(
taskSets(1).tasks(0), Success, makeMapStatus("hostA", 4),
Seq.empty, createFakeTaskInfoWithId(0)))

// Finish the third task of the first attempt of the shuffle map stage.
runEvent(makeCompletionEvent(
taskSets(0).tasks(2), Success, makeMapStatus("hostA", 4),
Seq.empty, createFakeTaskInfoWithId(0)))
assert(tasksMarkedAsCompleted.length == 1)
assert(tasksMarkedAsCompleted.head.partitionId == 2)

// Finish the forth task of the first attempt of the shuffle map stage.
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, makeMapStatus("hostA", 4),
Seq.empty, createFakeTaskInfoWithId(0)))
assert(tasksMarkedAsCompleted.length == 2)
assert(tasksMarkedAsCompleted.last.partitionId == 3)

// Now the shuffle map stage is completed, and the next stage is submitted.
assert(taskSets.length == 3)

// Finish
complete(taskSets(2), Seq((Success, 42), (Success, 42), (Success, 42), (Success, 42)))
assertDataStructuresEmpty()
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ private class DummyTaskScheduler extends TaskScheduler {
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,110 +1121,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}
}

test("Completions in zombie tasksets update status of non-zombie taskset") {
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
val valueSer = SparkEnv.get.serializer.newInstance()

def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = {
val indexInTsm = tsm.partitionToIndex(partition)
val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result)
}

// Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt,
// two times, so we have three active task sets for one stage. (For this to really happen,
// you'd need the previous stage to also get restarted, and then succeed, in between each
// attempt, but that happens outside what we're mocking here.)
val zombieAttempts = (0 until 2).map { stageAttempt =>
val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt)
taskScheduler.submitTasks(attempt)
val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get
val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
taskScheduler.resourceOffers(offers)
assert(tsm.runningTasks === 10)
// fail attempt
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED,
FetchFailed(null, 0, 0, 0, "fetch failed"))
// the attempt is a zombie, but the tasks are still running (this could be true even if
// we actively killed those tasks, as killing is best-effort)
assert(tsm.isZombie)
assert(tsm.runningTasks === 9)
tsm
}

// we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for
// the stage, but this time with insufficient resources so not all tasks are active.

val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2)
taskScheduler.submitTasks(finalAttempt)
val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get
val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task =>
finalAttempt.tasks(task.index).partitionId
}.toSet
assert(finalTsm.runningTasks === 5)
assert(!finalTsm.isZombie)

// We simulate late completions from our zombie tasksets, corresponding to all the pending
// partitions in our final attempt. This means we're only waiting on the tasks we've already
// launched.
val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions)
finalAttemptPendingPartitions.foreach { partition =>
completeTaskSuccessfully(zombieAttempts(0), partition)
}

// If there is another resource offer, we shouldn't run anything. Though our final attempt
// used to have pending tasks, now those tasks have been completed by zombie attempts. The
// remaining tasks to compute are already active in the non-zombie attempt.
assert(
taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty)

val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted

// finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be
// marked as zombie.
// for each of the remaining tasks, find the tasksets with an active copy of the task, and
// finish the task.
remainingTasks.foreach { partition =>
val tsm = if (partition == 0) {
// we failed this task on both zombie attempts, this one is only present in the latest
// taskset
finalTsm
} else {
// should be active in every taskset. We choose a zombie taskset just to make sure that
// we transition the active taskset correctly even if the final completion comes
// from a zombie.
zombieAttempts(partition % 2)
}
completeTaskSuccessfully(tsm, partition)
}

assert(finalTsm.isZombie)

// no taskset has completed all of its tasks, so no updates to the blacklist tracker yet
verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), any())

// finally, lets complete all the tasks. We simulate failures in attempt 1, but everything
// else succeeds, to make sure we get the right updates to the blacklist in all cases.
(zombieAttempts ++ Seq(finalTsm)).foreach { tsm =>
val stageAttempt = tsm.taskSet.stageAttemptId
tsm.runningTasksSet.foreach { index =>
if (stageAttempt == 1) {
tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost)
} else {
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result)
}
}

// we update the blacklist for the stage attempts with all successful tasks. Even though
// some tasksets had failures, we still consider them all successful from a blacklisting
// perspective, as the failures weren't from a problem w/ the tasks themselves.
verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), any())
}
}

test("don't schedule for a barrier taskSet if available slots are less than pending tasks") {
val taskCpus = 2
val taskScheduler = setupSchedulerWithMaster(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(taskOption4.get.addedJars === addedJarsMidTaskSet)
}

test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") {
test("SPARK-24677: Avoid NoSuchElementException from MedianHeap") {
val conf = new SparkConf().set(config.SPECULATION_ENABLED, true)
sc = new SparkContext("local", "test", conf)
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
Expand All @@ -1386,39 +1386,19 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val dagScheduler = new FakeDAGScheduler(sc, sched)
sched.setDAGScheduler(dagScheduler)

val taskSet1 = FakeTask.createTaskSet(10)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task =>
val taskSet = FakeTask.createTaskSet(10)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
task.metrics.internalAccums
}

sched.submitTasks(taskSet1)
sched.resourceOffers(
(0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })

val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get

// fail fetch
taskSetManager1.handleFailedTask(
taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED,
FetchFailed(null, 0, 0, 0, "fetch failed"))

assert(taskSetManager1.isZombie)
assert(taskSetManager1.runningTasks === 9)

val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1)
sched.submitTasks(taskSet2)
sched.submitTasks(taskSet)
sched.resourceOffers(
(11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })

// Complete the 2 tasks and leave 8 task in running
for (id <- Set(0, 1)) {
taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
assert(sched.endedTasks(id) === Success)
}
(0 until 8).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })

val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get
assert(!taskSetManager2.successfulTaskDurations.isEmpty())
taskSetManager2.checkSpeculatableTasks(0)
val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get
assert(taskSetManager.runningTasks === 8)
taskSetManager.markPartitionCompleted(8)
taskSetManager.checkSpeculatableTasks(0)
}


Expand Down

0 comments on commit 9a7b053

Please sign in to comment.