Skip to content

Commit

Permalink
style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed Jul 9, 2014
1 parent a2bf756 commit a10e68d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ private[spark] object PoissonBounds {
}

def getMinCount(lmbd: Double): Double = {
if (lmbd == 0) return 0
val poisson = new PoissonDistribution(lmbd, epsilon)
poisson.inverseCumulativeProbability(delta)
if (lmbd == 0) {
0
} else {
val poisson = new PoissonDistribution(lmbd, epsilon)
poisson.inverseCumulativeProbability(delta)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ private[spark] object StratifiedSampler extends Logging {
// We use the streaming version of the algorithm for sampling without replacement to avoid
// using an extra pass over the RDD for computing the count.
// Hence, acceptBound and waitListBound change on every iteration.
val g1 = - math.log(delta) / stratum.numItems // gamma1
val g2 = (2.0 / 3.0) * g1 // gamma 2
stratum.acceptBound = math.max(0, fraction + g2 - math.sqrt(g2 * g2 + 3 * g2 * fraction))
stratum.waitListBound = math.min(1, fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))
val gamma1 = - math.log(delta) / stratum.numItems
val gamma2 = (2.0 / 3.0) * gamma1
stratum.acceptBound = math.max(0,
fraction + gamma2 - math.sqrt(gamma2 * gamma2 + 3 * gamma2 * fraction))
stratum.waitListBound = math.min(1,
fraction + gamma1 + math.sqrt(gamma1 * gamma1 + 2 * gamma1 * fraction))

val x = rng.nextUniform(0.0, 1.0)
if (x < stratum.acceptBound) {
Expand All @@ -137,20 +139,20 @@ private[spark] object StratifiedSampler extends Logging {
* Returns the function used combine results returned by seqOp from different partitions.
*/
def getCombOp[K]: (MMap[K, Stratum], MMap[K, Stratum]) => MMap[K, Stratum] = {
(r1: MMap[K, Stratum], r2: MMap[K, Stratum]) => {
(result1: MMap[K, Stratum], result2: MMap[K, Stratum]) => {
// take union of both key sets in case one partition doesn't contain all keys
for (key <- r1.keySet.union(r2.keySet)) {
// Use r2 to keep the combined result since r1 is usual empty
val entry1 = r1.get(key)
if (r2.contains(key)) {
r2(key).merge(entry1)
for (key <- result1.keySet.union(result2.keySet)) {
// Use result2 to keep the combined result since r1 is usual empty
val entry1 = result1.get(key)
if (result2.contains(key)) {
result2(key).merge(entry1)
} else {
if (entry1.isDefined) {
r2 += (key -> entry1.get)
result2 += (key -> entry1.get)
}
}
}
r2
result2
}
}

Expand Down Expand Up @@ -237,10 +239,9 @@ private[spark] object StratifiedSampler extends Logging {
rng.reSeed(seed + idx)
iter.flatMap { item =>
val key = item._1
val q1 = finalResult(key).acceptBound
val q2 = finalResult(key).waitListBound
val copiesAccepted = if (q1 == 0) 0L else rng.nextPoisson(q1)
val copiesWailisted = rng.nextPoisson(q2).toInt
val acceptBound = finalResult(key).acceptBound
val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound)
val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound).toInt
val copiesInSample = copiesAccepted +
(0 until copiesWailisted).count(i => rng.nextUniform(0.0, 1.0) < thresholdByKey(key))
if (copiesInSample > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
(x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
}

def checkSize(exact: Boolean, withReplacement: Boolean,
expected: Long, actual: Long, p: Double): Boolean = {
def checkSize(exact: Boolean,
withReplacement: Boolean,
expected: Long,
actual: Long,
p: Double): Boolean = {
if (exact) {
return expected == actual
}
Expand All @@ -110,8 +113,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
val sampleCounts = sample.countByKey()
val takeSample = sample.collect()
assert(sampleCounts.forall({case(k,v) =>
checkSize(exact, false, expectedSampleSize(k), v, samplingRate)}))
assert(sampleCounts.forall {case(k,v) =>
checkSize(exact, false, expectedSampleSize(k), v, samplingRate)})
assert(takeSample.size === takeSample.toSet.size)
assert(takeSample.forall(x => 1 <= x._2 && x._2 <= n), s"elements not in [1, $n]")
}
Expand All @@ -128,9 +131,9 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
val sampleCounts = sample.countByKey()
val takeSample = sample.collect()
assert(sampleCounts.forall({case(k,v) =>
checkSize(exact, true, expectedSampleSize(k), v, samplingRate)}))
val groupedByKey = takeSample.groupBy({case(k, v) => k})
assert(sampleCounts.forall {case(k,v) =>
checkSize(exact, true, expectedSampleSize(k), v, samplingRate)})
val groupedByKey = takeSample.groupBy {case(k, v) => k}
for ((key, v) <- groupedByKey) {
if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
// sample large enough for there to be repeats with high likelihood
Expand All @@ -146,8 +149,10 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
assert(takeSample.forall(x => 1 <= x._2 && x._2 <= n), s"elements not in [1, $n]")
}

def checkAllCombos(stratifiedData: RDD[(String, Int)], samplingRate: Double,
seed: Long, n: Long) {
def checkAllCombos(stratifiedData: RDD[(String, Int)],
samplingRate: Double,
seed: Long,
n: Long) = {
takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n)
takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n)
takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n)
Expand Down

0 comments on commit a10e68d

Please sign in to comment.