diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 08d047402625f..50535d2711708 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -28,7 +28,7 @@ import warnings import heapq import bisect -from random import Random +import random from math import sqrt, log, isinf, isnan from pyspark.accumulators import PStatsParam @@ -38,7 +38,7 @@ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter -from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler +from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ @@ -316,6 +316,30 @@ def sample(self, withReplacement, fraction, seed=None): assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) + def randomSplit(self, weights, seed=None): + """ + Randomly splits this RDD with the provided weights. + + :param weights: weights for splits, will be normalized if they don't sum to 1 + :param seed: random seed + :return: split RDDs in a list + + >>> rdd = sc.parallelize(range(5), 1) + >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) + >>> rdd1.collect() + [1, 3] + >>> rdd2.collect() + [0, 2, 4] + """ + s = float(sum(weights)) + cweights = [0.0] + for w in weights: + cweights.append(cweights[-1] + w / s) + if seed is None: + seed = random.randint(0, 2 ** 32 - 1) + return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True) + for lb, ub in zip(cweights, cweights[1:])] + # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """ @@ -341,7 +365,7 @@ def takeSample(self, withReplacement, num, seed=None): if initialCount == 0: return [] - rand = Random(seed) + rand = random.Random(seed) if (not withReplacement) and num >= initialCount: # shuffle current RDD and return diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index f5c3cfd259a5b..558dcfd12d46f 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -115,6 +115,20 @@ def func(self, split, iterator): yield obj +class RDDRangeSampler(RDDSamplerBase): + + def __init__(self, lowerBound, upperBound, seed=None): + RDDSamplerBase.__init__(self, False, seed) + self._use_numpy = False # no performance gain from numpy + self._lowerBound = lowerBound + self._upperBound = upperBound + + def func(self, split, iterator): + for obj in iterator: + if self._lowerBound <= self.getUniformSample(split) < self._upperBound: + yield obj + + class RDDStratifiedSampler(RDDSamplerBase): def __init__(self, withReplacement, fractions, seed=None):