Skip to content

Commit

Permalink
[SPARK-25250] : On successful completion of a task attempt on a parti…
Browse files Browse the repository at this point in the history
…tion id, kill other running task attempts on that same partition

The fix that this PR addresses is as follows:
Whenever any Result Task gets successfully completed, we simply mark the corresponding partition id as completed in all attempts for that particular stage. As a result, we do not see any Killed tasks due to TaskCommitDenied Exceptions showing up in the UI. Also, since, the method defined uses hash maps and arrays for efficient searching and processing, so as a result, it's time complexity is not related to the number of tasks, hence, it is also efficient.
  • Loading branch information
pgandhi committed Oct 23, 2018
1 parent 78c8bd2 commit 5ad6efd
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,8 @@ private[spark] class DAGScheduler(
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
taskScheduler.markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
task.partitionId, task.stageId)
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
markStageAsFinished(resultStage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,7 @@ private[spark] trait TaskScheduler {
*/
def applicationAttemptId(): Option[String]

def markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
partitionId: Int, stageId: Int): Unit

}
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,23 @@ private[spark] class TaskSchedulerImpl(
}
}

override def markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
partitionId: Int, stageId: Int): Unit = {
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
val index: Option[Int] = tsm.partitionToIndex.get(partitionId)
if (!index.isEmpty) {
tsm.markPartitionIdAsCompletedForTaskAttempt(index.get)
val taskInfoList = tsm.taskAttempts(index.get)
taskInfoList.foreach { taskInfo =>
if (taskInfo.running) {
killTaskAttempt(taskInfo.taskId, false, "Corresponding Partition Id " + partitionId +
" has been marked as Completed")
}
}
}
}
}

/**
* Called to indicate that all task attempts (including speculated tasks) associated with the
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,10 @@ private[spark] class TaskSetManager(
def executorAdded() {
recomputeLocality()
}

def markPartitionIdAsCompletedForTaskAttempt(index: Int): Unit = {
successful(index) = true
}
}

private[spark] object TaskSetManager {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
override def applicationAttemptId(): Option[String] = None
override def markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
partitionId: Int, stageId: Int): Unit = {}
}

/** Length of time to wait while draining listener events. */
Expand Down Expand Up @@ -667,6 +669,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def workerRemoved(workerId: String, host: String, message: String): Unit = {}
override def applicationAttemptId(): Option[String] = None
override def markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
partitionId: Int, stageId: Int): Unit = {}
}
val noKillScheduler = new DAGScheduler(
sc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,6 @@ private class DummyTaskScheduler extends TaskScheduler {
accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
blockManagerId: BlockManagerId,
executorMetrics: ExecutorMetrics): Boolean = true
override def markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
partitionId: Int, stageId: Int): Unit = {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util.HashSet

import scala.collection.mutable.HashMap

Expand All @@ -37,6 +38,14 @@ class FakeSchedulerBackend extends SchedulerBackend {
def reviveOffers() {}
def defaultParallelism(): Int = 1
def maxNumConcurrentTasks(): Int = 0
val killedTaskIds: HashSet[Long] = new HashSet[Long]()
override def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit = {
killedTaskIds.add(taskId)
}
}

class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach
Expand Down Expand Up @@ -1136,4 +1145,26 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test"))
assert(tsm.isZombie)
}
test("SPARK-25250 On successful completion of a task attempt on a partition id, kill other" +
" running task attempts on that same partition") {
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
val firstAttempt = FakeTask.createTaskSet(10, stageAttemptId = 0)
taskScheduler.submitTasks(firstAttempt)
val offersFirstAttempt = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
taskScheduler.resourceOffers(offersFirstAttempt)
val tsm0 = taskScheduler.taskSetManagerForAttempt(0, 0).get
val matchingTaskInfoFirstAttempt = tsm0.taskAttempts(0).head
tsm0.handleFailedTask(matchingTaskInfoFirstAttempt.taskId, TaskState.FAILED,
FetchFailed(null, 0, 0, 0, "fetch failed"))
val secondAttempt = FakeTask.createTaskSet(10, stageAttemptId = 1)
taskScheduler.submitTasks(secondAttempt)
val offersSecondAttempt = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
taskScheduler.resourceOffers(offersSecondAttempt)
taskScheduler.markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(2, 0)
val tsm1 = taskScheduler.taskSetManagerForAttempt(0, 1).get
val indexInTsm = tsm1.partitionToIndex(2)
val matchingTaskInfoSecondAttempt = tsm1.taskAttempts.flatten.filter(_.index == indexInTsm).head
assert(taskScheduler.backend.asInstanceOf[FakeSchedulerBackend].killedTaskIds.contains(
matchingTaskInfoSecondAttempt.taskId))
}
}

0 comments on commit 5ad6efd

Please sign in to comment.