Skip to content

Commit

Permalink
[SPARK-50768][SQL][CORE][FOLLOW-UP] Apply TaskContext.createResourceU…
Browse files Browse the repository at this point in the history
…ninterruptibly() to risky resource creations

### What changes were proposed in this pull request?

This is a follow-up PR for #49413. This PR intends to apply `TaskContext.createResourceUninterruptibly()` to the resource creation where it has the potential risk of resource leak in the case of task cancellation.

### Why are the changes needed?

Avoid resource leak.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

n/a

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49508 from Ngone51/SPARK-50768-followup.

Authored-by: Yi Wu <yi.wu@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
Ngone51 authored and dongjoon-hyun committed Jan 16, 2025
1 parent 90801c2 commit 9b32334
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 39 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,12 @@ class BarrierTaskContext private[spark] (
override private[spark] def interruptible(): Boolean = taskContext.interruptible()

override private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String)
: Unit = {
: Unit = {
taskContext.pendingInterrupt(threadToInterrupt, reason)
}

override private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T)
: T = {
: T = {
taskContext.createResourceUninterruptibly(resourceBuilder)
}
}
Expand Down
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,13 @@ class NewHadoopRDD[K, V](
private var finished = false
private var reader =
try {
Utils.tryInitializeResource(
format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext)
) { reader =>
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
reader
Utils.createResourceUninterruptiblyIfInTaskThread {
Utils.tryInitializeResource(
format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext)
) { reader =>
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
reader
}
}
} catch {
case e: FileNotFoundException if ignoreMissingFiles =>
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3086,6 +3086,18 @@ private[spark] object Utils
files.toSeq
}

/**
* Create a resource uninterruptibly if we are in a task thread (i.e., TaskContext.get() != null).
* Otherwise, create the resource normally. This is mainly used in the situation where we want to
* create a multi-layer resource in a task thread. The uninterruptible behavior ensures we don't
* leak the underlying resources when there is a task cancellation request,
*/
def createResourceUninterruptiblyIfInTaskThread[R <: Closeable](createResource: => R): R = {
Option(TaskContext.get()).map(_.createResourceUninterruptibly {
createResource
}).getOrElse(createResource)
}

/**
* Return the median number of a long array
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{SerializableConfiguration, Utils}

private[libsvm] class LibSVMOutputWriter(
val path: String,
Expand Down Expand Up @@ -156,7 +156,9 @@ private[libsvm] class LibSVMFileFormat
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

(file: PartitionedFile) => {
val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))

val points = linesReader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* Common functions for parsing CSV files
Expand Down Expand Up @@ -99,7 +100,9 @@ object TextInputCSVDataSource extends CSVDataSource {
headerChecker: CSVHeaderChecker,
requiredSchema: StructType): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, parser.options.charset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ object TextInputJsonDataSource extends JsonDataSource {
file: PartitionedFile,
parser: JacksonParser,
schema: StructType): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
val textParser = parser.options.encoding
.map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text))
Expand Down Expand Up @@ -211,7 +213,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
schema: StructType): Iterator[InternalRow] = {
def partitionedFileString(ignored: Any): UTF8String = {
Utils.tryWithResource {
CodecStreams.createInputStreamWithCloseResource(conf, file.toPath)
Utils.createResourceUninterruptiblyIfInTaskThread {
CodecStreams.createInputStreamWithCloseResource(conf, file.toPath)
}
} { inputStream =>
UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* A data source for reading text files. The text files must be encoded as UTF-8.
Expand Down Expand Up @@ -119,10 +119,12 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {

(file: PartitionedFile) => {
val confValue = conf.value.value
val reader = if (!textOptions.wholeText) {
new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
val reader = Utils.createResourceUninterruptiblyIfInTaskThread {
if (!textOptions.wholeText) {
new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
}
}
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close()))
if (requiredSchema.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,21 +261,23 @@ case class ParquetPartitionReaderFactory(
val int96RebaseSpec = DataSourceUtils.int96RebaseSpec(
footerFileMetaData.getKeyValueMetaData.get,
int96RebaseModeInRead)
Utils.tryInitializeResource(
buildReaderFunc(
file.partitionValues,
pushed,
convertTz,
datetimeRebaseSpec,
int96RebaseSpec)
) { reader =>
reader match {
case vectorizedReader: VectorizedParquetRecordReader =>
vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter))
case _ =>
reader.initialize(split, hadoopAttemptContext)
Utils.createResourceUninterruptiblyIfInTaskThread {
Utils.tryInitializeResource(
buildReaderFunc(
file.partitionValues,
pushed,
convertTz,
datetimeRebaseSpec,
int96RebaseSpec)
) { reader =>
reader match {
case vectorizedReader: VectorizedParquetRecordReader =>
vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter))
case _ =>
reader.initialize(split, hadoopAttemptContext)
}
reader
}
reader
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.text.TextOptions
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* A factory used to create Text readers.
Expand All @@ -47,10 +47,12 @@ case class TextPartitionReaderFactory(

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val confValue = broadcastedConf.value.value
val reader = if (!options.wholeText) {
new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
val reader = Utils.createResourceUninterruptiblyIfInTaskThread {
if (!options.wholeText) {
new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
}
}
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close()))
val iter = if (readDataSchema.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* Common functions for parsing XML files
Expand Down Expand Up @@ -97,7 +98,9 @@ object TextInputXmlDataSource extends XmlDataSource {
parser: StaxXmlParser,
schema: StructType): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, None, conf)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, None, conf)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, parser.options.charset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.Utils

class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
override def shortName(): String = "test"
Expand Down Expand Up @@ -96,8 +97,9 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
// Uses a simple projection to simulate column pruning
val projection = new InterpretedProjection(outputAttributes, inputAttributes)

val unsafeRowIterator =
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line =>
val unsafeRowIterator = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
).map { line =>
val record = line.toString
new GenericInternalRow(record.split(",", -1).zip(fieldTypes).map {
case (v, dataType) =>
Expand Down

0 comments on commit 9b32334

Please sign in to comment.