From 0b632e601c6c49d20f3b2eac6bea9642bff39ae9 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 25 Nov 2014 19:23:18 -0800 Subject: [PATCH] kmeans --- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 34ea0de706f08..be4ce5e891f27 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm} +import breeze.linalg.{DenseVector => BDV, Vector => BV} import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -125,7 +124,7 @@ class KMeans private ( } // Compute squared norms and cache them. - val norms = data.map(v => breezeNorm(v.toBreeze, 2.0)) + val norms = data.map(_.norm(2.0)) norms.persist() val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) => new BreezeVectorWithNorm(v, norm) @@ -425,7 +424,7 @@ object KMeans { private[clustering] class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable { - def this(vector: BV[Double]) = this(vector, breezeNorm(vector, 2.0)) + def this(vector: BV[Double]) = this(vector, Vectors.fromBreeze(vector).norm(2.0)) def this(array: Array[Double]) = this(new BDV[Double](array))