From b7851cc2985a3eff38a59ba621ea95e0afc60a79 Mon Sep 17 00:00:00 2001 From: coderxiang Date: Tue, 7 Oct 2014 20:12:10 -0700 Subject: [PATCH] Use generic type to represent IDs --- .../apache/spark/mllib/evaluation/RankingMetrics.scala | 8 +++++--- .../spark/mllib/evaluation/RankingMetricsSuite.scala | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index d39985c3d7e16..685e6e27f0ba9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -17,12 +17,14 @@ package org.apache.spark.mllib.evaluation +import scala.reflect.ClassTag import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD + /** * ::Experimental:: * Evaluator for ranking algorithms. @@ -30,13 +32,13 @@ import org.apache.spark.rdd.RDD * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. */ @Experimental -class RankingMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { +class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) { /** * Returns the precsion@k for each query */ lazy val precAtK: RDD[Array[Double]] = predictionAndLabels.map {case (pred, lab)=> - val labSet : Set[Double] = lab.toSet + val labSet = lab.toSet val n = pred.length val topkPrec = Array.fill[Double](n)(0.0) var (i, cnt) = (0, 0) @@ -55,7 +57,7 @@ class RankingMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { * Returns the average precision for each query */ lazy val avePrec: RDD[Double] = predictionAndLabels.map {case (pred, lab) => - val labSet: Set[Double] = lab.toSet + val labSet = lab.toSet var (i, cnt, precSum) = (0, 0, 0.0) val n = pred.length diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 4ff7fb33b17c8..1c9c8d7210288 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -25,8 +25,8 @@ class RankingMetricsSuite extends FunSuite with LocalSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( - (Array[Double](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Double](1, 2, 3, 4, 5)), - (Array[Double](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Double](1, 2, 3)) + (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), + (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)) ), 2) val eps: Double = 1E-5