Skip to content

Commit

Permalink
[SPARK-4327] [PySpark] Python API for RDD.randomSplit()
Browse files Browse the repository at this point in the history
```
pyspark.RDD.randomSplit(self, weights, seed=None)
    Randomly splits this RDD with the provided weights.

    :param weights: weights for splits, will be normalized if they don't sum to 1
    :param seed: random seed
    :return: split RDDs in an list

    >>> rdd = sc.parallelize(range(10), 1)
    >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
    >>> rdd1.collect()
    [3, 6]
    >>> rdd2.collect()
    [0, 5, 7]
    >>> rdd3.collect()
    [1, 2, 4, 8, 9]
```

Author: Davies Liu <davies@databricks.com>

Closes apache#3193 from davies/randomSplit and squashes the following commits:

78bf997 [Davies Liu] fix tests, do not use numpy in randomSplit, no performance gain
f5fdf63 [Davies Liu] fix bug with int in weights
4dfa2cd [Davies Liu] refactor
f866bcf [Davies Liu] remove unneeded change
c7a2007 [Davies Liu] switch to python implementation
95a48ac [Davies Liu] Merge branch 'master' of github.com:apache/spark into randomSplit
0d9b256 [Davies Liu] refactor
1715ee3 [Davies Liu] address comments
41fce54 [Davies Liu] randomSplit()
  • Loading branch information
Davies Liu authored and mengxr committed Nov 19, 2014
1 parent bb46046 commit 7f22fa8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
30 changes: 27 additions & 3 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
import heapq
import bisect
from random import Random
import random
from math import sqrt, log, isinf, isnan

from pyspark.accumulators import PStatsParam
Expand All @@ -38,7 +38,7 @@
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
Expand Down Expand Up @@ -316,6 +316,30 @@ def sample(self, withReplacement, fraction, seed=None):
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

def randomSplit(self, weights, seed=None):
"""
Randomly splits this RDD with the provided weights.
:param weights: weights for splits, will be normalized if they don't sum to 1
:param seed: random seed
:return: split RDDs in a list
>>> rdd = sc.parallelize(range(5), 1)
>>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
>>> rdd1.collect()
[1, 3]
>>> rdd2.collect()
[0, 2, 4]
"""
s = float(sum(weights))
cweights = [0.0]
for w in weights:
cweights.append(cweights[-1] + w / s)
if seed is None:
seed = random.randint(0, 2 ** 32 - 1)
return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
for lb, ub in zip(cweights, cweights[1:])]

# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
"""
Expand All @@ -341,7 +365,7 @@ def takeSample(self, withReplacement, num, seed=None):
if initialCount == 0:
return []

rand = Random(seed)
rand = random.Random(seed)

if (not withReplacement) and num >= initialCount:
# shuffle current RDD and return
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ def func(self, split, iterator):
yield obj


class RDDRangeSampler(RDDSamplerBase):

def __init__(self, lowerBound, upperBound, seed=None):
RDDSamplerBase.__init__(self, False, seed)
self._use_numpy = False # no performance gain from numpy
self._lowerBound = lowerBound
self._upperBound = upperBound

def func(self, split, iterator):
for obj in iterator:
if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
yield obj


class RDDStratifiedSampler(RDDSamplerBase):

def __init__(self, withReplacement, fractions, seed=None):
Expand Down

0 comments on commit 7f22fa8

Please sign in to comment.