Skip to content

Commit

Permalink
[SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam…
Browse files Browse the repository at this point in the history
…` to (Streaming)LinearRegressionWithSGD

## What changes were proposed in this pull request?

`LinearRegressionWithSGD` and `StreamingLinearRegressionWithSGD` does not have `regParam` as their constructor arguments. They just depends on GradientDescent's default reqParam values.
To be consistent with other algorithms, we had better add them. The same default value is used.

## How was this patch tested?

Pass the existing unit test.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #11527 from dongjoon-hyun/SPARK-13686.
  • Loading branch information
dongjoon-hyun authored and mengxr committed Mar 14, 2016
1 parent 23385e8 commit a48296f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] {
class LinearRegressionWithSGD private[mllib] (
private var stepSize: Double,
private var numIterations: Int,
private var regParam: Double,
private var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {

Expand All @@ -98,14 +99,15 @@ class LinearRegressionWithSGD private[mllib] (
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)

/**
* Construct a LinearRegression object with default parameters: {stepSize: 1.0,
* numIterations: 100, miniBatchFraction: 1.0}.
*/
@Since("0.8.0")
def this() = this(1.0, 100, 1.0)
def this() = this(1.0, 100, 0.0, 1.0)

override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
new LinearRegressionModel(weights, intercept)
Expand Down Expand Up @@ -141,7 +143,7 @@ object LinearRegressionWithSGD {
stepSize: Double,
miniBatchFraction: Double,
initialWeights: Vector): LinearRegressionModel = {
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction)
.run(input, initialWeights)
}

Expand All @@ -163,7 +165,7 @@ object LinearRegressionWithSGD {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double): LinearRegressionModel = {
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input)
new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(input)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.mllib.linalg.Vector
class StreamingLinearRegressionWithSGD private[mllib] (
private var stepSize: Double,
private var numIterations: Int,
private var regParam: Double,
private var miniBatchFraction: Double)
extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD]
with Serializable {
Expand All @@ -54,10 +55,10 @@ class StreamingLinearRegressionWithSGD private[mllib] (
* (see `StreamingLinearAlgorithm`)
*/
@Since("1.1.0")
def this() = this(0.1, 50, 1.0)
def this() = this(0.1, 50, 0.0, 1.0)

@Since("1.1.0")
val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction)

protected var model: Option[LinearRegressionModel] = None

Expand All @@ -71,8 +72,17 @@ class StreamingLinearRegressionWithSGD private[mllib] (
}

/**
* Set the number of iterations of gradient descent to run per update. Default: 50.
* Set the regularization parameter. Default: 0.0.
*/
@Since("2.0.0")
def setRegParam(regParam: Double): this.type = {
this.algorithm.optimizer.setRegParam(regParam)
this
}

/**
* Set the number of iterations of gradient descent to run per update. Default: 50.
*/
@Since("1.1.0")
def setNumIterations(numIterations: Int): this.type = {
this.algorithm.optimizer.setNumIterations(numIterations)
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions")
) ++ Seq(
// [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this")
)
case v if v.startsWith("1.6") =>
Seq(
Expand Down

0 comments on commit a48296f

Please sign in to comment.