Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32613][CORE] Fix regressions in DecommissionWorkerSuite #29422

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,15 @@ private[spark] class CoarseGrainedExecutorBackend(
override def run(): Unit = {
var lastTaskRunningTime = System.nanoTime()
val sleep_time = 1000 // 1s

// This config is internal and only used by unit tests to force an executor
// to hang around for longer when decommissioned.
val initialSleepMillis = env.conf.getInt(
agrawaldevesh marked this conversation as resolved.
Show resolved Hide resolved
"spark.test.executor.decommission.initial.sleep.millis", sleep_time)
if (initialSleepMillis > 0) {
Thread.sleep(initialSleepMillis)
}
while (true) {
logInfo("Checking to see if we can shutdown.")
Thread.sleep(sleep_time)
if (executor == null || executor.numRunningTasks == 0) {
if (env.conf.get(STORAGE_DECOMMISSION_ENABLED)) {
logInfo("No running tasks, checking migrations")
Expand All @@ -323,6 +328,7 @@ private[spark] class CoarseGrainedExecutorBackend(
// move forward.
lastTaskRunningTime = System.nanoTime()
}
Thread.sleep(sleep_time)
agrawaldevesh marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,16 @@ package object config {
.timeConf(TimeUnit.SECONDS)
.createOptional

private[spark] val DECOMMISSIONED_EXECUTORS_REMEMBER_AFTER_REMOVAL_TTL =
ConfigBuilder("spark.executor.decommission.removed.infoCacheTTL")
.doc("Duration for which a decommissioned executor's information will be kept after its" +
"removal. Keeping the decommissioned info after removal helps pinpoint fetch failures to " +
"decommissioning even after the mapper executor has been decommissioned. This allows " +
"eager recovery from fetch failures caused by decommissioning, increasing job robustness.")
.version("3.1.0")
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString("5m")

private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir")
.doc("Staging directory used while submitting applications.")
.version("2.0.0")
Expand Down
41 changes: 29 additions & 12 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1846,7 +1846,14 @@ private[spark] class DAGScheduler(
execId = bmAddress.executorId,
fileLost = true,
hostToUnregisterOutputs = hostToUnregisterOutputs,
maybeEpoch = Some(task.epoch))
maybeEpoch = Some(task.epoch),
// shuffleFileLostEpoch is ignored when a host is decommissioned because some
// decommissioned executors on that host might have been removed before this fetch
// failure and might have bumped up the shuffleFileLostEpoch. We ignore that, and
// proceed with unconditional removal of shuffle outputs from all executors on that
// host, including from those that we still haven't confirmed as lost due to heartbeat
// delays.
ignoreShuffleFileLostEpoch = isHostDecommissioned)
}
}

Expand Down Expand Up @@ -2012,7 +2019,8 @@ private[spark] class DAGScheduler(
execId: String,
fileLost: Boolean,
hostToUnregisterOutputs: Option[String],
maybeEpoch: Option[Long] = None): Unit = {
maybeEpoch: Option[Long] = None,
ignoreShuffleFileLostEpoch: Boolean = false): Unit = {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
logDebug(s"Considering removal of executor $execId; " +
s"fileLost: $fileLost, currentEpoch: $currentEpoch")
Expand All @@ -2022,16 +2030,25 @@ private[spark] class DAGScheduler(
blockManagerMaster.removeExecutor(execId)
clearCacheLocs()
}
if (fileLost &&
(!shuffleFileLostEpoch.contains(execId) || shuffleFileLostEpoch(execId) < currentEpoch)) {
shuffleFileLostEpoch(execId) = currentEpoch
hostToUnregisterOutputs match {
case Some(host) =>
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnHost(host)
case None =>
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnExecutor(execId)
if (fileLost) {
agrawaldevesh marked this conversation as resolved.
Show resolved Hide resolved
val remove = if (ignoreShuffleFileLostEpoch) {
true
} else if (!shuffleFileLostEpoch.contains(execId) ||
shuffleFileLostEpoch(execId) < currentEpoch) {
shuffleFileLostEpoch(execId) = currentEpoch
true
} else {
false
}
if (remove) {
hostToUnregisterOutputs match {
case Some(host) =>
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnHost(host)
case None =>
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnExecutor(execId)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap, HashSet}
import scala.util.Random

import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.ExecutorMetrics
Expand Down Expand Up @@ -136,7 +139,21 @@ private[spark] class TaskSchedulerImpl(
// IDs of the tasks running on each executor
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]

private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
// We add executors here when we first get decommission notification for them. Executors can
// continue to run even after being asked to decommission, but they will eventually exit.
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]

// When they exit and we know of that via heartbeat failure, we will add them to this cache.
// This cache is consulted to know if a fetch failure is because a source executor was
// decommissioned.
lazy val decommissionedExecutorsRemoved = CacheBuilder.newBuilder()
.expireAfterWrite(
conf.get(DECOMMISSIONED_EXECUTORS_REMEMBER_AFTER_REMOVAL_TTL), TimeUnit.SECONDS)
.ticker(new Ticker{
override def read(): Long = TimeUnit.MILLISECONDS.toNanos(clock.getTimeMillis())
})
.build[String, ExecutorDecommissionInfo]()
.asMap()

def runningTasksByExecutors: Map[String, Int] = synchronized {
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
Expand Down Expand Up @@ -910,7 +927,7 @@ private[spark] class TaskSchedulerImpl(
// if we heard isHostDecommissioned ever true, then we keep that one since it is
// most likely coming from the cluster manager and thus authoritative
val oldDecomInfo = executorsPendingDecommission.get(executorId)
if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
executorsPendingDecommission(executorId) = decommissionInfo
}
}
Expand All @@ -921,7 +938,9 @@ private[spark] class TaskSchedulerImpl(

override def getExecutorDecommissionInfo(executorId: String)
: Option[ExecutorDecommissionInfo] = synchronized {
executorsPendingDecommission.get(executorId)
executorsPendingDecommission
.get(executorId)
.orElse(Option(decommissionedExecutorsRemoved.get(executorId)))
}

override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
Expand Down Expand Up @@ -1027,7 +1046,9 @@ private[spark] class TaskSchedulerImpl(
}
}

executorsPendingDecommission -= executorId

val decomInfo = executorsPendingDecommission.remove(executorId)
decomInfo.foreach(decommissionedExecutorsRemoved.put(executorId, _))

if (reason != LossReasonPending) {
executorIdToHost -= executorId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ class DecommissionWorkerSuite
}
}

// Unlike TestUtils.withListener, it also waits for the job to be done
agrawaldevesh marked this conversation as resolved.
Show resolved Hide resolved
def withListener(sc: SparkContext, listener: RootStageAwareListener)
(body: SparkListener => Unit): Unit = {
sc.addSparkListener(listener)
try {
body(listener)
sc.listenerBus.waitUntilEmpty()
listener.waitForJobDone()
} finally {
sc.listenerBus.removeListener(listener)
}
}

test("decommission workers should not result in job failure") {
val maxTaskFailures = 2
val numTimesToKillWorkers = maxTaskFailures + 1
Expand All @@ -109,7 +122,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
Thread.sleep(5 * 1000L); 1
}.count()
Expand Down Expand Up @@ -164,7 +177,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
val sleepTimeSeconds = if (pid == 0) 1 else 10
Thread.sleep(sleepTimeSeconds * 1000L)
Expand All @@ -190,10 +203,11 @@ class DecommissionWorkerSuite
}
}

test("decommission workers ensure that fetch failures lead to rerun") {
def testFetchFailures(initialSleepMillis: Int): Unit = {
createWorkers(2)
sc = createSparkContext(
config.Tests.TEST_NO_STAGE_RETRY.key -> "false",
"spark.test.executor.decommission.initial.sleep.millis" -> initialSleepMillis.toString,
config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE.key -> "true")
val executorIdToWorkerInfo = getExecutorToWorkerAssignments
val executorToDecom = executorIdToWorkerInfo.keysIterator.next
Expand All @@ -212,22 +226,29 @@ class DecommissionWorkerSuite
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskInfo = taskEnd.taskInfo
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
taskEnd.stageAttemptId == 0) {
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
decommissionWorkerOnMaster(workerToDecom,
"decommission worker after task on it is done")
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
val executorId = SparkEnv.get.executorId
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
val context = TaskContext.get()
// Only sleep in the first attempt to create the required window for decommissioning.
// Subsequent attempts don't need to be delayed to speed up the test.
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
}
agrawaldevesh marked this conversation as resolved.
Show resolved Hide resolved
List(1).iterator
}, preservesPartitioning = true)
.repartition(1).mapPartitions(iter => {
val context = TaskContext.get()
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
// Wait a bit for the decommissioning to be triggered in the listener
Thread.sleep(5000)
// MapIndex is explicitly -1 to force the entire host to be decommissioned
// However, this will cause both the tasks in the preceding stage since the host here is
// "localhost" (shortcoming of this single-machine unit test in that all the workers
Expand All @@ -246,6 +267,14 @@ class DecommissionWorkerSuite
assert(tasksSeen.size === 6, s"Expected 6 tasks but got $tasksSeen")
}

test("decommission stalled workers ensure that fetch failures lead to rerun") {
testFetchFailures(3600 * 1000)
}

test("decommission eager workers ensure that fetch failures lead to rerun") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick note: this test seems flaky:

sbt.ForkMain$ForkError: org.scalatest.exceptions.TestFailedException: 9 did not equal 6 Expected 6 tasks but got List(0:0:0:0-SUCCESS, 0:0:1:0-SUCCESS, 1:0:0:0-FAILED, 0:1:0:0-SUCCESS, 0:1:1:0-SUCCESS, 1:1:0:0-FAILED, 0:2:0:0-SUCCESS, 0:2:1:0-SUCCESS, 1:2:0:0-SUCCESS)
	at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:472)
	at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:471)
	at org.scalatest.Assertions$.newAssertionFailedException(Assertions.scala:1231)
	at org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:1295)
	at org.apache.spark.deploy.DecommissionWorkerSuite.testFetchFailures(DecommissionWorkerSuite.scala:267)
	at org.apache.spark.deploy.DecommissionWorkerSuite.$anonfun$new$14(DecommissionWorkerSuite.scala:275)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85)
	at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83)
	at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
	at org.scalatest.Transformer.apply(Transformer.scala:22)
	at org.scalatest.Transformer.apply(Transformer.scala:20)
	at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:189)
	at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:176)
	at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:187)
	at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:199)
	at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)
	at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:199)
	at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:181)
	at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:61)
	at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234)
	at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227)
	at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:61)
	at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:232)
	at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401)
	at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396)

I'll file a JIRA if I see more often.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Do you have a link for the test job? I want to take a look at it.

Copy link
Member

@HyukjinKwon HyukjinKwon Sep 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I retriggered the job ... I will let you know when I see one more time ..

testFetchFailures(0)
}

private abstract class RootStageAwareListener extends SparkListener {
private var rootStageId: Option[Int] = None
private val tasksFinished = new ConcurrentLinkedQueue[String]()
Expand All @@ -265,23 +294,31 @@ class DecommissionWorkerSuite
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobEnd.jobResult match {
case JobSucceeded => jobDone.set(true)
case JobFailed(exception) => logError(s"Job failed", exception)
}
}

protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}

protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}

private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
String = {
s"${stageId}:${stageAttemptId}:" +
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
}

override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
logInfo(s"Task started: $signature")
if (isRootStageId(taskStart.stageId)) {
rootTasksStarted.add(taskStart.taskInfo)
handleRootTaskStart(taskStart)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
logInfo(s"Task End $taskSignature")
tasksFinished.add(taskSignature)
if (isRootStageId(taskEnd.stageId)) {
Expand All @@ -291,8 +328,13 @@ class DecommissionWorkerSuite
}

def getTasksFinished(): Seq[String] = {
assert(jobDone.get(), "Job isn't successfully done yet")
tasksFinished.asScala.toSeq
tasksFinished.asScala.toList
}

def waitForJobDone(): Unit = {
eventually(timeout(10.seconds), interval(100.milliseconds)) {
assert(jobDone.get(), "Job isn't successfully done yet")
}
}
}

Expand Down
Loading