diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 0000000000000..66037da9c6ab8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class handles sort-based shuffle's `bypassMergeSort` write path, which is used for shuffles + * for which no Ordering and no Aggregator is given and the number of partitions is + * less than `spark.shuffle.sort.bypassMergeThreshold`. + * + * This path used to be part of [[ExternalSorter]] but was refactored into its own class in order to + * reduce code complexity; see SPARK-7855 for more details. + * + * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private BlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new BlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + for (BlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + Option maybeWriteMetrics = context.taskMetrics().shuffleWriteMetrics(); + if (maybeWriteMetrics.isDefined()) { + maybeWriteMetrics.get().incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + } + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (BlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleFileWriter.scala b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java similarity index 70% rename from core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleFileWriter.scala rename to core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java index 48233e7be0ff2..b61245834820d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleFileWriter.scala +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -15,20 +15,23 @@ * limitations under the License. */ -package org.apache.spark.shuffle.sort +package org.apache.spark.shuffle.sort; -import java.io.{IOException, File} +import java.io.File; +import java.io.IOException; -import org.apache.spark.TaskContext -import org.apache.spark.storage.BlockId +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; /** - * Interface for objects that [[SortShuffleWriter]] uses to write its output files. + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. */ -private[spark] trait SortShuffleFileWriter[K, V] { +interface SortShuffleFileWriter { - @throws[IOException] - def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit + void insertAll(Iterator> records) throws IOException; /** * Write all the data added into this shuffle sorter into a file in the disk store. This is @@ -39,12 +42,10 @@ def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - @throws[IOException] - def writePartitionedFile( - blockId: BlockId, - context: TaskContext, - outputFile: File): Array[Long] - - @throws[IOException] - def stop(): Unit + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.scala deleted file mode 100644 index afdda429cbe4e..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort - -import java.io.{File, FileInputStream, FileOutputStream} - -import org.apache.spark._ -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, BlockManager, BlockObjectWriter} -import org.apache.spark.util.Utils -import org.apache.spark.util.collection._ - -/** - * This class handles sort-based shuffle's `bypassMergeSort` write path, which is used for shuffles - * for which no Ordering and no Aggregator is given and the number of partitions is - * less than `spark.shuffle.sort.bypassMergeThreshold`. - * - * This path used to be part of [[ExternalSorter]] but was refactored into its own class in order to - * reduce code complexity; see SPARK-7855 for more details. - * - * There have been proposals to completely remove this code path; see SPARK-6026 for details. - */ -private[spark] class BypassMergeSortShuffleWriter[K, V]( - conf: SparkConf, - blockManager: BlockManager, - partitioner: Partitioner, - writeMetrics: ShuffleWriteMetrics, - serializer: Serializer) - extends Logging with SortShuffleFileWriter[K, V] { - - private[this] val numPartitions = partitioner.numPartitions - - /** Array of file writers for each partition */ - private[this] var partitionWriters: Array[BlockObjectWriter] = _ - - def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = { - assert (partitionWriters == null) - if (records.hasNext) { - val serInstance = serializer.newInstance() - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - val openStartTime = System.nanoTime - partitionWriters = Array.fill(numPartitions) { - val (blockId, file) = blockManager.diskBlockManager.createTempShuffleBlock() - val writer = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics) - writer.open() - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) - - while (records.hasNext) { - val record = records.next() - val key: K = record._1 - partitionWriters(partitioner.getPartition(key)).write(key, record._2) - } - } - } - - /** - * Write all the data added into this writer into a single file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * the per-partition binary files. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) - */ - def writePartitionedFile(blockId: BlockId, context: TaskContext, file: File): Array[Long] = { - if (partitionWriters == null) { - // We were passed an empty iterator - Array.fill(numPartitions)(0L) - } else { - partitionWriters.foreach(_.commitAndClose()) - - // Track location of each range in the output file - val lengths = new Array[Long](numPartitions) - - val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) - val out = new FileOutputStream(file, true) - val writeStartTime = System.nanoTime - Utils.tryWithSafeFinally { - for (i <- 0 until numPartitions) { - val in = new FileInputStream(partitionWriters(i).fileSegment().file) - Utils.tryWithSafeFinally { - lengths(i) = Utils.copyStream(in, out, closeStreams = false, transferToEnabled) - } { - in.close() - } - if (!blockManager.diskBlockManager.getFile(partitionWriters(i).blockId).delete()) { - logError("Unable to delete file for partition i. ") - } - } - } { - out.close() - context.taskMetrics().shuffleWriteMetrics.foreach { m => - m.incShuffleWriteTime(System.nanoTime - writeStartTime) - } - } - - lengths - } - } - - def stop(): Unit = { - if (partitionWriters != null) { - partitionWriters.foreach { w => - w.revertPartialWritesAndClose() - blockManager.diskBlockManager.getFile(w.blockId).delete() - } - partitionWriters = null - } - } -}