From 6a35716575253dbadb2d060817eef2c500dcc952 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 25 May 2015 11:25:01 -0700 Subject: [PATCH] Refactor logic for deciding when to bypass --- .../sort/BypassMergeSortShuffleWriter.scala | 2 +- .../shuffle/sort/SortShuffleWriter.scala | 27 ++++++++++++------- .../util/collection/ExternalSorter.scala | 21 ++++++--------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.scala index 3220bd419a598..29538bfac6914 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.sort import java.io.{File, FileInputStream, FileOutputStream} -import org.apache.spark.{Logging, Partitioner, SparkConf, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BlockManager, BlockObjectWriter} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 45ae4d2c4a22b..ece79153ab3a9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} @@ -49,24 +49,18 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { - val env = SparkEnv.get - val conf = env.conf - val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - val bypassMergeSort = dep.partitioner.numPartitions <= bypassMergeThreshold && - dep.aggregator.isEmpty && dep.keyOrdering.isEmpty - sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - } else if (bypassMergeSort) { + } else if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dep)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't // need local aggregation and sorting, write numPartitions files directly and just concatenate // them at the end. This avoids doing serialization and deserialization twice to merge // together the spilled files, which would happen with the normal code path. The downside is // having multiple files open at a time and thus more memory allocated to buffers. new BypassMergeSortShuffleWriter[K, V]( - conf, blockManager, dep.partitioner, writeMetrics, dep.serializer) + SparkEnv.get.conf, blockManager, dep.partitioner, writeMetrics, dep.serializer) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side @@ -112,3 +106,18 @@ private[spark] class SortShuffleWriter[K, V, C]( } } } + +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + shouldBypassMergeSort(conf, dep.partitioner.numPartitions, dep.aggregator, dep.keyOrdering) + } + + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6910ea94f3421..61f8648c365cc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,7 +29,7 @@ import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.shuffle.sort.SortShuffleSorter +import org.apache.spark.shuffle.sort.{SortShuffleWriter, SortShuffleSorter} import org.apache.spark.storage.BlockId /** @@ -98,24 +98,19 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf - // The bypassMergeSort optimization is no longer performed as part of this class. As a sanity - // check, make sure that we're not handling a shuffle which should have used that path: - { - val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - val bypassMergeSort = - numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty - if (bypassMergeSort) { - throw new IllegalArgumentException("ExternalSorter should not have been invoked to handle " - + " a sort that the BypassMergeSortShuffleWriter should handle") - } - } - private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 private def getPartition(key: K): Int = { if (shouldPartition) partitioner.get.getPartition(key) else 0 } + // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. + // As a sanity check, make sure that we're not handling a shuffle which should use that path. + if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { + throw new IllegalArgumentException("ExternalSorter should not be used to handle " + + " a sort that the BypassMergeSortShuffleWriter should handle") + } + private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer)