Skip to content

Commit

Permalink
fix java tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 30, 2015
1 parent e3c0e44 commit 37ba24e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class IsotonicRegressionModel (
* @return Predicted labels.
*/
def predict(testData: JavaDoubleRDD): JavaDoubleRDD = {
JavaDoubleRDD.fromRDD(predict(testData.rdd.asInstanceOf[RDD[Double]]))
JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]]))
}

/**
Expand Down Expand Up @@ -194,7 +194,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
* @return Isotonic regression model.
*/
def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = {
run(input.rdd.asInstanceOf[RDD[(Double, Double, Double)]])
run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]])
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,36 @@
package org.apache.spark.mllib.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaDoubleRDD;
import scala.Tuple3;

import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

public class JavaIsotonicRegressionSuite implements Serializable {
private transient JavaSparkContext sc;

private List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
List<Tuple3<Double, Double, Double>> input = new ArrayList<>();
List<Tuple3<Double, Double, Double>> input = Lists.newArrayList();

for(int i = 1; i <= labels.length; i++) {
input.add(new Tuple3(labels[i-1], (double)i, 1d));
for (int i = 1; i <= labels.length; i++) {
input.add(new Tuple3<Double, Double, Double>(labels[i-1], (double) i, 1d));
}

return input;
}

private double difference(List<Tuple3<Double, Double, Double>> expected, IsotonicRegressionModel model) {
double diff = 0;

for(int i = 0; i < model.predictions().length; i++) {
Tuple3<Double, Double, Double> exp = expected.get(i);
diff += Math.abs(model.predict(exp._2()) - exp._1());
}

return diff;
}

private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
sc.parallelize(generateIsotonicInput(labels)).cache();
sc.parallelize(generateIsotonicInput(labels), 2).cache();

return new IsotonicRegression().run(trainRDD);
}
Expand All @@ -80,20 +68,16 @@ public void testIsotonicRegressionJavaRDD() {
IsotonicRegressionModel model =
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});

List<Tuple3<Double, Double, Double>> expected =
generateIsotonicInput(new double[] {1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12});

Assert.assertTrue(difference(expected, model) == 0);
Assert.assertArrayEquals(
new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14);
}

@Test
public void testIsotonicRegressionPredictionsJavaRDD() {
IsotonicRegressionModel model =
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});

JavaDoubleRDD testRDD =
sc.parallelizeDoubles(Arrays.asList(new Double[] {0.0, 1.0, 9.5, 12.0, 13.0}));

JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0));
List<Double> predictions = model.predict(testRDD).collect();

Assert.assertTrue(predictions.get(0) == 1d);
Expand All @@ -102,4 +86,4 @@ public void testIsotonicRegressionPredictionsJavaRDD() {
Assert.assertTrue(predictions.get(3) == 12d);
Assert.assertTrue(predictions.get(4) == 12d);
}
}
}

0 comments on commit 37ba24e

Please sign in to comment.