Skip to content

Commit

Permalink
Refactor logic for deciding when to bypass
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 25, 2015
1 parent 4b03539 commit 6a35716
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.sort

import java.io.{File, FileInputStream, FileOutputStream}

import org.apache.spark.{Logging, Partitioner, SparkConf, TaskContext}
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BlockManager, BlockObjectWriter}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle.sort

import org.apache.spark.{SparkEnv, Logging, TaskContext}
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
Expand Down Expand Up @@ -49,24 +49,18 @@ private[spark] class SortShuffleWriter[K, V, C](

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
val env = SparkEnv.get
val conf = env.conf
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
val bypassMergeSort = dep.partitioner.numPartitions <= bypassMergeThreshold &&
dep.aggregator.isEmpty && dep.keyOrdering.isEmpty

sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else if (bypassMergeSort) {
} else if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dep)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need local aggregation and sorting, write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleWriter[K, V](
conf, blockManager, dep.partitioner, writeMetrics, dep.serializer)
SparkEnv.get.conf, blockManager, dep.partitioner, writeMetrics, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
Expand Down Expand Up @@ -112,3 +106,18 @@ private[spark] class SortShuffleWriter[K, V, C](
}
}
}

private[spark] object SortShuffleWriter {
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
shouldBypassMergeSort(conf, dep.partitioner.numPartitions, dep.aggregator, dep.keyOrdering)
}

def shouldBypassMergeSort(
conf: SparkConf,
numPartitions: Int,
aggregator: Option[Aggregator[_, _, _]],
keyOrdering: Option[Ordering[_]]): Boolean = {
val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.shuffle.sort.SortShuffleSorter
import org.apache.spark.shuffle.sort.{SortShuffleWriter, SortShuffleSorter}
import org.apache.spark.storage.BlockId

/**
Expand Down Expand Up @@ -98,24 +98,19 @@ private[spark] class ExternalSorter[K, V, C](

private val conf = SparkEnv.get.conf

// The bypassMergeSort optimization is no longer performed as part of this class. As a sanity
// check, make sure that we're not handling a shuffle which should have used that path:
{
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
val bypassMergeSort =
numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty
if (bypassMergeSort) {
throw new IllegalArgumentException("ExternalSorter should not have been invoked to handle "
+ " a sort that the BypassMergeSortShuffleWriter should handle")
}
}

private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
private val shouldPartition = numPartitions > 1
private def getPartition(key: K): Int = {
if (shouldPartition) partitioner.get.getPartition(key) else 0
}

// Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
// As a sanity check, make sure that we're not handling a shuffle which should use that path.
if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
throw new IllegalArgumentException("ExternalSorter should not be used to handle "
+ " a sort that the BypassMergeSortShuffleWriter should handle")
}

private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
Expand Down

0 comments on commit 6a35716

Please sign in to comment.