From 8d6befe21e26cc843fc96e4c2934a15c0797ce51 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 00:45:22 -0700 Subject: [PATCH 01/14] initial commit --- .../apache/spark/mllib/feature/Word2Vec.scala | 353 ++++++++++++++++++ .../spark/mllib/feature/Word2VecSuite.scala | 40 ++ 2 files changed, 393 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala new file mode 100644 index 0000000000000..9461d0cc1ba2d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -0,0 +1,353 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* Add a comment to this line +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.mllib.feature + +import scala.util._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark._ +import org.apache.spark.rdd._ +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.HashPartitioner + +private case class VocabWord( + var word: String, + var cn: Int, + var point: Array[Int], + var code: Array[Int], + var codeLen:Int +) + +class Word2Vec( + val size: Int, + val startingAlpha: Double, + val window: Int, + val minCount: Int) + extends Serializable with Logging { + + private val EXP_TABLE_SIZE = 1000 + private val MAX_EXP = 6 + private val MAX_CODE_LENGTH = 40 + private val MAX_SENTENCE_LENGTH = 1000 + private val layer1Size = size + + private var trainWordsCount = 0 + private var vocabSize = 0 + private var vocab: Array[VocabWord] = null + private var vocabHash = mutable.HashMap.empty[String, Int] + private var alpha = startingAlpha + + private def learnVocab(dataset: RDD[String]) { + vocab = dataset.flatMap(line => line.split(" ")) + .map(w => (w, 1)) + .reduceByKey(_ + _) + .map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) + .filter(_.cn >= minCount) + .collect() + .sortWith((a, b)=> a.cn > b.cn) + + vocabSize = vocab.length + var a = 0 + while (a < vocabSize) { + vocabHash += vocab(a).word -> a + trainWordsCount += vocab(a).cn + a += 1 + } + logInfo("trainWordsCount = " + trainWordsCount) + } + + private def createExpTable(): Array[Double] = { + val expTable = new Array[Double](EXP_TABLE_SIZE) + var i = 0 + while (i < EXP_TABLE_SIZE) { + val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) + expTable(i) = tmp / (tmp + 1) + i += 1 + } + expTable + } + + private def createBinaryTree() { + val count = new Array[Long](vocabSize * 2 + 1) + val binary = new Array[Int](vocabSize * 2 + 1) + val parentNode = new Array[Int](vocabSize * 2 + 1) + val code = new Array[Int](MAX_CODE_LENGTH) + val point = new Array[Int](MAX_CODE_LENGTH) + var a = 0 + while (a < vocabSize) { + count(a) = vocab(a).cn + a += 1 + } + while (a < 2 * vocabSize) { + count(a) = 1e9.toInt + a += 1 + } + var pos1 = vocabSize - 1 + var pos2 = vocabSize + + var min1i = 0 + var min2i = 0 + + a = 0 + while (a < vocabSize - 1) { + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min1i = pos1 + pos1 -= 1 + } else { + min1i = pos2 + pos2 += 1 + } + } else { + min1i = pos2 + pos2 += 1 + } + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min2i = pos1 + pos1 -= 1 + } else { + min2i = pos2 + pos2 += 1 + } + } else { + min2i = pos2 + pos2 += 1 + } + count(vocabSize + a) = count(min1i) + count(min2i) + parentNode(min1i) = vocabSize + a + parentNode(min2i) = vocabSize + a + binary(min2i) = 1 + a += 1 + } + // Now assign binary code to each vocabulary word + var i = 0 + a = 0 + while (a < vocabSize) { + var b = a + i = 0 + while (b != vocabSize * 2 - 2) { + code(i) = binary(b) + point(i) = b + i += 1 + b = parentNode(b) + } + vocab(a).codeLen = i + vocab(a).point(0) = vocabSize - 2 + b = 0 + while (b < i) { + vocab(a).code(i - b - 1) = code(b) + vocab(a).point(i - b) = point(b) - vocabSize + b += 1 + } + a += 1 + } + } + + /** + * Computes the vector representation of each word in + * vocabulary + * @param dataset an RDD of strings + */ + + def fit(dataset:RDD[String]): Word2VecModel = { + + learnVocab(dataset) + + createBinaryTree() + + val sc = dataset.context + + val expTable = sc.broadcast(createExpTable()) + val V = sc.broadcast(vocab) + val VHash = sc.broadcast(vocabHash) + + val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions { + iter => { new Iterator[Array[Int]] { + def hasNext = iter.hasNext + def next = { + var sentence = new ArrayBuffer[Int] + var sentenceLength = 0 + while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { + val word = VHash.value.get(iter.next) + word match { + case Some(w) => { + sentence += w + sentenceLength += 1 + } + case None => + } + } + sentence.toArray + } + } + } + } + + val newSentences = sentences.repartition(1).cache() + val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) + val (aggSyn0, _, _, _) = + // TODO: broadcast temp instead of serializing it directly or initialize the model in each executor + newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( + seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => + var lwc = lastWordCount + var wc = wordCount + if (wordCount - lastWordCount > 10000) { + lwc = wordCount + alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + } + wc += sentence.size + var pos = 0 + while (pos < sentence.size) { + val word = sentence(pos) + // TODO: fix random seed + val b = Random.nextInt(window) + // Train Skip-gram + var a = b + while (a < window * 2 + 1 - b) { + if (a != window) { + val c = pos - window + a + if (c >= 0 && c < sentence.size) { + val lastWord = sentence(c) + val l1 = lastWord * layer1Size + val neu1e = new Array[Double](layer1Size) + //HS + var d = 0 + while (d < vocab(word).codeLen) { + val l2 = vocab(word).point(d) * layer1Size + // Propagate hidden -> output + var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) + if (f > -MAX_EXP && f < MAX_EXP) { + val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt + f = expTable.value(ind) + val g = (1 - vocab(word).code(d) - f) * alpha + blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + } + d += 1 + } + blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) + } + } + a += 1 + } + pos += 1 + } + (syn0, syn1, lwc, wc) + }, + combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) + blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) + (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) + }) + + val wordMap = new Array[(String, Array[Double])](vocabSize) + var i = 0 + while (i < vocabSize) { + val word = vocab(i).word + val vector = new Array[Double](layer1Size) + Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size) + wordMap(i) = (word, vector) + i += 1 + } + val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100)) + new Word2VecModel(modelRDD) + } +} + +class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { + + val model = _model + + private def distance(v1: Array[Double], v2: Array[Double]): Double = { + require(v1.length == v2.length, "Vectors should have the same length") + val n = v1.length + val norm1 = blas.dnrm2(n, v1, 1) + val norm2 = blas.dnrm2(n, v2, 1) + if (norm1 == 0 || norm2 == 0) return 0.0 + blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 + } + + def transform(word: String): Array[Double] = { + val result = model.lookup(word) + if (result.isEmpty) Array[Double]() + else result(0) + } + + def transform(dataset: RDD[String]): RDD[Array[Double]] = { + dataset.map(word => transform(word)) + } + + def findSynonyms(word: String, num: Int): Array[(String, Double)] = { + val vector = transform(word) + if (vector.isEmpty) Array[(String, Double)]() + else findSynonyms(vector,num) + } + + def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { + require(num > 0, "Number of similar words should > 0") + val topK = model.map( + {case(w, vec) => (distance(vector, vec), w)}) + .sortByKey(ascending = false) + .take(num + 1) + .map({case (dist, w) => (w, dist)}).drop(1) + + topK + } +} + +object Word2Vec extends Serializable with Logging { + def train( + input: RDD[String], + size: Int, + startingAlpha: Double, + window: Int, + minCount: Int): Word2VecModel = { + new Word2Vec(size,startingAlpha, window, minCount).fit(input) + } + + def main(args: Array[String]) { + if (args.length < 6) { + println("Usage: word2vec input size startingAlpha window minCount num") + sys.exit(1) + } + val conf = new SparkConf() + .setAppName("word2vec") + + val sc = new SparkContext(conf) + val input = sc.textFile(args(0)) + val size = args(1).toInt + val startingAlpha = args(2).toDouble + val window = args(3).toInt + val minCount = args(4).toInt + val num = args(5).toInt + val model = train(input, size, startingAlpha, window, minCount) + val vec = model.findSynonyms("china", num) + for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString) + sc.stop() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala new file mode 100644 index 0000000000000..6e8cb94d44726 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* Add a comment to this line +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.LocalSparkContext + +class Word2VecSuite extends FunSuite with LocalSparkContext { + test("word2vec") { + val num = 2 + val localModel = Seq( + ("china" , Array(0.50, 0.50, 0.50, 0.50)), + ("japan" , Array(0.40, 0.50, 0.50, 0.50)), + ("taiwan", Array(0.60, 0.50, 0.50, 0.50)), + ("korea" , Array(0.45, 0.60, 0.60, 0.60)) + ) + val model = new Word2VecModel(sc.parallelize(localModel, 2)) + val synons = model.findSynonyms("china", num) + assert(synons.length == num) + assert(synons(0)._1 == "taiwan") + assert(synons(1)._1 == "japan") + } +} \ No newline at end of file From 0aafb1b02a19fe4f1689543baf1882a49a7ff11a Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 08:34:11 -0700 Subject: [PATCH 02/14] Add comments, minor fixes --- .../apache/spark/mllib/feature/Word2Vec.scala | 69 ++++++++++++------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9461d0cc1ba2d..18f507c2f1b46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.HashPartitioner +/** + * Entry in vocabulary + */ private case class VocabWord( var word: String, var cn: Int, @@ -39,6 +42,9 @@ private case class VocabWord( var codeLen:Int ) +/** + * Vector representation of word + */ class Word2Vec( val size: Int, val startingAlpha: Double, @@ -51,7 +57,8 @@ class Word2Vec( private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 private val layer1Size = size - + private val modelPartitionNum = 100 + private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null @@ -169,6 +176,7 @@ class Word2Vec( * Computes the vector representation of each word in * vocabulary * @param dataset an RDD of strings + * @return a Word2VecModel */ def fit(dataset:RDD[String]): Word2VecModel = { @@ -274,11 +282,14 @@ class Word2Vec( wordMap(i) = (word, vector) i += 1 } - val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100)) + val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum)) new Word2VecModel(modelRDD) } } +/** +* Word2Vec model +*/ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { val model = _model @@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 } + /** + * Transforms a word to its vector representation + * @param word a word + * @return vector representation of word + */ + def transform(word: String): Array[Double] = { val result = model.lookup(word) if (result.isEmpty) Array[Double]() else result(0) } + /** + * Transforms an RDD to its vector representation + * @param dataset a an RDD of words + * @return RDD of vector representation + */ + def transform(dataset: RDD[String]): RDD[Array[Double]] = { dataset.map(word => transform(word)) } + /** + * Find synonyms of a word + * @param word a word + * @param num number of synonyms to find + * @return array of (word, similarity) + */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) if (vector.isEmpty) Array[(String, Double)]() else findSynonyms(vector,num) } + /** + * Find synonyms of the vector representation of a word + * @param vector vector representation of a word + * @param num number of synonyms to find + * @return array of (word, similarity) + */ def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") val topK = model.map( @@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab } object Word2Vec extends Serializable with Logging { + /** + * Train Word2Vec model + * @param input RDD of words + * @param size vectoer dimension + * @param startingAlpha initial learning rate + * @param window context words from [-window, window] + * @param minCount minimum frequncy to consider a vocabulary word + * @return Word2Vec model + */ def train( input: RDD[String], size: Int, @@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging { minCount: Int): Word2VecModel = { new Word2Vec(size,startingAlpha, window, minCount).fit(input) } - - def main(args: Array[String]) { - if (args.length < 6) { - println("Usage: word2vec input size startingAlpha window minCount num") - sys.exit(1) - } - val conf = new SparkConf() - .setAppName("word2vec") - - val sc = new SparkContext(conf) - val input = sc.textFile(args(0)) - val size = args(1).toInt - val startingAlpha = args(2).toDouble - val window = args(3).toInt - val minCount = args(4).toInt - val num = args(5).toInt - val model = train(input, size, startingAlpha, window, minCount) - val vec = model.findSynonyms("china", num) - for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString) - sc.stop() - } } From e4a04d32be284f9a7ab2d3f57d745342912930a7 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 08:46:38 -0700 Subject: [PATCH 03/14] minor fix --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 18f507c2f1b46..adf1d2dc10fb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -373,4 +373,4 @@ object Word2Vec extends Serializable with Logging { minCount: Int): Word2VecModel = { new Word2Vec(size,startingAlpha, window, minCount).fit(input) } -} +} \ No newline at end of file From 57dc50d3f24beda8eb0348c0baf8dc343065fd2d Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 09:20:10 -0700 Subject: [PATCH 04/14] code formatting --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 5 ++--- .../scala/org/apache/spark/mllib/feature/Word2VecSuite.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index adf1d2dc10fb4..1cc432c6a6214 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -18,9 +18,8 @@ package org.apache.spark.mllib.feature -import scala.util._ +import scala.util.{Random => Random} import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -373,4 +372,4 @@ object Word2Vec extends Serializable with Logging { minCount: Int): Word2VecModel = { new Word2Vec(size,startingAlpha, window, minCount).fit(input) } -} \ No newline at end of file +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 6e8cb94d44726..2a02ace83c380 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -37,4 +37,4 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { assert(synons(0)._1 == "taiwan") assert(synons(1)._1 == "japan") } -} \ No newline at end of file +} From 2e92b5991ad8f3f73bbeab9a056f452c4b532b3c Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 18:17:38 -0700 Subject: [PATCH 05/14] modify according to feedback --- .../apache/spark/mllib/feature/Word2Vec.scala | 146 +++++++++++------- .../spark/mllib/feature/Word2VecSuite.scala | 32 ++-- 2 files changed, 102 insertions(+), 76 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1cc432c6a6214..f4266c94f63e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -1,33 +1,33 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* Add a comment to this line -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.mllib.feature -import scala.util.{Random => Random} +import scala.util.Random import scala.collection.mutable.ArrayBuffer import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.HashPartitioner /** @@ -42,8 +42,27 @@ private case class VocabWord( ) /** - * Vector representation of word + * :: Experimental :: + * Word2Vec creates vector representation of words in a text corpus. + * The algorithm first constructs a vocabulary from the corpus + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in + * natural language processing and machine learning algorithms. + * + * We used skip-gram model in our implementation and hierarchical softmax + * method to train the model. + * + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see + * Efficient Estimation of Word Representations in Vector Space + * and + * Distributed Representations of Words and Phrases and their Compositionality + * @param size vector dimension + * @param startingAlpha initial learning rate + * @param window context words from [-window, window] + * @param minCount minimum frequncy to consider a vocabulary word */ +@Experimental class Word2Vec( val size: Int, val startingAlpha: Double, @@ -64,11 +83,15 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(dataset: RDD[String]) { - vocab = dataset.flatMap(line => line.split(" ")) - .map(w => (w, 1)) + private def learnVocab(words:RDD[String]) { + vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) - .map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) + .map(x => VocabWord( + x._1, + x._2, + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b)=> a.cn > b.cn) @@ -172,15 +195,16 @@ class Word2Vec( } /** - * Computes the vector representation of each word in - * vocabulary - * @param dataset an RDD of strings + * Computes the vector representation of each word in vocabulary. + * @param dataset an RDD of words * @return a Word2VecModel */ - def fit(dataset:RDD[String]): Word2VecModel = { + def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = { - learnVocab(dataset) + val words = dataset.flatMap(x => x) + + learnVocab(words) createBinaryTree() @@ -190,9 +214,10 @@ class Word2Vec( val V = sc.broadcast(vocab) val VHash = sc.broadcast(vocabHash) - val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions { + val sentences = words.mapPartitions { iter => { new Iterator[Array[Int]] { def hasNext = iter.hasNext + def next = { var sentence = new ArrayBuffer[Int] var sentenceLength = 0 @@ -215,7 +240,8 @@ class Word2Vec( val newSentences = sentences.repartition(1).cache() val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) val (aggSyn0, _, _, _) = - // TODO: broadcast temp instead of serializing it directly or initialize the model in each executor + // TODO: broadcast temp instead of serializing it directly + // or initialize the model in each executor newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -241,7 +267,7 @@ class Word2Vec( val lastWord = sentence(c) val l1 = lastWord * layer1Size val neu1e = new Array[Double](layer1Size) - //HS + // Hierarchical softmax var d = 0 while (d < vocab(word).codeLen) { val l2 = vocab(word).point(d) * layer1Size @@ -265,11 +291,12 @@ class Word2Vec( } (syn0, syn1, lwc, wc) }, - combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) - (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) + combOp = (c1, c2) => (c1, c2) match { + case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) + blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) + (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) }) val wordMap = new Array[(String, Array[Double])](vocabSize) @@ -281,7 +308,8 @@ class Word2Vec( wordMap(i) = (word, vector) i += 1 } - val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum)) + val modelRDD = sc.parallelize(wordMap, modelPartitionNum) + .partitionBy(new HashPartitioner(modelPartitionNum)) new Word2VecModel(modelRDD) } } @@ -289,11 +317,9 @@ class Word2Vec( /** * Word2Vec model */ -class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { - - val model = _model +class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable { - private def distance(v1: Array[Double], v2: Array[Double]): Double = { + private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = { require(v1.length == v2.length, "Vectors should have the same length") val n = v1.length val norm1 = blas.dnrm2(n, v1, 1) @@ -307,11 +333,12 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab * @param word a word * @return vector representation of word */ - - def transform(word: String): Array[Double] = { + def transform(word: String): Vector = { val result = model.lookup(word) - if (result.isEmpty) Array[Double]() - else result(0) + if (result.isEmpty) { + throw new IllegalStateException(s"${word} not in vocabulary") + } + else Vectors.dense(result(0)) } /** @@ -319,8 +346,7 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab * @param dataset a an RDD of words * @return RDD of vector representation */ - - def transform(dataset: RDD[String]): RDD[Array[Double]] = { + def transform(dataset: RDD[String]): RDD[Vector] = { dataset.map(word => transform(word)) } @@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - if (vector.isEmpty) Array[(String, Double)]() - else findSynonyms(vector,num) + findSynonyms(vector,num) } /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word * @param num number of synonyms to find - * @return array of (word, similarity) + * @return array of (word, cosineSimilarity) */ - def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { + def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - val topK = model.map( - {case(w, vec) => (distance(vector, vec), w)}) + val topK = model.map { case(w, vec) => + (cosineSimilarity(vector.toArray, vec), w) } .sortByKey(ascending = false) .take(num + 1) - .map({case (dist, w) => (w, dist)}).drop(1) + .map(_.swap) + .tail topK } } -object Word2Vec extends Serializable with Logging { +object Word2Vec{ /** * Train Word2Vec model * @param input RDD of words - * @param size vectoer dimension + * @param size vector dimension * @param startingAlpha initial learning rate * @param window context words from [-window, window] * @param minCount minimum frequncy to consider a vocabulary word * @return Word2Vec model */ - def train( - input: RDD[String], + def train[S <: Iterable[String]]( + input: RDD[S], size: Int, startingAlpha: Double, window: Int, minCount: Int): Word2VecModel = { - new Word2Vec(size,startingAlpha, window, minCount).fit(input) + new Word2Vec(size,startingAlpha, window, minCount).fit[S](input) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 2a02ace83c380..54e56529c5a47 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -1,24 +1,24 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* Add a comment to this line -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.mllib.feature import org.scalatest.FunSuite + import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.LocalSparkContext From 720b5a3ea697a881fc7d7c286b65ef110421f89e Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 22:53:03 -0700 Subject: [PATCH 06/14] Add test for Word2Vec algorithm, minor fixes --- .../apache/spark/mllib/feature/Word2Vec.scala | 17 ++++++++------ .../spark/mllib/feature/Word2VecSuite.scala | 22 ++++++++++++++++++- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f4266c94f63e0..b55122d3c9f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -50,24 +50,27 @@ private case class VocabWord( * natural language processing and machine learning algorithms. * * We used skip-gram model in our implementation and hierarchical softmax - * method to train the model. + * method to train the model. The variable names in the implementation + * mathes the original C implementation. * * For original C implementation, see https://code.google.com/p/word2vec/ * For research papers, see * Efficient Estimation of Word Representations in Vector Space * and - * Distributed Representations of Words and Phrases and their Compositionality + * Distributed Representations of Words and Phrases and their Compositionality. * @param size vector dimension * @param startingAlpha initial learning rate * @param window context words from [-window, window] * @param minCount minimum frequncy to consider a vocabulary word + * @param parallelisum number of partitions to run Word2Vec */ @Experimental class Word2Vec( val size: Int, val startingAlpha: Double, val window: Int, - val minCount: Int) + val minCount: Int, + val parallelism:Int = 1) extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 @@ -237,7 +240,7 @@ class Word2Vec( } } - val newSentences = sentences.repartition(1).cache() + val newSentences = sentences.repartition(parallelism).cache() val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) val (aggSyn0, _, _, _) = // TODO: broadcast temp instead of serializing it directly @@ -248,7 +251,7 @@ class Word2Vec( var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1)) + alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } @@ -296,7 +299,7 @@ class Word2Vec( val n = syn0_1.length blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) - (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) }) val wordMap = new Array[(String, Array[Double])](vocabSize) @@ -309,7 +312,7 @@ class Word2Vec( i += 1 } val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)) + .partitionBy(new HashPartitioner(modelPartitionNum)).cache() new Word2VecModel(modelRDD) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 54e56529c5a47..e2b71c16f3308 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -23,7 +23,27 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.LocalSparkContext class Word2VecSuite extends FunSuite with LocalSparkContext { - test("word2vec") { + test("Word2Vec") { + val sentence = "a b " * 100 + "a c " * 10 + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + val size = 10 + val startingAlpha = 0.025 + val window = 2 + val minCount = 2 + val num = 2 + val word = "a" + + val model = Word2Vec.train(doc, size, startingAlpha, window, minCount) + val synons = model.findSynonyms("a", 2) + assert(synons.length == num) + assert(synons(0)._1 == "b") + assert(synons(1)._1 == "c") + } + + + test("Word2VecModel") { val num = 2 val localModel = Seq( ("china" , Array(0.50, 0.50, 0.50, 0.50)), From 6bcc8be34f6253bc7d4f9d4dcb478bf91f108c86 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 11:15:09 -0700 Subject: [PATCH 07/14] add multiple iteration support --- .../apache/spark/mllib/feature/Word2Vec.scala | 130 ++++++++++-------- 1 file changed, 70 insertions(+), 60 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b55122d3c9f1e..21c2395fb18ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -70,7 +70,8 @@ class Word2Vec( val startingAlpha: Double, val window: Int, val minCount: Int, - val parallelism:Int = 1) + val parallelism:Int = 1, + val numIterations:Int = 1) extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 @@ -241,73 +242,80 @@ class Word2Vec( } val newSentences = sentences.repartition(parallelism).cache() - val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) - val (aggSyn0, _, _, _) = - // TODO: broadcast temp instead of serializing it directly - // or initialize the model in each executor - newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( - seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => - var lwc = lastWordCount - var wc = wordCount - if (wordCount - lastWordCount > 10000) { - lwc = wordCount - alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) - if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 - logInfo("wordCount = " + wordCount + ", alpha = " + alpha) - } - wc += sentence.size - var pos = 0 - while (pos < sentence.size) { - val word = sentence(pos) - // TODO: fix random seed - val b = Random.nextInt(window) - // Train Skip-gram - var a = b - while (a < window * 2 + 1 - b) { - if (a != window) { - val c = pos - window + a - if (c >= 0 && c < sentence.size) { - val lastWord = sentence(c) - val l1 = lastWord * layer1Size - val neu1e = new Array[Double](layer1Size) - // Hierarchical softmax - var d = 0 - while (d < vocab(word).codeLen) { - val l2 = vocab(word).point(d) * layer1Size - // Propagate hidden -> output - var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) - if (f > -MAX_EXP && f < MAX_EXP) { - val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt - f = expTable.value(ind) - val g = (1 - vocab(word).code(d) - f) * alpha - blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + var syn0Global + = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) + var syn1Global = new Array[Double](vocabSize * layer1Size) + + for(iter <- 1 to numIterations) { + val (aggSyn0, aggSyn1, _, _) = + // TODO: broadcast temp instead of serializing it directly + // or initialize the model in each executor + newSentences.aggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( + seqOp = (c, v) => (c, v) match { + case ((syn0, syn1, lastWordCount, wordCount), sentence) => + var lwc = lastWordCount + var wc = wordCount + if (wordCount - lastWordCount > 10000) { + lwc = wordCount + alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + } + wc += sentence.size + var pos = 0 + while (pos < sentence.size) { + val word = sentence(pos) + // TODO: fix random seed + val b = Random.nextInt(window) + // Train Skip-gram + var a = b + while (a < window * 2 + 1 - b) { + if (a != window) { + val c = pos - window + a + if (c >= 0 && c < sentence.size) { + val lastWord = sentence(c) + val l1 = lastWord * layer1Size + val neu1e = new Array[Double](layer1Size) + // Hierarchical softmax + var d = 0 + while (d < vocab(word).codeLen) { + val l2 = vocab(word).point(d) * layer1Size + // Propagate hidden -> output + var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) + if (f > -MAX_EXP && f < MAX_EXP) { + val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt + f = expTable.value(ind) + val g = (1 - vocab(word).code(d) - f) * alpha + blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + } + d += 1 } - d += 1 + blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) } - blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) } + a += 1 } - a += 1 + pos += 1 } - pos += 1 - } - (syn0, syn1, lwc, wc) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - }) - + (syn0, syn1, lwc, wc) + }, + combOp = (c1, c2) => (c1, c2) match { + case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) + blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + }) + syn0Global = aggSyn0 + syn1Global = aggSyn1 + } val wordMap = new Array[(String, Array[Double])](vocabSize) var i = 0 while (i < vocabSize) { val word = vocab(i).word val vector = new Array[Double](layer1Size) - Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size) + Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) wordMap(i) = (word, vector) i += 1 } @@ -398,7 +406,9 @@ object Word2Vec{ size: Int, startingAlpha: Double, window: Int, - minCount: Int): Word2VecModel = { - new Word2Vec(size,startingAlpha, window, minCount).fit[S](input) + minCount: Int, + parallelism: Int = 1, + numIterations:Int = 1): Word2VecModel = { + new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input) } } From 7efbb6f91ca94f9243dbb7a16ea3fc9b6f548b99 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 12:16:19 -0700 Subject: [PATCH 08/14] use broadcast version of vocab in aggregate --- .../apache/spark/mllib/feature/Word2Vec.scala | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 21c2395fb18ae..3ace0800fb9f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.HashPartitioner - +import org.apache.spark.storage.StorageLevel /** * Entry in vocabulary */ @@ -215,10 +215,10 @@ class Word2Vec( val sc = dataset.context val expTable = sc.broadcast(createExpTable()) - val V = sc.broadcast(vocab) - val VHash = sc.broadcast(vocabHash) + val bcVocab = sc.broadcast(vocab) + val bcVocabHash = sc.broadcast(vocabHash) - val sentences = words.mapPartitions { + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => { new Iterator[Array[Int]] { def hasNext = iter.hasNext @@ -226,7 +226,7 @@ class Word2Vec( var sentence = new ArrayBuffer[Int] var sentenceLength = 0 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { - val word = VHash.value.get(iter.next) + val word = bcVocabHash.value.get(iter.next) word match { case Some(w) => { sentence += w @@ -278,14 +278,14 @@ class Word2Vec( val neu1e = new Array[Double](layer1Size) // Hierarchical softmax var d = 0 - while (d < vocab(word).codeLen) { - val l2 = vocab(word).point(d) * layer1Size + while (d < bcVocab.value(word).codeLen) { + val l2 = bcVocab.value(word).point(d) * layer1Size // Propagate hidden -> output var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) - val g = (1 - vocab(word).code(d) - f) * alpha + val g = (1 - bcVocab.value(word).code(d) - f) * alpha blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) } @@ -310,17 +310,21 @@ class Word2Vec( syn0Global = aggSyn0 syn1Global = aggSyn1 } + newSentences.unpersist() + val wordMap = new Array[(String, Array[Double])](vocabSize) var i = 0 while (i < vocabSize) { - val word = vocab(i).word + val word = bcVocab.value(i).word val vector = new Array[Double](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) wordMap(i) = (word, vector) i += 1 } val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)).cache() + .partitionBy(new HashPartitioner(modelPartitionNum)) + .persist(StorageLevel.MEMORY_AND_DISK) + new Word2VecModel(modelRDD) } } From 1a8fb4127b9433945e75beea16fc2d485a249219 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 16:24:35 -0700 Subject: [PATCH 09/14] use weighted sum in combOp --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3ace0800fb9f8..66429f5af1a46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -87,7 +87,7 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]) { + private def learnVocab(words:RDD[String]){ vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( @@ -110,6 +110,10 @@ class Word2Vec( logInfo("trainWordsCount = " + trainWordsCount) } + private def learnVocabPerPartition(words:RDD[String]) { + + } + private def createExpTable(): Array[Double] = { val expTable = new Array[Double](EXP_TABLE_SIZE) var i = 0 @@ -303,8 +307,12 @@ class Word2Vec( combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => val n = syn0_1.length - blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) + val weight1 = 1.0 * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0 * wc_2 / (wc_1 + wc_2) + blas.dscal(n, weight1, syn0_1, 1) + blas.dscal(n, weight1, syn1_1, 1) + blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1) (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) }) syn0Global = aggSyn0 From e93e7263d74879379257e6fff40d5efc8417f2ce Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 20:53:21 -0700 Subject: [PATCH 10/14] use treeAggregate instead of aggregate --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 66429f5af1a46..b966f775f7b01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -30,6 +30,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.HashPartitioner import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Entry in vocabulary */ @@ -111,9 +112,9 @@ class Word2Vec( } private def learnVocabPerPartition(words:RDD[String]) { - + } - + private def createExpTable(): Array[Double] = { val expTable = new Array[Double](EXP_TABLE_SIZE) var i = 0 @@ -254,7 +255,7 @@ class Word2Vec( val (aggSyn0, aggSyn1, _, _) = // TODO: broadcast temp instead of serializing it directly // or initialize the model in each executor - newSentences.aggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( + newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount From 384c77185544d6f80de96bd366e19760eacbd936 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 Aug 2014 21:33:05 -0700 Subject: [PATCH 11/14] remove minCount and window from constructor change model to use float instead of double --- .../apache/spark/mllib/feature/Word2Vec.scala | 130 +++++++++--------- 1 file changed, 63 insertions(+), 67 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b966f775f7b01..03cb0ff11027f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.HashPartitioner import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.rdd.RDDFunctions._ + /** * Entry in vocabulary */ @@ -61,18 +62,15 @@ private case class VocabWord( * Distributed Representations of Words and Phrases and their Compositionality. * @param size vector dimension * @param startingAlpha initial learning rate - * @param window context words from [-window, window] - * @param minCount minimum frequncy to consider a vocabulary word - * @param parallelisum number of partitions to run Word2Vec + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations to run, should be smaller than or equal to parallelism */ @Experimental class Word2Vec( val size: Int, val startingAlpha: Double, - val window: Int, - val minCount: Int, - val parallelism:Int = 1, - val numIterations:Int = 1) + val parallelism: Int = 1, + val numIterations: Int = 1) extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 @@ -81,7 +79,13 @@ class Word2Vec( private val MAX_SENTENCE_LENGTH = 1000 private val layer1Size = size private val modelPartitionNum = 100 - + + /** context words from [-window, window] */ + private val window = 5 + + /** minimum frequency to consider a vocabulary word */ + private val minCount = 5 + private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null @@ -99,7 +103,7 @@ class Word2Vec( 0)) .filter(_.cn >= minCount) .collect() - .sortWith((a, b)=> a.cn > b.cn) + .sortWith((a, b) => a.cn > b.cn) vocabSize = vocab.length var a = 0 @@ -111,16 +115,12 @@ class Word2Vec( logInfo("trainWordsCount = " + trainWordsCount) } - private def learnVocabPerPartition(words:RDD[String]) { - - } - - private def createExpTable(): Array[Double] = { - val expTable = new Array[Double](EXP_TABLE_SIZE) + private def createExpTable(): Array[Float] = { + val expTable = new Array[Float](EXP_TABLE_SIZE) var i = 0 while (i < EXP_TABLE_SIZE) { val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) - expTable(i) = tmp / (tmp + 1) + expTable(i) = (tmp / (tmp + 1.0)).toFloat i += 1 } expTable @@ -209,7 +209,7 @@ class Word2Vec( * @return a Word2VecModel */ - def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = { + def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) @@ -223,39 +223,37 @@ class Word2Vec( val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - val sentences: RDD[Array[Int]] = words.mapPartitions { - iter => { new Iterator[Array[Int]] { - def hasNext = iter.hasNext - - def next = { - var sentence = new ArrayBuffer[Int] - var sentenceLength = 0 - while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { - val word = bcVocabHash.value.get(iter.next) - word match { - case Some(w) => { - sentence += w - sentenceLength += 1 - } - case None => - } + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => + new Iterator[Array[Int]] { + def hasNext: Boolean = iter.hasNext + + def next(): Array[Int] = { + var sentence = new ArrayBuffer[Int] + var sentenceLength = 0 + while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { + val word = bcVocabHash.value.get(iter.next()) + word match { + case Some(w) => + sentence += w + sentenceLength += 1 + case None => } - sentence.toArray } + sentence.toArray } } } val newSentences = sentences.repartition(parallelism).cache() - var syn0Global - = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) - var syn1Global = new Array[Double](vocabSize * layer1Size) + var syn0Global = + Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + var syn1Global = new Array[Float](vocabSize * layer1Size) for(iter <- 1 to numIterations) { val (aggSyn0, aggSyn1, _, _) = - // TODO: broadcast temp instead of serializing it directly + // TODO: broadcast temp instead of serializing it directly // or initialize the model in each executor - newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( + newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -280,23 +278,23 @@ class Word2Vec( if (c >= 0 && c < sentence.size) { val lastWord = sentence(c) val l1 = lastWord * layer1Size - val neu1e = new Array[Double](layer1Size) + val neu1e = new Array[Float](layer1Size) // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { val l2 = bcVocab.value(word).point(d) * layer1Size // Propagate hidden -> output - var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) + var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) - val g = (1 - bcVocab.value(word).code(d) - f) * alpha - blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat + blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) } d += 1 } - blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) + blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) } } a += 1 @@ -308,12 +306,12 @@ class Word2Vec( combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => val n = syn0_1.length - val weight1 = 1.0 * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0 * wc_2 / (wc_1 + wc_2) - blas.dscal(n, weight1, syn0_1, 1) - blas.dscal(n, weight1, syn1_1, 1) - blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1) + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) }) syn0Global = aggSyn0 @@ -321,11 +319,11 @@ class Word2Vec( } newSentences.unpersist() - val wordMap = new Array[(String, Array[Double])](vocabSize) + val wordMap = new Array[(String, Array[Float])](vocabSize) var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word - val vector = new Array[Double](layer1Size) + val vector = new Array[Float](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) wordMap(i) = (word, vector) i += 1 @@ -341,15 +339,15 @@ class Word2Vec( /** * Word2Vec model */ -class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable { +class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Serializable { - private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = { + private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") val n = v1.length - val norm1 = blas.dnrm2(n, v1, 1) - val norm2 = blas.dnrm2(n, v2, 1) + val norm1 = blas.snrm2(n, v1, 1) + val norm2 = blas.snrm2(n, v2, 1) if (norm1 == 0 || norm2 == 0) return 0.0 - blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 + blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 } /** @@ -362,7 +360,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser if (result.isEmpty) { throw new IllegalStateException(s"${word} not in vocabulary") } - else Vectors.dense(result(0)) + else Vectors.dense(result(0).map(_.toDouble)) } /** @@ -394,7 +392,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") val topK = model.map { case(w, vec) => - (cosineSimilarity(vector.toArray, vec), w) } + (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } .sortByKey(ascending = false) .take(num + 1) .map(_.swap) @@ -410,18 +408,16 @@ object Word2Vec{ * @param input RDD of words * @param size vector dimension * @param startingAlpha initial learning rate - * @param window context words from [-window, window] - * @param minCount minimum frequncy to consider a vocabulary word - * @return Word2Vec model - */ + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations, should be smaller than or equal to parallelism + * @return Word2Vec model + */ def train[S <: Iterable[String]]( input: RDD[S], size: Int, startingAlpha: Double, - window: Int, - minCount: Int, parallelism: Int = 1, numIterations:Int = 1): Word2VecModel = { - new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input) + new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) } } From c14da411d4da1b6553759afff7952ac746c9fa15 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 Aug 2014 22:09:58 -0700 Subject: [PATCH 12/14] fix styles --- .../apache/spark/mllib/feature/Word2Vec.scala | 37 ++++++++++--------- .../spark/mllib/feature/Word2VecSuite.scala | 29 ++++++++------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 03cb0ff11027f..87c81e7b0bd2f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -17,20 +17,18 @@ package org.apache.spark.mllib.feature -import scala.util.Random -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} - -import org.apache.spark.annotation.Experimental -import org.apache.spark.Logging -import org.apache.spark.rdd._ +import org.apache.spark.{HashPartitioner, Logging} import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.HashPartitioner -import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel /** * Entry in vocabulary @@ -53,7 +51,7 @@ private case class VocabWord( * * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation - * mathes the original C implementation. + * matches the original C implementation. * * For original C implementation, see https://code.google.com/p/word2vec/ * For research papers, see @@ -69,10 +67,14 @@ private case class VocabWord( class Word2Vec( val size: Int, val startingAlpha: Double, - val parallelism: Int = 1, - val numIterations: Int = 1) - extends Serializable with Logging { - + val parallelism: Int, + val numIterations: Int) extends Serializable with Logging { + + /** + * Word2Vec with a single thread. + */ + def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -92,7 +94,7 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]){ + private def learnVocab(words:RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( @@ -126,7 +128,7 @@ class Word2Vec( expTable } - private def createBinaryTree() { + private def createBinaryTree(): Unit = { val count = new Array[Long](vocabSize * 2 + 1) val binary = new Array[Int](vocabSize * 2 + 1) val parentNode = new Array[Int](vocabSize * 2 + 1) @@ -208,7 +210,6 @@ class Word2Vec( * @param dataset an RDD of words * @return a Word2VecModel */ - def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) @@ -339,7 +340,7 @@ class Word2Vec( /** * Word2Vec model */ -class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Serializable { +class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -358,7 +359,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Seri def transform(word: String): Vector = { val result = model.lookup(word) if (result.isEmpty) { - throw new IllegalStateException(s"${word} not in vocabulary") + throw new IllegalStateException(s"$word not in vocabulary") } else Vectors.dense(result(0).map(_.toDouble)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index e2b71c16f3308..3ec3208f5fa34 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.LocalSparkContext class Word2VecSuite extends FunSuite with LocalSparkContext { + + // TODO: add more tests + test("Word2Vec") { val sentence = "a b " * 100 + "a c " * 10 val localDoc = Seq(sentence, sentence) @@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { val window = 2 val minCount = 2 val num = 2 - val word = "a" val model = Word2Vec.train(doc, size, startingAlpha, window, minCount) - val synons = model.findSynonyms("a", 2) - assert(synons.length == num) - assert(synons(0)._1 == "b") - assert(synons(1)._1 == "c") + val syms = model.findSynonyms("a", 2) + assert(syms.length == num) + assert(syms(0)._1 == "b") + assert(syms(1)._1 == "c") } test("Word2VecModel") { val num = 2 val localModel = Seq( - ("china" , Array(0.50, 0.50, 0.50, 0.50)), - ("japan" , Array(0.40, 0.50, 0.50, 0.50)), - ("taiwan", Array(0.60, 0.50, 0.50, 0.50)), - ("korea" , Array(0.45, 0.60, 0.60, 0.60)) + ("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f)) ) val model = new Word2VecModel(sc.parallelize(localModel, 2)) - val synons = model.findSynonyms("china", num) - assert(synons.length == num) - assert(synons(0)._1 == "taiwan") - assert(synons(1)._1 == "japan") + val syms = model.findSynonyms("china", num) + assert(syms.length == num) + assert(syms(0)._1 == "taiwan") + assert(syms(1)._1 == "japan") } } From e2484414d65c3b8aebffa79c3cac34452cf53d38 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 22:47:53 -0700 Subject: [PATCH 13/14] minor style change --- .../org/apache/spark/mllib/feature/Word2VecSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 3ec3208f5fa34..359497cbfd037 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -47,10 +47,10 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { test("Word2VecModel") { val num = 2 val localModel = Seq( - ("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)), - ("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)), - ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), - ("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f)) + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) ) val model = new Word2VecModel(sc.parallelize(localModel, 2)) val syms = model.findSynonyms("china", num) From 2ba948384e96e79e95a529f032d4768f24236547 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 22:59:40 -0700 Subject: [PATCH 14/14] minor fix for Word2Vec test --- .../scala/org/apache/spark/mllib/feature/Word2VecSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 359497cbfd037..b5db39b68a223 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -36,7 +36,7 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { val minCount = 2 val num = 2 - val model = Word2Vec.train(doc, size, startingAlpha, window, minCount) + val model = Word2Vec.train(doc, size, startingAlpha) val syms = model.findSynonyms("a", 2) assert(syms.length == num) assert(syms(0)._1 == "b")