From 35c188463798729b65ca74549984cb765ac1e9c9 Mon Sep 17 00:00:00 2001 From: "nate.crosswhite" Date: Thu, 4 Dec 2014 14:12:29 -0500 Subject: [PATCH] Add kmeans initial seed to pyspark API --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++++- .../apache/spark/mllib/clustering/KMeans.scala | 12 +++++++++--- python/pyspark/mllib/clustering.py | 4 ++-- python/pyspark/mllib/tests.py | 17 +++++++++++++++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9f20cd5d00dcd..481084f413da6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable { k: Int, maxIterations: Int, runs: Int, - initializationMode: String): KMeansModel = { + initializationMode: String, + seed: java.lang.Long): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) + + if (seed != null) kMeansAlg.setSeed(seed) + try { kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) } finally { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 54c301d3e9e14..10192cf1d7362 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -43,7 +43,8 @@ class KMeans private ( private var runs: Int, private var initializationMode: String, private var initializationSteps: Int, - private var epsilon: Double) extends Serializable with Logging { + private var epsilon: Double, + private var seed: Long = System.nanoTime()) extends Serializable with Logging { /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, @@ -51,6 +52,11 @@ class KMeans private ( */ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + /** Set the number of clusters to create (k). Default: 2. */ def setK(k: Int): this.type = { this.k = k @@ -255,7 +261,7 @@ class KMeans private ( private def initRandom(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) }.toArray) @@ -273,7 +279,7 @@ class KMeans private ( private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Initialize each run's center to a random point - val seed = new XORShiftRandom().nextInt() + val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e2492eef5bd6a..6b713aa39374e 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -78,10 +78,10 @@ def predict(self, x): class KMeans(object): @classmethod - def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None): """Train a k-means clustering model.""" model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode) + runs, initializationMode, seed) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 8332f8e061f48..20e451a7f93f8 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -129,6 +129,23 @@ def test_clustering(self): self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + def test_clustering_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + clusters3 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + centers1 = array(clusters1.centers).flatten().tolist() + centers2 = array(clusters2.centers).flatten().tolist() + centers3 = array(clusters3.centers).flatten().tolist() + self.assertListEqual(centers1, centers2) + self.assertListEqual(centers1, centers3) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree