Skip to content

Commit

Permalink
[SYMAN-4256] Extract BinaryClassificationMetricsSuite assertions into…
Browse files Browse the repository at this point in the history
… private method
  • Loading branch information
Andrew Bullen committed Nov 10, 2014
1 parent 4d2f79a commit 36b0533
Showing 1 changed file with 40 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.apache.spark.mllib.util.TestingUtils._

class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {

def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5

def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)

private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
Expand All @@ -37,72 +37,76 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
assert(left.zip(right).forall(pairsWithinEpsilon))
}

private def validateMetrics(metrics: BinaryClassificationMetrics,
expectedThresholds: Seq[Double],
expectedROCCurve: Seq[(Double, Double)],
expectedPRCurve: Seq[(Double, Double)],
expectedFMeasures1: Seq[Double],
expectedFmeasures2: Seq[Double],
expectedPrecisions: Seq[Double],
expectedRecalls: Seq[Double]) = {

assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), expectedThresholds.zip(expectedFMeasures1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), expectedThresholds.zip(expectedFmeasures2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), expectedThresholds.zip(expectedPrecisions))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), expectedThresholds.zip(expectedRecalls))
}

test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val threshold = Seq(0.8, 0.6, 0.4, 0.1)
val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
val numTruePositives = Seq(1, 3, 3, 4)
val numFalsePositives = Seq(0, 1, 2, 3)
val numPositives = 4
val numNegatives = 3
val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
t.toDouble / (t + f)
}
val recall = numTruePositives.map(t => t.toDouble / numPositives)
val recalls = numTruePositives.map(t => t.toDouble / numPositives)
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

assertSequencesMatch(metrics.thresholds().collect(), threshold)
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics for All Positive RDD") {
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels)

val threshold = Seq(0.5)
val precision = Seq(1.0)
val recall = Seq(1.0)
val thresholds = Seq(0.5)
val precisions = Seq(1.0)
val recalls = Seq(1.0)
val fpr = Seq(0.0)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

assertSequencesMatch(metrics.thresholds().collect(), threshold)
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics for All Negative RDD") {
val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2)
val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels)

val threshold = Seq(0.5)
val precision = Seq(0.0)
val recall = Seq(0.0)
val thresholds = Seq(0.5)
val precisions = Seq(0.0)
val recalls = Seq(0.0)
val fpr = Seq(1.0)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map {
case (0, 0) => 0.0
Expand All @@ -113,14 +117,6 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
}

assertSequencesMatch(metrics.thresholds().collect(), threshold)
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}
}

0 comments on commit 36b0533

Please sign in to comment.