Skip to content

Commit

Permalink
update Lasso and RidgeRegression to parse the weights correctly from GLM
Browse files Browse the repository at this point in the history
mark createModel protected
mark predictPoint protected
  • Loading branch information
mengxr committed Mar 26, 2014
1 parent d7f629f commit 0e57aa4
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
* @param weightMatrix Column vector containing the weights of the model
* @param intercept Intercept of the model.
*/
def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
protected def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double): Double

/**
Expand Down
20 changes: 14 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class LassoModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override protected def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand Down Expand Up @@ -66,7 +68,7 @@ class LassoWithSGD private (
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept, so set this to false.
setIntercept(false)
super.setIntercept(false)

var yMean = 0.0
var xColMean: DoubleMatrix = _
Expand All @@ -77,10 +79,16 @@ class LassoWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}

override protected def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)

new LassoModel(weightsScaled.data, interceptScaled)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LinearRegressionModel(
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {

override protected def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand All @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {

val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
Expand All @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
override protected def createModel(weights: Array[Double], intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class RidgeRegressionModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override protected def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand Down Expand Up @@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private (
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept in RidgeRegression, so set this to false.
setIntercept(false)
super.setIntercept(false)

var yMean = 0.0
var xColMean: DoubleMatrix = _
Expand All @@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}

override protected def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)

Expand Down

0 comments on commit 0e57aa4

Please sign in to comment.