Skip to content

Commit

Permalink
fixed bug in Pipeline (typo from last commit). updated examples for C…
Browse files Browse the repository at this point in the history
…V and Params for spark.ml
  • Loading branch information
jkbradley committed Dec 4, 2014
1 parent c38469c commit d393b5c
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ public static void main(String[] args) {

// Run cross-validation, and choose the best set of parameters.
CrossValidatorModel cvModel = crossval.fit(training);
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
Model lrModel = cvModel.bestModel();

// Prepare test documents, which are unlabeled.
List<Document> localTest = Lists.newArrayList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ public static void main(String[] args) {
JavaSQLContext jsql = new JavaSQLContext(jsc);

// Prepare training data.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
// into SchemaRDDs, where it uses the case class metadata to infer the schema.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
// into SchemaRDDs, where it uses the bean metadata to infer the schema.
List<LabeledPoint> localTraining = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
Expand All @@ -75,13 +75,13 @@ public static void main(String[] args) {

// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
paramMap.put(lr.maxIter(), 20); // Specify 1 Param.
paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
paramMap.put(lr.regParam(), 0.1);
paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.

// One can also combine ParamMaps.
ParamMap paramMap2 = new ParamMap();
paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name.
paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);

// Now learn a new model using the paramMapCombined parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ object CrossValidatorExample {

// Run cross-validation, and choose the best set of parameters.
val cvModel = crossval.fit(training)
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
val lrModel = cvModel.bestModel

// Prepare test documents, which are unlabeled.
val test = sparkContext.parallelize(Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.examples.ml

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.mllib.linalg.{Vector, Vectors}
Expand All @@ -40,8 +41,8 @@ object SimpleParamsExample {
import sqlContext._

// Prepare training data.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
// into SchemaRDDs, where it uses the case class metadata to infer the schema.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
// into SchemaRDDs, where it uses the bean metadata to infer the schema.
val training = sparkContext.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
Expand Down Expand Up @@ -69,10 +70,10 @@ object SimpleParamsExample {
// which supports several methods for specifying parameters.
val paramMap = ParamMap(lr.maxIter -> 20)
paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.5) // Specify multiple Params.
paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.

// One can also combine ParamMaps.
val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Changes output column name.
val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name
val paramMapCombined = paramMap ++ paramMap2

// Now learn a new model using the paramMapCombined parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.examples.ml
import scala.beans.BeanInfo

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
Expand Down
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ class PipelineModel private[ml] (

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
transformSchema(dataset.schema, map, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}
10 changes: 6 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.ml.param

import java.lang.reflect.Modifier

import org.apache.spark.annotation.AlphaComponent

import scala.annotation.varargs
import scala.collection.mutable

import java.lang.reflect.Modifier

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Identifiable

/**
Expand Down Expand Up @@ -223,6 +222,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* Puts a list of param pairs (overwrites if the input params exists).
* Not usable from Java
*/
@varargs
def put(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
put(p.param.asInstanceOf[Param[Any]], p.value)
Expand Down Expand Up @@ -283,6 +283,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* where the latter overwrites this if there exists conflicts.
*/
def ++(other: ParamMap): ParamMap = {
// TODO: Provide a better method name for Java users.
new ParamMap(this.map ++ other.map)
}

Expand All @@ -291,6 +292,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* Adds all parameters from the input param map into this param map.
*/
def ++=(other: ParamMap): this.type = {
// TODO: Provide a better method name for Java users.
this.map ++= other.map
this
}
Expand Down

0 comments on commit d393b5c

Please sign in to comment.