Skip to content

Commit

Permalink
modify according to feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Liquan Pei committed Aug 2, 2014
1 parent 57dc50d commit 2e92b59
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 76 deletions.
146 changes: 86 additions & 60 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* Add a comment to this line
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import scala.util.{Random => Random}
import scala.util.Random
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable

import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark._
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.HashPartitioner

/**
Expand All @@ -42,8 +42,27 @@ private case class VocabWord(
)

/**
* Vector representation of word
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
* and then learns vector representation of words in the vocabulary.
* The vector representation can be used as features in
* natural language processing and machine learning algorithms.
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
*/
@Experimental
class Word2Vec(
val size: Int,
val startingAlpha: Double,
Expand All @@ -64,11 +83,15 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(dataset: RDD[String]) {
vocab = dataset.flatMap(line => line.split(" "))
.map(w => (w, 1))
private def learnVocab(words:RDD[String]) {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0))
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b)=> a.cn > b.cn)
Expand Down Expand Up @@ -172,15 +195,16 @@ class Word2Vec(
}

/**
* Computes the vector representation of each word in
* vocabulary
* @param dataset an RDD of strings
* Computes the vector representation of each word in vocabulary.
* @param dataset an RDD of words
* @return a Word2VecModel
*/

def fit(dataset:RDD[String]): Word2VecModel = {
def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {

learnVocab(dataset)
val words = dataset.flatMap(x => x)

learnVocab(words)

createBinaryTree()

Expand All @@ -190,9 +214,10 @@ class Word2Vec(
val V = sc.broadcast(vocab)
val VHash = sc.broadcast(vocabHash)

val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions {
val sentences = words.mapPartitions {
iter => { new Iterator[Array[Int]] {
def hasNext = iter.hasNext

def next = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
Expand All @@ -215,7 +240,8 @@ class Word2Vec(
val newSentences = sentences.repartition(1).cache()
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
val (aggSyn0, _, _, _) =
// TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))(
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
Expand All @@ -241,7 +267,7 @@ class Word2Vec(
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Double](layer1Size)
//HS
// Hierarchical softmax
var d = 0
while (d < vocab(word).codeLen) {
val l2 = vocab(word).point(d) * layer1Size
Expand All @@ -265,11 +291,12 @@ class Word2Vec(
}
(syn0, syn1, lwc, wc)
},
combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
})

val wordMap = new Array[(String, Array[Double])](vocabSize)
Expand All @@ -281,19 +308,18 @@ class Word2Vec(
wordMap(i) = (word, vector)
i += 1
}
val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum))
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
.partitionBy(new HashPartitioner(modelPartitionNum))
new Word2VecModel(modelRDD)
}
}

/**
* Word2Vec model
*/
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {

val model = _model
class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable {

private def distance(v1: Array[Double], v2: Array[Double]): Double = {
private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.dnrm2(n, v1, 1)
Expand All @@ -307,20 +333,20 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
* @param word a word
* @return vector representation of word
*/

def transform(word: String): Array[Double] = {
def transform(word: String): Vector = {
val result = model.lookup(word)
if (result.isEmpty) Array[Double]()
else result(0)
if (result.isEmpty) {
throw new IllegalStateException(s"${word} not in vocabulary")
}
else Vectors.dense(result(0))
}

/**
* Transforms an RDD to its vector representation
* @param dataset a an RDD of words
* @return RDD of vector representation
*/

def transform(dataset: RDD[String]): RDD[Array[Double]] = {
def transform(dataset: RDD[String]): RDD[Vector] = {
dataset.map(word => transform(word))
}

Expand All @@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
if (vector.isEmpty) Array[(String, Double)]()
else findSynonyms(vector,num)
findSynonyms(vector,num)
}

/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, similarity)
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map(
{case(w, vec) => (distance(vector, vec), w)})
val topK = model.map { case(w, vec) =>
(cosineSimilarity(vector.toArray, vec), w) }
.sortByKey(ascending = false)
.take(num + 1)
.map({case (dist, w) => (w, dist)}).drop(1)
.map(_.swap)
.tail

topK
}
}

object Word2Vec extends Serializable with Logging {
object Word2Vec{
/**
* Train Word2Vec model
* @param input RDD of words
* @param size vectoer dimension
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @return Word2Vec model
*/
def train(
input: RDD[String],
def train[S <: Iterable[String]](
input: RDD[S],
size: Int,
startingAlpha: Double,
window: Int,
minCount: Int): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
new Word2Vec(size,startingAlpha, window, minCount).fit[S](input)
}
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* Add a comment to this line
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.LocalSparkContext

Expand Down

0 comments on commit 2e92b59

Please sign in to comment.