Skip to content

Commit

Permalink
randomSplit()
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 10, 2014
1 parent c6f4e70 commit 41fce54
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
13 changes: 13 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,19 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}

/**
* A helper to convert java.util.List[Double] into Array[Double]
* @param list
* @return
*/
def listToArrayDouble(list: JList[Double]): Array[Double] = {
val r = new Array[Double](list.size)
list.zipWithIndex.foreach {
case (v, i) => r(i) = v
}
r
}
}

private
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,34 @@ def sample(self, withReplacement, fraction, seed=None):
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

def randomSplit(self, weights, seed=None):
"""
Randomly splits this RDD with the provided weights.
:param weights: weights for splits, will be normalized if they don't sum to 1
:param seed: random seed
:return: split RDDs in an list
>>> rdd = sc.parallelize(range(10), 1)
>>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
>>> rdd1.collect()
[3, 6]
>>> rdd2.collect()
[0, 5, 7]
>>> rdd3.collect()
[1, 2, 4, 8, 9]
"""
ser = BatchedSerializer(PickleSerializer(), 1)
rdd = self._reserialize(ser)
jweights = ListConverter().convert([float(w) for w in weights],
self.ctx._gateway._gateway_client)
jweights = self.ctx._jvm.PythonRDD.listToArrayDouble(jweights)
if seed is None:
jrdds = rdd._jrdd.randomSplit(jweights)
else:
jrdds = rdd._jrdd.randomSplit(jweights, seed)
return [RDD(jrdd, self.ctx, ser) for jrdd in jrdds]

# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
"""
Expand Down

0 comments on commit 41fce54

Please sign in to comment.