Skip to content

Commit

Permalink
refactor clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 13, 2014
1 parent 52d1350 commit f1544c4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,13 @@ class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib KMeans.train()
*/
def trainKMeansModel(
dataBytesJRDD: JavaRDD[Array[Byte]],
dataJRDD: JavaRDD[Any],
k: Int,
maxIterations: Int,
runs: Int,
initializationMode: String): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes))
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
ret
initializationMode: String): KMeansModel = {
val data = dataJRDD.rdd.map(_.asInstanceOf[Vector])
KMeans.train(data, k, maxIterations, runs, initializationMode)
}

/**
Expand Down
37 changes: 14 additions & 23 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,8 @@
# limitations under the License.
#

from numpy import array, dot
from math import sqrt
from pyspark import SparkContext
from pyspark.mllib._common import \
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, _squared_distance, \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _regression_train_wrapper
from pyspark.mllib.linalg import SparseVector
from pyspark import SparkContext, PickleSerializer
from pyspark.mllib.linalg import SparseVector, _convert_to_vector

__all__ = ['KMeansModel', 'KMeans']

Expand All @@ -32,6 +25,7 @@ class KMeansModel(object):

"""A clustering model derived from the k-means method.
>>> from numpy import array
>>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
>>> model = KMeans.train(
... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random")
Expand Down Expand Up @@ -71,8 +65,9 @@ def predict(self, x):
"""Find the cluster to which x belongs in this model."""
best = 0
best_distance = float("inf")
for i in range(0, len(self.centers)):
distance = _squared_distance(x, self.centers[i])
x = _convert_to_vector(x)
for i in xrange(len(self.centers)):
distance = x.squared_distance(self.centers[i])
if distance < best_distance:
best = i
best_distance = distance
Expand All @@ -82,19 +77,15 @@ def predict(self, x):
class KMeans(object):

@classmethod
def train(cls, data, k, maxIterations=100, runs=1, initializationMode="k-means||"):
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
"""Train a k-means clustering model."""
sc = data.context
dataBytes = _get_unmangled_double_vector_rdd(data)
ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(
dataBytes._jrdd, k, maxIterations, runs, initializationMode)
if len(ans) != 1:
raise RuntimeError("JVM call result had unexpected length")
elif type(ans[0]) != bytearray:
raise RuntimeError("JVM call result had first element of type "
+ type(ans[0]) + " which is not bytearray")
matrix = _deserialize_double_matrix(ans[0])
return KMeansModel([row for row in matrix])
sc = rdd.context
jrdd = rdd.map(_convert_to_vector)._to_java_object_rdd().cache()
model = sc._jvm.PythonMLLibAPI().trainKMeansModel(
jrdd, k, maxIterations, runs, initializationMode)
bytes = sc._jvm.SerDe.dumps(model.clusterCenters())
centers = PickleSerializer().loads(str(bytes))
return KMeansModel([c.toArray() for c in centers])


def _test():
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _convert_to_vector(l):

if isinstance(l, Vector):
return l
elif type(l) in (array.array, np.array, list):
elif type(l) in (array.array, np.array, np.ndarray, list):
return DenseVector(l)
elif _have_scipy and _scipy_issparse(l):
assert l.shape[1] == 1, "Expected column vector"
Expand Down

0 comments on commit f1544c4

Please sign in to comment.