Skip to content

Commit

Permalink
Add KMeans.setPredictionCol
Browse files Browse the repository at this point in the history
  • Loading branch information
yu-iskw committed Jul 1, 2015
1 parent aa9469d commit 85d92b1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean
/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

/** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group setParam */
def setK(value: Int): this.type = set(k, value)

Expand Down Expand Up @@ -187,7 +190,8 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean
.setMaxIterations(map(maxIter))
.setSeed(map(seed))
val parentModel = algo.run(rdd)
new KMeansModel(uid, map, parentModel)
val model = new KMeansModel(uid, map, parentModel)
copyValues(model)
}

override def transformSchema(schema: StructType): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {

assert(kmeans.getK === 2)
assert(kmeans.getFeaturesCol === "features")
assert(kmeans.getPredictionCol === "prediction")
assert(kmeans.getMaxIter === 20)
assert(kmeans.getRuns === 1)
assert(kmeans.getInitializationMode === KMeans.K_MEANS_PARALLEL)
Expand All @@ -58,14 +59,16 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}

test("fit & transform") {
val kmeans = new KMeans().setK(k)
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName)
val model = kmeans.fit(dataset)
assert(model.clusterCenters.length === k)

val transformed = model.transform(dataset)
assert(transformed.columns === Array("features", "prediction"))
val clusters = transformed.select("prediction")
assert(transformed.columns === Array("features", predictionColName))
val clusters = transformed.select(predictionColName)
.map(row => row.apply(0)).distinct().collect().toSet
assert(clusters.size == 5)
assert(clusters === Set(0, 1, 2, 3, 4))
}
}

0 comments on commit 85d92b1

Please sign in to comment.