diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1cc432c6a6214..f4266c94f63e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -1,33 +1,33 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* Add a comment to this line -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.mllib.feature -import scala.util.{Random => Random} +import scala.util.Random import scala.collection.mutable.ArrayBuffer import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.HashPartitioner /** @@ -42,8 +42,27 @@ private case class VocabWord( ) /** - * Vector representation of word + * :: Experimental :: + * Word2Vec creates vector representation of words in a text corpus. + * The algorithm first constructs a vocabulary from the corpus + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in + * natural language processing and machine learning algorithms. + * + * We used skip-gram model in our implementation and hierarchical softmax + * method to train the model. + * + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see + * Efficient Estimation of Word Representations in Vector Space + * and + * Distributed Representations of Words and Phrases and their Compositionality + * @param size vector dimension + * @param startingAlpha initial learning rate + * @param window context words from [-window, window] + * @param minCount minimum frequncy to consider a vocabulary word */ +@Experimental class Word2Vec( val size: Int, val startingAlpha: Double, @@ -64,11 +83,15 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(dataset: RDD[String]) { - vocab = dataset.flatMap(line => line.split(" ")) - .map(w => (w, 1)) + private def learnVocab(words:RDD[String]) { + vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) - .map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) + .map(x => VocabWord( + x._1, + x._2, + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b)=> a.cn > b.cn) @@ -172,15 +195,16 @@ class Word2Vec( } /** - * Computes the vector representation of each word in - * vocabulary - * @param dataset an RDD of strings + * Computes the vector representation of each word in vocabulary. + * @param dataset an RDD of words * @return a Word2VecModel */ - def fit(dataset:RDD[String]): Word2VecModel = { + def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = { - learnVocab(dataset) + val words = dataset.flatMap(x => x) + + learnVocab(words) createBinaryTree() @@ -190,9 +214,10 @@ class Word2Vec( val V = sc.broadcast(vocab) val VHash = sc.broadcast(vocabHash) - val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions { + val sentences = words.mapPartitions { iter => { new Iterator[Array[Int]] { def hasNext = iter.hasNext + def next = { var sentence = new ArrayBuffer[Int] var sentenceLength = 0 @@ -215,7 +240,8 @@ class Word2Vec( val newSentences = sentences.repartition(1).cache() val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) val (aggSyn0, _, _, _) = - // TODO: broadcast temp instead of serializing it directly or initialize the model in each executor + // TODO: broadcast temp instead of serializing it directly + // or initialize the model in each executor newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -241,7 +267,7 @@ class Word2Vec( val lastWord = sentence(c) val l1 = lastWord * layer1Size val neu1e = new Array[Double](layer1Size) - //HS + // Hierarchical softmax var d = 0 while (d < vocab(word).codeLen) { val l2 = vocab(word).point(d) * layer1Size @@ -265,11 +291,12 @@ class Word2Vec( } (syn0, syn1, lwc, wc) }, - combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) - (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) + combOp = (c1, c2) => (c1, c2) match { + case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) + blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) + (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) }) val wordMap = new Array[(String, Array[Double])](vocabSize) @@ -281,7 +308,8 @@ class Word2Vec( wordMap(i) = (word, vector) i += 1 } - val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum)) + val modelRDD = sc.parallelize(wordMap, modelPartitionNum) + .partitionBy(new HashPartitioner(modelPartitionNum)) new Word2VecModel(modelRDD) } } @@ -289,11 +317,9 @@ class Word2Vec( /** * Word2Vec model */ -class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { - - val model = _model +class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable { - private def distance(v1: Array[Double], v2: Array[Double]): Double = { + private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = { require(v1.length == v2.length, "Vectors should have the same length") val n = v1.length val norm1 = blas.dnrm2(n, v1, 1) @@ -307,11 +333,12 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab * @param word a word * @return vector representation of word */ - - def transform(word: String): Array[Double] = { + def transform(word: String): Vector = { val result = model.lookup(word) - if (result.isEmpty) Array[Double]() - else result(0) + if (result.isEmpty) { + throw new IllegalStateException(s"${word} not in vocabulary") + } + else Vectors.dense(result(0)) } /** @@ -319,8 +346,7 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab * @param dataset a an RDD of words * @return RDD of vector representation */ - - def transform(dataset: RDD[String]): RDD[Array[Double]] = { + def transform(dataset: RDD[String]): RDD[Vector] = { dataset.map(word => transform(word)) } @@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - if (vector.isEmpty) Array[(String, Double)]() - else findSynonyms(vector,num) + findSynonyms(vector,num) } /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word * @param num number of synonyms to find - * @return array of (word, similarity) + * @return array of (word, cosineSimilarity) */ - def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { + def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - val topK = model.map( - {case(w, vec) => (distance(vector, vec), w)}) + val topK = model.map { case(w, vec) => + (cosineSimilarity(vector.toArray, vec), w) } .sortByKey(ascending = false) .take(num + 1) - .map({case (dist, w) => (w, dist)}).drop(1) + .map(_.swap) + .tail topK } } -object Word2Vec extends Serializable with Logging { +object Word2Vec{ /** * Train Word2Vec model * @param input RDD of words - * @param size vectoer dimension + * @param size vector dimension * @param startingAlpha initial learning rate * @param window context words from [-window, window] * @param minCount minimum frequncy to consider a vocabulary word * @return Word2Vec model */ - def train( - input: RDD[String], + def train[S <: Iterable[String]]( + input: RDD[S], size: Int, startingAlpha: Double, window: Int, minCount: Int): Word2VecModel = { - new Word2Vec(size,startingAlpha, window, minCount).fit(input) + new Word2Vec(size,startingAlpha, window, minCount).fit[S](input) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 2a02ace83c380..54e56529c5a47 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -1,24 +1,24 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* Add a comment to this line -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.mllib.feature import org.scalatest.FunSuite + import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.LocalSparkContext