diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index b5c6033bd9da4..c38d552a27aa5 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -278,12 +278,12 @@ class BarrierTaskContext private[spark] ( override private[spark] def interruptible(): Boolean = taskContext.interruptible() override private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String) - : Unit = { + : Unit = { taskContext.pendingInterrupt(threadToInterrupt, reason) } override private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T) - : T = { + : T = { taskContext.createResourceUninterruptibly(resourceBuilder) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index d619602305890..032ee3cfbca9c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -244,11 +244,13 @@ class NewHadoopRDD[K, V]( private var finished = false private var reader = try { - Utils.tryInitializeResource( - format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext) - ) { reader => - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - reader + Utils.createResourceUninterruptiblyIfInTaskThread { + Utils.tryInitializeResource( + format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext) + ) { reader => + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + reader + } } } catch { case e: FileNotFoundException if ignoreMissingFiles => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1efe181a8c38a..c15a1f0d173ca 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -3086,6 +3086,18 @@ private[spark] object Utils files.toSeq } + /** + * Create a resource uninterruptibly if we are in a task thread (i.e., TaskContext.get() != null). + * Otherwise, create the resource normally. This is mainly used in the situation where we want to + * create a multi-layer resource in a task thread. The uninterruptible behavior ensures we don't + * leak the underlying resources when there is a task cancellation request, + */ + def createResourceUninterruptiblyIfInTaskThread[R <: Closeable](createResource: => R): R = { + Option(TaskContext.get()).map(_.createResourceUninterruptibly { + createResource + }).getOrElse(createResource) + } + /** * Return the median number of a long array * diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 3581693c05be3..9e641161d9b97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} private[libsvm] class LibSVMOutputWriter( val path: String, @@ -156,7 +156,9 @@ private[libsvm] class LibSVMFileFormat sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread( + new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + ) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val points = linesReader diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index a8730c20dbcb5..655632aa6d9b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Common functions for parsing CSV files @@ -99,7 +100,9 @@ object TextInputCSVDataSource extends CSVDataSource { headerChecker: CSVHeaderChecker, requiredSchema: StructType): Iterator[InternalRow] = { val lines = { - val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) + val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread( + new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) + ) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index cb4c4f5290880..e9b31875bd7b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -129,7 +129,9 @@ object TextInputJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { - val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) + val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread( + new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) + ) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val textParser = parser.options.encoding .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) @@ -211,7 +213,9 @@ object MultiLineJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { - CodecStreams.createInputStreamWithCloseResource(conf, file.toPath) + Utils.createResourceUninterruptiblyIfInTaskThread { + CodecStreams.createInputStreamWithCloseResource(conf, file.toPath) + } } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 5de51e55816e7..3f2024126717d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StringType, StructType} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A data source for reading text files. The text files must be encoded as UTF-8. @@ -119,10 +119,12 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val confValue = conf.value.value - val reader = if (!textOptions.wholeText) { - new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue) - } else { - new HadoopFileWholeTextReader(file, confValue) + val reader = Utils.createResourceUninterruptiblyIfInTaskThread { + if (!textOptions.wholeText) { + new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue) + } else { + new HadoopFileWholeTextReader(file, confValue) + } } Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close())) if (requiredSchema.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index d3643f7426db0..ac15456f0c3d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -261,21 +261,23 @@ case class ParquetPartitionReaderFactory( val int96RebaseSpec = DataSourceUtils.int96RebaseSpec( footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) - Utils.tryInitializeResource( - buildReaderFunc( - file.partitionValues, - pushed, - convertTz, - datetimeRebaseSpec, - int96RebaseSpec) - ) { reader => - reader match { - case vectorizedReader: VectorizedParquetRecordReader => - vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter)) - case _ => - reader.initialize(split, hadoopAttemptContext) + Utils.createResourceUninterruptiblyIfInTaskThread { + Utils.tryInitializeResource( + buildReaderFunc( + file.partitionValues, + pushed, + convertTz, + datetimeRebaseSpec, + int96RebaseSpec) + ) { reader => + reader match { + case vectorizedReader: VectorizedParquetRecordReader => + vectorizedReader.initialize(split, hadoopAttemptContext, Option.apply(fileFooter)) + case _ => + reader.initialize(split, hadoopAttemptContext) + } + reader } - reader } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextPartitionReaderFactory.scala index 6542c1c2c3e93..7ee0fa878a8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextPartitionReaderFactory.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.text.TextOptions import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A factory used to create Text readers. @@ -47,10 +47,12 @@ case class TextPartitionReaderFactory( override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val confValue = broadcastedConf.value.value - val reader = if (!options.wholeText) { - new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue) - } else { - new HadoopFileWholeTextReader(file, confValue) + val reader = Utils.createResourceUninterruptiblyIfInTaskThread { + if (!options.wholeText) { + new HadoopFileLinesReader(file, options.lineSeparatorInRead, confValue) + } else { + new HadoopFileWholeTextReader(file, confValue) + } } Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close())) val iter = if (readDataSchema.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index 8a179afb0f357..da86dd63cfd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Common functions for parsing XML files @@ -97,7 +98,9 @@ object TextInputXmlDataSource extends XmlDataSource { parser: StaxXmlParser, schema: StructType): Iterator[InternalRow] = { val lines = { - val linesReader = new HadoopFileLinesReader(file, None, conf) + val linesReader = Utils.createResourceUninterruptiblyIfInTaskThread( + new HadoopFileLinesReader(file, None, conf) + ) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9aa4cb77b226a..4060cc7172e60 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.Utils class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "test" @@ -96,8 +97,9 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { // Uses a simple projection to simulate column pruning val projection = new InterpretedProjection(outputAttributes, inputAttributes) - val unsafeRowIterator = - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line => + val unsafeRowIterator = Utils.createResourceUninterruptiblyIfInTaskThread( + new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + ).map { line => val record = line.toString new GenericInternalRow(record.split(",", -1).zip(fieldTypes).map { case (v, dataType) =>