diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f4228fe5e7522..924ab43f26e06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -20,8 +20,10 @@ package org.apache.spark.mllib.classification import scala.collection.mutable import org.jblas.DoubleMatrix +import breeze.linalg.{Vector => BV} import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.SparkContext._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.mllib.util.MLUtils @@ -76,7 +78,13 @@ class NaiveBayes private (var lambda: Double) * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ def run(data: RDD[LabeledPoint]) = { - runRaw(data.map(v => (v.label, v.features.toArray))) + val agg = data.map(p => (p.label, p.features)).combineByKey[(Long, BV[Double])]( + createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), + mergeValue = (c: (Long, BV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), + mergeCombiners = (c1: (Long, BV[Double]), c2: (Long, BV[Double])) => + (c1._1 + c2._1, c1._2 += c2._2) + ).collect() + val numLabels = agg.size } /**