diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 757aea32ef94e..2df5b0d02b699 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -110,8 +110,8 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, - dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): - java.util.LinkedList[java.lang.Object] = { + dataBytesJRDD: JavaRDD[Array[Byte]], + initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { val x = deserializeDoubleVector(xBytes) LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) @@ -238,9 +238,9 @@ class PythonMLLibAPI extends Serializable { /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double) - : java.util.List[java.lang.Object] = - { + def trainNaiveBayes( + dataBytesJRDD: JavaRDD[Array[Byte]], + lambda: Double): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { val x = deserializeDoubleVector(xBytes) LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) @@ -256,9 +256,12 @@ class PythonMLLibAPI extends Serializable { /** * Java stub for Python mllib KMeans.train() */ - def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, - maxIterations: Int, runs: Int, initializationMode: String): - java.util.List[java.lang.Object] = { + def trainKMeansModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes))) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() @@ -311,8 +314,12 @@ class PythonMLLibAPI extends Serializable { * needs to be taken in the Python code to ensure it gets freed on exit; see * the Py4J documentation. */ - def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + def trainALSModel( + ratingsBytesJRDD: JavaRDD[Array[Byte]], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.train(ratings, rank, iterations, lambda, blocks) } @@ -323,8 +330,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. */ - def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + def trainImplicitALSModel( + ratingsBytesJRDD: JavaRDD[Array[Byte]], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int, + alpha: Double): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 2591d89b9e0dc..bd10e2e9e10e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,15 +17,19 @@ package org.apache.spark.mllib.classification -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +/** + * Represents a classification model that predicts to which of a set of categories an example + * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. + */ trait ClassificationModel extends Serializable { /** * Predict values for the given data set using the model trained. * * @param testData RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction + * @return an RDD[Double] where each entry contains the corresponding prediction */ def predict(testData: RDD[Vector]): RDD[Double] @@ -33,7 +37,7 @@ trait ClassificationModel extends Serializable { * Predict values for a single data point using the model trained. * * @param testData array representing a single data point - * @return Int prediction from the trained model + * @return predicted category from the trained model */ def predict(testData: Vector): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index da9995f5879ad..798f3a5c94740 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,15 +17,12 @@ package org.apache.spark.mllib.classification -import scala.math.round - import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.DataValidators -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.rdd.RDD /** * Classification model trained using Logistic Regression. @@ -36,13 +33,36 @@ import org.apache.spark.mllib.linalg.Vector class LogisticRegressionModel( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + + private var threshold: Option[Double] = Some(0.5) + + /** + * Sets the threshold that separates positive predictions from negative predictions. An example + * with prediction score greater than or equal to this threshold is identified as an positive, + * and negative otherwise. The default value is 0.5. + */ + def setThreshold(threshold: Double): this.type = { + this.threshold = Some(threshold) + this + } + + /** + * Clears the threshold so that `predict` will output raw prediction scores. + */ + def clearThreshold(): this.type = { + threshold = None + this + } override def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - round(1.0/ (1.0 + math.exp(margin * -1))) + val score = 1.0/ (1.0 + math.exp(-margin)) + threshold match { + case Some(t) => if (score < t) 0.0 else 1.0 + case None => score + } } } @@ -55,16 +75,15 @@ class LogisticRegressionWithSGD private ( var numIterations: Int, var regParam: Double, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LogisticRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { val gradient = new LogisticGradient() val updater = new SimpleUpdater() override val optimizer = new GradientDescent(gradient, updater) - .setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) override val validators = List(DataValidators.classificationLabels) /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index b854bcab815f0..e31a08899f8bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -18,12 +18,11 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.DataValidators -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.rdd.RDD /** * Model for Support Vector Machines (SVMs). @@ -34,13 +33,35 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} class SVMModel( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + + private var threshold: Option[Double] = Some(0.0) + + /** + * Sets the threshold that separates positive predictions from negative predictions. An example + * with prediction score greater than or equal to this threshold is identified as an positive, + * and negative otherwise. The default value is 0.0. + */ + def setThreshold(threshold: Double): this.type = { + this.threshold = Some(threshold) + this + } + + /** + * Clears the threshold so that `predict` will output raw prediction scores. + */ + def clearThreshold(): this.type = { + threshold = None + this + } override def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - if (margin < 0) 0.0 else 1.0 + threshold match { + case Some(t) => if (margin < 0) 0.0 else 1.0 + case None => margin + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b412738e3f00a..a78503df3134d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -42,8 +42,7 @@ class KMeans private ( var runs: Int, var initializationMode: String, var initializationSteps: Int, - var epsilon: Double) - extends Serializable with Logging { + var epsilon: Double) extends Serializable with Logging { def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) /** Set the number of clusters to create (k). Default: 2. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 5d21ed8e6983c..4c47eb5c7506c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -150,6 +150,18 @@ object MLUtils { val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 val normDiff = norm1 - norm2 var sqDist = 0.0 + /* + * The relative error is + *
+     * EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
+     * 
+ * which is bounded by + *
+     * 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
+     * 
+ * The bound doesn't need the inner product, so we can use it as a sufficient condition to + * check quickly whether the inner product approach is accurate. + */ val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) if (precisionBound1 < precision) { sqDist = sumSquaredNorm - 2.0 * v1.dot(v2) diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 117e5eaa8b78e..4701a5e545020 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.classification; - import java.io.Serializable; import java.util.List; @@ -28,7 +27,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; - import org.apache.spark.mllib.regression.LabeledPoint; public class JavaSVMSuite implements Serializable { @@ -94,5 +92,4 @@ public void runSVMUsingStaticMethods() { int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index dc35d2483296d..dfacbfeee6fb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext { } intercept[SparkException] { - val model = SVMWithSGD.train(testRDDInvalid, 100) + SVMWithSGD.train(testRDDInvalid, 100) } // Turning off data validation should not throw an exception - val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) + new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } }