From c443def131fb36a4c915448581a2486802e9ee67 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 10 Jun 2015 14:20:40 -0500 Subject: [PATCH] better fix and simpler test case --- .../apache/spark/scheduler/DAGScheduler.scala | 65 ++++++++++--------- .../DAGSchedulerFailureRecoverySuite.scala | 51 ++++++++------- 2 files changed, 60 insertions(+), 56 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 026fef7a6618f..e7226a63793a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1102,44 +1102,47 @@ class DAGScheduler( case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) + if (failedStage.attemptId - 1 > task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId}, which has already failed") + } else { - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is possible - // the fetch failure has already been handled by the scheduler. - if (runningStages.contains(failedStage)) { - if (failedStage.attemptId - 1 > task.stageAttemptId) { - logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + - s" ${task.stageAttemptId}, which has already failed") - } else { + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is possible + // the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) + } else { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage, " + + s"which is no longer running") } - } - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) - } - failedStages += failedStage - failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage + // Mark the map whose fetch failed as broken in the map stage + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + } } case commitDenied: TaskCommitDenied => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFailureRecoverySuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFailureRecoverySuite.scala index fe4ef2deb735d..f330ef622f7d4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFailureRecoverySuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFailureRecoverySuite.scala @@ -26,35 +26,33 @@ import org.apache.spark._ class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging { - // TODO we should run this with a matrix of configurations: different shufflers, - // external shuffle service, etc. But that is really pushing the question of how to run - // such a long test ... - - ignore("no concurrent retries for stage attempts (SPARK-7308)") { - // see SPARK-7308 for a detailed description of the conditions this is trying to recreate. - // note that this is somewhat convoluted for a test case, but isn't actually very unusual - // under a real workload. We only fail the first attempt of stage 2, but that - // could be enough to cause havoc. - - (0 until 100).foreach { idx => - println(new Date() + "\ttrial " + idx) + test("no concurrent retries for stage attempts (SPARK-8103)") { + // make sure that if we get fetch failures after the retry has started, we ignore them, + // and so don't end up submitting multiple concurrent attempts for the same stage + + (0 until 20).foreach { idx => logInfo(new Date() + "\ttrial " + idx) val conf = new SparkConf().set("spark.executor.memory", "100m") - val clusterSc = new SparkContext("local-cluster[5,4,100]", "test-cluster", conf) + val clusterSc = new SparkContext("local-cluster[2,2,100]", "test-cluster", conf) val bms = ArrayBuffer[BlockManagerId]() val stageFailureCount = HashMap[Int, Int]() + val stageSubmissionCount = HashMap[Int, Int]() clusterSc.addSparkListener(new SparkListener { override def onBlockManagerAdded(bmAdded: SparkListenerBlockManagerAdded): Unit = { bms += bmAdded.blockManagerId } + override def onStageSubmitted(stageSubmited: SparkListenerStageSubmitted): Unit = { + val stage = stageSubmited.stageInfo.stageId + stageSubmissionCount(stage) = stageSubmissionCount.getOrElse(stage, 0) + 1 + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { if (stageCompleted.stageInfo.failureReason.isDefined) { val stage = stageCompleted.stageInfo.stageId stageFailureCount(stage) = stageFailureCount.getOrElse(stage, 0) + 1 - val reason = stageCompleted.stageInfo.failureReason.get - println("stage " + stage + " failed: " + stageFailureCount(stage)) } } }) @@ -66,7 +64,7 @@ class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging { // to avoid broadcast failures val someBlockManager = bms.filter{!_.isDriver}(0) - val shuffled = rawData.groupByKey(100).mapPartitionsWithIndex { case (idx, itr) => + val shuffled = rawData.groupByKey(20).mapPartitionsWithIndex { case (idx, itr) => // we want one failure quickly, and more failures after stage 0 has finished its // second attempt val stageAttemptId = TaskContext.get().asInstanceOf[TaskContextImpl].stageAttemptId @@ -74,26 +72,29 @@ class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging { if (idx == 0) { throw new FetchFailedException(someBlockManager, 0, 0, idx, cause = new RuntimeException("simulated fetch failure")) - } else if (idx > 0 && math.random < 0.2) { - Thread.sleep(5000) + } else if (idx == 1) { + Thread.sleep(2000) throw new FetchFailedException(someBlockManager, 0, 0, idx, cause = new RuntimeException("simulated fetch failure")) - } else { - // want to make sure plenty of these finish after task 0 fails, and some even finish - // after the previous stage is retried and this stage retry is started - Thread.sleep((500 + math.random * 5000).toLong) } + } else { + // just to make sure the second attempt doesn't finish before we trigger more failures + // from the first attempt + Thread.sleep(2000) } itr.map { x => ((x._1 + 5) % 100) -> x._2 } } - val data = shuffled.mapPartitions { itr => itr.flatMap(_._2) }.collect() + val data = shuffled.mapPartitions { itr => + itr.flatMap(_._2) + }.cache().collect() val count = data.size assert(count === 1e6.toInt) assert(data.toSet === (1 to 1e6.toInt).toSet) assert(stageFailureCount.getOrElse(1, 0) === 0) - assert(stageFailureCount.getOrElse(2, 0) == 1) - assert(stageFailureCount.getOrElse(3, 0) == 0) + assert(stageFailureCount.getOrElse(2, 0) === 1) + assert(stageSubmissionCount.getOrElse(1, 0) <= 2) + assert(stageSubmissionCount.getOrElse(2, 0) === 2) } finally { clusterSc.stop() }