diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6d549b40e5698..f3b432ff248a9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -268,6 +268,7 @@ def sample(self, withReplacement, fraction, seed): >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ + assert fraction >= 0.0, "Invalid fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala @@ -288,6 +289,9 @@ def takeSample(self, withReplacement, num, seed): if (num < 0): raise ValueError + if (initialCount == 0): + return list() + if initialCount > sys.maxint - 1: maxSelected = sys.maxint - 1 else: