Skip to content

Commit

Permalink
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed May 29, 2014
1 parent 60b89fe commit 1441977
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 22 deletions.
4 changes: 4 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand Down
32 changes: 29 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,17 @@ abstract class RDD[T: ClassTag](
}.toArray
}

def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
{
/**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
*/
def takeSample(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
var fraction = 0.0
var total = 0
val multiplier = 3.0
Expand All @@ -402,10 +411,11 @@ abstract class RDD[T: ClassTag](
}

if (num > initialCount && !withReplacement) {
// special case not covered in computeFraction
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
} else {
fraction = multiplier * (num + 1) / initialCount
fraction = computeFraction(num, initialCount, withReplacement)
total = num
}

Expand All @@ -421,6 +431,22 @@ abstract class RDD[T: ClassTag](
Utils.randomizeInPlace(samples, rand).take(total)
}

private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = {
val fraction = num.toDouble / total
if (withReplacement) {
var numStDev = 5
if (num < 12) {
// special case to guarantee sample size for small s
numStDev = 9
}
fraction + numStDev * math.sqrt(fraction / total)
} else {
val delta = 0.00005
val gamma = - math.log(delta)/total
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
}
}

/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}

/**
* Return a sampler with is the complement of the range specified of the current sampler.
* Return a sampler which is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)

Expand Down
63 changes: 46 additions & 17 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.ClassTag

import org.scalatest.FunSuite

import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
Expand Down Expand Up @@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}

test("computeFraction") {
// test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001
val data = new EmptyRDD[Int](sc)
val n = 100000

for (s <- 1 to 15) {
val frac = data.computeFraction(s, n, true)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- 1 to 15) {
val frac = data.computeFraction(s, n, false)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = data.computeFraction(s, n, true)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = data.computeFraction(s, n, false)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
}

test("takeSample") {
val data = sc.parallelize(1 to 100, 2)
val n = 1000000
val data = sc.parallelize(1 to n, 2)

for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 200, seed)
val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=100)
assert(sample.size === 100) // Got exactly 100 elements
val sample = data.takeSample(withReplacement=true, num=n)
assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 100, seed)
assert(sample.size === 100) // Got exactly 100 elements
val sample = data.takeSample(withReplacement=true, n, seed)
assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 200, seed)
assert(sample.size === 200) // Got exactly 200 elements
val sample = data.takeSample(withReplacement=true, 2*n, seed)
assert(sample.size === 2*n) // Got exactly 200 elements
// Chance of getting all distinct elements is still quite low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
}

Expand Down
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@
<artifactId>commons-codec</artifactId>
<version>1.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.2</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand Down
1 change: 1 addition & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"com.google.guava" % "guava" % "14.0.1",
"org.apache.commons" % "commons-lang3" % "3.3.2",
"org.apache.commons" % "commons-math3" % "3.2",
"com.google.code.findbugs" % "jsr305" % "1.3.9",
"log4j" % "log4j" % "1.2.17",
"org.slf4j" % "slf4j-api" % slf4jVersion,
Expand Down
15 changes: 14 additions & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import warnings
import heapq
from random import Random
from math import sqrt, log, min

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
Expand Down Expand Up @@ -374,7 +375,7 @@ def takeSample(self, withReplacement, num, seed=None):
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
else:
fraction = multiplier * (num + 1) / initialCount
fraction = self._computeFraction(num, initialCount, withReplacement)
total = num

samples = self.sample(withReplacement, fraction, seed).collect()
Expand All @@ -390,6 +391,18 @@ def takeSample(self, withReplacement, num, seed=None):
sampler.shuffle(samples)
return samples[0:total]

def _computeFraction(self, num, total, withReplacement):
fraction = float(num)/total
if withReplacement:
numStDev = 5
if (num < 12):
numStDev = 9
return fraction + numStDev * sqrt(fraction/total)
else:
delta = 0.00005
gamma = - log(delta)/total
return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction))

def union(self, other):
"""
Return the union of this RDD and another one.
Expand Down

0 comments on commit 1441977

Please sign in to comment.