Skip to content

Commit

Permalink
specialize methods/types for Java
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 5, 2014
1 parent df293ed commit 1ef26e0
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 11 deletions.
5 changes: 4 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.ml
import org.apache.spark.ml.param.{ParamMap, Params, ParamPair}
import org.apache.spark.sql.SchemaRDD

import scala.annotation.varargs

/**
* Abstract class for estimators that fits models to data.
*/
Expand Down Expand Up @@ -52,7 +54,8 @@ abstract class Estimator[M <: Model] extends Identifiable with Params with Pipel
* @param otherParamPairs other parameters
* @return fitted model
*/
def fit(
@varargs
def fit[T](
dataset: SchemaRDD,
firstParamPair: ParamPair[_],
otherParamPairs: ParamPair[_]*): M = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.ml.example

import org.apache.spark.ml._
import org.apache.spark.ml.api.param.HasMetricName
import org.apache.spark.ml.param._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.SchemaRDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with Params

private val f2jBLAS = new F2jBLAS

// Overwrite return type for Java users.
override def setEstimator(estimator: Estimator[_]): this.type = super.setEstimator(estimator)
override def setEstimatorParamMaps(estimatorParamMaps: Array[ParamMap]): this.type =
super.setEstimatorParamMaps(estimatorParamMaps)
override def setEvaluator(evaluator: Evaluator): this.type = super.setEvaluator(evaluator)

val numFolds: Param[Int] = new Param(this, "numFolds", "number of folds for cross validation", 3)

def setNumFolds(numFolds: Int): this.type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.ml.example

import org.apache.spark.ml._
import org.apache.spark.ml.api.param._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector}
Expand All @@ -37,6 +36,12 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
setRegParam(0.1)
setMaxIter(100)

// Overwrite the return type of setters for Java users.
override def setRegParam(regParam: Double): this.type = super.setRegParam(regParam)
override def setMaxIter(maxIter: Int): this.type = super.setMaxIter(maxIter)
override def setLabelCol(labelCol: String): this.type = super.setLabelCol(labelCol)
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)

override final val model: LogisticRegressionModelParams = new LogisticRegressionModelParams {}

override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
package org.apache.spark.ml.example

import org.apache.spark.ml._
import org.apache.spark.ml.api.param.HasOutputCol
import org.apache.spark.ml.param.{ParamMap, Params, HasOutputCol, HasInputCol}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.SchemaRDD
Expand All @@ -29,6 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.Row

class StandardScaler extends Transformer with Params with HasInputCol with HasOutputCol {

override def setInputCol(inputCol: String): this.type = super.setInputCol(inputCol)
override def setOutputCol(outputCol: String): this.type = super.setOutputCol(outputCol)

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
Expand Down
32 changes: 30 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.lang.reflect.Modifier
import scala.collection.mutable
import scala.language.implicitConversions

import org.apache.spark.ml.Identifiable

/**
* A param with self-contained documentation and optionally default value.
*
Expand All @@ -30,7 +32,7 @@ import scala.language.implicitConversions
* @param doc documentation
* @tparam T param value type
*/
class Param[T] private (
class Param[T] private[param] (
val parent: Params,
val name: String,
val doc: String,
Expand Down Expand Up @@ -75,6 +77,16 @@ class Param[T] private (
}
}

class DoubleParam(parent: Params, name: String, doc: String, default: Option[Double] = None)
extends Param[Double](parent, name, doc, default) {
override def w(value: Double): ParamPair[Double] = ParamPair(this, value)
}

class IntParam(parent: Params, name: String, doc: String, default: Option[Int] = None)
extends Param[Int](parent, name, doc, default) {
override def w(value: Int): ParamPair[Int] = ParamPair(this, value)
}

/**
* A param amd its value.
*/
Expand Down Expand Up @@ -199,7 +211,7 @@ class ParamMap private[ml] (
/**
* Filter this param map for the given parent.
*/
def filter(parent: Identifiable): ParamMap = {
def filter(parent: Params): ParamMap = {
val map = params.filterKeys(_.parent == parent)
new ParamMap(map.asInstanceOf[mutable.Map[Param[Any], Any]])
}
Expand Down Expand Up @@ -260,6 +272,22 @@ class ParamGridBuilder {
this
}

/**
* Specialize for Java users.
*/
def addMulti(param: DoubleParam, values: Array[Double]): this.type = {
paramGrid.put(param, values)
this
}

/**
* Specialize for Java users.
*/
def addMulti(param: IntParam, values: Array[Int]): this.type = {
paramGrid.put(param, values)
this
}

def build(): Array[ParamMap] = {
var paramSets = Array(new ParamMap)
paramGrid.foreach { case (param, values) =>
Expand Down
6 changes: 2 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/param/shared.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.ml.param

import org.apache.spark.ml.Params

trait HasRegParam extends Params {

val regParam: Param[Double] = new Param(this, "regParam", "regularization parameter")
val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")

def setRegParam(regParam: Double): this.type = {
set(this.regParam, regParam)
Expand All @@ -35,7 +33,7 @@ trait HasRegParam extends Params {

trait HasMaxIter extends Params {

val maxIter: Param[Int] = new Param(this, "maxIter", "max number of iterations")
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")

def setMaxIter(maxIter: Int): this.type = {
set(this.maxIter, maxIter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.regression

import scala.beans.BeanInfo

import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
Expand All @@ -27,6 +29,7 @@ import org.apache.spark.SparkException
* @param label Label for this data point.
* @param features List of features for this data point.
*/
@BeanInfo
case class LabeledPoint(label: Double, features: Vector) {
override def toString: String = {
"(%s,%s)".format(label, features)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.example;

import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.param.ParamGridBuilder;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.api.java.JavaSQLContext;
import org.apache.spark.sql.api.java.JavaSchemaRDD;
import org.apache.spark.sql.api.java.Row;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

public class JavaLogisticRegressionSuite implements Serializable {

private transient JavaSparkContext jsc;
private transient JavaSQLContext jsql;
private transient JavaSchemaRDD dataset;

@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new JavaSQLContext(jsc);
JavaRDD<LabeledPoint> points =
MLUtils.loadLibSVMFile(jsc.sc(), "../data/mllib/sample_binary_classification_data.txt")
.toJavaRDD();
dataset = jsql.applySchema(points, LabeledPoint.class);
}

@After
public void tearDown() {
jsc.stop();
jsc = null;
}

@Test
public void logisticRegression() {
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0);
lr.model().setThreshold(0.8);
// In Java we can access baseSchemaRDD, while in Scala we cannot.
LogisticRegressionModel model = lr.fit(dataset.baseSchemaRDD());
model.transform(dataset.baseSchemaRDD()).registerTempTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println(r);
}
}

@Test
public void logisticRegressionWithCrossValidation() {
LogisticRegression lr = new LogisticRegression();
ParamMap[] lrParamMaps = new ParamGridBuilder()
.addMulti(lr.regParam(), new double[] {0.1, 100.0})
.addMulti(lr.maxIter(), new int[] {0, 5})
.build();
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
CrossValidator cv = new CrossValidator()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3);
CrossValidatorModel bestModel = cv.fit(dataset.baseSchemaRDD());
}

@Test
public void logisticRegressionWithPipeline() {
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures");
LogisticRegression lr = new LogisticRegression()
.setFeaturesCol("scaledFeatures");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset.baseSchemaRDD());
model.transform(dataset.baseSchemaRDD()).registerTempTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println(r);
}
}
}

0 comments on commit 1ef26e0

Please sign in to comment.