From 80e2568b25780a7094199239da8ad6cfb6efc9f7 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 20 Jul 2015 10:28:32 -0700 Subject: [PATCH] [SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage https://issues.apache.org/jira/browse/SPARK-8103 cc kayousterhout (thanks for the extra test case) Author: Imran Rashid Author: Kay Ousterhout Author: Imran Rashid Closes #6750 from squito/SPARK-8103 and squashes the following commits: fb3acfc [Imran Rashid] fix log msg e01b7aa [Imran Rashid] fix some comments, style 584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103 6bc23af [Imran Rashid] update log msg 4470fa1 [Imran Rashid] rename c04707e [Imran Rashid] style 88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts d7f1ef2 [Imran Rashid] get rid of activeTaskSets a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103 906d626 [Imran Rashid] fix merge 109900e [Imran Rashid] Merge branch 'master' into SPARK-8103 c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id" f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103 baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id 19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments a5f7c8c [Imran Rashid] remove comment for reviewers 227b40d [Imran Rashid] style 517b6e5 [Imran Rashid] get rid of SparkIllegalStateException b2faef5 [Imran Rashid] faster check for conflicting task sets 6542b42 [Imran Rashid] remove extra stageAttemptId ada7726 [Imran Rashid] reviewer feedback d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103 46bc26a [Imran Rashid] more cleanup of debug garbage cb245da [Imran Rashid] finally found the issue ... clean up debug stuff 8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103 89a59b6 [Imran Rashid] more printlns ... 9601b47 [Imran Rashid] more debug printlns ecb4e7d [Imran Rashid] debugging printlns b6bc248 [Imran Rashid] style 55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer 7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean 883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue 6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts 06a0af6 [Imran Rashid] ignore for jenkins c443def [Imran Rashid] better fix and simpler test case 28d70aa [Imran Rashid] wip on getting a better test case ... a9bf31f [Imran Rashid] wip --- .../apache/spark/scheduler/DAGScheduler.scala | 78 +++++----- .../apache/spark/scheduler/ResultTask.scala | 3 +- .../spark/scheduler/ShuffleMapTask.scala | 5 +- .../org/apache/spark/scheduler/Task.scala | 5 +- .../spark/scheduler/TaskSchedulerImpl.scala | 99 +++++++----- .../org/apache/spark/scheduler/TaskSet.scala | 4 +- .../CoarseGrainedSchedulerBackend.scala | 5 +- .../spark/scheduler/DAGSchedulerSuite.scala | 141 ++++++++++++++++++ .../org/apache/spark/scheduler/FakeTask.scala | 8 +- .../scheduler/NotSerializableFakeTask.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 4 +- .../scheduler/TaskSchedulerImplSuite.scala | 113 +++++++++++++- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- 13 files changed, 383 insertions(+), 86 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dd55cd8054332..71a219a4f3414 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -857,7 +857,6 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() - // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { stage match { @@ -918,7 +917,7 @@ class DAGScheduler( partitionsToCompute.map { id => val locs = getPreferredLocs(stage.rdd, id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) } case stage: ResultStage => @@ -927,7 +926,7 @@ class DAGScheduler( val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) + new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) } } } catch { @@ -1069,10 +1068,11 @@ class DAGScheduler( val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { shuffleStage.addOutputLoc(smt.partitionId, status) } + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") @@ -1132,38 +1132,48 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is possible - // the fetch failure has already been handled by the scheduler. - if (runningStages.contains(failedStage)) { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + - s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) - } + if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + } else { - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) - } - failedStages += failedStage - failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is + // possible the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, Some(failureMessage)) + } else { + logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + + s"longer running") + } + + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage + // Mark the map whose fetch failed as broken in the map stage + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + } } case commitDenied: TaskCommitDenied => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c9a124113961f..9c2606e278c54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], val outputId: Int) - extends Task[U](stageId, partition.index) with Serializable { + extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index bd3dd23dfe1ac..14c8c00961487 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter */ private[spark] class ShuffleMapTask( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, partition.index) with Logging { + extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6a86f9d4b8530..76a19aeac4679 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils * @param stageId id of the stage this task belongs to * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { +private[spark] abstract class Task[T]( + val stageId: Int, + val stageAttemptId: Int, + var partitionId: Int) extends Serializable { /** * The key of the Map is the accumulator id and the value of the Map is the latest accumulator diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ed3dde0fc3055..1705e7f962de2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] + private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToTaskSetId = new HashMap[Long, String] + private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager + val stage = taskSet.stageId + val stageTaskSets = + taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + stageTaskSets(taskSet.stageAttemptId) = manager + val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => + ts.taskSet != taskSet && !ts.isZombie + } + if (conflictingTaskSet) { + throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + + s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") + } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) + } + tsm.abort("Stage %s cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) } - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) } } @@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - activeTaskSets -= manager.taskSet.id + taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsForStage -= manager.taskSet.stageAttemptId + if (taskSetsForStage.isEmpty) { + taskSetsByStageIdAndAttempt -= manager.taskSet.stageId + } + } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name)) @@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => + taskIdToTaskSetManager.get(tid) match { + case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) + taskIdToTaskSetManager.remove(tid) taskIdToExecutorId.remove(tid) } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") - .format(state, tid)) + "likely the result of receiving duplicate task finished status updates)") + .format(state, tid)) } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToTaskSetId.get(id) - .flatMap(activeTaskSets.get) - .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) + taskIdToTaskSetManager.get(id).map { taskSetMgr => + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + } } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) @@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.nonEmpty) { + if (taskSetsByStageIdAndAttempt.nonEmpty) { // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + for { + attempts <- taskSetsByStageIdAndAttempt.values + manager <- attempts.values + } { try { manager.abort(message) } catch { @@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl( override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() + private[scheduler] def taskSetManagerForAttempt( + stageId: Int, + stageAttemptId: Int): Option[TaskSetManager] = { + for { + attempts <- taskSetsByStageIdAndAttempt.get(stageId) + manager <- attempts.get(stageAttemptId) + } yield { + manager + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156f53..be8526ba9b94f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -26,10 +26,10 @@ import java.util.Properties private[spark] class TaskSet( val tasks: Array[Task[_]], val stageId: Int, - val attempt: Int, + val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + attempt + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 0e3215d6e9ec8..f14c603ac6891 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) - scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => + scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + "spark.akka.frameSize or using broadcast variables for large values." msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, AkkaUtils.reservedSizeBytes) - taskSet.abort(msg) + taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4f2b0fa162b72..86728cb2b62af 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -101,9 +101,15 @@ class DAGSchedulerSuite /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val sparkListener = new SparkListener() { + val submittedStageInfos = new HashSet[StageInfo] val successfulStages = new HashSet[Int] val failedStages = new ArrayBuffer[Int] val stageByOrderOfExecution = new ArrayBuffer[Int] + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + submittedStageInfos += stageSubmitted.stageInfo + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageInfo = stageCompleted.stageInfo stageByOrderOfExecution += stageInfo.stageId @@ -150,6 +156,7 @@ class DAGSchedulerSuite // Enable local execution for this test val conf = new SparkConf().set("spark.localExecution.enabled", "true") sc = new SparkContext("local", "DAGSchedulerSuite", conf) + sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null @@ -547,6 +554,140 @@ class DAGSchedulerSuite assert(sparkListener.failedStages.size == 1) } + /** + * This tests the case where another FetchFailed comes in while the map stage is getting + * re-run. + */ + test("late fetch failures don't cause multiple concurrent attempts for the same map stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + assert(countSubmittedMapStageAttempts() === 1) + + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // The MapOutputTracker should know about both map output locations. + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === + Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) === + Array("hostA", "hostB")) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(1)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + + // The second ResultTask fails, with a fetch failure for the output from the second mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Another ResubmitFailedStages event should not result in another attempt for the map + // stage being run concurrently. + // NOTE: the actual ResubmitFailedStages may get called at any time during this, but it + // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the + // desired event and our check. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + + } + + /** + * This tests the case where a late FetchFailed comes in after the map stage has finished getting + * retried and a new reduce stage starts running. + */ + test("extremely late fetch failures don't cause multiple concurrent attempts for " + + "the same stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + def countSubmittedReduceStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 1) + } + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 0) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // Complete the map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + + // The reduce stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedReduceStageAttempts() === 1) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Trigger resubmission of the failed map stage and finish the re-started map task. + runEvent(ResubmitFailedStages) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + + // Because the map stage finished, another attempt for the reduce stage should have been + // submitted, resulting in 2 total attempts for each the map and the reduce stage. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + assert(countSubmittedReduceStageAttempts() === 2) + + // A late FetchFailed arrives from the second task in the original reduce stage. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because + // the FetchFailed should have been ignored + runEvent(ResubmitFailedStages) + + // The FetchFailed from the original reduce stage should be ignored. + assert(countSubmittedMapStageAttempts() === 2) + } + test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0a7cb69416a08..b3ca150195a5f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs @@ -31,12 +31,16 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, 0, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, 0, 0, null) + new TaskSet(tasks, 0, stageAttemptId, 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 9b92f8de56759..383855caefa2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0) { + extends Task[Array[Byte]](stageId, 0, 0) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7c1adc1aef1b6..b9b0eccb0d834 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -41,8 +41,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String]( - 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val task = new ResultTask[String, String](0, 0, + sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0, 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a6d5232feb8de..c2edd4c317d6e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(taskDescriptions.map(_.executorId) === Seq("executor0")) } + test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + val dagScheduler = new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + taskScheduler.setDAGScheduler(dagScheduler) + val attempt1 = FakeTask.createTaskSet(1, 0) + val attempt2 = FakeTask.createTaskSet(1, 1) + taskScheduler.submitTasks(attempt1) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } + + // OK to submit multiple if previous attempts are all zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt2) + val attempt3 = FakeTask.createTaskSet(1, 2) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } + taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt3) + } + + test("don't schedule more tasks after a taskset is zombie") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 1 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // if we schedule another attempt for the same stage, it should get scheduled + val attempt2 = FakeTask.createTaskSet(10, 1) + + // submit attempt 2, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt2) + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions3.length) + val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + + test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 10 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get + mgr1.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // submit attempt 2 + val attempt2 = FakeTask.createTaskSet(10, 1) + taskScheduler.submitTasks(attempt2) + + // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were + // already submitted, and then they finish) + taskScheduler.taskSetFinished(mgr1) + + // now with another resource offer, we should still schedule all the tasks in attempt2 + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions3.length) + + taskDescriptions3.foreach { task => + val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index cdae0d83d01dc..3abb99c4b2b54 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer)