From 0e57aa43f61a62a70faf27aed58dea201b494809 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Mar 2014 11:44:48 -0700 Subject: [PATCH] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected --- .../GeneralizedLinearAlgorithm.scala | 2 +- .../apache/spark/mllib/regression/Lasso.scala | 20 +++++++++++++------ .../mllib/regression/LinearRegression.scala | 20 +++++++++---------- .../mllib/regression/RidgeRegression.scala | 18 ++++++++++++----- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 3e1ed91bf6729..2166c6bb6b443 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -44,7 +44,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + protected def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, intercept: Double): Double /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index fb2bc9b92a51c..e397a573079e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -36,8 +36,10 @@ class LassoModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -66,7 +68,7 @@ class LassoWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -77,10 +79,16 @@ class LassoWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override protected def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) new LassoModel(weightsScaled.data, interceptScaled) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 8ee40addb25d9..b4aafbe8bcaff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LinearRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private ( var stepSize: Double, var numIterations: Int, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LinearRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { val gradient = new LeastSquaresGradient() val updater = new SimpleUpdater() @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + override protected def createModel(weights: Array[Double], intercept: Double) = { new LinearRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index c504d3d40c773..325e78c8f2233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -36,8 +36,10 @@ class RidgeRegressionModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept in RidgeRegression, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override protected def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)