From f411e7034800d1eead88de2025983250df70a775 Mon Sep 17 00:00:00 2001 From: Andrew Bullen Date: Mon, 10 Nov 2014 12:04:35 -0800 Subject: [PATCH] [SPARK-4256] Define precision as 1.0 when there are no positive examples; update code formatting per pull request comments --- .../BinaryClassificationMetricComputers.scala | 25 +++++++++++-------- .../BinaryClassificationMetricsSuite.scala | 8 +++--- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala index 3b573db115c25..be3319d60ce25 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -24,38 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl def apply(c: BinaryConfusionMatrix): Double } -/** Precision. */ +/** Precision. Defined as 1.0 when there are no positive examples. */ private[evaluation] object Precision extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - if (c.numTruePositives + c.numFalsePositives == 0) { - 0.0 + override def apply(c: BinaryConfusionMatrix): Double = { + val totalPositives = c.numTruePositives + c.numFalsePositives + if (totalPositives == 0) { + 1.0 } else { - c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) + c.numTruePositives.toDouble / totalPositives } + } } -/** False positive rate. */ +/** False positive rate. Defined as 0.0 when there are no negative examples. */ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = + override def apply(c: BinaryConfusionMatrix): Double = { if (c.numNegatives == 0) { 0.0 } else { c.numFalsePositives.toDouble / c.numNegatives } + } } -/** Recall. */ +/** Recall. Defined as 0.0 when there are no positive examples. */ private[evaluation] object Recall extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = + override def apply(c: BinaryConfusionMatrix): Double = { if (c.numPositives == 0) { 0.0 } else { c.numTruePositives.toDouble / c.numPositives } + } } /** - * F-Measure. + * F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples + * are false positives. * @param beta the beta constant in F-Measure * @see http://en.wikipedia.org/wiki/F1_score */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index a7f2e63b77ba9..54d481940bca6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -61,7 +61,7 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { } test("binary evaluation metrics for All Positive RDD") { - val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0)), 2) + val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2) val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels) val threshold = Seq(0.5) @@ -86,7 +86,7 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { } test("binary evaluation metrics for All Negative RDD") { - val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0)), 2) + val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2) val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels) val threshold = Seq(0.5) @@ -97,11 +97,11 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val pr = recall.zip(precision) val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { - case (0,0) => 0.0 + case (0, 0) => 0.0 case (r, p) => 2.0 * (p * r) / (p + r) } val f2 = pr.map { - case (0,0) => 0.0 + case (0, 0) => 0.0 case (r, p) => 5.0 * (p * r) / (4.0 * p + r) }