Skip to content

Commit

Permalink
merge glm
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 26, 2014
2 parents 3f346ba + 0e57aa4 commit 135ab72
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
* @param weightMatrix Column vector containing the weights of the model
* @param intercept Intercept of the model.
*/
def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double
protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double

/**
* Predict values for the given data set using the model trained.
Expand Down Expand Up @@ -116,6 +116,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
run(input, initialWeights)
}

/** Prepends one to the input vector. */
private def prependOne(vector: Vector): Vector = {
val vectorWithIntercept = vector match {
case dv: BDV[Double] => BDV.vertcat(BDV.ones(1), dv)
Expand Down Expand Up @@ -154,8 +155,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
val intercept = if (addIntercept) brzWeightsWithIntercept(0) else 0.0
val brzWeights = if (addIntercept) brzWeightsWithIntercept(1 to -1) else brzWeightsWithIntercept

val model = createModel(Vectors.fromBreeze(brzWeights), intercept)

model
createModel(Vectors.fromBreeze(brzWeights), intercept)
}
}
20 changes: 14 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class LassoModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override protected def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand Down Expand Up @@ -66,7 +68,7 @@ class LassoWithSGD private (
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept, so set this to false.
setIntercept(false)
super.setIntercept(false)

var yMean = 0.0
var xColMean: DoubleMatrix = _
Expand All @@ -77,10 +79,16 @@ class LassoWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}

override protected def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)

new LassoModel(weightsScaled.data, interceptScaled)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LinearRegressionModel(
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {

override protected def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand All @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {

val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
Expand All @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
override protected def createModel(weights: Array[Double], intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ class RidgeRegressionModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override protected def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand Down Expand Up @@ -66,7 +68,7 @@ class RidgeRegressionWithSGD private (
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept in RidgeRegression, so set this to false.
setIntercept(false)
super.setIntercept(false)

var yMean = 0.0
var xColMean: DoubleMatrix = _
Expand All @@ -77,8 +79,14 @@ class RidgeRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}

override protected def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.regression

import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
Expand Down Expand Up @@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly learn Y = 10*X1 + 10*X2
test("linear regression without intercept") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithSGD().setIntercept(false)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)

val model = linReg.run(testRDD)

assert(model.intercept === 0.0)
assert(model.weights.length === 2)
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)

val validationData = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 17)
val validationRDD = sc.parallelize(validationData, 2).cache()

// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)

// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}

0 comments on commit 135ab72

Please sign in to comment.