diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 7036f3c6424c3..3b156fa0482fc 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -107,8 +107,6 @@ public static void main(String[] args) { // Run cross-validation, and choose the best set of parameters. CrossValidatorModel cvModel = crossval.fit(training); - // Get the best LogisticRegression model (with the best set of parameters from paramGrid). - Model lrModel = cvModel.bestModel(); // Prepare test documents, which are unlabeled. List localTest = Lists.newArrayList( diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index f13e6c12307b2..cf58f4dfaa15b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -47,8 +47,8 @@ public static void main(String[] args) { JavaSQLContext jsql = new JavaSQLContext(jsc); // Prepare training data. - // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes - // into SchemaRDDs, where it uses the case class metadata to infer the schema. + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans + // into SchemaRDDs, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), @@ -75,13 +75,13 @@ public static void main(String[] args) { // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); - paramMap.put(lr.maxIter(), 20); // Specify 1 Param. + paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - paramMap.put(lr.regParam(), 0.1); + paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); - paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name. + paramMap2.put(lr.scoreCol().w("probability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 3e46b3a8d4b1e..ce6bc066bd70d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -91,8 +91,6 @@ object CrossValidatorExample { // Run cross-validation, and choose the best set of parameters. val cvModel = crossval.fit(training) - // Get the best LogisticRegression model (with the best set of parameters from paramGrid). - val lrModel = cvModel.bestModel // Prepare test documents, which are unlabeled. val test = sparkContext.parallelize(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 071af5de33379..44d5b084c269a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,6 +18,7 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -40,8 +41,8 @@ object SimpleParamsExample { import sqlContext._ // Prepare training data. - // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes - // into SchemaRDDs, where it uses the case class metadata to infer the schema. + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans + // into SchemaRDDs, where it uses the bean metadata to infer the schema. val training = sparkContext.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), @@ -69,10 +70,10 @@ object SimpleParamsExample { // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.5) // Specify multiple Params. + paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. - val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Changes output column name. + val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index d50ef885200aa..92895a05e479a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,6 +20,7 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 7ff2191848494..081a574beea5d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -163,14 +163,14 @@ class PipelineModel private[ml] ( override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap + val map = (fittingParamMap ++ this.paramMap) ++ paramMap transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap + val map = (fittingParamMap ++ this.paramMap) ++ paramMap stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5c8ae7e2b7b47..4b4340af543b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,13 +17,12 @@ package org.apache.spark.ml.param -import java.lang.reflect.Modifier - -import org.apache.spark.annotation.AlphaComponent - import scala.annotation.varargs import scala.collection.mutable +import java.lang.reflect.Modifier + +import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Identifiable /** @@ -223,6 +222,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Puts a list of param pairs (overwrites if the input params exists). * Not usable from Java */ + @varargs def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => put(p.param.asInstanceOf[Param[Any]], p.value) @@ -283,6 +283,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * where the latter overwrites this if there exists conflicts. */ def ++(other: ParamMap): ParamMap = { + // TODO: Provide a better method name for Java users. new ParamMap(this.map ++ other.map) } @@ -291,6 +292,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Adds all parameters from the input param map into this param map. */ def ++=(other: ParamMap): this.type = { + // TODO: Provide a better method name for Java users. this.map ++= other.map this }