diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ca2ac820fab13..5ed6477bae3b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -75,7 +75,7 @@ class IsotonicRegressionModel ( * @return Predicted labels. */ def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { - JavaDoubleRDD.fromRDD(predict(testData.rdd.asInstanceOf[RDD[Double]])) + JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) } /** @@ -194,7 +194,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * @return Isotonic regression model. */ def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { - run(input.rdd.asInstanceOf[RDD[(Double, Double, Double)]]) + run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) } /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index 046b359ea3eb6..d38fc91ace3cf 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -18,18 +18,17 @@ package org.apache.spark.mllib.regression; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import org.apache.spark.api.java.JavaDoubleRDD; import scala.Tuple3; +import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -37,29 +36,18 @@ public class JavaIsotonicRegressionSuite implements Serializable { private transient JavaSparkContext sc; private List> generateIsotonicInput(double[] labels) { - List> input = new ArrayList<>(); + List> input = Lists.newArrayList(); - for(int i = 1; i <= labels.length; i++) { - input.add(new Tuple3(labels[i-1], (double)i, 1d)); + for (int i = 1; i <= labels.length; i++) { + input.add(new Tuple3(labels[i-1], (double) i, 1d)); } return input; } - private double difference(List> expected, IsotonicRegressionModel model) { - double diff = 0; - - for(int i = 0; i < model.predictions().length; i++) { - Tuple3 exp = expected.get(i); - diff += Math.abs(model.predict(exp._2()) - exp._1()); - } - - return diff; - } - private IsotonicRegressionModel runIsotonicRegression(double[] labels) { JavaRDD> trainRDD = - sc.parallelize(generateIsotonicInput(labels)).cache(); + sc.parallelize(generateIsotonicInput(labels), 2).cache(); return new IsotonicRegression().run(trainRDD); } @@ -80,10 +68,8 @@ public void testIsotonicRegressionJavaRDD() { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); - List> expected = - generateIsotonicInput(new double[] {1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12}); - - Assert.assertTrue(difference(expected, model) == 0); + Assert.assertArrayEquals( + new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14); } @Test @@ -91,9 +77,7 @@ public void testIsotonicRegressionPredictionsJavaRDD() { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); - JavaDoubleRDD testRDD = - sc.parallelizeDoubles(Arrays.asList(new Double[] {0.0, 1.0, 9.5, 12.0, 13.0})); - + JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); Assert.assertTrue(predictions.get(0) == 1d); @@ -102,4 +86,4 @@ public void testIsotonicRegressionPredictionsJavaRDD() { Assert.assertTrue(predictions.get(3) == 12d); Assert.assertTrue(predictions.get(4) == 12d); } -} \ No newline at end of file +}