-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-50768][CORE] Introduce TaskContext.createResourceUninterruptibly to avoid stream leak by task interruption #49413
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark | ||
|
||
import java.io.Closeable | ||
import java.util.{Properties, TimerTask} | ||
import java.util.concurrent.{ScheduledThreadPoolExecutor, TimeUnit} | ||
|
||
|
@@ -273,6 +274,18 @@ class BarrierTaskContext private[spark] ( | |
} | ||
|
||
override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties | ||
|
||
override private[spark] def interruptible(): Boolean = taskContext.interruptible() | ||
|
||
override private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String) | ||
: Unit = { | ||
taskContext.pendingInterrupt(threadToInterrupt, reason) | ||
} | ||
|
||
override private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T) | ||
: T = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
taskContext.createResourceUninterruptibly(resourceBuilder) | ||
} | ||
} | ||
|
||
@Experimental | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark | ||
|
||
import java.io.Closeable | ||
import java.util.{Properties, Stack} | ||
import javax.annotation.concurrent.GuardedBy | ||
|
||
|
@@ -82,6 +83,13 @@ private[spark] class TaskContextImpl( | |
// If defined, the corresponding task has been killed and this option contains the reason. | ||
@volatile private var reasonIfKilled: Option[String] = None | ||
|
||
// The pending interruption request, which is blocked by uninterruptible resource creation. | ||
// Should be protected by `TaskContext.synchronized`. | ||
private var pendingInterruptRequest: Option[(Option[Thread], String)] = None | ||
|
||
// Whether this task is able to be interrupted. Should be protected by `TaskContext.synchronized`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or maybe mark as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you referring to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we already have |
||
private var _interruptible = true | ||
|
||
// Whether the task has completed. | ||
private var completed: Boolean = false | ||
|
||
|
@@ -296,4 +304,42 @@ private[spark] class TaskContextImpl( | |
private[spark] override def fetchFailed: Option[FetchFailedException] = _fetchFailedException | ||
|
||
private[spark] override def getLocalProperties: Properties = localProperties | ||
|
||
|
||
override def interruptible(): Boolean = TaskContext.synchronized(_interruptible) | ||
|
||
override def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This API looks weird if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When |
||
TaskContext.synchronized { | ||
pendingInterruptRequest = Some((threadToInterrupt, reason)) | ||
} | ||
} | ||
|
||
def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T): T = { | ||
|
||
@inline def interruptIfRequired(): Unit = { | ||
pendingInterruptRequest.foreach { case (threadToInterrupt, reason) => | ||
markInterrupted(reason) | ||
threadToInterrupt.foreach(_.interrupt()) | ||
} | ||
killTaskIfInterrupted() | ||
} | ||
|
||
TaskContext.synchronized { | ||
interruptIfRequired() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a safeguard that the caller may mistakenly call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. This is for the case where resource creation happens after the task has been marked as |
||
|
||
if (_interruptible) { | ||
Ngone51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_interruptible = false | ||
} | ||
} | ||
try { | ||
val resource = resourceBuilder | ||
addTaskCompletionListener[Unit](_ => resource.close()) | ||
resource | ||
} finally { | ||
TaskContext.synchronized { | ||
_interruptible = true | ||
interruptIfRequired() | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark | ||
|
||
import java.io.{File, FileOutputStream, InputStream, ObjectOutputStream} | ||
import java.util.concurrent.{Semaphore, TimeUnit} | ||
import java.util.concurrent.atomic.AtomicInteger | ||
|
||
|
@@ -35,7 +36,7 @@ import org.apache.spark.executor.ExecutorExitCode | |
import org.apache.spark.internal.config._ | ||
import org.apache.spark.internal.config.Deploy._ | ||
import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerExecutorRemoved, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} | ||
import org.apache.spark.util.ThreadUtils | ||
import org.apache.spark.util.{ThreadUtils, Utils} | ||
|
||
/** | ||
* Test suite for cancelling running jobs. We run the cancellation tasks for single job action | ||
|
@@ -712,6 +713,128 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft | |
assert(executionOfInterruptibleCounter.get() < numElements) | ||
} | ||
|
||
Seq(true, false).foreach { interruptible => | ||
|
||
val (hint1, hint2) = if (interruptible) { | ||
(" not", "") | ||
} else { | ||
("", " not") | ||
} | ||
|
||
val testName = s"SPARK-50768:$hint1 use TaskContext.createResourceUninterruptibly " + | ||
s"would$hint2 cause stream leak on task interruption" | ||
|
||
test(testName) { | ||
import org.apache.spark.JobCancellationSuite._ | ||
withTempDir { dir => | ||
|
||
// `InterruptionSensitiveInputStream` is designed to easily leak the underlying stream | ||
Ngone51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// when task thread interruption happens during its initialization. | ||
class InterruptionSensitiveInputStream(fileHint: String) extends InputStream { | ||
private var underlying: InputStream = _ | ||
|
||
def initialize(): InputStream = { | ||
val in: InputStream = new InputStream { | ||
|
||
open() | ||
|
||
private def dumpFile(typeName: String): Unit = { | ||
val file = new File(dir, s"$typeName.$fileHint") | ||
Ngone51 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val out = new FileOutputStream(file) | ||
val objOut = new ObjectOutputStream(out) | ||
objOut.writeBoolean(true) | ||
objOut.close() | ||
} | ||
|
||
private def open(): Unit = { | ||
dumpFile("open") | ||
} | ||
|
||
override def close(): Unit = { | ||
dumpFile("close") | ||
} | ||
|
||
override def read(): Int = -1 | ||
} | ||
|
||
// Leave some time for the task to be interrupted during the | ||
// creation of `InterruptionSensitiveInputStream`. | ||
Thread.sleep(5000) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How important is this sleep within the task? Could it potentially make the test flaky? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is necessary to ensure the task is interrupted during |
||
|
||
underlying = in | ||
underlying | ||
} | ||
|
||
override def read(): Int = -1 | ||
|
||
override def close(): Unit = { | ||
if (underlying != null) { | ||
underlying.close() | ||
} | ||
} | ||
} | ||
|
||
def createStream(fileHint: String): Unit = { | ||
if (interruptible) { | ||
Utils.tryInitializeResource { | ||
new InterruptionSensitiveInputStream(fileHint) | ||
} { | ||
_.initialize() | ||
} | ||
} else { | ||
TaskContext.get().createResourceUninterruptibly[java.io.InputStream] { | ||
Utils.tryInitializeResource { | ||
new InterruptionSensitiveInputStream(fileHint) | ||
} { | ||
_.initialize() | ||
} | ||
} | ||
} | ||
} | ||
|
||
sc = new SparkContext("local[2]", "test interrupt streams") | ||
|
||
sc.addSparkListener(new SparkListener { | ||
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { | ||
// Sleep some time to ensure task has started | ||
Thread.sleep(1000) | ||
taskStartedSemaphore.release() | ||
} | ||
|
||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { | ||
if (taskEnd.reason.isInstanceOf[TaskKilled]) { | ||
taskCancelledSemaphore.release() | ||
} | ||
} | ||
}) | ||
|
||
sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") | ||
|
||
val fileHint = if (interruptible) "interruptible" else "uninterruptible" | ||
val future = sc.parallelize(1 to 100, 1).mapPartitions { _ => | ||
createStream(fileHint) | ||
Iterator.single(1) | ||
}.collectAsync() | ||
|
||
taskStartedSemaphore.acquire() | ||
future.cancel() | ||
taskCancelledSemaphore.acquire() | ||
|
||
val fileOpen = new File(dir, s"open.$fileHint") | ||
val fileClose = new File(dir, s"close.$fileHint") | ||
assert(fileOpen.exists()) | ||
|
||
if (interruptible) { | ||
// The underlying stream leaks when the stream creation is interruptible. | ||
assert(!fileClose.exists()) | ||
} else { | ||
// The underlying stream won't leak when the stream creation is uninterruptible. | ||
assert(fileClose.exists()) | ||
} | ||
} | ||
} | ||
} | ||
|
||
def testCount(): Unit = { | ||
// Cancel before launching any tasks | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. Indentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll fix it with a followup. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed at #49508