diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2ecc48a6b0566..f34e98f86b86b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -239,8 +239,10 @@ class DAGScheduler( if (mapOutputTracker.has(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedLocations(shuffleDep.shuffleId) val locs = mapOutputTracker.deserializeStatuses(serLocs) - for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i)) - stage.numAvailableOutputs = locs.size + for (i <- 0 until locs.size) { + stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + } + stage.numAvailableOutputs = locs.count(_ != null) } else { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown @@ -337,24 +339,26 @@ class DAGScheduler( } else { def removeStage(stageId: Int) { // data structures based on Stage - stageIdToStage.get(stageId).foreach { s => - if (running.contains(s)) { + for (stage <- stageIdToStage.get(stageId)) { + if (running.contains(stage)) { logDebug("Removing running stage %d".format(stageId)) - running -= s + running -= stage + } + stageToInfos -= stage + for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) { + shuffleToMapStage.remove(k) } - stageToInfos -= s - shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) - if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { + if (pendingTasks.contains(stage) && !pendingTasks(stage).isEmpty) { logDebug("Removing pending status for stage %d".format(stageId)) } - pendingTasks -= s - if (waiting.contains(s)) { + pendingTasks -= stage + if (waiting.contains(stage)) { logDebug("Removing stage %d from waiting set.".format(stageId)) - waiting -= s + waiting -= stage } - if (failed.contains(s)) { + if (failed.contains(stage)) { logDebug("Removing stage %d from failed set.".format(stageId)) - failed -= s + failed -= stage } } // data structures based on StageId diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index af448fcb37a1f..b7e95b639032c 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -81,6 +81,19 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } + // Run a map-reduce job in which the map stage always fails. + test("failure in a map stage") { + sc = new SparkContext("local", "test") + val data = sc.makeRDD(1 to 3).map(x => { throw new Exception; (x, x) }).groupByKey(3) + intercept[SparkException] { + data.collect() + } + // Make sure that running new jobs with the same map stage also fails + intercept[SparkException] { + data.collect() + } + } + test("failure because task results are not serializable") { sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)