diff --git a/core/pom.xml b/core/pom.xml index bab50f5ce2888..6cb58dbd291c4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -67,6 +67,10 @@ org.apache.commons commons-lang3 + + org.apache.commons + commons-math3 + com.google.code.findbugs jsr305 diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index aa03e9276fb34..2fdf45a0c8b8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -379,8 +379,17 @@ abstract class RDD[T: ClassTag]( }.toArray } - def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = - { + /** + * Return a fixed-size sampled subset of this RDD in an array + * + * @param withReplacement whether sampling is done with replacement + * @param num size of the returned sample + * @param seed seed for the random number generator + * @return sample of specified size in an array + */ + def takeSample(withReplacement: Boolean, + num: Int, + seed: Long = Utils.random.nextLong): Array[T] = { var fraction = 0.0 var total = 0 val multiplier = 3.0 @@ -402,10 +411,11 @@ abstract class RDD[T: ClassTag]( } if (num > initialCount && !withReplacement) { + // special case not covered in computeFraction total = maxSelected fraction = multiplier * (maxSelected + 1) / initialCount } else { - fraction = multiplier * (num + 1) / initialCount + fraction = computeFraction(num, initialCount, withReplacement) total = num } @@ -421,6 +431,22 @@ abstract class RDD[T: ClassTag]( Utils.randomizeInPlace(samples, rand).take(total) } + private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = { + val fraction = num.toDouble / total + if (withReplacement) { + var numStDev = 5 + if (num < 12) { + // special case to guarantee sample size for small s + numStDev = 9 + } + fraction + numStDev * math.sqrt(fraction / total) + } else { + val delta = 0.00005 + val gamma = - math.log(delta)/total + math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + } + } + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 4dc8ada00a3e8..e53103755b279 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } /** - * Return a sampler with is the complement of the range specified of the current sampler. + * Return a sampler which is the complement of the range specified of the current sampler. */ def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e686068f7a99a..5bdcb9bef6d62 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.scalatest.FunSuite +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd._ @@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedTopK === nums.sorted(ord).take(5)) } + test("computeFraction") { + // test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001 + val data = new EmptyRDD[Int](sc) + val n = 100000 + + for (s <- 1 to 15) { + val frac = data.computeFraction(s, n, true) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- 1 to 15) { + val frac = data.computeFraction(s, n, false) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = data.computeFraction(s, n, true) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = data.computeFraction(s, n, false) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + } + test("takeSample") { - val data = sc.parallelize(1 to 100, 2) + val n = 1000000 + val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) + val sample = data.takeSample(withReplacement=false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } { val sample = data.takeSample(withReplacement=true, num=20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } { - val sample = data.takeSample(withReplacement=true, num=100) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, num=n) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, n, seed) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements + val sample = data.takeSample(withReplacement=true, 2*n, seed) + assert(sample.size === 2*n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } } diff --git a/pom.xml b/pom.xml index 7bf9f135fd340..01d6eef32be63 100644 --- a/pom.xml +++ b/pom.xml @@ -245,6 +245,11 @@ commons-codec 1.5 + + org.apache.commons + commons-math3 + 3.2 + com.google.code.findbugs jsr305 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 8ef1e91f609fb..a6b6c26a49395 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -331,6 +331,7 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "com.google.guava" % "guava" % "14.0.1", "org.apache.commons" % "commons-lang3" % "3.3.2", + "org.apache.commons" % "commons-math3" % "3.2", "com.google.code.findbugs" % "jsr305" % "1.3.9", "log4j" % "log4j" % "1.2.17", "org.slf4j" % "slf4j-api" % slf4jVersion, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 07578b8d937fc..b400404ad97c7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,6 +31,7 @@ import warnings import heapq from random import Random +from math import sqrt, log, min from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long @@ -374,7 +375,7 @@ def takeSample(self, withReplacement, num, seed=None): total = maxSelected fraction = multiplier * (maxSelected + 1) / initialCount else: - fraction = multiplier * (num + 1) / initialCount + fraction = self._computeFraction(num, initialCount, withReplacement) total = num samples = self.sample(withReplacement, fraction, seed).collect() @@ -390,6 +391,18 @@ def takeSample(self, withReplacement, num, seed=None): sampler.shuffle(samples) return samples[0:total] + def _computeFraction(self, num, total, withReplacement): + fraction = float(num)/total + if withReplacement: + numStDev = 5 + if (num < 12): + numStDev = 9 + return fraction + numStDev * sqrt(fraction/total) + else: + delta = 0.00005 + gamma = - log(delta)/total + return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction)) + def union(self, other): """ Return the union of this RDD and another one.