Skip to content

Commit

Permalink
[SPARK-25250] : Addressing Reviews January 2, 2019
Browse files Browse the repository at this point in the history
  • Loading branch information
pgandhi committed Jan 2, 2019
1 parent ee5bc68 commit 7677aec
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,17 @@ private[spark] class TaskSchedulerImpl(
override def markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
partitionId: Int, stageId: Int): Unit = {
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
val index: Option[Int] = tsm.partitionToIndex.get(partitionId)
if (!index.isEmpty) {
tsm.markPartitionIdAsCompletedForTaskAttempt(index.get)
val taskInfoList = tsm.taskAttempts(index.get)
taskInfoList.foreach { taskInfo =>
if (taskInfo.running) {
killTaskAttempt(taskInfo.taskId, false, "Corresponding Partition Id " + partitionId +
" has been marked as Completed")
tsm.partitionToIndex.get(partitionId) match {
case Some(index) =>
tsm.markPartitionIdAsCompletedForTaskAttempt(index)
val taskInfoList = tsm.taskAttempts(index)
taskInfoList.filter(_.running).foreach { taskInfo =>
killTaskAttempt(taskInfo.taskId, false,
s"Corresponding Partition ID $partitionId has been marked as Completed")
}
}

case None =>
logError(s"No corresponding index found for partition ID $partitionId")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util.HashSet

import scala.collection.mutable.HashMap
import scala.collection.mutable.Set
import scala.concurrent.duration._

import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq}
Expand All @@ -40,7 +40,7 @@ class FakeSchedulerBackend extends SchedulerBackend {
def reviveOffers() {}
def defaultParallelism(): Int = 1
def maxNumConcurrentTasks(): Int = 0
val killedTaskIds: HashSet[Long] = new HashSet[Long]()
val killedTaskIds: Set[Long] = Set[Long]()
override def killTask(
taskId: Long,
executorId: String,
Expand Down Expand Up @@ -1328,22 +1328,30 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test"))
assert(tsm.isZombie)
}

test("SPARK-25250 On successful completion of a task attempt on a partition id, kill other" +
" running task attempts on that same partition") {
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()

val firstAttempt = FakeTask.createTaskSet(10, stageAttemptId = 0)
taskScheduler.submitTasks(firstAttempt)

val offersFirstAttempt = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
taskScheduler.resourceOffers(offersFirstAttempt)

val tsm0 = taskScheduler.taskSetManagerForAttempt(0, 0).get
val matchingTaskInfoFirstAttempt = tsm0.taskAttempts(0).head
tsm0.handleFailedTask(matchingTaskInfoFirstAttempt.taskId, TaskState.FAILED,
FetchFailed(null, 0, 0, 0, "fetch failed"))

val secondAttempt = FakeTask.createTaskSet(10, stageAttemptId = 1)
taskScheduler.submitTasks(secondAttempt)

val offersSecondAttempt = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
taskScheduler.resourceOffers(offersSecondAttempt)

taskScheduler.markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(2, 0)

val tsm1 = taskScheduler.taskSetManagerForAttempt(0, 1).get
val indexInTsm = tsm1.partitionToIndex(2)
val matchingTaskInfoSecondAttempt = tsm1.taskAttempts.flatten.filter(_.index == indexInTsm).head
Expand Down

0 comments on commit 7677aec

Please sign in to comment.