Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jul 24, 2014
1 parent a6e35d6 commit db58a55
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class RangePartitioner[K : Ordering : ClassTag, V](

private var ordering = implicitly[Ordering[K]]

@transient private[spark] var singlePass = true // for unit tests

// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions == 1) {
Expand All @@ -116,7 +118,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
// This is the sample size we need to have roughly balanced output partitions.
val sampleSize = 20.0 * partitions
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(5.0 * sampleSize / rdd.partitions.size).toInt
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
val shift = rdd.id
val classTagK = classTag[K]
val sketch = rdd.mapPartitionsWithIndex { (idx, iter) =>
Expand Down Expand Up @@ -149,9 +151,10 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
}
if (imbalancedPartitions.nonEmpty) {
singlePass = false
val sampleFunc: (TaskContext, Iterator[Product2[K, V]]) => Array[K] = { (context, iter) =>
val random = new XORShiftRandom(byteswap32(context.partitionId - shift))
iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray
iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray(classTagK)
}
val weight = (1.0 / fraction).toFloat
val resultHandler: (Int, Array[K]) => Unit = { (_, sample) =>
Expand Down
28 changes: 28 additions & 0 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,34 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
partitioner.getPartition(Row(100))
}

test("RangePartitioner should run only one job if data is roughly balanced") {
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
val random = new java.util.Random(i)
Iterator.fill(5000 * i)((random.nextDouble() + i, i))
}.cache()
for (numPartitions <- Seq(10, 20, 40)) {
val partitioner = new RangePartitioner(numPartitions, rdd)
assert(partitioner.numPartitions === numPartitions)
assert(partitioner.singlePass === true)
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
assert(counts.max < 2.0 * counts.min)
}
}

test("RangePartitioner should work well on unbalanced data") {
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
val random = new java.util.Random(i)
Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i))
}.cache()
for (numPartitions <- Seq(2, 4, 8)) {
val partitioner = new RangePartitioner(numPartitions, rdd)
assert(partitioner.numPartitions === numPartitions)
assert(partitioner.singlePass === false)
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
assert(counts.max < 2.0 * counts.min)
}
}

test("HashPartitioner not equal to RangePartitioner") {
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
Expand Down

0 comments on commit db58a55

Please sign in to comment.