Skip to content

Commit

Permalink
remove normalization from Lasso and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 31, 2014
1 parent f04fe8a commit 4ca5b1b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 50 deletions.
39 changes: 4 additions & 35 deletions mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

package org.apache.spark.mllib.regression

import breeze.linalg.{Vector => BV}

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD

/**
* Regression model trained using Lasso.
Expand Down Expand Up @@ -58,8 +56,7 @@ class LassoWithSGD private (
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {
extends GeneralizedLinearAlgorithm[LassoModel] with Serializable {

val gradient = new LeastSquaresGradient()
val updater = new L1Updater()
Expand All @@ -71,10 +68,6 @@ class LassoWithSGD private (
// We don't want to penalize the intercept, so set this to false.
super.setIntercept(false)

private var yMean = 0.0
private var xColMean: BV[Double] = _
private var xColSd: BV[Double] = _

/**
* Construct a Lasso object with default parameters
*/
Expand All @@ -87,31 +80,7 @@ class LassoWithSGD private (
}

override protected def createModel(weights: Vector, intercept: Double) = {
val weightsMat = weights.toBreeze
val weightsScaled = weightsMat :/ xColSd
val interceptScaled = yMean - weightsMat.dot(xColMean :/ xColSd)

new LassoModel(Vectors.fromBreeze(weightsScaled), interceptScaled)
}

override def run(input: RDD[LabeledPoint], initialWeights: Vector): LassoModel = {
val nfeatures: Int = input.first.features.size
val nexamples: Long = input.count()

// To avoid penalizing the intercept, we center and scale the data.
val stats = MLUtils.computeStats(input, nfeatures, nexamples)
yMean = stats._1
xColMean = stats._2.toBreeze
xColSd = stats._3.toBreeze

val normalizedData = input.map { point =>
val yNormalized = point.label - yMean
val featuresMat = point.features.toBreeze
val featuresNormalized = (featuresMat - xColMean) :/ xColSd
LabeledPoint(yNormalized, Vectors.fromBreeze(featuresNormalized))
}

super.run(normalizedData, initialWeights)
new LassoModel(weights, intercept)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
@Test
public void runLassoUsingConstructor() {
int nPoints = 10000;
double A = 2.0;
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};

JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
Expand All @@ -80,7 +80,7 @@ public void runLassoUsingConstructor() {
@Test
public void runLassoUsingStaticMethods() {
int nPoints = 10000;
double A = 2.0;
double A = 0.0;
double[] weights = {-1.5, 1.0e-2};

JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,27 @@ class LassoSuite extends FunSuite with LocalSparkContext {
val B = -1.5
val C = 1.0e-2

val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)

val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
.map { case LabeledPoint(label, features) =>
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
}
val testRDD = sc.parallelize(testData, 2).cache()

val ls = new LassoWithSGD()
ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)

val model = ls.run(testRDD)
val weight0 = model.weights(0)
val weight1 = model.weights(1)
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
val weight2 = model.weights(2)
assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")

val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
.map { case LabeledPoint(label, features) =>
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
}
val validationRDD = sc.parallelize(validationData, 2)

// Test prediction on RDD.
Expand All @@ -73,25 +78,32 @@ class LassoSuite extends FunSuite with LocalSparkContext {
val C = 1.0e-2

val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
.map { case LabeledPoint(label, features) =>
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
}

val initialA = -1.0
val initialB = -1.0
val initialC = -1.0
val initialWeights = Vectors.dense(initialB, initialC)
val initialWeights = Vectors.dense(initialA, initialB, initialC)

val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val testRDD = sc.parallelize(testData, 2).cache()

val ls = new LassoWithSGD()
ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)

val model = ls.run(testRDD, initialWeights)
val weight0 = model.weights(0)
val weight1 = model.weights(1)
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
val weight2 = model.weights(2)
assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")

val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
.map { case LabeledPoint(label, features) =>
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
}
val validationRDD = sc.parallelize(validationData,2)

// Test prediction on RDD.
Expand Down

0 comments on commit 4ca5b1b

Please sign in to comment.