Skip to content

Commit

Permalink
test case modified and one runs condition added
Browse files Browse the repository at this point in the history
  • Loading branch information
FlytxtRnD committed Jun 19, 2015
1 parent cd5dc5c commit 3f5fc8e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,14 @@ class KMeans private (
// random or k-means|| initializationMode
private var initialModel: Option[KMeansModel] = None

/** Set the initial starting point, bypassing the random initialization or k-means||
* The condition (model.k == this.k) must be met; failure will result in an
* IllegalArgumentException.
*/
/**
* Set the initial starting point, bypassing the random initialization or k-means||
* The condition model.k == this.k must be met, and only one run is allowed;
* failure in either case will result in an IllegalArgumentException.
*/
def setInitialModel(model: KMeansModel): this.type = {
require(model.k==k, "mismatched cluster count")
require(model.k == k, "mismatched cluster count")
require(runs == 1, "can only run once with given initial model")
initialModel = Some(model)
this
}
Expand Down Expand Up @@ -499,25 +501,6 @@ object KMeans {
train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
}

/**
* Trains a k-means model using the given set of parameters and initial cluster centers
*
* @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param initialModel an existing set of cluster centers.
*/
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int,
initialModel: KMeansModel): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setInitialModel(initialModel)
.run(data)
}

/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,28 +282,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Initialize using given cluster centers") {
val points = Seq(
Vectors.dense(0.0, 0.0),
Vectors.dense(0.0, 0.1),
Vectors.dense(0.1, 0.0),
Vectors.dense(9.0, 0.0),
Vectors.dense(9.0, 0.2),
Vectors.dense(9.2, 0.0)
Vectors.dense(1.0, 0.0),
Vectors.dense(0.0, 1.0),
Vectors.dense(1.0, 1.0)
)
val rdd = sc.parallelize(points, 3)
val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1)

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
model.save(sc, path)
val loadedModel = KMeansModel.load(sc, path)

val newModel = KMeans.train(rdd, k = 2, maxIterations = 2, initialModel = loadedModel)
val predicts = newModel.predict(rdd).collect()

assert(predicts(0) === predicts(1))
assert(predicts(0) === predicts(2))
assert(predicts(3) === predicts(4))
assert(predicts(3) === predicts(5))
assert(predicts(0) != predicts(3))
val m1 = new KMeansModel(Array(points(0), points(2)))
val m2 = new KMeansModel(Array(points(1), points(3)))

val modelM1 = new KMeans()
.setK(2)
.setMaxIterations(1)
.setInitialModel(m1)
.run(rdd)
val modelM2 = new KMeans()
.setK(2)
.setMaxIterations(1)
.setInitialModel(m2)
.run(rdd)

val predicts1 = modelM1.predict(rdd).collect()
val predicts2 = modelM2.predict(rdd).collect()

assert(predicts1(0) === predicts1(1))
assert(predicts1(2) === predicts1(3))
assert(predicts2(0) === predicts2(1))
assert(predicts2(2) === predicts2(3))
}

}
Expand Down

0 comments on commit 3f5fc8e

Please sign in to comment.