diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 824ce42e494e0..8eb1604a941cd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -362,44 +362,50 @@ def takeSample(self, withReplacement, num, seed=None): Return a fixed-size sampled subset of this RDD (currently requires numpy). - >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP - [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] + >>> rdd = sc.parallelize(range(0, 10)) + >>> len(rdd.takeSample(True, 20, 1)) + 20 + >>> len(rdd.takeSample(False, 5, 2)) + 5 + >>> len(rdd.takeSample(False, 15, 3)) + 10 """ - numStDev = 10.0 - initialCount = self.count() - if num < 0: - raise ValueError + raise ValueError("Sample size cannot be negative.") + elif num == 0: + return [] - if initialCount == 0 or num == 0: - return list() + initialCount = self.count() + if initialCount == 0: + return [] rand = Random(seed) - if (not withReplacement) and num > initialCount: + + if (not withReplacement) and num >= initialCount: # shuffle current RDD and return samples = self.collect() - fraction = float(num) / initialCount - num = initialCount - else: - maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) - if num > maxSampleSize: - raise ValueError - - fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement) + rand.shuffle(samples) + return samples + numStDev = 10.0 + maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + if num > maxSampleSize: + raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) + + fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement) + samples = self.sample(withReplacement, fraction, seed).collect() + + # If the first sample didn't turn out large enough, keep trying to take samples; + # this shouldn't happen often because we use a big multiplier for their initial size. + # See: scala/spark/RDD.scala + while len(samples) < num: + # TODO: add log warning for when more than one iteration was run + seed = rand.randint(0, sys.maxint) samples = self.sample(withReplacement, fraction, seed).collect() - # If the first sample didn't turn out large enough, keep trying to take samples; - # this shouldn't happen often because we use a big multiplier for their initial size. - # See: scala/spark/RDD.scala - while len(samples) < num: - #TODO add log warning for when more than one iteration was run - seed = rand.randint(0, sys.maxint) - samples = self.sample(withReplacement, fraction, seed).collect() + rand.shuffle(samples) - sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint)) - sampler.shuffle(samples) return samples[0:num] @staticmethod