diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9aa454a5c8b88..6f22486465a8b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} +import org.apache.spark.util.collection.{FlexibleExternalAppendOnlyMap, AppendOnlyMap} import org.apache.spark.serializer.Serializer private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -58,14 +58,14 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep] * @param part partitioner used to partition the shuffle output. */ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + extends RDD[(K, Seq[Iterator[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs). // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner. // CoGroupValue is the intermediate state of each value before being merged in compute. private type CoGroup = ArrayBuffer[Any] private type CoGroupValue = (Any, Int) // Int is dependency number - private type CoGroupCombiner = Seq[CoGroup] + private type CoGroupCombiner = Array[CoGroup] private var serializer: Serializer = null @@ -105,7 +105,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner: Some[Partitioner] = Some(part) - override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = { + override def compute(s: Partition, context: TaskContext): Iterator[(K, Iterator[CoGroup])] = { val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] @@ -141,7 +141,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: getCombiner(kv._1)(depNum) += kv._2 } } - new InterruptibleIterator(context, map.iterator) + // Convert to iterators + val finalMap = new AppendOnlyMap[K, Iterator[CoGroup]](math.max(map.size, 64)) + map.foreach { case (it, k) => + finalMap.update(it, k.iterator) + } + new InterruptibleIterator(context, finalMap.iterator) } else { val map = createExternalMap(numRdds) rddIterators.foreach { case (it, depNum) => @@ -157,7 +162,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } private def createExternalMap(numRdds: Int) - : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { + : FlexibleExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner, Iterator[CoGroup]] = { val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { val newCombiner = Array.fill(numRdds)(new CoGroup) @@ -169,12 +174,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: value match { case (v, depNum) => combiner(depNum) += v } combiner } - val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = + val mergeCombiners: (CoGroupCombiner, Iterator[CoGroup]) => Iterator[CoGroup] = (combiner1, combiner2) => { - combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } + combiner1.toIterator.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } } - new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( - createCombiner, mergeValue, mergeCombiners) + val returnCombiner: (CoGroupCombiner) => Iterator[CoGroup] = + (combiner) => combiner.toIterator + new FlexibleExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner, Iterator[CoGroup]]( + createCombiner, mergeValue, mergeCombiners, returnCombiner) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index caa06d5b445b4..868351b5a2458 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -55,16 +55,25 @@ import org.apache.spark.storage.{BlockId, BlockManager} * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of * this threshold, in case map size estimation is not sufficiently accurate. */ - private[spark] class ExternalAppendOnlyMap[K, V, C]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager) - extends Iterable[(K, C)] with Serializable with Logging { + extends FlexibleExternalAppendOnlyMap[K, V, C, C](createCombiner, mergeValue, mergeCombiners, (x => x), + serializer, blockManager) { +} +private[spark] class FlexibleExternalAppendOnlyMap[K, V, C, T]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, T) => T, + returnCombiner: C => T, + serializer: Serializer = SparkEnv.get.serializer, + blockManager: BlockManager = SparkEnv.get.blockManager) + extends Iterable[(K, T)] with Serializable with Logging { - import ExternalAppendOnlyMap._ + import FlexibleExternalAppendOnlyMap._ private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] @@ -263,13 +272,13 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( * If the given buffer contains a value for the given key, merge that value into * baseCombiner and remove the corresponding (K, C) pair from the buffer. */ - private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { + private def mergeIfKeyExists(key: K, baseCombiner: T, buffer: StreamBuffer): T = { var i = 0 while (i < buffer.pairs.length) { val (k, c) = buffer.pairs(i) if (k == key) { buffer.pairs.remove(i) - return mergeCombiners(baseCombiner, c) + return mergeCombiners(c, baseCombiner) } i += 1 } @@ -292,7 +301,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( // Select a key from the StreamBuffer that holds the lowest key hash val minBuffer = mergeHeap.dequeue() val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) - var (minKey, minCombiner) = minPairs.remove(0) + var (minKey, minCombinerC) = minPairs.remove(0) + var minCombiner = returnCombiner(minCombinerC) assert(minKey.hashCode() == minHash) // For all other streams that may have this key (i.e. have the same minimum key hash), @@ -418,7 +428,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( } } -private[spark] object ExternalAppendOnlyMap { +private[spark] object FlexibleExternalAppendOnlyMap { private class KCComparator[K, C] extends Comparator[(K, C)] { def compare(kc1: (K, C), kc2: (K, C)): Int = { kc1._1.hashCode().compareTo(kc2._1.hashCode())