Skip to content

Commit

Permalink
Notify SparkListeners when stages fail or are cancelled.
Browse files Browse the repository at this point in the history
Previously, when stages fail or get cancelled, the SparkListener is only notified
indirectly through the SparkListenerJobEnd, where we sometimes pass in a single
stage that failed.  This worked before job cancellation, because jobs would only fail
due to a single stage failure.  However, with job cancellation, multiple running stages
can fail when a job gets cancelled.  Right now, this is not handled correctly, which
results in stages that get stuck in the “Running Stages” window in the UI even
though they’re dead.

This PR changes the SparkListenerStageCompleted event to a SparkListenerStageEnded
event, and uses this event to tell SparkListeners when stages fail in addition to when
they complete successfully.  This change is NOT publicly backward compatible for two
reasons.  First, it changes the SparkListener interface.  We could alternately add a new event,
SparkListenerStageFailed, and keep the existing SparkListenerStageCompleted.  However,
this is less consistent with the listener events for tasks / jobs ending, and will result in some
code duplication for listeners (because failed and completed stages are handled in similar
ways).  Note that I haven’t finished updating the JSON code to correctly handle the new event
because I’m waiting for feedback on whether this is a good or bad idea (hence the “WIP”).

It is also not backwards compatible because it changes the publicly visible JobWaiter.jobFailed()
method to no longer include a stage that caused the failure.  I think this change should definitely
stay, because with cancellation (as described above), a failure isn’t necessarily caused by a
single stage.
  • Loading branch information
kayousterhout committed Apr 8, 2014
1 parent 6dc5f58 commit 320c7c7
Show file tree
Hide file tree
Showing 15 changed files with 171 additions and 100 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
private def awaitResult(): Try[T] = {
jobWaiter.awaitResult() match {
case JobSucceeded => scala.util.Success(resultFunc)
case JobFailed(e: Exception, _) => scala.util.Failure(e)
case JobFailed(e: Exception) => scala.util.Failure(e)
}
}
}
Expand Down
122 changes: 77 additions & 45 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,22 +342,24 @@ class DAGScheduler(
}

/**
* Removes job and any stages that are not needed by any other job. Returns the set of ids for
* stages that were removed. The associated tasks for those stages need to be cancelled if we
* got here via job cancellation.
* Removes state for job and any stages that are not needed by any other job. Does not
* handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
*
* @param job The job whose state to cleanup.
* @param resultStage Specifies the result stage for the job; if set to None, this method
* searches resultStagesToJob to find and cleanup the appropriate result stage.
*/
private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
val registeredStages = jobIdToStageIds(jobId)
val independentStages = new HashSet[Int]()
if (registeredStages.isEmpty) {
logError("No stages registered for job " + jobId)
private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) {
val registeredStages = jobIdToStageIds.get(job.jobId)
if (registeredStages.isEmpty || registeredStages.get.isEmpty) {
logError("No stages registered for job " + job.jobId)
} else {
stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
case (stageId, jobSet) =>
if (!jobSet.contains(jobId)) {
if (!jobSet.contains(job.jobId)) {
logError(
"Job %d not registered for stage %d even though that stage was registered for the job"
.format(jobId, stageId))
.format(job.jobId, stageId))
} else {
def removeStage(stageId: Int) {
// data structures based on Stage
Expand Down Expand Up @@ -394,23 +396,28 @@ class DAGScheduler(
.format(stageId, stageIdToStage.size))
}

jobSet -= jobId
jobSet -= job.jobId
if (jobSet.isEmpty) { // no other job needs this stage
independentStages += stageId
removeStage(stageId)
}
}
}
}
independentStages.toSet
}
jobIdToStageIds -= job.jobId
jobIdToActiveJob -= job.jobId
activeJobs -= job

private def jobIdToStageIdsRemove(jobId: Int) {
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to remove unregistered job " + jobId)
if (resultStage.isEmpty) {
// Clean up result stages.
val resultStagesForJob = resultStageToJob.keySet.filter(
stage => resultStageToJob(stage).jobId == job.jobId)
if (resultStagesForJob.size != 1) {
logWarning(
s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
}
resultStageToJob --= resultStagesForJob
} else {
removeJobAndIndependentStages(jobId)
jobIdToStageIds -= jobId
resultStageToJob -= resultStage.get
}
}

Expand Down Expand Up @@ -460,7 +467,7 @@ class DAGScheduler(
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception, _) =>
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
}
Expand Down Expand Up @@ -606,7 +613,16 @@ class DAGScheduler(
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, -1)))
// Tell the listeners the all of the stages have ended. Don't bother cancelling the
// stages because if the DAG scheduler is stopped, the entire application is in the
// process of getting stopped.
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
runningStages.foreach { stage =>
val info = stageToInfos(stage)
info.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageEnded(info))
}
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
}
return true
}
Expand Down Expand Up @@ -676,7 +692,7 @@ class DAGScheduler(
}
} catch {
case e: Exception =>
jobResult = JobFailed(e, job.finalStage.id)
jobResult = JobFailed(e)
job.listener.jobFailed(e)
} finally {
val s = job.finalStage
Expand Down Expand Up @@ -807,7 +823,7 @@ class DAGScheduler(
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
listenerBus.post(SparkListenerStageEnded(stageToInfos(stage)))
runningStages -= stage
}
event.reason match {
Expand All @@ -826,11 +842,9 @@ class DAGScheduler(
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
jobIdToActiveJob -= stage.jobId
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
jobIdToStageIdsRemove(job.jobId)
cleanupStateForJobAndIndependentStages(job, Some(stage))
listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
}
job.listener.taskSucceeded(rt.outputId, event.result)
Expand Down Expand Up @@ -982,7 +996,7 @@ class DAGScheduler(
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled")
failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled", None)
}
}

Expand All @@ -999,7 +1013,8 @@ class DAGScheduler(
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason",
Some(resultStage))
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
Expand All @@ -1008,28 +1023,45 @@ class DAGScheduler(

/**
* Fails a job and all stages that are only used by that job, and cleans up relevant state.
*
* @param resultStage The result stage for the job, if known. Used to cleanup state for the job
* slightly more efficiently than when not specified.
*/
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String,
resultStage: Option[Stage]) {
val error = new SparkException(failureReason)
job.listener.jobFailed(error)

// Cancel all tasks in independent stages.
val independentStages = removeJobAndIndependentStages(job.jobId)
independentStages.foreach(taskScheduler.cancelTasks)

// Clean up remaining state we store for the job.
jobIdToActiveJob -= job.jobId
activeJobs -= job
jobIdToStageIds -= job.jobId
val resultStagesForJob = resultStageToJob.keySet.filter(
stage => resultStageToJob(stage).jobId == job.jobId)
if (resultStagesForJob.size != 1) {
logWarning(
s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
// Cancel all independent, running stages.
val stages = jobIdToStageIds(job.jobId)
if (stages.isEmpty) {
logError("No stages registered for job " + job.jobId)
}
resultStageToJob --= resultStagesForJob
stages.foreach { stageId =>
val jobsForStage = stageIdToJobIds.get(stageId)
if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) {
logError(
"Job %d not registered for stage %d even though that stage was registered for the job"
.format(job.jobId, stageId))
} else if (jobsForStage.get.size == 1) {
if (!stageIdToStage.contains(stageId)) {
// This is the only job that uses this stage, so fail the stage if it is running.
logError("Missing Stage for stage with id $stageId")
} else {
val stage = stageIdToStage(stageId)
if (runningStages.contains(stage)) {
taskScheduler.cancelTasks(stageId)
val stageInfo = stageToInfos(stage)
stageInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageEnded(stageToInfos(stage)))
}
}
}
}

cleanupStateForJobAndIndependentStages(job, None)

listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id)))
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private[spark] class EventLoggingListener(appName: String, conf: SparkConf)
logEvent(event)

// Events that trigger a flush
override def onStageCompleted(event: SparkListenerStageCompleted) =
override def onStageEnded(event: SparkListenerStageEnded) =
logEvent(event, flushLogger = true)
override def onJobStart(event: SparkListenerJobStart) =
logEvent(event, flushLogger = true)
Expand Down
14 changes: 9 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,15 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener

/**
* When stage is completed, record stage completion status
* @param stageCompleted Stage completed event
* @param stageEnded Stage ended event
*/
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
val stageId = stageCompleted.stageInfo.stageId
stageLogInfo(stageId, "STAGE_ID=%d STATUS=COMPLETED".format(stageId))
override def onStageEnded(stageEnded: SparkListenerStageEnded) {
val stageId = stageEnded.stageInfo.stageId
if (stageEnded.stageInfo.failureReason.isEmpty) {
stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED")
} else {
stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED")
}
}

/**
Expand Down Expand Up @@ -227,7 +231,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
var info = "JOB_ID=" + jobId
jobEnd.jobResult match {
case JobSucceeded => info += " STATUS=SUCCESS"
case JobFailed(exception, _) =>
case JobFailed(exception) =>
info += " STATUS=FAILED REASON="
exception.getMessage.split("\\s+").foreach(info += _ + "_")
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,4 @@ private[spark] sealed trait JobResult

private[spark] case object JobSucceeded extends JobResult

// A failed stage ID of -1 means there is not a particular stage that caused the failure
private[spark] case class JobFailed(exception: Exception, failedStageId: Int) extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ private[spark] class JobWaiter[T](

override def jobFailed(exception: Exception): Unit = synchronized {
_jobFinished = true
jobResult = JobFailed(exception, -1)
jobResult = JobFailed(exception)
this.notifyAll()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ sealed trait SparkListenerEvent
case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null)
extends SparkListenerEvent

case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent
case class SparkListenerStageEnded(stageInfo: StageInfo) extends SparkListenerEvent

case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent

Expand Down Expand Up @@ -71,9 +71,9 @@ private[spark] case object SparkListenerShutdown extends SparkListenerEvent
*/
trait SparkListener {
/**
* Called when a stage is completed, with information on the completed stage
* Called when a stage completes or fails, with information on the completed stage
*/
def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { }
def onStageEnded(stageEnded: SparkListenerStageEnded) { }

/**
* Called when a stage is submitted
Expand Down Expand Up @@ -144,7 +144,7 @@ class StatsReportListener extends SparkListener with Logging {
}
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
override def onStageEnded(stageCompleted: SparkListenerStageEnded) {
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ private[spark] trait SparkListenerBus {
event match {
case stageSubmitted: SparkListenerStageSubmitted =>
sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
case stageCompleted: SparkListenerStageCompleted =>
sparkListeners.foreach(_.onStageCompleted(stageCompleted))
case stageEnded: SparkListenerStageEnded =>
sparkListeners.foreach(_.onStageEnded(stageEnded))
case jobStart: SparkListenerJobStart =>
sparkListeners.foreach(_.onJobStart(jobStart))
case jobEnd: SparkListenerJobEnd =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,17 @@ private[spark]
class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfo: RDDInfo) {
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
/** Time when all tasks in the stage completed or when the stage was cancelled. */
var completionTime: Option[Long] = None
/** If the stage failed, the reason why. */
var failureReason: Option[String] = None

var emittedTaskSizeWarning = false

def stageFailed(reason: String) {
failureReason = Some(reason)
completionTime = Some(System.currentTimeMillis)
}
}

private[spark]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,19 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener {

def blockManagerIds = executorIdToBlockManagerId.values.toSeq

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
val stage = stageCompleted.stageInfo
override def onStageEnded(stageEnded: SparkListenerStageEnded) = synchronized {
val stage = stageEnded.stageInfo
val stageId = stage.stageId
// Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage
poolToActiveStages(stageIdToPool(stageId)).remove(stageId)
activeStages.remove(stageId)
completedStages += stage
trimIfNecessary(completedStages)
if (stage.failureReason.isEmpty) {
completedStages += stage
trimIfNecessary(completedStages)
} else {
failedStages += stage
trimIfNecessary(failedStages)
}
}

/** If stages is too large, remove and garbage collect old stages */
Expand Down Expand Up @@ -215,20 +220,6 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener {
}
}

override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
jobEnd.jobResult match {
case JobFailed(_, stageId) =>
activeStages.get(stageId).foreach { s =>
// Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage
activeStages.remove(s.stageId)
poolToActiveStages(stageIdToPool(stageId)).remove(s.stageId)
failedStages += s
trimIfNecessary(failedStages)
}
case _ =>
}
}

override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) {
synchronized {
val schedulingModeName =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListe
_rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo)
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
override def onStageEnded(stageEnded: SparkListenerStageEnded) = synchronized {
// Remove all partitions that are no longer cached
_rddInfoMap.retain { case (_, info) => info.numCachedPartitions > 0 }
}
Expand Down
Loading

0 comments on commit 320c7c7

Please sign in to comment.