Skip to content

Commit

Permalink
[SPARK-4256] Refactor classification metrics tests - extract comparis…
Browse files Browse the repository at this point in the history
…on functions in test
  • Loading branch information
Andrew Bullen committed Nov 10, 2014
1 parent f411e70 commit 4d2f79a
Showing 1 changed file with 31 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ import org.apache.spark.mllib.util.TestingUtils._

class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {

def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5

def cond2(x: ((Double, Double), (Double, Double))): Boolean =
def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)

private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
assert(left.zip(right).forall(areWithinEpsilon))
}

private def assertTupleSequencesMatch(left: Seq[(Double, Double)], right: Seq[(Double, Double)]): Unit = {
assert(left.zip(right).forall(pairsWithinEpsilon))
}

test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
Expand All @@ -49,15 +57,15 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
assertSequencesMatch(metrics.thresholds().collect(), threshold)
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
}

test("binary evaluation metrics for All Positive RDD") {
Expand All @@ -74,15 +82,15 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
assertSequencesMatch(metrics.thresholds().collect(), threshold)
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
}

test("binary evaluation metrics for All Negative RDD") {
Expand All @@ -105,14 +113,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
}

assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
assertSequencesMatch(metrics.thresholds().collect(), threshold)
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
}
}

0 comments on commit 4d2f79a

Please sign in to comment.