Skip to content

Commit

Permalink
implement setters inside each class, add Params.copyValues [ci skip]
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 6, 2014
1 parent fd751fc commit 2d040b3
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params

setMetricName("areaUnderROC")

def setMetricName(value: String): this.type = { set(metricName, value); this }
def setScoreCol(value: String): this.type = { set(scoreCol, value); this }
def setLabelCol(value: String): this.type = { set(labelCol, value); this }

override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,75 +20,33 @@ package org.apache.spark.ml.example
import com.github.fommil.netlib.F2jBLAS

import org.apache.spark.ml._
import org.apache.spark.ml.param.{ParamMap, Params, Param}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SchemaRDD

trait HasEstimator extends Params {
class CrossValidator extends Estimator[CrossValidatorModel] with Params {

val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")

def setEstimator(estimator: Estimator[_]): this.type = {
set(this.estimator, estimator)
this
}

def getEstimator: Estimator[_] = {
get(this.estimator)
}
}

trait HasEvaluator extends Params {

val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")

def setEvaluator(evaluator: Evaluator): this.type = {
set(this.evaluator, evaluator)
this
}

def getEvaluator: Evaluator = {
get(evaluator)
}
}
private val f2jBLAS = new F2jBLAS

trait HasEstimatorParamMaps extends Params {
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
def setEstimator(value: Estimator[_]): this.type = { set(estimator, value); this }
def getEstimator: Estimator[_] = get(estimator)

val estimatorParamMaps: Param[Array[ParamMap]] =
new Param(this, "estimatorParamMaps", "param maps for the estimator")

def setEstimatorParamMaps(estimatorParamMaps: Array[ParamMap]): this.type = {
set(this.estimatorParamMaps, estimatorParamMaps)
def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = {
set(estimatorParamMaps, value)
this
}

def getEstimatorParamMaps: Array[ParamMap] = {
get(estimatorParamMaps)
}
}


class CrossValidator extends Estimator[CrossValidatorModel] with Params
with HasEstimator with HasEstimatorParamMaps with HasEvaluator {

private val f2jBLAS = new F2jBLAS

// Overwrite return type for Java users.
override def setEstimator(estimator: Estimator[_]): this.type = super.setEstimator(estimator)
override def setEstimatorParamMaps(estimatorParamMaps: Array[ParamMap]): this.type =
super.setEstimatorParamMaps(estimatorParamMaps)
override def setEvaluator(evaluator: Evaluator): this.type = super.setEvaluator(evaluator)
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
def setEvaluator(value: Evaluator): this.type = { set(evaluator, value); this }
def getEvaluator: Evaluator = get(evaluator)

val numFolds: Param[Int] = new Param(this, "numFolds", "number of folds for cross validation", 3)

def setNumFolds(numFolds: Int): this.type = {
set(this.numFolds, numFolds)
this
}

def getNumFolds: Int = {
get(numFolds)
}
def setNumFolds(value: Int): this.type = { set(numFolds, value); this }
def getNumFolds: Int = get(numFolds)

/**
* Fits a single model to the input data with provided parameter map.
Expand Down Expand Up @@ -135,4 +93,3 @@ class CrossValidatorModel(bestModel: Model, metric: Double) extends Model {
bestModel.transform(dataset, paramMap)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
setRegParam(0.1)
setMaxIter(100)

// Overwrite the return type of setters for Java users.
override def setRegParam(regParam: Double): this.type = super.setRegParam(regParam)
override def setMaxIter(maxIter: Int): this.type = super.setMaxIter(maxIter)
override def setLabelCol(labelCol: String): this.type = super.setLabelCol(labelCol)
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)
def setRegParam(value: Double): this.type = { set(regParam, value); this }
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
def setLabelCol(value: String): this.type = { set(labelCol, value); this }
def setFeaturesCol(value: String): this.type = { set(featuresCol, value); this }

override final val modelParams: LogisticRegressionModelParams = new LogisticRegressionModelParams {}
override final val modelParams: LogisticRegressionModelParams =
new LogisticRegressionModelParams {}

override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
import dataset.sqlContext._
Expand All @@ -58,23 +58,27 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
.setNumIterations(maxIter)
val lrm = new LogisticRegressionModel(lr.run(instances).weights)
instances.unpersist()
this.modelParams.params.foreach { param =>
if (map.contains(param)) {
lrm.paramMap.put(lrm.getParam(param.name), map(param))
}
}
Params.copyValues(modelParams, lrm)
if (!lrm.paramMap.contains(lrm.featuresCol) && map.contains(lrm.featuresCol)) {
lrm.setFeaturesCol(featuresCol)
}
lrm
}

/**
* Validates parameters specified by the input parameter map.
* Raises an exception if any parameter belongs to this object is invalid.
*/
override def validateParams(paramMap: ParamMap): Unit = {
super.validateParams(paramMap)
}
}

trait LogisticRegressionModelParams extends Params with HasThreshold with HasFeaturesCol
with HasScoreCol {
override def setThreshold(threshold: Double): this.type = super.setThreshold(threshold)
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)
override def setScoreCol(scoreCol: String): this.type = super.setScoreCol(scoreCol)
with HasScoreCol {
def setThreshold(value: Double): this.type = { set(threshold, value); this }
def setFeaturesCol(value: String): this.type = { set(featuresCol, value); this }
def setScoreCol(value: String): this.type = { set(scoreCol, value); this }
}

class LogisticRegressionModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,45 @@ import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.expressions.Row

class StandardScaler extends Transformer with Params with HasInputCol with HasOutputCol {
class StandardScaler extends Estimator[StandardScalerModel] with HasInputCol {

override def setInputCol(inputCol: String): this.type = super.setInputCol(inputCol)
override def setOutputCol(outputCol: String): this.type = super.setOutputCol(outputCol)
def setInputCol(value: String): this.type = { set(inputCol, value); this }

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
override val modelParams: StandardScalerModelParams = new StandardScalerModelParams {}

override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
import map.implicitMapping
val input = dataset.select((inputCol: String).attr)
.map { case Row(v: Vector) =>
v
}.cache()
}
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(scaler)
Params.copyValues(modelParams, model)
if (!model.paramMap.contains(model.inputCol)) {
model.setInputCol(inputCol)
}
model
}
}

trait StandardScalerModelParams extends Params with HasInputCol with HasOutputCol {
def setInputCol(value: String): this.type = { set(inputCol, value); this }
def setOutputCol(value: String): this.type = { set(outputCol, value); this }
}

class StandardScalerModel private[ml] (
scaler: feature.StandardScalerModel) extends Model with StandardScalerModelParams {

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
import map.implicitMapping
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
dataset.select(Star(None), scale.call((inputCol: String).attr) as Symbol(outputCol))
dataset.select(Star(None), scale.call((inputCol: String).attr) as outputCol)
}
}
13 changes: 13 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class Param[T] private[param] (
}
}

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...

class DoubleParam(parent: Params, name: String, doc: String, default: Option[Double] = None)
extends Param[Double](parent, name, doc, default) {
override def w(value: Double): ParamPair[Double] = ParamPair(this, value)
Expand Down Expand Up @@ -152,6 +154,17 @@ private[ml] object Params {
val empty: Params = new Params {
override def params: Array[Param[_]] = Array.empty
}

/**
* Copy parameter values from one Params instance to another.
*/
def copyValues[F <: Params, T <: F](from: F, to: T): Unit = {
from.params.foreach { param =>
if (from.paramMap.contains(param)) {
to.paramMap.put(to.getParam(param.name), from.paramMap(param))
}
}
}
}

/**
Expand Down
46 changes: 1 addition & 45 deletions mllib/src/main/scala/org/apache/spark/ml/param/shared.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ trait HasRegParam extends Params {

val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")

def setRegParam(regParam: Double): this.type = {
set(this.regParam, regParam)
this
}

def getRegParam: Double = {
get(regParam)
}
Expand All @@ -35,11 +30,6 @@ trait HasMaxIter extends Params {

val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")

def setMaxIter(maxIter: Int): this.type = {
set(this.maxIter, maxIter)
this
}

def getMaxIter: Int = {
get(maxIter)
}
Expand All @@ -50,11 +40,6 @@ trait HasFeaturesCol extends Params {
val featuresCol: Param[String] =
new Param(this, "featuresCol", "features column name", "features")

def setFeaturesCol(featuresCol: String): this.type = {
set(this.featuresCol, featuresCol)
this
}

def getFeaturesCol: String = {
get(featuresCol)
}
Expand All @@ -64,23 +49,14 @@ trait HasLabelCol extends Params {

val labelCol: Param[String] = new Param(this, "labelCol", "label column name", "label")

def setLabelCol(labelCol: String): this.type = {
set(this.labelCol, labelCol)
this
}

def getLabelCol: String = {
get(labelCol)
}
}

trait HasScoreCol extends Params {
val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", "score")

def setScoreCol(scoreCol: String): this.type = {
set(this.scoreCol, scoreCol)
this
}
val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", "score")

def getScoreCol: String = {
get(scoreCol)
Expand All @@ -91,11 +67,6 @@ trait HasThreshold extends Params {

val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold for prediction")

def setThreshold(threshold: Double): this.type = {
set(this.threshold, threshold)
this
}

def getThreshold: Double = {
get(threshold)
}
Expand All @@ -105,11 +76,6 @@ trait HasMetricName extends Params {

val metricName: Param[String] = new Param(this, "metricName", "metric name for evaluation")

def setMetricName(metricName: String): this.type = {
set(this.metricName, metricName)
this
}

def getMetricName: String = {
get(metricName)
}
Expand All @@ -119,11 +85,6 @@ trait HasInputCol extends Params {

val inputCol: Param[String] = new Param(this, "inputCol", "input column name")

def setInputCol(inputCol: String): this.type = {
set(this.inputCol, inputCol)
this
}

def getInputCol: String = {
get(inputCol)
}
Expand All @@ -133,11 +94,6 @@ trait HasOutputCol extends Params {

val outputCol: Param[String] = new Param(this, "outputCol", "output column name")

def setOutputCol(outputCol: String): this.type = {
set(this.outputCol, outputCol)
this
}

def getOutputCol: String = {
get(outputCol)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ public void logisticRegressionWithCrossValidation() {
@Test
public void logisticRegressionWithPipeline() {
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
.setInputCol("features");
scaler.modelParams()
.setOutputCol("scaledFeatures");
LogisticRegression lr = new LogisticRegression()
.setFeaturesCol("scaledFeatures");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite {
test("logistic regression with pipeline") {
val scaler = new StandardScaler()
.setInputCol("features")
scaler.modelParams
.setOutputCol("scaledFeatures")
val lr = new LogisticRegression()
.setFeaturesCol("scaledFeatures")
Expand Down

0 comments on commit 2d040b3

Please sign in to comment.