Skip to content

Commit

Permalink
set the default value of AddIntercept to false
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 29, 2014
1 parent 03389c0 commit c81807f
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
/** The optimizer to solve the problem. */
def optimizer: Optimizer

/** Whether to add intercept (default: true). */
protected var addIntercept: Boolean = true
/** Whether to add intercept (default: false). */
protected var addIntercept: Boolean = false

protected var validateData: Boolean = true

Expand All @@ -94,7 +94,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
protected def createModel(weights: Vector, intercept: Double): M

/**
* Set if the algorithm should add an intercept. Default true.
* Set if the algorithm should add an intercept. Default false.
* We set the default to false because adding the intercept will cause memory allocation.
*/
def setIntercept(addIntercept: Boolean): this.type = {
this.addIntercept = addIntercept
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ class LassoWithSGD private (
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept, so set this to false.
super.setIntercept(false)

/**
* Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100,
* regParam: 1.0, miniBatchFraction: 1.0}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ class RidgeRegressionWithSGD private (
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept in RidgeRegression, so set this to false.
super.setIntercept(false)

/**
* Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100,
* regParam: 1.0, miniBatchFraction: 1.0}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul

val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val lr = new LogisticRegressionWithSGD()
val lr = new LogisticRegressionWithSGD().setIntercept(true)
lr.optimizer.setStepSize(10.0).setNumIterations(20)

val model = lr.run(testRDD)
Expand Down Expand Up @@ -118,7 +118,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul
testRDD.cache()

// Use half as many iterations as the previous test.
val lr = new LogisticRegressionWithSGD()
val lr = new LogisticRegressionWithSGD().setIntercept(true)
lr.optimizer.setStepSize(10.0).setNumIterations(10)

val model = lr.run(testRDD, initialWeights)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class SVMSuite extends FunSuite with LocalSparkContext {
assert(numOffPredictions < input.length / 5)
}


test("SVM using local random SGD") {
val nPoints = 10000

Expand All @@ -83,7 +82,7 @@ class SVMSuite extends FunSuite with LocalSparkContext {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

val svm = new SVMWithSGD()
val svm = new SVMWithSGD().setIntercept(true)
svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)

val model = svm.run(testRDD)
Expand Down Expand Up @@ -115,7 +114,7 @@ class SVMSuite extends FunSuite with LocalSparkContext {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

val svm = new SVMWithSGD()
val svm = new SVMWithSGD().setIntercept(true)
svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)

val model = svm.run(testRDD, initialWeights)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
test("linear regression") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithSGD()
val linReg = new LinearRegressionWithSGD().setIntercept(true)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)

val model = linReg.run(testRDD)
Expand Down

0 comments on commit c81807f

Please sign in to comment.