Skip to content

Commit

Permalink
minor updates to NB
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 30, 2014
1 parent b11659c commit 0f8759b
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package org.apache.spark.mllib.classification
import scala.collection.mutable

import org.jblas.DoubleMatrix
import breeze.linalg.{Vector => BV}

import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.util.MLUtils
Expand Down Expand Up @@ -76,7 +78,13 @@ class NaiveBayes private (var lambda: Double)
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
def run(data: RDD[LabeledPoint]) = {
runRaw(data.map(v => (v.label, v.features.toArray)))
val agg = data.map(p => (p.label, p.features)).combineByKey[(Long, BV[Double])](
createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector),
mergeValue = (c: (Long, BV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze),
mergeCombiners = (c1: (Long, BV[Double]), c2: (Long, BV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()
val numLabels = agg.size
}

/**
Expand Down

0 comments on commit 0f8759b

Please sign in to comment.