From 5d087b40e14db51b1eeb44e462e04d5e718338be Mon Sep 17 00:00:00 2001 From: "nate.crosswhite" Date: Thu, 4 Dec 2014 16:25:49 -0500 Subject: [PATCH] Adding KMeans train with seed and Scala unit test --- .../spark/mllib/clustering/KMeans.scala | 25 +++++++++++++++++++ .../spark/mllib/clustering/KMeansSuite.scala | 21 ++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 10192cf1d7362..9a5f3583506aa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -358,6 +358,31 @@ object KMeans { .run(data) } + /** + * Trains a k-means model using the given set of parameters. + * + * @param data training points stored as `RDD[Array[Double]]` + * @param k number of clusters + * @param maxIterations max number of iterations + * @param runs number of parallel runs, defaults to 1. The best model is returned. + * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param seed seed value for cluster initialization + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String, + seed: Long): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + .setSeed(seed) + .run(data) + } + /** * Trains a k-means model using specified parameters and the default values for unspecified. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 9ebef8466c831..2ada0d9505fa9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(model.clusterCenters.size === 3) } + test("deterministic initilization") { + // Create a large-ish set of point to cluster + val points = List.tabulate(1000)(n => Vectors.dense(n,n)) + val rdd = sc.parallelize(points, 3) + + for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { + // Create three deterministic models and compare cluster means + val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42) + val centers1 = model1.clusterCenters + + val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42) + val centers2 = model2.clusterCenters + + val model3 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42) + val centers3 = model3.clusterCenters + + assert(centers1.deep == centers2.deep) + assert(centers1.deep == centers3.deep) + } + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0),