Skip to content

Commit

Permalink
allow to change or clear threshold in LR and SVM
Browse files Browse the repository at this point in the history
add more comments to MLUtils.fastSquaredDistance
  • Loading branch information
mengxr committed Mar 31, 2014
1 parent 4addc50 commit b01df54
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ class PythonMLLibAPI extends Serializable {

private def trainRegressionModel(
trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
java.util.LinkedList[java.lang.Object] = {
dataBytesJRDD: JavaRDD[Array[Byte]],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
Expand Down Expand Up @@ -238,9 +238,9 @@ class PythonMLLibAPI extends Serializable {
/**
* Java stub for NaiveBayes.train()
*/
def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double)
: java.util.List[java.lang.Object] =
{
def trainNaiveBayes(
dataBytesJRDD: JavaRDD[Array[Byte]],
lambda: Double): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
Expand All @@ -256,9 +256,12 @@ class PythonMLLibAPI extends Serializable {
/**
* Java stub for Python mllib KMeans.train()
*/
def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
maxIterations: Int, runs: Int, initializationMode: String):
java.util.List[java.lang.Object] = {
def trainKMeansModel(
dataBytesJRDD: JavaRDD[Array[Byte]],
k: Int,
maxIterations: Int,
runs: Int,
initializationMode: String): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes)))
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
Expand Down Expand Up @@ -311,8 +314,12 @@ class PythonMLLibAPI extends Serializable {
* needs to be taken in the Python code to ensure it gets freed on exit; see
* the Py4J documentation.
*/
def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
def trainALSModel(
ratingsBytesJRDD: JavaRDD[Array[Byte]],
rank: Int,
iterations: Int,
lambda: Double,
blocks: Int): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
ALS.train(ratings, rank, iterations, lambda, blocks)
}
Expand All @@ -323,8 +330,13 @@ class PythonMLLibAPI extends Serializable {
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
def trainImplicitALSModel(
ratingsBytesJRDD: JavaRDD[Array[Byte]],
rank: Int,
iterations: Int,
lambda: Double,
blocks: Int,
alpha: Double): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@

package org.apache.spark.mllib.classification

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* Represents a classification model that predicts to which of a set of categories an example
* belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc.
*/
trait ClassificationModel extends Serializable {
/**
* Predict values for the given data set using the model trained.
*
* @param testData RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
* @return an RDD[Double] where each entry contains the corresponding prediction
*/
def predict(testData: RDD[Vector]): RDD[Double]

/**
* Predict values for a single data point using the model trained.
*
* @param testData array representing a single data point
* @return Int prediction from the trained model
* @return predicted category from the trained model
*/
def predict(testData: Vector): Double
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

package org.apache.spark.mllib.classification

import scala.math.round

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
import org.apache.spark.rdd.RDD

/**
* Classification model trained using Logistic Regression.
Expand All @@ -36,13 +33,36 @@ import org.apache.spark.mllib.linalg.Vector
class LogisticRegressionModel(
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {

private var threshold: Option[Double] = Some(0.5)

/**
* Sets the threshold that separates positive predictions from negative predictions. An example
* with prediction score greater than or equal to this threshold is identified as an positive,
* and negative otherwise. The default value is 0.5.
*/
def setThreshold(threshold: Double): this.type = {
this.threshold = Some(threshold)
this
}

/**
* Clears the threshold so that `predict` will output raw prediction scores.
*/
def clearThreshold(): this.type = {
threshold = None
this
}

override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
round(1.0/ (1.0 + math.exp(margin * -1)))
val score = 1.0/ (1.0 + math.exp(-margin))
threshold match {
case Some(t) => if (score < t) 0.0 else 1.0
case None => score
}
}
}

Expand All @@ -55,16 +75,15 @@ class LogisticRegressionWithSGD private (
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
with Serializable {
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {

val gradient = new LogisticGradient()
val updater = new SimpleUpdater()
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
.setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
override val validators = List(DataValidators.classificationLabels)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
package org.apache.spark.mllib.classification

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
import org.apache.spark.rdd.RDD

/**
* Model for Support Vector Machines (SVMs).
Expand All @@ -34,13 +33,35 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
class SVMModel(
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {

private var threshold: Option[Double] = Some(0.0)

/**
* Sets the threshold that separates positive predictions from negative predictions. An example
* with prediction score greater than or equal to this threshold is identified as an positive,
* and negative otherwise. The default value is 0.0.
*/
def setThreshold(threshold: Double): this.type = {
this.threshold = Some(threshold)
this
}

/**
* Clears the threshold so that `predict` will output raw prediction scores.
*/
def clearThreshold(): this.type = {
threshold = None
this
}

override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
if (margin < 0) 0.0 else 1.0
threshold match {
case Some(t) => if (margin < 0) 0.0 else 1.0
case None => margin
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ class KMeans private (
var runs: Int,
var initializationMode: String,
var initializationSteps: Int,
var epsilon: Double)
extends Serializable with Logging {
var epsilon: Double) extends Serializable with Logging {
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)

/** Set the number of clusters to create (k). Default: 2. */
Expand Down
12 changes: 12 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,18 @@ object MLUtils {
val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
val normDiff = norm1 - norm2
var sqDist = 0.0
/*
* The relative error is
* <pre>
* EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
* </pre>
* which is bounded by
* <pre>
* 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
* </pre>
* The bound doesn't need the inner product, so we can use it as a sufficient condition to
* check quickly whether the inner product approach is accurate.
*/
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
sqDist = sumSquaredNorm - 2.0 * v1.dot(v2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.classification;


import java.io.Serializable;
import java.util.List;

Expand All @@ -28,7 +27,6 @@

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

import org.apache.spark.mllib.regression.LabeledPoint;

public class JavaSVMSuite implements Serializable {
Expand Down Expand Up @@ -94,5 +92,4 @@ public void runSVMUsingStaticMethods() {
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext {
}

intercept[SparkException] {
val model = SVMWithSGD.train(testRDDInvalid, 100)
SVMWithSGD.train(testRDDInvalid, 100)
}

// Turning off data validation should not throw an exception
val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}

0 comments on commit b01df54

Please sign in to comment.