From 320c7c7ee72339166bf1f8b743471df4fae63de7 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 2 Apr 2014 15:54:29 -0700 Subject: [PATCH] Notify SparkListeners when stages fail or are cancelled. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, when stages fail or get cancelled, the SparkListener is only notified indirectly through the SparkListenerJobEnd, where we sometimes pass in a single stage that failed. This worked before job cancellation, because jobs would only fail due to a single stage failure. However, with job cancellation, multiple running stages can fail when a job gets cancelled. Right now, this is not handled correctly, which results in stages that get stuck in the “Running Stages” window in the UI even though they’re dead. This PR changes the SparkListenerStageCompleted event to a SparkListenerStageEnded event, and uses this event to tell SparkListeners when stages fail in addition to when they complete successfully. This change is NOT publicly backward compatible for two reasons. First, it changes the SparkListener interface. We could alternately add a new event, SparkListenerStageFailed, and keep the existing SparkListenerStageCompleted. However, this is less consistent with the listener events for tasks / jobs ending, and will result in some code duplication for listeners (because failed and completed stages are handled in similar ways). Note that I haven’t finished updating the JSON code to correctly handle the new event because I’m waiting for feedback on whether this is a good or bad idea (hence the “WIP”). It is also not backwards compatible because it changes the publicly visible JobWaiter.jobFailed() method to no longer include a stage that caused the failure. I think this change should definitely stay, because with cancellation (as described above), a failure isn’t necessarily caused by a single stage. --- .../scala/org/apache/spark/FutureAction.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 122 +++++++++++------- .../scheduler/EventLoggingListener.scala | 2 +- .../apache/spark/scheduler/JobLogger.scala | 14 +- .../apache/spark/scheduler/JobResult.scala | 3 +- .../apache/spark/scheduler/JobWaiter.scala | 2 +- .../spark/scheduler/SparkListener.scala | 8 +- .../spark/scheduler/SparkListenerBus.scala | 4 +- .../apache/spark/scheduler/StageInfo.scala | 9 ++ .../spark/ui/jobs/JobProgressListener.scala | 27 ++-- .../spark/ui/storage/BlockManagerUI.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 18 ++- .../spark/scheduler/DAGSchedulerSuite.scala | 47 ++++++- .../spark/scheduler/SparkListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 9 +- 15 files changed, 171 insertions(+), 100 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index f2decd14ef6d9..2eec09cd1c795 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -141,7 +141,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: private def awaitResult(): Try[T] = { jobWaiter.awaitResult() match { case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception, _) => scala.util.Failure(e) + case JobFailed(e: Exception) => scala.util.Failure(e) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c96d7435a7ed4..8e4674125a729 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -342,22 +342,24 @@ class DAGScheduler( } /** - * Removes job and any stages that are not needed by any other job. Returns the set of ids for - * stages that were removed. The associated tasks for those stages need to be cancelled if we - * got here via job cancellation. + * Removes state for job and any stages that are not needed by any other job. Does not + * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. + * + * @param job The job whose state to cleanup. + * @param resultStage Specifies the result stage for the job; if set to None, this method + * searches resultStagesToJob to find and cleanup the appropriate result stage. */ - private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { - val registeredStages = jobIdToStageIds(jobId) - val independentStages = new HashSet[Int]() - if (registeredStages.isEmpty) { - logError("No stages registered for job " + jobId) + private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) { + val registeredStages = jobIdToStageIds.get(job.jobId) + if (registeredStages.isEmpty || registeredStages.get.isEmpty) { + logError("No stages registered for job " + job.jobId) } else { - stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { + stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { case (stageId, jobSet) => - if (!jobSet.contains(jobId)) { + if (!jobSet.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" - .format(jobId, stageId)) + .format(job.jobId, stageId)) } else { def removeStage(stageId: Int) { // data structures based on Stage @@ -394,23 +396,28 @@ class DAGScheduler( .format(stageId, stageIdToStage.size)) } - jobSet -= jobId + jobSet -= job.jobId if (jobSet.isEmpty) { // no other job needs this stage - independentStages += stageId removeStage(stageId) } } } } - independentStages.toSet - } + jobIdToStageIds -= job.jobId + jobIdToActiveJob -= job.jobId + activeJobs -= job - private def jobIdToStageIdsRemove(jobId: Int) { - if (!jobIdToStageIds.contains(jobId)) { - logDebug("Trying to remove unregistered job " + jobId) + if (resultStage.isEmpty) { + // Clean up result stages. + val resultStagesForJob = resultStageToJob.keySet.filter( + stage => resultStageToJob(stage).jobId == job.jobId) + if (resultStagesForJob.size != 1) { + logWarning( + s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") + } + resultStageToJob --= resultStagesForJob } else { - removeJobAndIndependentStages(jobId) - jobIdToStageIds -= jobId + resultStageToJob -= resultStage.get } } @@ -460,7 +467,7 @@ class DAGScheduler( val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => {} - case JobFailed(exception: Exception, _) => + case JobFailed(exception: Exception) => logInfo("Failed to run " + callSite) throw exception } @@ -606,7 +613,16 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, -1))) + // Tell the listeners the all of the stages have ended. Don't bother cancelling the + // stages because if the DAG scheduler is stopped, the entire application is in the + // process of getting stopped. + val stageFailedMessage = "Stage cancelled because SparkContext was shut down" + runningStages.foreach { stage => + val info = stageToInfos(stage) + info.stageFailed(stageFailedMessage) + listenerBus.post(SparkListenerStageEnded(info)) + } + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } return true } @@ -676,7 +692,7 @@ class DAGScheduler( } } catch { case e: Exception => - jobResult = JobFailed(e, job.finalStage.id) + jobResult = JobFailed(e) job.listener.jobFailed(e) } finally { val s = job.finalStage @@ -807,7 +823,7 @@ class DAGScheduler( } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stageToInfos(stage).completionTime = Some(System.currentTimeMillis()) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + listenerBus.post(SparkListenerStageEnded(stageToInfos(stage))) runningStages -= stage } event.reason match { @@ -826,11 +842,9 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - jobIdToActiveJob -= stage.jobId - activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) - jobIdToStageIdsRemove(job.jobId) + cleanupStateForJobAndIndependentStages(job, Some(stage)) listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded)) } job.listener.taskSucceeded(rt.outputId, event.result) @@ -982,7 +996,7 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled") + failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled", None) } } @@ -999,7 +1013,8 @@ class DAGScheduler( stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", + Some(resultStage)) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1008,28 +1023,45 @@ class DAGScheduler( /** * Fails a job and all stages that are only used by that job, and cleans up relevant state. + * + * @param resultStage The result stage for the job, if known. Used to cleanup state for the job + * slightly more efficiently than when not specified. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { + private def failJobAndIndependentStages(job: ActiveJob, failureReason: String, + resultStage: Option[Stage]) { val error = new SparkException(failureReason) job.listener.jobFailed(error) - // Cancel all tasks in independent stages. - val independentStages = removeJobAndIndependentStages(job.jobId) - independentStages.foreach(taskScheduler.cancelTasks) - - // Clean up remaining state we store for the job. - jobIdToActiveJob -= job.jobId - activeJobs -= job - jobIdToStageIds -= job.jobId - val resultStagesForJob = resultStageToJob.keySet.filter( - stage => resultStageToJob(stage).jobId == job.jobId) - if (resultStagesForJob.size != 1) { - logWarning( - s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") + // Cancel all independent, running stages. + val stages = jobIdToStageIds(job.jobId) + if (stages.isEmpty) { + logError("No stages registered for job " + job.jobId) } - resultStageToJob --= resultStagesForJob + stages.foreach { stageId => + val jobsForStage = stageIdToJobIds.get(stageId) + if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" + .format(job.jobId, stageId)) + } else if (jobsForStage.get.size == 1) { + if (!stageIdToStage.contains(stageId)) { + // This is the only job that uses this stage, so fail the stage if it is running. + logError("Missing Stage for stage with id $stageId") + } else { + val stage = stageIdToStage(stageId) + if (runningStages.contains(stage)) { + taskScheduler.cancelTasks(stageId) + val stageInfo = stageToInfos(stage) + stageInfo.stageFailed(failureReason) + listenerBus.post(SparkListenerStageEnded(stageToInfos(stage))) + } + } + } + } + + cleanupStateForJobAndIndependentStages(job, None) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id))) + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 217f8825c2ae9..dc8f88d16fadd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -78,7 +78,7 @@ private[spark] class EventLoggingListener(appName: String, conf: SparkConf) logEvent(event) // Events that trigger a flush - override def onStageCompleted(event: SparkListenerStageCompleted) = + override def onStageEnded(event: SparkListenerStageEnded) = logEvent(event, flushLogger = true) override def onJobStart(event: SparkListenerJobStart) = logEvent(event, flushLogger = true) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 5cecf9416b32c..be361bd2231ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -187,11 +187,15 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener /** * When stage is completed, record stage completion status - * @param stageCompleted Stage completed event + * @param stageEnded Stage ended event */ - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { - val stageId = stageCompleted.stageInfo.stageId - stageLogInfo(stageId, "STAGE_ID=%d STATUS=COMPLETED".format(stageId)) + override def onStageEnded(stageEnded: SparkListenerStageEnded) { + val stageId = stageEnded.stageInfo.stageId + if (stageEnded.stageInfo.failureReason.isEmpty) { + stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED") + } else { + stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED") + } } /** @@ -227,7 +231,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener var info = "JOB_ID=" + jobId jobEnd.jobResult match { case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception, _) => + case JobFailed(exception) => info += " STATUS=FAILED REASON=" exception.getMessage.split("\\s+").foreach(info += _ + "_") case _ => diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala index 3cf4e3077e4a4..047bd27056120 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala @@ -24,5 +24,4 @@ private[spark] sealed trait JobResult private[spark] case object JobSucceeded extends JobResult -// A failed stage ID of -1 means there is not a particular stage that caused the failure -private[spark] case class JobFailed(exception: Exception, failedStageId: Int) extends JobResult +private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 8007b5418741e..e9bfee2248e5b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -64,7 +64,7 @@ private[spark] class JobWaiter[T]( override def jobFailed(exception: Exception): Unit = synchronized { _jobFinished = true - jobResult = JobFailed(exception, -1) + jobResult = JobFailed(exception) this.notifyAll() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index d4eb0ac88d8e8..53f180db9c403 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -32,7 +32,7 @@ sealed trait SparkListenerEvent case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) extends SparkListenerEvent -case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent +case class SparkListenerStageEnded(stageInfo: StageInfo) extends SparkListenerEvent case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent @@ -71,9 +71,9 @@ private[spark] case object SparkListenerShutdown extends SparkListenerEvent */ trait SparkListener { /** - * Called when a stage is completed, with information on the completed stage + * Called when a stage completes or fails, with information on the completed stage */ - def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { } + def onStageEnded(stageEnded: SparkListenerStageEnded) { } /** * Called when a stage is submitted @@ -144,7 +144,7 @@ class StatsReportListener extends SparkListener with Logging { } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { + override def onStageEnded(stageCompleted: SparkListenerStageEnded) { implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stageInfo) showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 729e120497571..7a9615f4355d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -41,8 +41,8 @@ private[spark] trait SparkListenerBus { event match { case stageSubmitted: SparkListenerStageSubmitted => sparkListeners.foreach(_.onStageSubmitted(stageSubmitted)) - case stageCompleted: SparkListenerStageCompleted => - sparkListeners.foreach(_.onStageCompleted(stageCompleted)) + case stageEnded: SparkListenerStageEnded => + sparkListeners.foreach(_.onStageEnded(stageEnded)) case jobStart: SparkListenerJobStart => sparkListeners.foreach(_.onJobStart(jobStart)) case jobEnd: SparkListenerJobEnd => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 8115a7ed7896d..eec409b182ac6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -26,8 +26,17 @@ private[spark] class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfo: RDDInfo) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None + /** Time when all tasks in the stage completed or when the stage was cancelled. */ var completionTime: Option[Long] = None + /** If the stage failed, the reason why. */ + var failureReason: Option[String] = None + var emittedTaskSizeWarning = false + + def stageFailed(reason: String) { + failureReason = Some(reason) + completionTime = Some(System.currentTimeMillis) + } } private[spark] diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 048f671c8788f..fe7b4e5ca9290 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -68,14 +68,19 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { def blockManagerIds = executorIdToBlockManagerId.values.toSeq - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { - val stage = stageCompleted.stageInfo + override def onStageEnded(stageEnded: SparkListenerStageEnded) = synchronized { + val stage = stageEnded.stageInfo val stageId = stage.stageId // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage poolToActiveStages(stageIdToPool(stageId)).remove(stageId) activeStages.remove(stageId) - completedStages += stage - trimIfNecessary(completedStages) + if (stage.failureReason.isEmpty) { + completedStages += stage + trimIfNecessary(completedStages) + } else { + failedStages += stage + trimIfNecessary(failedStages) + } } /** If stages is too large, remove and garbage collect old stages */ @@ -215,20 +220,6 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { } } - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { - jobEnd.jobResult match { - case JobFailed(_, stageId) => - activeStages.get(stageId).foreach { s => - // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage - activeStages.remove(s.stageId) - poolToActiveStages(stageIdToPool(stageId)).remove(s.stageId) - failedStages += s - trimIfNecessary(failedStages) - } - case _ => - } - } - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { val schedulingModeName = diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index a7b24ff695214..b3e819044ab51 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -87,7 +87,7 @@ private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListe _rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo) } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageEnded(stageEnded: SparkListenerStageEnded) = synchronized { // Remove all partitions that are no longer cached _rddInfoMap.retain { case (_, info) => info.numCachedPartitions > 0 } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 2155a8888c85c..18f8b4161e458 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -22,7 +22,7 @@ import java.util.{Properties, UUID} import scala.collection.JavaConverters._ import scala.collection.Map -import org.json4s.DefaultFormats +import org.json4s.{DefaultFormats, JsonAST} import org.json4s.JsonDSL._ import org.json4s.JsonAST._ @@ -42,7 +42,7 @@ private[spark] object JsonProtocol { event match { case stageSubmitted: SparkListenerStageSubmitted => stageSubmittedToJson(stageSubmitted) - case stageCompleted: SparkListenerStageCompleted => + case stageCompleted: SparkListenerStageEnded => stageCompletedToJson(stageCompleted) case taskStart: SparkListenerTaskStart => taskStartToJson(taskStart) @@ -76,7 +76,7 @@ private[spark] object JsonProtocol { ("Properties" -> properties) } - def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted): JValue = { + def stageCompletedToJson(stageCompleted: SparkListenerStageEnded): JValue = { val stageInfo = stageInfoToJson(stageCompleted.stageInfo) ("Event" -> Utils.getFormattedClassName(stageCompleted)) ~ ("Stage Info" -> stageInfo) @@ -260,8 +260,7 @@ private[spark] object JsonProtocol { case JobSucceeded => Utils.emptyJson case jobFailed: JobFailed => val exception = exceptionToJson(jobFailed.exception) - ("Exception" -> exception) ~ - ("Failed Stage ID" -> jobFailed.failedStageId) + JsonAST.JObject("Exception" -> exception) } ("Result" -> result) ~ json } @@ -336,7 +335,7 @@ private[spark] object JsonProtocol { def sparkEventFromJson(json: JValue): SparkListenerEvent = { val stageSubmitted = Utils.getFormattedClassName(SparkListenerStageSubmitted) - val stageCompleted = Utils.getFormattedClassName(SparkListenerStageCompleted) + val stageCompleted = Utils.getFormattedClassName(SparkListenerStageEnded) val taskStart = Utils.getFormattedClassName(SparkListenerTaskStart) val taskGettingResult = Utils.getFormattedClassName(SparkListenerTaskGettingResult) val taskEnd = Utils.getFormattedClassName(SparkListenerTaskEnd) @@ -368,9 +367,9 @@ private[spark] object JsonProtocol { SparkListenerStageSubmitted(stageInfo, properties) } - def stageCompletedFromJson(json: JValue): SparkListenerStageCompleted = { + def stageCompletedFromJson(json: JValue): SparkListenerStageEnded = { val stageInfo = stageInfoFromJson(json \ "Stage Info") - SparkListenerStageCompleted(stageInfo) + SparkListenerStageEnded(stageInfo) } def taskStartFromJson(json: JValue): SparkListenerTaskStart = { @@ -561,8 +560,7 @@ private[spark] object JsonProtocol { case `jobSucceeded` => JobSucceeded case `jobFailed` => val exception = exceptionFromJson(json \ "Exception") - val failedStageId = (json \ "Failed Stage ID").extract[Int] - new JobFailed(exception, failedStageId) + new JobFailed(exception) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2e3026bffba2f..c3d718c23c579 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -64,6 +64,21 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def defaultParallelism() = 2 } + /** Length of time to wait while draining listener events. */ + val WAIT_TIMEOUT_MILLIS = 10000 + val sparkListener = new SparkListener() { + val successfulStages = new HashSet[Int]() + val failedStages = new HashSet[Int]() + override def onStageEnded(stageEnded: SparkListenerStageEnded) { + val stageInfo = stageEnded.stageInfo + if (stageInfo.failureReason.isEmpty) { + successfulStages += stageInfo.stageId + } else { + failedStages += stageInfo.stageId + } + } + } + var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null @@ -89,13 +104,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont /** The list of results that DAGScheduler has collected. */ val results = new HashMap[Int, Any]() var failure: Exception = _ - val listener = new JobListener() { + val jobListener = new JobListener() { override def taskSucceeded(index: Int, result: Any) = results.put(index, result) override def jobFailed(exception: Exception) = { failure = exception } } before { sc = new SparkContext("local", "DAGSchedulerSuite") + sparkListener.successfulStages.clear() + sparkListener.failedStages.clear() + sc.addSparkListener(sparkListener) taskSets.clear() cancelledStages.clear() cacheLocations.clear() @@ -187,7 +205,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, allowLocal: Boolean = false, - listener: JobListener = listener): Int = { + listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) return jobId @@ -231,7 +249,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, jobListener)) assert(results === Map(0 -> 42)) assertDataStructuresEmpty } @@ -262,6 +280,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted due to stage failure: some failure") + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty } @@ -270,6 +291,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val jobId = submit(rdd, Array(0)) cancel(jobId) assert(failure.getMessage === s"Job $jobId cancelled") + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty } @@ -354,6 +378,13 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) assert(failure.getMessage === s"Job aborted due to stage failure: $stageFailureMessage") + + // Listener bus should get told about the map stage failing, but not the reduce stage + // (since the reduce stage hasn't been started yet). + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(1)) + assert(sparkListener.failedStages.size === 1) + assertDataStructuresEmpty } @@ -398,8 +429,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) - + assert(cancelledStages.contains(1)) + + // Make sure the listeners got told about both failed stages. + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.successfulStages.isEmpty) + assert(sparkListener.failedStages.contains(1)) + assert(sparkListener.failedStages.contains(3)) + assert(sparkListener.failedStages.size === 2) + assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assertDataStructuresEmpty diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 7c843772bc2e0..349c864935b3e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -257,7 +257,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } - override def onStageCompleted(stage: SparkListenerStageCompleted) { + override def onStageEnded(stage: SparkListenerStageEnded) { stageInfos(stage.stageInfo) = taskInfoMetrics taskInfoMetrics = mutable.Buffer.empty } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 7bab7da8fed68..06d349c9512d1 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -34,7 +34,7 @@ class JsonProtocolSuite extends FunSuite { test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) - val stageCompleted = SparkListenerStageCompleted(makeStageInfo(101, 201, 301, 401L, 501L)) + val stageEnded = SparkListenerStageEnded(makeStageInfo(101, 201, 301, 401L, 501L)) val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 444L)) val taskGettingResult = SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 3000L)) val taskEnd = SparkListenerTaskEnd(1, "ShuffleMapTask", Success, @@ -54,7 +54,7 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) testEvent(stageSubmitted, stageSubmittedJsonString) - testEvent(stageCompleted, stageCompletedJsonString) + testEvent(stageEnded, stageCompletedJsonString) testEvent(taskStart, taskStartJsonString) testEvent(taskGettingResult, taskGettingResultJsonString) testEvent(taskEnd, taskEndJsonString) @@ -89,7 +89,7 @@ class JsonProtocolSuite extends FunSuite { // JobResult val exception = new Exception("Out of Memory! Please restock film.") exception.setStackTrace(stackTrace) - val jobFailed = JobFailed(exception, 2) + val jobFailed = JobFailed(exception) testJobResult(JobSucceeded) testJobResult(jobFailed) @@ -180,7 +180,7 @@ class JsonProtocolSuite extends FunSuite { case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) => assert(e1.properties === e2.properties) assertEquals(e1.stageInfo, e2.stageInfo) - case (e1: SparkListenerStageCompleted, e2: SparkListenerStageCompleted) => + case (e1: SparkListenerStageEnded, e2: SparkListenerStageEnded) => assertEquals(e1.stageInfo, e2.stageInfo) case (e1: SparkListenerTaskStart, e2: SparkListenerTaskStart) => assert(e1.stageId === e2.stageId) @@ -294,7 +294,6 @@ class JsonProtocolSuite extends FunSuite { (result1, result2) match { case (JobSucceeded, JobSucceeded) => case (r1: JobFailed, r2: JobFailed) => - assert(r1.failedStageId === r2.failedStageId) assertEquals(r1.exception, r2.exception) case _ => fail("Job results don't match in types!") }