Skip to content

Commit

Permalink
rename copyValues to inheritValues and make it do the right thing
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 11, 2014
1 parent 51f1c06 commit 8791e8e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
instances.unpersist()
// copy model params
Params.copyValues(this, lrm)
Params.inheritValues(map, this, lrm)
lrm
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
}
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(this, map, scaler)
Params.copyValues(this, model)
Params.inheritValues(map, this, model)
model
}

Expand Down
30 changes: 18 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,10 @@ trait Params extends Identifiable with Serializable {
m.invoke(this).asInstanceOf[Param[Any]]
}

/**
* Internal param map.
*/
protected val paramMap: ParamMap = ParamMap.empty

/**
* Sets a parameter in the own parameter map.
*/
protected def set[T](param: Param[T], value: T): this.type = {
private[ml] def set[T](param: Param[T], value: T): this.type = {
require(param.parent.eq(this))
paramMap.put(param.asInstanceOf[Param[Any]], value)
this
Expand All @@ -160,10 +155,15 @@ trait Params extends Identifiable with Serializable {
/**
* Gets the value of a parameter.
*/
protected def get[T](param: Param[T]): T = {
private[ml] def get[T](param: Param[T]): T = {
require(param.parent.eq(this))
paramMap(param)
}

/**
* Internal param map.
*/
protected val paramMap: ParamMap = ParamMap.empty
}

private[ml] object Params {
Expand All @@ -174,12 +174,18 @@ private[ml] object Params {
val empty: Params = new Params {}

/**
* Copy parameter values that are explicitly set from one Params instance to another.
* Copies parameter values from the parent estimator to the child model it produced.
* @param paramMap the param map that holds parameters of the parent
* @param parent the parent estimator
* @param child the child model
*/
private[ml] def copyValues[F <: Params, T <: F](from: F, to: T): Unit = {
from.params.foreach { param =>
if (from.isSet(param)) {
to.set(to.getParam(param.name), from.get(param))
private[ml] def inheritValues[E <: Params, M <: E](
paramMap: ParamMap,
parent: E,
child: M): Unit = {
parent.params.foreach { param =>
if (paramMap.contains(param)) {
child.set(child.getParam(param.name), paramMap(param))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
val cvModel = new CrossValidatorModel(this, map, bestModel)
Params.copyValues(this, cvModel)
Params.inheritValues(map, this, cvModel)
cvModel
}

Expand Down

0 comments on commit 8791e8e

Please sign in to comment.