diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 993583e2f4119..3073d489bad4a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -338,6 +338,7 @@ def contains_file(self, filename): python_test_goals=[ "pyspark.ml.feature", "pyspark.ml.classification", + "pyspark.ml.clustering", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala new file mode 100644 index 0000000000000..dc192add6ca13 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.Utils + + +/** + * Common params for KMeans and KMeansModel + */ +private[clustering] trait KMeansParams + extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + def getK: Int = $(k) + + /** + * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm + * this many times with random starting conditions (configured by the initialization mode), then + * return the best clustering found over any run. Must be >= 1. Default: 1. + * @group param + */ + final val runs = new IntParam(this, "runs", + "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1) + + /** @group getParam */ + def getRuns: Int = $(runs) + + /** + * Param the distance threshold within which we've consider centers to have converged. + * If all centers move less than this Euclidean distance, we stop iterating one run. + * Must be >= 0.0. Default: 1e-4 + * @group param + */ + final val epsilon = new DoubleParam(this, "epsilon", + "distance threshold within which we've consider centers to have converge", + (value: Double) => value >= 0.0) + + /** @group getParam */ + def getEpsilon: Double = $(epsilon) + + /** + * Param for the initialization algorithm. This can be either "random" to choose random points as + * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ + * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * @group expertParam + */ + final val initMode = new Param[String](this, "initMode", "initialization algorithm", + (value: String) => MLlibKMeans.validateInitMode(value)) + + /** @group expertGetParam */ + def getInitMode: String = $(initMode) + + /** + * Param for the number of steps for the k-means|| initialization mode. This is an advanced + * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5. + * @group expertParam + */ + final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||", + (value: Int) => value > 0) + + /** @group expertGetParam */ + def getInitSteps: Int = $(initSteps) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Model fitted by KMeans. + * + * @param parentModel a model trained by spark.mllib.clustering.KMeans. + */ +@Experimental +class KMeansModel private[ml] ( + override val uid: String, + private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + + override def copy(extra: ParamMap): KMeansModel = { + val copied = new KMeansModel(uid, parentModel) + copyValues(copied, extra) + } + + override def transform(dataset: DataFrame): DataFrame = { + val predictUDF = udf((vector: Vector) => predict(vector)) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + def clusterCenters: Array[Vector] = parentModel.clusterCenters +} + +/** + * :: Experimental :: + * K-means clustering with support for multiple parallel runs and a k-means++ like initialization + * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + * they are executed together with joint passes over the data for efficiency. + */ +@Experimental +class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { + + setDefault( + k -> 2, + maxIter -> 20, + runs -> 1, + initMode -> MLlibKMeans.K_MEANS_PARALLEL, + initSteps -> 5, + epsilon -> 1e-4) + + override def copy(extra: ParamMap): KMeans = defaultCopy(extra) + + def this() = this(Identifiable.randomUID("kmeans")) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group expertSetParam */ + def setInitSteps(value: Int): this.type = set(initSteps, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setRuns(value: Int): this.type = set(runs, value) + + /** @group setParam */ + def setEpsilon(value: Double): this.type = set(epsilon, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + override def fit(dataset: DataFrame): KMeansModel = { + val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + + val algo = new MLlibKMeans() + .setK($(k)) + .setInitializationMode($(initMode)) + .setInitializationSteps($(initSteps)) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setEpsilon($(epsilon)) + .setRuns($(runs)) + val parentModel = algo.run(rdd) + val model = new KMeansModel(uid, parentModel) + copyValues(model) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 68297130a7b03..0a65403f4ec95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -85,9 +85,7 @@ class KMeans private ( * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. */ def setInitializationMode(initializationMode: String): this.type = { - if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { - throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) - } + KMeans.validateInitMode(initializationMode) this.initializationMode = initializationMode this } @@ -550,6 +548,14 @@ object KMeans { v2: VectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } + + private[spark] def validateInitMode(initMode: String): Boolean = { + initMode match { + case KMeans.RANDOM => true + case KMeans.K_MEANS_PARALLEL => true + case _ => false + } + } } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java new file mode 100644 index 0000000000000..d09fa7fd5637c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaKMeansSuite implements Serializable { + + private transient int k = 5; + private transient JavaSparkContext sc; + private transient DataFrame dataset; + private transient SQLContext sql; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaKMeansSuite"); + sql = new SQLContext(sc); + + dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void fitAndTransform() { + KMeans kmeans = new KMeans().setK(k).setSeed(1); + KMeansModel model = kmeans.fit(dataset); + + Vector[] centers = model.clusterCenters(); + assertEquals(k, centers.length); + + DataFrame transformed = model.transform(dataset); + List columns = Arrays.asList(transformed.columns()); + List expectedColumns = Arrays.asList("features", "prediction"); + for (String column: expectedColumns) { + assertTrue(columns.contains(column)); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala new file mode 100644 index 0000000000000..1f15ac02f4008 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + +private[clustering] case class TestRow(features: Vector) + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } +} + +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + final val k = 5 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val kmeans = new KMeans() + + assert(kmeans.getK === 2) + assert(kmeans.getFeaturesCol === "features") + assert(kmeans.getPredictionCol === "prediction") + assert(kmeans.getMaxIter === 20) + assert(kmeans.getRuns === 1) + assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) + assert(kmeans.getInitSteps === 5) + assert(kmeans.getEpsilon === 1e-4) + } + + test("set parameters") { + val kmeans = new KMeans() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setMaxIter(33) + .setRuns(7) + .setInitMode(MLlibKMeans.RANDOM) + .setInitSteps(3) + .setSeed(123) + .setEpsilon(1e-3) + + assert(kmeans.getK === 9) + assert(kmeans.getFeaturesCol === "test_feature") + assert(kmeans.getPredictionCol === "test_prediction") + assert(kmeans.getMaxIter === 33) + assert(kmeans.getRuns === 7) + assert(kmeans.getInitMode === MLlibKMeans.RANDOM) + assert(kmeans.getInitSteps === 3) + assert(kmeans.getSeed === 123) + assert(kmeans.getEpsilon === 1e-3) + } + + test("parameters validation") { + intercept[IllegalArgumentException] { + new KMeans().setK(1) + } + intercept[IllegalArgumentException] { + new KMeans().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new KMeans().setInitSteps(0) + } + intercept[IllegalArgumentException] { + new KMeans().setRuns(0) + } + } + + test("fit & transform") { + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = kmeans.fit(dataset) + assert(model.clusterCenters.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4291b0be2a616..12828547d7077 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -481,8 +481,8 @@ object Unidoc { "mllib.tree.impurity", "mllib.tree.model", "mllib.util", "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation", "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss", - "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", - "ml.recommendation", "ml.regression", "ml.tuning" + "ml", "ml.attribute", "ml.classification", "ml.clustering", "ml.evaluation", "ml.feature", + "ml.param", "ml.recommendation", "ml.regression", "ml.tuning" ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 518b8e774dd5f..86d4186a2c798 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -33,6 +33,14 @@ pyspark.ml.classification module :undoc-members: :inherited-members: +pyspark.ml.clustering module +---------------------------- + +.. automodule:: pyspark.ml.clustering + :members: + :undoc-members: + :inherited-members: + pyspark.ml.recommendation module -------------------------------- diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py new file mode 100644 index 0000000000000..b5e9b6549d9f1 --- /dev/null +++ b/python/pyspark/ml/clustering.py @@ -0,0 +1,206 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * +from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector + +__all__ = ['KMeans', 'KMeansModel'] + + +class KMeansModel(JavaModel): + """ + Model fitted by KMeans. + """ + + def clusterCenters(self): + """Get the cluster centers, represented as a list of NumPy arrays.""" + return [c.toArray() for c in self._call_java("clusterCenters")] + + +@inherit_doc +class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): + """ + K-means Clustering + + >>> from pyspark.mllib.linalg import Vectors + >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features") + >>> model = kmeans.fit(df) + >>> centers = model.clusterCenters() + >>> len(centers) + 2 + >>> transformed = model.transform(df).select("features", "prediction") + >>> rows = transformed.collect() + >>> rows[0].prediction == rows[1].prediction + True + >>> rows[2].prediction == rows[3].prediction + True + """ + + # a placeholder to make it appear in the generated doc + k = Param(Params._dummy(), "k", "number of clusters to create") + epsilon = Param(Params._dummy(), "epsilon", + "distance threshold within which " + + "we've consider centers to have converged") + runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel") + initMode = Param(Params._dummy(), "initMode", + "the initialization algorithm. This can be either \"random\" to " + + "choose random points as initial cluster centers, or \"k-means||\" " + + "to use a parallel variant of k-means++") + initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") + + @keyword_only + def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5): + super(KMeans, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) + self.k = Param(self, "k", "number of clusters to create") + self.epsilon = Param(self, "epsilon", + "distance threshold within which " + + "we've consider centers to have converged") + self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel") + self.seed = Param(self, "seed", "random seed") + self.initMode = Param(self, "initMode", + "the initialization algorithm. This can be either \"random\" to " + + "choose random points as initial cluster centers, or \"k-means||\" " + + "to use a parallel variant of k-means++") + self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") + self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def _create_model(self, java_model): + return KMeansModel(java_model) + + @keyword_only + def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + """ + setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + + Sets params for KMeans. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + + >>> algo = KMeans().setK(10) + >>> algo.getK() + 10 + """ + self._paramMap[self.k] = value + return self + + def getK(self): + """ + Gets the value of `k` + """ + return self.getOrDefault(self.k) + + def setEpsilon(self, value): + """ + Sets the value of :py:attr:`epsilon`. + + >>> algo = KMeans().setEpsilon(1e-5) + >>> abs(algo.getEpsilon() - 1e-5) < 1e-5 + True + """ + self._paramMap[self.epsilon] = value + return self + + def getEpsilon(self): + """ + Gets the value of `epsilon` + """ + return self.getOrDefault(self.epsilon) + + def setRuns(self, value): + """ + Sets the value of :py:attr:`runs`. + + >>> algo = KMeans().setRuns(10) + >>> algo.getRuns() + 10 + """ + self._paramMap[self.runs] = value + return self + + def getRuns(self): + """ + Gets the value of `runs` + """ + return self.getOrDefault(self.runs) + + def setInitMode(self, value): + """ + Sets the value of :py:attr:`initMode`. + + >>> algo = KMeans() + >>> algo.getInitMode() + 'k-means||' + >>> algo = algo.setInitMode("random") + >>> algo.getInitMode() + 'random' + """ + self._paramMap[self.initMode] = value + return self + + def getInitMode(self): + """ + Gets the value of `initMode` + """ + return self.getOrDefault(self.initMode) + + def setInitSteps(self, value): + """ + Sets the value of :py:attr:`initSteps`. + + >>> algo = KMeans().setInitSteps(10) + >>> algo.getInitSteps() + 10 + """ + self._paramMap[self.initSteps] = value + return self + + def getInitSteps(self): + """ + Gets the value of `initSteps` + """ + return self.getOrDefault(self.initSteps) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.clustering tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1)