Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50768][SQL][CORE][FOLLOW-UP] Apply TaskContext.createResourceUninterruptibly() to risky resource creations #49508

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,12 @@ class BarrierTaskContext private[spark] (
override private[spark] def interruptible(): Boolean = taskContext.interruptible()

override private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String)
: Unit = {
: Unit = {
taskContext.pendingInterrupt(threadToInterrupt, reason)
}

override private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T)
: T = {
: T = {
taskContext.createResourceUninterruptibly(resourceBuilder)
}
}
Expand Down
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,13 @@ class NewHadoopRDD[K, V](
private var finished = false
private var reader =
try {
Utils.tryInitializeResource(
format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext)
) { reader =>
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
reader
Utils.createResourceUninterruptiblyIfInTaskThread {
Utils.tryInitializeResource(
format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext)
) { reader =>
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
reader
}
}
} catch {
case e: FileNotFoundException if ignoreMissingFiles =>
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3086,6 +3086,18 @@ private[spark] object Utils
files.toSeq
}

/**
* Create a resource uninterruptibly if we are in a task thread (i.e., TaskContext.get() != null).
* Otherwise, create the resource normally. This is mainly used in the situation where we want to
* create a multi-layer resource in a task thread. The uninterruptible behavior ensures we don't
* leak the underlying resources when there is a task cancellation request,
*/
def createResourceUninterruptiblyIfInTaskThread[R <: Closeable](createResource: => R): R = {
Option(TaskContext.get()).map(_.createResourceUninterruptibly {
createResource
}).getOrElse(createResource)
}

/**
* Return the median number of a long array
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{SerializableConfiguration, Utils}

private[libsvm] class LibSVMOutputWriter(
val path: String,
Expand Down Expand Up @@ -156,7 +156,9 @@ private[libsvm] class LibSVMFileFormat
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

(file: PartitionedFile) => {
val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))

val points = linesReader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* Common functions for parsing CSV files
Expand Down Expand Up @@ -99,7 +100,9 @@ object TextInputCSVDataSource extends CSVDataSource {
headerChecker: CSVHeaderChecker,
requiredSchema: StructType): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, parser.options.charset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ object TextInputJsonDataSource extends JsonDataSource {
file: PartitionedFile,
parser: JacksonParser,
schema: StructType): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
val textParser = parser.options.encoding
.map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text))
Expand Down Expand Up @@ -211,7 +213,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
schema: StructType): Iterator[InternalRow] = {
def partitionedFileString(ignored: Any): UTF8String = {
Utils.tryWithResource {
CodecStreams.createInputStreamWithCloseResource(conf, file.toPath)
Utils.createResourceUninterruptiblyIfInTaskThread {
CodecStreams.createInputStreamWithCloseResource(conf, file.toPath)
}
} { inputStream =>
UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* A data source for reading text files. The text files must be encoded as UTF-8.
Expand Down Expand Up @@ -119,10 +119,12 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {

(file: PartitionedFile) => {
val confValue = conf.value.value
val reader = if (!textOptions.wholeText) {
new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
val reader = Utils.createResourceUninterruptiblyIfInTaskThread {
if (!textOptions.wholeText) {
new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
}
}
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close()))
if (requiredSchema.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,21 +261,23 @@ case class ParquetPartitionReaderFactory(
val int96RebaseSpec = DataSourceUtils.int96RebaseSpec(
footerFileMetaData.getKeyValueMetaData.get,
int96RebaseModeInRead)
Utils.tryInitializeResource(
buildReaderFunc(
file.partitionValues,
pushed,
convertTz,
datetimeRebaseSpec,
int96RebaseSpec)
) { reader =>
reader match {
case vectorizedReader: VectorizedParquetRecordReader =>
vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter))
case _ =>
reader.initialize(split, hadoopAttemptContext)
Utils.createResourceUninterruptiblyIfInTaskThread {
Utils.tryInitializeResource(
buildReaderFunc(
file.partitionValues,
pushed,
convertTz,
datetimeRebaseSpec,
int96RebaseSpec)
) { reader =>
reader match {
case vectorizedReader: VectorizedParquetRecordReader =>
vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter))
case _ =>
reader.initialize(split, hadoopAttemptContext)
}
reader
}
reader
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.text.TextOptions
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{SerializableConfiguration, Utils}

/**
* A factory used to create Text readers.
Expand All @@ -47,10 +47,12 @@ case class TextPartitionReaderFactory(

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val confValue = broadcastedConf.value.value
val reader = if (!options.wholeText) {
new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
val reader = Utils.createResourceUninterruptiblyIfInTaskThread {
if (!options.wholeText) {
new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
}
}
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close()))
val iter = if (readDataSchema.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* Common functions for parsing XML files
Expand Down Expand Up @@ -97,7 +98,9 @@ object TextInputXmlDataSource extends XmlDataSource {
parser: StaxXmlParser,
schema: StructType): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, None, conf)
val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, None, conf)
)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, parser.options.charset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.Utils

class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
override def shortName(): String = "test"
Expand Down Expand Up @@ -96,8 +97,9 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
// Uses a simple projection to simulate column pruning
val projection = new InterpretedProjection(outputAttributes, inputAttributes)

val unsafeRowIterator =
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line =>
val unsafeRowIterator = Utils.createResourceUninterruptiblyIfInTaskThread(
new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
).map { line =>
val record = line.toString
new GenericInternalRow(record.split(",", -1).zip(fieldTypes).map {
case (v, dataType) =>
Expand Down
Loading