Skip to content

Commit

Permalink
SPARK-2978. Transformation with MR shuffle semantics
Browse files Browse the repository at this point in the history
I didn't add this to the transformations list in the docs because it's kind of obscure, but would be happy to do so if others think it would be helpful.

Author: Sandy Ryza <sandy@cloudera.com>

Closes apache#2274 from sryza/sandy-spark-2978 and squashes the following commits:

4a5332a [Sandy Ryza] Fix Java test
c04b447 [Sandy Ryza] Fix Python doc and add back deleted code
433ad5b [Sandy Ryza] Add Java test
4c25a54 [Sandy Ryza] Add s at the end and a couple other fixes
9b0ba99 [Sandy Ryza] Fix compilation
36e0571 [Sandy Ryza] Fix import ordering
48c12c2 [Sandy Ryza] Add Java version and additional doc
e5381cd [Sandy Ryza] Fix python style warnings
f147634 [Sandy Ryza] SPARK-2978. Transformation with MR shuffle semantics
  • Loading branch information
sryza authored and mateiz committed Sep 8, 2014
1 parent e16a8e7 commit 16a73c2
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 1 deletion.
26 changes: 26 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,32 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
rdd.saveAsHadoopDataset(conf)
}

/**
* Repartition the RDD according to the given partitioner and, within each resulting partition,
* sort records by their keys.
*
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
repartitionAndSortWithinPartitions(partitioner, comp)
}

/**
* Repartition the RDD according to the given partitioner and, within each resulting partition,
* sort records by their keys.
*
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K])
: JavaPairRDD[K, V] = {
implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering.
fromRDD(
new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartitions(partitioner))
}

/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements in
* ascending order. Calling `collect` or `save` on the resulting RDD will return or output an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Logging, RangePartitioner}
import org.apache.spark.{Logging, Partitioner, RangePartitioner}
import org.apache.spark.annotation.DeveloperApi

/**
Expand Down Expand Up @@ -64,4 +64,16 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}

/**
* Repartition the RDD according to the given partitioner and, within each resulting partition,
* sort records by their keys.
*
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = {
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}

}
30 changes: 30 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,36 @@ public void sortByKey() {
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
}

@Test
public void repartitionAndSortWithinPartitions() {
List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
pairs.add(new Tuple2<Integer, Integer>(0, 5));
pairs.add(new Tuple2<Integer, Integer>(3, 8));
pairs.add(new Tuple2<Integer, Integer>(2, 6));
pairs.add(new Tuple2<Integer, Integer>(0, 8));
pairs.add(new Tuple2<Integer, Integer>(3, 8));
pairs.add(new Tuple2<Integer, Integer>(1, 3));

JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);

Partitioner partitioner = new Partitioner() {
public int numPartitions() {
return 2;
}
public int getPartition(Object key) {
return ((Integer)key).intValue() % 2;
}
};

JavaPairRDD<Integer, Integer> repartitioned =
rdd.repartitionAndSortWithinPartitions(partitioner);
List<List<Tuple2<Integer, Integer>>> partitions = repartitioned.glom().collect();
Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2<Integer, Integer>(0, 5),
new Tuple2<Integer, Integer>(0, 8), new Tuple2<Integer, Integer>(2, 6)));
Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2<Integer, Integer>(1, 3),
new Tuple2<Integer, Integer>(3, 8), new Tuple2<Integer, Integer>(3, 8)));
}

@Test
public void emptyRDD() {
JavaRDD<String> rdd = sc.emptyRDD();
Expand Down
14 changes: 14 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
}

test("repartitionAndSortWithinPartitions") {
val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2)

val partitioner = new Partitioner {
def numPartitions: Int = 2
def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2
}

val repartitioned = data.repartitionAndSortWithinPartitions(partitioner)
val partitions = repartitioned.glom().collect()
assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6)))
assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8)))
}

test("intersection") {
val all = sc.parallelize(1 to 10)
val evens = sc.parallelize(2 to 10 by 2)
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,30 @@ def __add__(self, other):
raise TypeError
return self.union(other)

def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash,
ascending=True, keyfunc=lambda x: x):
"""
Repartition the RDD according to the given partitioner and, within each resulting partition,
sort records by their keys.
>>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
>>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2)
>>> rdd2.glom().collect()
[[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()

spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true")
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
serializer = self._jrdd_deserializer

def sortPartition(iterator):
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))

return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)

def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
"""
Sorts this RDD, which is assumed to consist of (key, value) pairs.
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,14 @@ def test_histogram(self):
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))

def test_repartitionAndSortWithinPartitions(self):
rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)

repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
partitions = repartitioned.glom().collect()
self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])


class TestSQL(PySparkTestCase):

Expand Down

0 comments on commit 16a73c2

Please sign in to comment.