Skip to content

Commit

Permalink
Check sample size and move computeFraction
Browse files Browse the repository at this point in the history
Check that the sample size is within supported range. Moved
computeFraction int a private util class in util.random
  • Loading branch information
dorx committed Jun 9, 2014
1 parent e3fd6a6 commit 9bdd36e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 55 deletions.
44 changes: 12 additions & 32 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, SerializableHyperLogLog, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler}
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -400,12 +400,21 @@ abstract class RDD[T: ClassTag](
throw new IllegalArgumentException("Negative number of elements requested")
}

if (!withReplacement && num > initialCount) {
throw new IllegalArgumentException("Cannot create sample larger than the original when " +
"sampling without replacement")
}

if (initialCount == 0) {
return new Array[T](0)
}

if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt
if (num > maxSelected) {
throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " +
"5.0 * math.sqrt(Integer.MAX_VALUE)")
}
} else {
maxSelected = initialCount.toInt
}
Expand All @@ -415,7 +424,7 @@ abstract class RDD[T: ClassTag](
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
} else {
fraction = computeFraction(num, initialCount, withReplacement)
fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement)
total = num
}

Expand All @@ -431,35 +440,6 @@ abstract class RDD[T: ClassTag](
Utils.randomizeInPlace(samples, rand).take(total)
}

/**
* Let p = num / total, where num is the sample size and total is the total number of
* datapoints in the RDD. We're trying to compute q > p such that
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
* num > 12, but we need a slightly larger q (9 empirically determined).
* - when sampling without replacement, we're drawing each datapoint with prob_i
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
* rate, where success rate is defined the same as in sampling with replacement.
*
* @param num sample size
* @param total size of RDD
* @param withReplacement whether sampling with replacement
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
*/
private[rdd] def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
val fraction = num.toDouble / total
if (withReplacement) {
val numStDev = if (num < 12) 9 else 5
fraction + numStDev * math.sqrt(fraction / total)
} else {
val delta = 1e-4
val gamma = - math.log(delta) / total
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
}
}

/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.random

private[spark] object SamplingUtils {

/**
* Let p = num / total, where num is the sample size and total is the total number of
* datapoints in the RDD. We're trying to compute q > p such that
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
* num > 12, but we need a slightly larger q (9 empirically determined).
* - when sampling without replacement, we're drawing each datapoint with prob_i
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
* rate, where success rate is defined the same as in sampling with replacement.
*
* @param num sample size
* @param total size of RDD
* @param withReplacement whether sampling with replacement
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
*/
def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
val fraction = num.toDouble / total
if (withReplacement) {
val numStDev = if (num < 12) 9 else 5
fraction + numStDev * math.sqrt(fraction / total)
} else {
val delta = 1e-4
val gamma = - math.log(delta) / total
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
}
}
}
23 changes: 0 additions & 23 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -497,29 +497,6 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}

test("computeFraction") {
// test that the computed fraction guarantees enough datapoints
// in the sample with a failure rate <= 0.0001
val data = new EmptyRDD[Int](sc)
val n = 100000

for (s <- 1 to 15) {
val frac = data.computeFraction(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(20, 100, 1000)) {
val frac = data.computeFraction(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = data.computeFraction(s, n, false)
val binomial = new BinomialDistribution(n, frac)
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
}
}

test("takeSample") {
val n = 1000000
val data = sc.parallelize(1 to n, 2)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.random

import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
import org.scalatest.FunSuite

class SamplingUtilsSuite extends FunSuite{

test("computeFraction") {
// test that the computed fraction guarantees enough datapoints
// in the sample with a failure rate <= 0.0001
val n = 100000

for (s <- 1 to 15) {
val frac = SamplingUtils.computeFraction(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(20, 100, 1000)) {
val frac = SamplingUtils.computeFraction(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = SamplingUtils.computeFraction(s, n, false)
val binomial = new BinomialDistribution(n, frac)
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
}
}
}

0 comments on commit 9bdd36e

Please sign in to comment.