Skip to content

Commit

Permalink
convert rdd into RDD of Vector
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Oct 21, 2014
1 parent 342b57d commit 0871576
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/pyspark/mllib/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from functools import wraps

from pyspark import PickleSerializer
from pyspark.mllib.linalg import _to_java_object_rdd
from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd


__all__ = ['MultivariateStatisticalSummary', 'Statistics']
Expand Down Expand Up @@ -107,7 +107,7 @@ def colStats(rdd):
array([ 2., 0., 0., -2.])
"""
sc = rdd.ctx
jrdd = _to_java_object_rdd(rdd)
jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
return MultivariateStatisticalSummary(sc, cStats)

Expand Down Expand Up @@ -163,13 +163,14 @@ def corr(x, y=None, method=None):
if type(y) == str:
raise TypeError("Use 'method=' to specify method name.")

jx = _to_java_object_rdd(x)
if not y:
jx = _to_java_object_rdd(x.map(_convert_to_vector))
resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
bytes = sc._jvm.SerDe.dumps(resultMat)
ser = PickleSerializer()
return ser.loads(str(bytes)).toArray()
else:
jx = _to_java_object_rdd(x)
jy = _to_java_object_rdd(y)
return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase


Expand Down Expand Up @@ -202,6 +204,16 @@ def test_regression(self):
self.assertTrue(dt_model.predict(features[3]) > 0)


class StatTests(PySparkTestCase):
# SPARK-4023
def test_col_with_random_rdd(self):
data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
summary = Statistics.colStats(data)
self.assertEqual(1000, summary.count())
mean = summary.mean()
self.assertTrue(all(abs(v) < 0.1 for v in mean))


@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):

Expand Down

0 comments on commit 0871576

Please sign in to comment.