diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 759aa57662436..cade839db118f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -246,16 +246,13 @@ class PythonMLLibAPI extends Serializable { * Java stub for Python mllib KMeans.train() */ def trainKMeansModel( - dataBytesJRDD: JavaRDD[Array[Byte]], + dataJRDD: JavaRDD[Any], k: Int, maxIterations: Int, runs: Int, - initializationMode: String): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes)) - val model = KMeans.train(data, k, maxIterations, runs, initializationMode) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) - ret + initializationMode: String): KMeansModel = { + val data = dataJRDD.rdd.map(_.asInstanceOf[Vector]) + KMeans.train(data, k, maxIterations, runs, initializationMode) } /** diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 5ba38850b42e3..a11d74fb34274 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,15 +15,8 @@ # limitations under the License. # -from numpy import array, dot -from math import sqrt -from pyspark import SparkContext -from pyspark.mllib._common import \ - _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _squared_distance, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _regression_train_wrapper -from pyspark.mllib.linalg import SparseVector +from pyspark import SparkContext, PickleSerializer +from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -32,6 +25,7 @@ class KMeansModel(object): """A clustering model derived from the k-means method. + >>> from numpy import array >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) >>> model = KMeans.train( ... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") @@ -71,8 +65,9 @@ def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 best_distance = float("inf") - for i in range(0, len(self.centers)): - distance = _squared_distance(x, self.centers[i]) + x = _convert_to_vector(x) + for i in xrange(len(self.centers)): + distance = x.squared_distance(self.centers[i]) if distance < best_distance: best = i best_distance = distance @@ -82,19 +77,15 @@ def predict(self, x): class KMeans(object): @classmethod - def train(cls, data, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - sc = data.context - dataBytes = _get_unmangled_double_vector_rdd(data) - ans = sc._jvm.PythonMLLibAPI().trainKMeansModel( - dataBytes._jrdd, k, maxIterations, runs, initializationMode) - if len(ans) != 1: - raise RuntimeError("JVM call result had unexpected length") - elif type(ans[0]) != bytearray: - raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray") - matrix = _deserialize_double_matrix(ans[0]) - return KMeansModel([row for row in matrix]) + sc = rdd.context + jrdd = rdd.map(_convert_to_vector)._to_java_object_rdd().cache() + model = sc._jvm.PythonMLLibAPI().trainKMeansModel( + jrdd, k, maxIterations, runs, initializationMode) + bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) + centers = PickleSerializer().loads(str(bytes)) + return KMeansModel([c.toArray() for c in centers]) def _test(): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 2fe90fada674b..10bd30aaee7b3 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -34,7 +34,7 @@ def _convert_to_vector(l): if isinstance(l, Vector): return l - elif type(l) in (array.array, np.array, list): + elif type(l) in (array.array, np.array, np.ndarray, list): return DenseVector(l) elif _have_scipy and _scipy_issparse(l): assert l.shape[1] == 1, "Expected column vector"