Skip to content

Commit

Permalink
remove normalization from RidgeRegression and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 31, 2014
1 parent d088552 commit f04fe8a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

package org.apache.spark.mllib.regression

import breeze.linalg.{Vector => BV}

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.linalg.Vector

/**
* Regression model trained using RidgeRegression.
Expand Down Expand Up @@ -58,8 +56,7 @@ class RidgeRegressionWithSGD private (
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[RidgeRegressionModel]
with Serializable {
extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable {

val gradient = new LeastSquaresGradient()
val updater = new SquaredL2Updater()
Expand All @@ -72,10 +69,6 @@ class RidgeRegressionWithSGD private (
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
super.setIntercept(false)

private var yMean = 0.0
private var xColMean: BV[Double] = _
private var xColSd: BV[Double] = _

/**
* Construct a RidgeRegression object with default parameters
*/
Expand All @@ -88,35 +81,7 @@ class RidgeRegressionWithSGD private (
}

override protected def createModel(weights: Vector, intercept: Double) = {
val weightsMat = weights.toBreeze
val weightsScaled = weightsMat :/ xColSd
val interceptScaled = yMean - weightsMat.dot(xColMean :/ xColSd)

new RidgeRegressionModel(Vectors.fromBreeze(weightsScaled), interceptScaled)
}

override def run(
input: RDD[LabeledPoint],
initialWeights: Vector)
: RidgeRegressionModel =
{
val nfeatures: Int = input.first().features.size
val nexamples: Long = input.count()

// To avoid penalizing the intercept, we center and scale the data.
val stats = MLUtils.computeStats(input, nfeatures, nexamples)
yMean = stats._1
xColMean = stats._2.toBreeze
xColSd = stats._3.toBreeze

val normalizedData = input.map { point =>
val yNormalized = point.label - yMean
val featuresMat = point.features.toBreeze
val featuresNormalized = (featuresMat - xColMean) :/ xColSd
LabeledPoint(yNormalized, Vectors.fromBreeze(featuresNormalized))
}

super.run(normalizedData, initialWeights)
new RidgeRegressionModel(weights, intercept)
}
}

Expand Down Expand Up @@ -145,9 +110,7 @@ object RidgeRegressionWithSGD {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
initialWeights: Vector)
: RidgeRegressionModel =
{
initialWeights: Vector): RidgeRegressionModel = {
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(
input, initialWeights)
}
Expand All @@ -168,9 +131,7 @@ object RidgeRegressionWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double,
miniBatchFraction: Double)
: RidgeRegressionModel =
{
miniBatchFraction: Double): RidgeRegressionModel = {
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}

Expand All @@ -189,9 +150,7 @@ object RidgeRegressionWithSGD {
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double)
: RidgeRegressionModel =
{
regParam: Double): RidgeRegressionModel = {
train(input, numIterations, stepSize, regParam, 1.0)
}

Expand All @@ -206,9 +165,7 @@ object RidgeRegressionWithSGD {
*/
def train(
input: RDD[LabeledPoint],
numIterations: Int)
: RidgeRegressionModel =
{
numIterations: Int): RidgeRegressionModel = {
train(input, numIterations, 1.0, 1.0, 1.0)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,27 @@ public void tearDown() {
return errorSum / validationData.size();
}

List<LabeledPoint> generateRidgeData(int numPoints, int nfeatures, double eps) {
List<LabeledPoint> generateRidgeData(int numPoints, int numFeatures, double std) {
org.jblas.util.Random.seed(42);
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
// Set first two weights to eps
w.put(0, 0, eps);
w.put(1, 0, eps);
return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5);
return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std);
}

@Test
public void runRidgeRegressionUsingConstructor() {
int nexamples = 200;
int nfeatures = 20;
double eps = 10.0;
List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
int numExamples = 50;
int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0);

JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);

RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
ridgeSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(0.0)
.setNumIterations(200);
ridgeSGDImpl.optimizer()
.setStepSize(1.0)
.setRegParam(0.0)
.setNumIterations(200);
RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
double unRegularizedErr = predictionError(validationData, model);

Expand All @@ -91,13 +88,12 @@ public void runRidgeRegressionUsingConstructor() {

@Test
public void runRidgeRegressionUsingStaticMethods() {
int nexamples = 200;
int nfeatures = 20;
double eps = 10.0;
List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
int numExamples = 50;
int numFeatures = 20;
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);

JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);

RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
double unRegularizedErr = predictionError(validationData, model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
}.reduceLeft(_ + _) / predictions.size
}

test("regularization with skewed weights") {
val nexamples = 200
val nfeatures = 20
val eps = 10
test("ridge regression can help avoid overfitting") {

// For small number of examples and large variance of error distribution,
// ridge regression should give smaller generalization error that linear regression.

val numExamples = 50
val numFeatures = 20

org.jblas.util.Random.seed(42)
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
// Set first two weights to eps
w.put(0, 0, eps)
w.put(1, 0, eps)
val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5)

// Use half of data for training and other half for validation
val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps)
val testData = data.take(nexamples)
val validationData = data.takeRight(nexamples)
val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0)
val testData = data.take(numExamples)
val validationData = data.takeRight(numExamples)

val testRDD = sc.parallelize(testData, 2).cache()
val validationRDD = sc.parallelize(validationData, 2).cache()
Expand All @@ -68,7 +68,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
val ridgeErr = predictionError(
ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)

// Ridge CV-error should be lower than linear regression
// Ridge validation error should be lower than linear regression.
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
Expand Down

0 comments on commit f04fe8a

Please sign in to comment.