Skip to content

Commit

Permalink
add a test for sparse linear regression
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 27, 2014
1 parent 44733e1 commit f0fe616
Showing 1 changed file with 34 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
import org.apache.spark.mllib.linalg.Vectors

class LinearRegressionSuite extends FunSuite with LocalSparkContext {

Expand Down Expand Up @@ -84,4 +85,37 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly learn Y = 10*X1 + 10*X10000
test("sparse linear regression without intercept") {
val denseRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2)
val sparseRDD = denseRDD.map { case LabeledPoint(label, v) =>
val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
LabeledPoint(label, sv)
}.cache()
val linReg = new LinearRegressionWithSGD().setIntercept(false)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)

val model = linReg.run(sparseRDD)

assert(model.intercept === 0.0)

val weights = model.weights
assert(weights.size === 10000)
assert(weights(0) >= 9.0 && weights(0) <= 11.0)
assert(weights(9999) >= 9.0 && weights(9999) <= 11.0)

val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
val sparseValidationData = validationData.map { case LabeledPoint(label, v) =>
val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
LabeledPoint(label, sv)
}
val sparseValidationRDD = sc.parallelize(sparseValidationData, 2)

// Test prediction on RDD.
validatePrediction(model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData)

// Test prediction on Array.
validatePrediction(sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
}
}

0 comments on commit f0fe616

Please sign in to comment.