Skip to content

Commit

Permalink
rename model to modelParams
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 6, 2014
1 parent 9d2d35d commit 5b8f413
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 11 deletions.
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class Estimator[M <: Model] extends PipelineStage with Params {
}

/**
* Parameter for the output model.
* Parameters for the output model.
*/
def model: Params = Params.empty
def modelParams: Params = Params.empty
}
4 changes: 3 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ abstract class Transformer extends PipelineStage with Params {
*/
@varargs
def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
transform(dataset, ParamMap.empty)
val map = new ParamMap()
paramPairs.foreach(map.put(_))
transform(dataset, map)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
override def setLabelCol(labelCol: String): this.type = super.setLabelCol(labelCol)
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)

override final val model: LogisticRegressionModelParams = new LogisticRegressionModelParams {}
override final val modelParams: LogisticRegressionModelParams = new LogisticRegressionModelParams {}

override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
import dataset.sqlContext._
Expand All @@ -58,7 +58,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
.setNumIterations(maxIter)
val lrm = new LogisticRegressionModel(lr.run(instances).weights)
instances.unpersist()
this.model.params.foreach { param =>
this.modelParams.params.foreach { param =>
if (map.contains(param)) {
lrm.paramMap.put(lrm.getParam(param.name), map(param))
}
Expand All @@ -71,6 +71,11 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
}

trait LogisticRegressionModelParams extends Params with HasThreshold with HasFeaturesCol
with HasScoreCol {
override def setThreshold(threshold: Double): this.type = super.setThreshold(threshold)
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)
override def setScoreCol(scoreCol: String): this.type = super.setScoreCol(scoreCol)
}

class LogisticRegressionModel(
val weights: Vector)
Expand All @@ -81,6 +86,7 @@ class LogisticRegressionModel(
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
println(s"transform called with $map")
import map.implicitMapping
val score: Vector => Double = (v) => {
val margin = BLAS.dot(v, weights)
Expand All @@ -92,7 +98,7 @@ class LogisticRegressionModel(
}
dataset.select(
Star(None),
score.call((featuresCol: String).attr) as 'score,
score.call((featuresCol: String).attr) as scoreCol,
predict.call((featuresCol: String).attr) as 'prediction)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,21 @@ public void tearDown() {

@Test
public void logisticRegression() {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset.schemaRDD());
model.transform(dataset.schemaRDD()).registerTempTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println(r);
}
}

@Test
public void logisticRegressionWithSetters() {
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0);
lr.model().setThreshold(0.8);
lr.modelParams().setThreshold(0.8);
LogisticRegressionModel model = lr.fit(dataset.schemaRDD());
model.transform(dataset.schemaRDD()).registerTempTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
Expand All @@ -72,6 +83,15 @@ public void logisticRegression() {
}
}

@Test
public void chainModelParams() {
LogisticRegression lr = new LogisticRegression();
lr.modelParams()
.setFeaturesCol("features")
.setScoreCol("score")
.setThreshold(0.5);
}

@Test
public void logisticRegressionFitWithVarargs() {
LogisticRegression lr = new LogisticRegression();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,39 @@ class LogisticRegressionSuite extends FunSuite {
.loadLibSVMFile(sparkContext, "../data/mllib/sample_binary_classification_data.txt")
.cache()

test("logistic regression alone") {
test("logistic regression") {
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
.select('label, 'prediction)
.collect()
.foreach(println)
}

test("logistic regression with setters") {
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
val model = lr.fit(dataset)
model.transform(dataset, lr.model.threshold -> 0.8) // overwrite threshold
model.transform(dataset, lr.modelParams.threshold -> 0.8) // overwrite threshold
.select('label, 'score, 'prediction).collect()
.foreach(println)
}

test("logistic regression fit with varargs") {
test("chain model parameters") {
val lr = new LogisticRegression
lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
lr.modelParams
.setFeaturesCol("features")
.setScoreCol("score")
.setThreshold(0.5)
}

test("logistic regression fit and transform with varargs") {
val lr = new LogisticRegression
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
.select('label, 'probability, 'prediction)
.foreach(println)
}

test("logistic regression with cross validation") {
Expand Down

0 comments on commit 5b8f413

Please sign in to comment.