diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 36d7f6c428987..f8283fbbb980d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -310,8 +310,7 @@ abstract class RDD[T: ClassTag]( * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { - require(fraction >= 0 && fraction <= Double.MaxValue, - "Invalid fraction value: " + fraction) + require(fraction >= 0.0, "Invalid fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) } else { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 7124cfc742e79..ddf2214c397dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -457,10 +457,10 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) - val emptySet = data.filter(_ => false) + val emptySet = data.mapPartitions { iter => Iterator.empty } val sample = emptySet.takeSample(false, 20, 1) - assert(sample.size === 0) + assert(sample.length === 0) for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements