Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50768][CORE] Introduce TaskContext.createResourceUninterruptibly to avoid stream leak by task interruption #49413

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.io.Closeable
import java.util.{Properties, TimerTask}
import java.util.concurrent.{ScheduledThreadPoolExecutor, TimeUnit}

Expand Down Expand Up @@ -273,6 +274,18 @@ class BarrierTaskContext private[spark] (
}

override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties

override private[spark] def interruptible(): Boolean = taskContext.interruptible()

override private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String)
: Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix it with a followup. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed at #49508

taskContext.pendingInterrupt(threadToInterrupt, reason)
}

override private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T)
: T = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taskContext.createResourceUninterruptibly(resourceBuilder)
}
}

@Experimental
Expand Down
18 changes: 17 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark

import java.io.Serializable
import java.io.Closeable
import java.util.Properties

import org.apache.spark.annotation.{DeveloperApi, Evolving, Since}
Expand Down Expand Up @@ -305,4 +305,20 @@ abstract class TaskContext extends Serializable {

/** Gets local properties set upstream in the driver. */
private[spark] def getLocalProperties: Properties

Ngone51 marked this conversation as resolved.
Show resolved Hide resolved

/** Whether the current task is allowed to interrupt. */
private[spark] def interruptible(): Boolean

/**
* Pending the interruption request until the task is able to
Ngone51 marked this conversation as resolved.
Show resolved Hide resolved
* interrupt after creating the resource uninterruptibly.
*/
private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String): Unit

/**
* Creating a closeable resource uninterruptibly. A task is not allowed to interrupt in this
* state until the resource creation finishes.
*/
private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T): T
Ngone51 marked this conversation as resolved.
Show resolved Hide resolved
}
46 changes: 46 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.io.Closeable
import java.util.{Properties, Stack}
import javax.annotation.concurrent.GuardedBy

Expand Down Expand Up @@ -82,6 +83,13 @@ private[spark] class TaskContextImpl(
// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None

// The pending interruption request, which is blocked by uninterruptible resource creation.
// Should be protected by `TaskContext.synchronized`.
private var pendingInterruptRequest: Option[(Option[Thread], String)] = None

// Whether this task is able to be interrupted. Should be protected by `TaskContext.synchronized`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe mark as @transient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to @volatile?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have TaskContext.synchronized protected, @volatile could be redundant. And if we use @volatile only, it's not thread safety as we could have two threads modify _interruptible concurrently.

private var _interruptible = true

// Whether the task has completed.
private var completed: Boolean = false

Expand Down Expand Up @@ -296,4 +304,42 @@ private[spark] class TaskContextImpl(
private[spark] override def fetchFailed: Option[FetchFailedException] = _fetchFailedException

private[spark] override def getLocalProperties: Properties = localProperties


override def interruptible(): Boolean = TaskContext.synchronized(_interruptible)

override def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API looks weird if threadToInterrupt is None, as there is nothing to interrupt.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When threadToInterrupt=None, it would still "interrupt" the task by invoking TaskContext.markInterrupted(). It just won't invoke Thread.interrupt() on the task thread.

TaskContext.synchronized {
pendingInterruptRequest = Some((threadToInterrupt, reason))
}
}

def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T): T = {

@inline def interruptIfRequired(): Unit = {
pendingInterruptRequest.foreach { case (threadToInterrupt, reason) =>
markInterrupted(reason)
threadToInterrupt.foreach(_.interrupt())
}
killTaskIfInterrupted()
}

TaskContext.synchronized {
interruptIfRequired()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a safeguard that the caller may mistakenly call pendingInterrupt even if the task is interruptable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This is for the case where resource creation happens after the task has been marked as interrupted. In that case, pendingInterruptRequest is None and reasonIfKilled is not None, and killTaskIfInterrupted() would throw TaskKilledException() to stop the task thread.


if (_interruptible) {
Ngone51 marked this conversation as resolved.
Show resolved Hide resolved
_interruptible = false
}
}
try {
val resource = resourceBuilder
addTaskCompletionListener[Unit](_ => resource.close())
resource
} finally {
TaskContext.synchronized {
_interruptible = true
interruptIfRequired()
}
}
}
}
24 changes: 19 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.Properties

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.internal.plugin.PluginContainer
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
Expand Down Expand Up @@ -70,7 +71,7 @@ private[spark] abstract class Task[T](
val jobId: Option[Int] = None,
val appId: Option[String] = None,
val appAttemptId: Option[String] = None,
val isBarrier: Boolean = false) extends Serializable {
val isBarrier: Boolean = false) extends Serializable with Logging {

@transient lazy val metrics: TaskMetrics =
SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics))
Expand Down Expand Up @@ -231,10 +232,23 @@ private[spark] abstract class Task[T](
require(reason != null)
_reasonIfKilled = reason
if (context != null) {
context.markInterrupted(reason)
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
TaskContext.synchronized {
if (context.interruptible()) {
context.markInterrupted(reason)
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
} else {
val threadToInterrupt = if (interruptThread && taskThread != null) {
Ngone51 marked this conversation as resolved.
Show resolved Hide resolved
Some(taskThread)
} else {
None
}
logInfo(log"Task ${MDC(LogKeys.TASK_ID, context.taskAttemptId())} " +
log"is currently not interruptible. ")
context.pendingInterrupt(threadToInterrupt, reason)
}
}
}
}
}
125 changes: 124 additions & 1 deletion core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.io.{File, FileOutputStream, InputStream, ObjectOutputStream}
import java.util.concurrent.{Semaphore, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

Expand All @@ -35,7 +36,7 @@ import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Deploy._
import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerExecutorRemoved, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.{ThreadUtils, Utils}

/**
* Test suite for cancelling running jobs. We run the cancellation tasks for single job action
Expand Down Expand Up @@ -712,6 +713,128 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
assert(executionOfInterruptibleCounter.get() < numElements)
}

Seq(true, false).foreach { interruptible =>

val (hint1, hint2) = if (interruptible) {
(" not", "")
} else {
("", " not")
}

val testName = s"SPARK-50768:$hint1 use TaskContext.createResourceUninterruptibly " +
s"would$hint2 cause stream leak on task interruption"

test(testName) {
import org.apache.spark.JobCancellationSuite._
withTempDir { dir =>

// `InterruptionSensitiveInputStream` is designed to easily leak the underlying stream
Ngone51 marked this conversation as resolved.
Show resolved Hide resolved
// when task thread interruption happens during its initialization.
class InterruptionSensitiveInputStream(fileHint: String) extends InputStream {
private var underlying: InputStream = _

def initialize(): InputStream = {
val in: InputStream = new InputStream {

open()

private def dumpFile(typeName: String): Unit = {
val file = new File(dir, s"$typeName.$fileHint")
Ngone51 marked this conversation as resolved.
Show resolved Hide resolved
val out = new FileOutputStream(file)
val objOut = new ObjectOutputStream(out)
objOut.writeBoolean(true)
objOut.close()
}

private def open(): Unit = {
dumpFile("open")
}

override def close(): Unit = {
dumpFile("close")
}

override def read(): Int = -1
}

// Leave some time for the task to be interrupted during the
// creation of `InterruptionSensitiveInputStream`.
Thread.sleep(5000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is this sleep within the task? Could it potentially make the test flaky?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is necessary to ensure the task is interrupted during InterruptionSensitiveInputStream#initialize() so that we can test the leaked stream. Increase the sleep time should less likely be flaky.


underlying = in
underlying
}

override def read(): Int = -1

override def close(): Unit = {
if (underlying != null) {
underlying.close()
}
}
}

def createStream(fileHint: String): Unit = {
if (interruptible) {
Utils.tryInitializeResource {
new InterruptionSensitiveInputStream(fileHint)
} {
_.initialize()
}
} else {
TaskContext.get().createResourceUninterruptibly[java.io.InputStream] {
Utils.tryInitializeResource {
new InterruptionSensitiveInputStream(fileHint)
} {
_.initialize()
}
}
}
}

sc = new SparkContext("local[2]", "test interrupt streams")

sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
// Sleep some time to ensure task has started
Thread.sleep(1000)
taskStartedSemaphore.release()
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
if (taskEnd.reason.isInstanceOf[TaskKilled]) {
taskCancelledSemaphore.release()
}
}
})

sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true")

val fileHint = if (interruptible) "interruptible" else "uninterruptible"
val future = sc.parallelize(1 to 100, 1).mapPartitions { _ =>
createStream(fileHint)
Iterator.single(1)
}.collectAsync()

taskStartedSemaphore.acquire()
future.cancel()
taskCancelledSemaphore.acquire()

val fileOpen = new File(dir, s"open.$fileHint")
val fileClose = new File(dir, s"close.$fileHint")
assert(fileOpen.exists())

if (interruptible) {
// The underlying stream leaks when the stream creation is interruptible.
assert(!fileClose.exists())
} else {
// The underlying stream won't leak when the stream creation is uninterruptible.
assert(fileClose.exists())
}
}
}
}

def testCount(): Unit = {
// Cancel before launching any tasks
{
Expand Down
Loading