diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index f2decd14ef6d9..2eec09cd1c795 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -141,7 +141,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: private def awaitResult(): Try[T] = { jobWaiter.awaitResult() match { case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception, _) => scala.util.Failure(e) + case JobFailed(e: Exception) => scala.util.Failure(e) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c96d7435a7ed4..8e4674125a729 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -342,22 +342,24 @@ class DAGScheduler( } /** - * Removes job and any stages that are not needed by any other job. Returns the set of ids for - * stages that were removed. The associated tasks for those stages need to be cancelled if we - * got here via job cancellation. + * Removes state for job and any stages that are not needed by any other job. Does not + * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. + * + * @param job The job whose state to cleanup. + * @param resultStage Specifies the result stage for the job; if set to None, this method + * searches resultStagesToJob to find and cleanup the appropriate result stage. */ - private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { - val registeredStages = jobIdToStageIds(jobId) - val independentStages = new HashSet[Int]() - if (registeredStages.isEmpty) { - logError("No stages registered for job " + jobId) + private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) { + val registeredStages = jobIdToStageIds.get(job.jobId) + if (registeredStages.isEmpty || registeredStages.get.isEmpty) { + logError("No stages registered for job " + job.jobId) } else { - stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { + stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { case (stageId, jobSet) => - if (!jobSet.contains(jobId)) { + if (!jobSet.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" - .format(jobId, stageId)) + .format(job.jobId, stageId)) } else { def removeStage(stageId: Int) { // data structures based on Stage @@ -394,23 +396,28 @@ class DAGScheduler( .format(stageId, stageIdToStage.size)) } - jobSet -= jobId + jobSet -= job.jobId if (jobSet.isEmpty) { // no other job needs this stage - independentStages += stageId removeStage(stageId) } } } } - independentStages.toSet - } + jobIdToStageIds -= job.jobId + jobIdToActiveJob -= job.jobId + activeJobs -= job - private def jobIdToStageIdsRemove(jobId: Int) { - if (!jobIdToStageIds.contains(jobId)) { - logDebug("Trying to remove unregistered job " + jobId) + if (resultStage.isEmpty) { + // Clean up result stages. + val resultStagesForJob = resultStageToJob.keySet.filter( + stage => resultStageToJob(stage).jobId == job.jobId) + if (resultStagesForJob.size != 1) { + logWarning( + s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") + } + resultStageToJob --= resultStagesForJob } else { - removeJobAndIndependentStages(jobId) - jobIdToStageIds -= jobId + resultStageToJob -= resultStage.get } } @@ -460,7 +467,7 @@ class DAGScheduler( val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => {} - case JobFailed(exception: Exception, _) => + case JobFailed(exception: Exception) => logInfo("Failed to run " + callSite) throw exception } @@ -606,7 +613,16 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, -1))) + // Tell the listeners the all of the stages have ended. Don't bother cancelling the + // stages because if the DAG scheduler is stopped, the entire application is in the + // process of getting stopped. + val stageFailedMessage = "Stage cancelled because SparkContext was shut down" + runningStages.foreach { stage => + val info = stageToInfos(stage) + info.stageFailed(stageFailedMessage) + listenerBus.post(SparkListenerStageEnded(info)) + } + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } return true } @@ -676,7 +692,7 @@ class DAGScheduler( } } catch { case e: Exception => - jobResult = JobFailed(e, job.finalStage.id) + jobResult = JobFailed(e) job.listener.jobFailed(e) } finally { val s = job.finalStage @@ -807,7 +823,7 @@ class DAGScheduler( } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stageToInfos(stage).completionTime = Some(System.currentTimeMillis()) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + listenerBus.post(SparkListenerStageEnded(stageToInfos(stage))) runningStages -= stage } event.reason match { @@ -826,11 +842,9 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - jobIdToActiveJob -= stage.jobId - activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) - jobIdToStageIdsRemove(job.jobId) + cleanupStateForJobAndIndependentStages(job, Some(stage)) listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded)) } job.listener.taskSucceeded(rt.outputId, event.result) @@ -982,7 +996,7 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled") + failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled", None) } } @@ -999,7 +1013,8 @@ class DAGScheduler( stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", + Some(resultStage)) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1008,28 +1023,45 @@ class DAGScheduler( /** * Fails a job and all stages that are only used by that job, and cleans up relevant state. + * + * @param resultStage The result stage for the job, if known. Used to cleanup state for the job + * slightly more efficiently than when not specified. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { + private def failJobAndIndependentStages(job: ActiveJob, failureReason: String, + resultStage: Option[Stage]) { val error = new SparkException(failureReason) job.listener.jobFailed(error) - // Cancel all tasks in independent stages. - val independentStages = removeJobAndIndependentStages(job.jobId) - independentStages.foreach(taskScheduler.cancelTasks) - - // Clean up remaining state we store for the job. - jobIdToActiveJob -= job.jobId - activeJobs -= job - jobIdToStageIds -= job.jobId - val resultStagesForJob = resultStageToJob.keySet.filter( - stage => resultStageToJob(stage).jobId == job.jobId) - if (resultStagesForJob.size != 1) { - logWarning( - s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") + // Cancel all independent, running stages. + val stages = jobIdToStageIds(job.jobId) + if (stages.isEmpty) { + logError("No stages registered for job " + job.jobId) } - resultStageToJob --= resultStagesForJob + stages.foreach { stageId => + val jobsForStage = stageIdToJobIds.get(stageId) + if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" + .format(job.jobId, stageId)) + } else if (jobsForStage.get.size == 1) { + if (!stageIdToStage.contains(stageId)) { + // This is the only job that uses this stage, so fail the stage if it is running. + logError("Missing Stage for stage with id $stageId") + } else { + val stage = stageIdToStage(stageId) + if (runningStages.contains(stage)) { + taskScheduler.cancelTasks(stageId) + val stageInfo = stageToInfos(stage) + stageInfo.stageFailed(failureReason) + listenerBus.post(SparkListenerStageEnded(stageToInfos(stage))) + } + } + } + } + + cleanupStateForJobAndIndependentStages(job, None) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id))) + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 217f8825c2ae9..dc8f88d16fadd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -78,7 +78,7 @@ private[spark] class EventLoggingListener(appName: String, conf: SparkConf) logEvent(event) // Events that trigger a flush - override def onStageCompleted(event: SparkListenerStageCompleted) = + override def onStageEnded(event: SparkListenerStageEnded) = logEvent(event, flushLogger = true) override def onJobStart(event: SparkListenerJobStart) = logEvent(event, flushLogger = true) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 5cecf9416b32c..be361bd2231ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -187,11 +187,15 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener /** * When stage is completed, record stage completion status - * @param stageCompleted Stage completed event + * @param stageEnded Stage ended event */ - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { - val stageId = stageCompleted.stageInfo.stageId - stageLogInfo(stageId, "STAGE_ID=%d STATUS=COMPLETED".format(stageId)) + override def onStageEnded(stageEnded: SparkListenerStageEnded) { + val stageId = stageEnded.stageInfo.stageId + if (stageEnded.stageInfo.failureReason.isEmpty) { + stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED") + } else { + stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED") + } } /** @@ -227,7 +231,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener var info = "JOB_ID=" + jobId jobEnd.jobResult match { case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception, _) => + case JobFailed(exception) => info += " STATUS=FAILED REASON=" exception.getMessage.split("\\s+").foreach(info += _ + "_") case _ => diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala index 3cf4e3077e4a4..047bd27056120 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala @@ -24,5 +24,4 @@ private[spark] sealed trait JobResult private[spark] case object JobSucceeded extends JobResult -// A failed stage ID of -1 means there is not a particular stage that caused the failure -private[spark] case class JobFailed(exception: Exception, failedStageId: Int) extends JobResult +private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 8007b5418741e..e9bfee2248e5b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -64,7 +64,7 @@ private[spark] class JobWaiter[T]( override def jobFailed(exception: Exception): Unit = synchronized { _jobFinished = true - jobResult = JobFailed(exception, -1) + jobResult = JobFailed(exception) this.notifyAll() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index d4eb0ac88d8e8..53f180db9c403 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -32,7 +32,7 @@ sealed trait SparkListenerEvent case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) extends SparkListenerEvent -case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent +case class SparkListenerStageEnded(stageInfo: StageInfo) extends SparkListenerEvent case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent @@ -71,9 +71,9 @@ private[spark] case object SparkListenerShutdown extends SparkListenerEvent */ trait SparkListener { /** - * Called when a stage is completed, with information on the completed stage + * Called when a stage completes or fails, with information on the completed stage */ - def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { } + def onStageEnded(stageEnded: SparkListenerStageEnded) { } /** * Called when a stage is submitted @@ -144,7 +144,7 @@ class StatsReportListener extends SparkListener with Logging { } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { + override def onStageEnded(stageCompleted: SparkListenerStageEnded) { implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stageInfo) showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 729e120497571..7a9615f4355d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -41,8 +41,8 @@ private[spark] trait SparkListenerBus { event match { case stageSubmitted: SparkListenerStageSubmitted => sparkListeners.foreach(_.onStageSubmitted(stageSubmitted)) - case stageCompleted: SparkListenerStageCompleted => - sparkListeners.foreach(_.onStageCompleted(stageCompleted)) + case stageEnded: SparkListenerStageEnded => + sparkListeners.foreach(_.onStageEnded(stageEnded)) case jobStart: SparkListenerJobStart => sparkListeners.foreach(_.onJobStart(jobStart)) case jobEnd: SparkListenerJobEnd => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 8115a7ed7896d..eec409b182ac6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -26,8 +26,17 @@ private[spark] class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfo: RDDInfo) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None + /** Time when all tasks in the stage completed or when the stage was cancelled. */ var completionTime: Option[Long] = None + /** If the stage failed, the reason why. */ + var failureReason: Option[String] = None + var emittedTaskSizeWarning = false + + def stageFailed(reason: String) { + failureReason = Some(reason) + completionTime = Some(System.currentTimeMillis) + } } private[spark] diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 048f671c8788f..fe7b4e5ca9290 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -68,14 +68,19 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { def blockManagerIds = executorIdToBlockManagerId.values.toSeq - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { - val stage = stageCompleted.stageInfo + override def onStageEnded(stageEnded: SparkListenerStageEnded) = synchronized { + val stage = stageEnded.stageInfo val stageId = stage.stageId // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage poolToActiveStages(stageIdToPool(stageId)).remove(stageId) activeStages.remove(stageId) - completedStages += stage - trimIfNecessary(completedStages) + if (stage.failureReason.isEmpty) { + completedStages += stage + trimIfNecessary(completedStages) + } else { + failedStages += stage + trimIfNecessary(failedStages) + } } /** If stages is too large, remove and garbage collect old stages */ @@ -215,20 +220,6 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { } } - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { - jobEnd.jobResult match { - case JobFailed(_, stageId) => - activeStages.get(stageId).foreach { s => - // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage - activeStages.remove(s.stageId) - poolToActiveStages(stageIdToPool(stageId)).remove(s.stageId) - failedStages += s - trimIfNecessary(failedStages) - } - case _ => - } - } - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { val schedulingModeName = diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index a7b24ff695214..b3e819044ab51 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -87,7 +87,7 @@ private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListe _rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo) } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageEnded(stageEnded: SparkListenerStageEnded) = synchronized { // Remove all partitions that are no longer cached _rddInfoMap.retain { case (_, info) => info.numCachedPartitions > 0 } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 2155a8888c85c..18f8b4161e458 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -22,7 +22,7 @@ import java.util.{Properties, UUID} import scala.collection.JavaConverters._ import scala.collection.Map -import org.json4s.DefaultFormats +import org.json4s.{DefaultFormats, JsonAST} import org.json4s.JsonDSL._ import org.json4s.JsonAST._ @@ -42,7 +42,7 @@ private[spark] object JsonProtocol { event match { case stageSubmitted: SparkListenerStageSubmitted => stageSubmittedToJson(stageSubmitted) - case stageCompleted: SparkListenerStageCompleted => + case stageCompleted: SparkListenerStageEnded => stageCompletedToJson(stageCompleted) case taskStart: SparkListenerTaskStart => taskStartToJson(taskStart) @@ -76,7 +76,7 @@ private[spark] object JsonProtocol { ("Properties" -> properties) } - def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted): JValue = { + def stageCompletedToJson(stageCompleted: SparkListenerStageEnded): JValue = { val stageInfo = stageInfoToJson(stageCompleted.stageInfo) ("Event" -> Utils.getFormattedClassName(stageCompleted)) ~ ("Stage Info" -> stageInfo) @@ -260,8 +260,7 @@ private[spark] object JsonProtocol { case JobSucceeded => Utils.emptyJson case jobFailed: JobFailed => val exception = exceptionToJson(jobFailed.exception) - ("Exception" -> exception) ~ - ("Failed Stage ID" -> jobFailed.failedStageId) + JsonAST.JObject("Exception" -> exception) } ("Result" -> result) ~ json } @@ -336,7 +335,7 @@ private[spark] object JsonProtocol { def sparkEventFromJson(json: JValue): SparkListenerEvent = { val stageSubmitted = Utils.getFormattedClassName(SparkListenerStageSubmitted) - val stageCompleted = Utils.getFormattedClassName(SparkListenerStageCompleted) + val stageCompleted = Utils.getFormattedClassName(SparkListenerStageEnded) val taskStart = Utils.getFormattedClassName(SparkListenerTaskStart) val taskGettingResult = Utils.getFormattedClassName(SparkListenerTaskGettingResult) val taskEnd = Utils.getFormattedClassName(SparkListenerTaskEnd) @@ -368,9 +367,9 @@ private[spark] object JsonProtocol { SparkListenerStageSubmitted(stageInfo, properties) } - def stageCompletedFromJson(json: JValue): SparkListenerStageCompleted = { + def stageCompletedFromJson(json: JValue): SparkListenerStageEnded = { val stageInfo = stageInfoFromJson(json \ "Stage Info") - SparkListenerStageCompleted(stageInfo) + SparkListenerStageEnded(stageInfo) } def taskStartFromJson(json: JValue): SparkListenerTaskStart = { @@ -561,8 +560,7 @@ private[spark] object JsonProtocol { case `jobSucceeded` => JobSucceeded case `jobFailed` => val exception = exceptionFromJson(json \ "Exception") - val failedStageId = (json \ "Failed Stage ID").extract[Int] - new JobFailed(exception, failedStageId) + new JobFailed(exception) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2e3026bffba2f..c3d718c23c579 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -64,6 +64,21 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def defaultParallelism() = 2 } + /** Length of time to wait while draining listener events. */ + val WAIT_TIMEOUT_MILLIS = 10000 + val sparkListener = new SparkListener() { + val successfulStages = new HashSet[Int]() + val failedStages = new HashSet[Int]() + override def onStageEnded(stageEnded: SparkListenerStageEnded) { + val stageInfo = stageEnded.stageInfo + if (stageInfo.failureReason.isEmpty) { + successfulStages += stageInfo.stageId + } else { + failedStages += stageInfo.stageId + } + } + } + var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null @@ -89,13 +104,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont /** The list of results that DAGScheduler has collected. */ val results = new HashMap[Int, Any]() var failure: Exception = _ - val listener = new JobListener() { + val jobListener = new JobListener() { override def taskSucceeded(index: Int, result: Any) = results.put(index, result) override def jobFailed(exception: Exception) = { failure = exception } } before { sc = new SparkContext("local", "DAGSchedulerSuite") + sparkListener.successfulStages.clear() + sparkListener.failedStages.clear() + sc.addSparkListener(sparkListener) taskSets.clear() cancelledStages.clear() cacheLocations.clear() @@ -187,7 +205,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, allowLocal: Boolean = false, - listener: JobListener = listener): Int = { + listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) return jobId @@ -231,7 +249,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, jobListener)) assert(results === Map(0 -> 42)) assertDataStructuresEmpty } @@ -262,6 +280,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted due to stage failure: some failure") + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty } @@ -270,6 +291,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val jobId = submit(rdd, Array(0)) cancel(jobId) assert(failure.getMessage === s"Job $jobId cancelled") + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty } @@ -354,6 +378,13 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) assert(failure.getMessage === s"Job aborted due to stage failure: $stageFailureMessage") + + // Listener bus should get told about the map stage failing, but not the reduce stage + // (since the reduce stage hasn't been started yet). + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(1)) + assert(sparkListener.failedStages.size === 1) + assertDataStructuresEmpty } @@ -398,8 +429,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) - + assert(cancelledStages.contains(1)) + + // Make sure the listeners got told about both failed stages. + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.successfulStages.isEmpty) + assert(sparkListener.failedStages.contains(1)) + assert(sparkListener.failedStages.contains(3)) + assert(sparkListener.failedStages.size === 2) + assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assertDataStructuresEmpty diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 7c843772bc2e0..349c864935b3e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -257,7 +257,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } - override def onStageCompleted(stage: SparkListenerStageCompleted) { + override def onStageEnded(stage: SparkListenerStageEnded) { stageInfos(stage.stageInfo) = taskInfoMetrics taskInfoMetrics = mutable.Buffer.empty } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 7bab7da8fed68..06d349c9512d1 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -34,7 +34,7 @@ class JsonProtocolSuite extends FunSuite { test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) - val stageCompleted = SparkListenerStageCompleted(makeStageInfo(101, 201, 301, 401L, 501L)) + val stageEnded = SparkListenerStageEnded(makeStageInfo(101, 201, 301, 401L, 501L)) val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 444L)) val taskGettingResult = SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 3000L)) val taskEnd = SparkListenerTaskEnd(1, "ShuffleMapTask", Success, @@ -54,7 +54,7 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) testEvent(stageSubmitted, stageSubmittedJsonString) - testEvent(stageCompleted, stageCompletedJsonString) + testEvent(stageEnded, stageCompletedJsonString) testEvent(taskStart, taskStartJsonString) testEvent(taskGettingResult, taskGettingResultJsonString) testEvent(taskEnd, taskEndJsonString) @@ -89,7 +89,7 @@ class JsonProtocolSuite extends FunSuite { // JobResult val exception = new Exception("Out of Memory! Please restock film.") exception.setStackTrace(stackTrace) - val jobFailed = JobFailed(exception, 2) + val jobFailed = JobFailed(exception) testJobResult(JobSucceeded) testJobResult(jobFailed) @@ -180,7 +180,7 @@ class JsonProtocolSuite extends FunSuite { case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) => assert(e1.properties === e2.properties) assertEquals(e1.stageInfo, e2.stageInfo) - case (e1: SparkListenerStageCompleted, e2: SparkListenerStageCompleted) => + case (e1: SparkListenerStageEnded, e2: SparkListenerStageEnded) => assertEquals(e1.stageInfo, e2.stageInfo) case (e1: SparkListenerTaskStart, e2: SparkListenerTaskStart) => assert(e1.stageId === e2.stageId) @@ -294,7 +294,6 @@ class JsonProtocolSuite extends FunSuite { (result1, result2) match { case (JobSucceeded, JobSucceeded) => case (r1: JobFailed, r2: JobFailed) => - assert(r1.failedStageId === r2.failedStageId) assertEquals(r1.exception, r2.exception) case _ => fail("Job results don't match in types!") }