diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 0b033484c1a9d..ac77f0bb5d8f7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -30,31 +30,36 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = { - assert(left.zip(right).forall(areWithinEpsilon)) + assert(left.zip(right).forall(areWithinEpsilon)) } - private def assertTupleSequencesMatch(left: Seq[(Double, Double)], right: Seq[(Double, Double)]): Unit = { + private def assertTupleSequencesMatch(left: Seq[(Double, Double)], + right: Seq[(Double, Double)]): Unit = { assert(left.zip(right).forall(pairsWithinEpsilon)) } private def validateMetrics(metrics: BinaryClassificationMetrics, - expectedThresholds: Seq[Double], - expectedROCCurve: Seq[(Double, Double)], - expectedPRCurve: Seq[(Double, Double)], - expectedFMeasures1: Seq[Double], - expectedFmeasures2: Seq[Double], - expectedPrecisions: Seq[Double], - expectedRecalls: Seq[Double]) = { + expectedThresholds: Seq[Double], + expectedROCCurve: Seq[(Double, Double)], + expectedPRCurve: Seq[(Double, Double)], + expectedFMeasures1: Seq[Double], + expectedFmeasures2: Seq[Double], + expectedPrecisions: Seq[Double], + expectedRecalls: Seq[Double]) = { assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5) assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve) assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5) - assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), expectedThresholds.zip(expectedFMeasures1)) - assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), expectedThresholds.zip(expectedFmeasures2)) - assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), expectedThresholds.zip(expectedPrecisions)) - assertTupleSequencesMatch(metrics.recallByThreshold().collect(), expectedThresholds.zip(expectedRecalls)) + assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), + expectedThresholds.zip(expectedFMeasures1)) + assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), + expectedThresholds.zip(expectedFmeasures2)) + assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), + expectedThresholds.zip(expectedPrecisions)) + assertTupleSequencesMatch(metrics.recallByThreshold().collect(), + expectedThresholds.zip(expectedRecalls)) } test("binary evaluation metrics") { @@ -80,9 +85,9 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } - test("binary evaluation metrics for All Positive RDD") { + test("binary evaluation metrics for RDD where all examples have positive label") { val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2) - val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) val thresholds = Seq(0.5) val precisions = Seq(1.0) @@ -97,9 +102,9 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } - test("binary evaluation metrics for All Negative RDD") { + test("binary evaluation metrics for RDD where all examples have negative label") { val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2) - val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) val thresholds = Seq(0.5) val precisions = Seq(0.0)