From 5b8f41354e2a1ca5ee29c28552fc47120f1b7078 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 5 Nov 2014 21:14:02 -0800 Subject: [PATCH] rename model to modelParams --- .../scala/org/apache/spark/ml/Estimator.scala | 4 +-- .../org/apache/spark/ml/Transformer.scala | 4 ++- .../spark/ml/example/LogisticRegression.scala | 12 ++++++-- .../example/JavaLogisticRegressionSuite.java | 22 ++++++++++++++- .../ml/example/LogisticRegressionSuite.scala | 28 ++++++++++++++++--- 5 files changed, 59 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index d78cc4802cd82..bd71e56737620 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -64,7 +64,7 @@ abstract class Estimator[M <: Model] extends PipelineStage with Params { } /** - * Parameter for the output model. + * Parameters for the output model. */ - def model: Params = Params.empty + def modelParams: Params = Params.empty } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index bdc4863fac670..dbfed526eda82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -35,7 +35,9 @@ abstract class Transformer extends PipelineStage with Params { */ @varargs def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { - transform(dataset, ParamMap.empty) + val map = new ParamMap() + paramPairs.foreach(map.put(_)) + transform(dataset, map) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/example/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/example/LogisticRegression.scala index 20f49dee3efb4..097f44dc079b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/example/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/example/LogisticRegression.scala @@ -42,7 +42,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] override def setLabelCol(labelCol: String): this.type = super.setLabelCol(labelCol) override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol) - override final val model: LogisticRegressionModelParams = new LogisticRegressionModelParams {} + override final val modelParams: LogisticRegressionModelParams = new LogisticRegressionModelParams {} override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { import dataset.sqlContext._ @@ -58,7 +58,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] .setNumIterations(maxIter) val lrm = new LogisticRegressionModel(lr.run(instances).weights) instances.unpersist() - this.model.params.foreach { param => + this.modelParams.params.foreach { param => if (map.contains(param)) { lrm.paramMap.put(lrm.getParam(param.name), map(param)) } @@ -71,6 +71,11 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] } trait LogisticRegressionModelParams extends Params with HasThreshold with HasFeaturesCol + with HasScoreCol { + override def setThreshold(threshold: Double): this.type = super.setThreshold(threshold) + override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol) + override def setScoreCol(scoreCol: String): this.type = super.setScoreCol(scoreCol) +} class LogisticRegressionModel( val weights: Vector) @@ -81,6 +86,7 @@ class LogisticRegressionModel( override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { import dataset.sqlContext._ val map = this.paramMap ++ paramMap + println(s"transform called with $map") import map.implicitMapping val score: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) @@ -92,7 +98,7 @@ class LogisticRegressionModel( } dataset.select( Star(None), - score.call((featuresCol: String).attr) as 'score, + score.call((featuresCol: String).attr) as scoreCol, predict.call((featuresCol: String).attr) as 'prediction) } } diff --git a/mllib/src/test/java/org/apache/spark/ml/example/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/example/JavaLogisticRegressionSuite.java index 11b1a11aba0e1..8b14116df0385 100644 --- a/mllib/src/test/java/org/apache/spark/ml/example/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/example/JavaLogisticRegressionSuite.java @@ -60,10 +60,21 @@ public void tearDown() { @Test public void logisticRegression() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset.schemaRDD()); + model.transform(dataset.schemaRDD()).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + for (Row r: predictions.collect()) { + System.out.println(r); + } + } + + @Test + public void logisticRegressionWithSetters() { LogisticRegression lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0); - lr.model().setThreshold(0.8); + lr.modelParams().setThreshold(0.8); LogisticRegressionModel model = lr.fit(dataset.schemaRDD()); model.transform(dataset.schemaRDD()).registerTempTable("prediction"); JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); @@ -72,6 +83,15 @@ public void logisticRegression() { } } + @Test + public void chainModelParams() { + LogisticRegression lr = new LogisticRegression(); + lr.modelParams() + .setFeaturesCol("features") + .setScoreCol("score") + .setThreshold(0.5); + } + @Test public void logisticRegressionFitWithVarargs() { LogisticRegression lr = new LogisticRegression(); diff --git a/mllib/src/test/scala/org/apache/spark/ml/example/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/example/LogisticRegressionSuite.scala index 82d665976f44d..a648aa19135be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/example/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/example/LogisticRegressionSuite.scala @@ -30,19 +30,39 @@ class LogisticRegressionSuite extends FunSuite { .loadLibSVMFile(sparkContext, "../data/mllib/sample_binary_classification_data.txt") .cache() - test("logistic regression alone") { + test("logistic regression") { + val lr = new LogisticRegression + val model = lr.fit(dataset) + model.transform(dataset) + .select('label, 'prediction) + .collect() + .foreach(println) + } + + test("logistic regression with setters") { val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) - model.transform(dataset, lr.model.threshold -> 0.8) // overwrite threshold + model.transform(dataset, lr.modelParams.threshold -> 0.8) // overwrite threshold .select('label, 'score, 'prediction).collect() .foreach(println) } - test("logistic regression fit with varargs") { + test("chain model parameters") { val lr = new LogisticRegression - lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) + lr.modelParams + .setFeaturesCol("features") + .setScoreCol("score") + .setThreshold(0.5) + } + + test("logistic regression fit and transform with varargs") { + val lr = new LogisticRegression + val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) + model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") + .select('label, 'probability, 'prediction) + .foreach(println) } test("logistic regression with cross validation") {