Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-2978. Transformation with MR shuffle semantics #2274

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,32 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
rdd.saveAsHadoopDataset(conf)
}

/**
* Repartition the RDD according to the given partitioner and, within each resulting partition,
* sort records by their keys.
*
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
repartitionAndSortWithinPartitions(partitioner, comp)
}

/**
* Repartition the RDD according to the given partitioner and, within each resulting partition,
* sort records by their keys.
*
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K])
: JavaPairRDD[K, V] = {
implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering.
fromRDD(
new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartitions(partitioner))
}

/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements in
* ascending order. Calling `collect` or `save` on the resulting RDD will return or output an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Logging, RangePartitioner}
import org.apache.spark.{Logging, Partitioner, RangePartitioner}
import org.apache.spark.annotation.DeveloperApi

/**
Expand Down Expand Up @@ -64,4 +64,16 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}

/**
* Repartition the RDD according to the given partitioner and, within each resulting partition,
* sort records by their keys.
*
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = {
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}

}
30 changes: 30 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,36 @@ public void sortByKey() {
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
}

@Test
public void repartitionAndSortWithinPartitions() {
List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
pairs.add(new Tuple2<Integer, Integer>(0, 5));
pairs.add(new Tuple2<Integer, Integer>(3, 8));
pairs.add(new Tuple2<Integer, Integer>(2, 6));
pairs.add(new Tuple2<Integer, Integer>(0, 8));
pairs.add(new Tuple2<Integer, Integer>(3, 8));
pairs.add(new Tuple2<Integer, Integer>(1, 3));

JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);

Partitioner partitioner = new Partitioner() {
public int numPartitions() {
return 2;
}
public int getPartition(Object key) {
return ((Integer)key).intValue() % 2;
}
};

JavaPairRDD<Integer, Integer> repartitioned =
rdd.repartitionAndSortWithinPartitions(partitioner);
List<List<Tuple2<Integer, Integer>>> partitions = repartitioned.glom().collect();
Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2<Integer, Integer>(0, 5),
new Tuple2<Integer, Integer>(0, 8), new Tuple2<Integer, Integer>(2, 6)));
Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2<Integer, Integer>(1, 3),
new Tuple2<Integer, Integer>(3, 8), new Tuple2<Integer, Integer>(3, 8)));
}

@Test
public void emptyRDD() {
JavaRDD<String> rdd = sc.emptyRDD();
Expand Down
14 changes: 14 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
}

test("repartitionAndSortWithinPartitions") {
val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2)

val partitioner = new Partitioner {
def numPartitions: Int = 2
def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2
}

val repartitioned = data.repartitionAndSortWithinPartitions(partitioner)
val partitions = repartitioned.glom().collect()
assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6)))
assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8)))
}

test("intersection") {
val all = sc.parallelize(1 to 10)
val evens = sc.parallelize(2 to 10 by 2)
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,30 @@ def __add__(self, other):
raise TypeError
return self.union(other)

def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash,
ascending=True, keyfunc=lambda x: x):
"""
Repartition the RDD according to the given partitioner and, within each resulting partition,
sort records by their keys.

>>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
>>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2)
>>> rdd2.glom().collect()
[[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()

spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true")
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
serializer = self._jrdd_deserializer

def sortPartition(iterator):
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))

return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)

def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
"""
Sorts this RDD, which is assumed to consist of (key, value) pairs.
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,14 @@ def test_histogram(self):
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))

def test_repartitionAndSortWithinPartitions(self):
rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)

repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
partitions = repartitioned.glom().collect()
self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])


class TestSQL(PySparkTestCase):

Expand Down