Skip to content

Commit

Permalink
[WIP] Add "predict_output_word" to word2vec negative sampling. Fix #863
Browse files Browse the repository at this point in the history
…. (#1209)

* added function to predict output word in CBOW from context words

* handling negative_sampling case

* added warnings for out-of-vocabulary and not negative sampling cases

* added unit tests for predict_output_word

* updated CHANGELOG
  • Loading branch information
chinmayapancholi13 authored and tmylk committed Mar 20, 2017
1 parent 854fad6 commit cc86005
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 23 deletions.
42 changes: 22 additions & 20 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Changes

Unreleased:

New features:
* Add output word prediction for negative sampling scheme. (@chinmayapancholi13,[#1209](https://github.com/RaRe-Technologies/gensim/pull/1209))

========
1.0.1, 2017-03-03
Expand Down Expand Up @@ -35,35 +37,35 @@ Improvements:
* Phrases and Phraser allow a generator corpus (ELind77 [#1099](https://github.com/RaRe-Technologies/gensim/pull/1099))
* Ignore DocvecsArray.doctag_syn0norm in save. Fix #789 (@accraze,[#1053](https://github.com/RaRe-Technologies/gensim/pull/1053))
* Fix bug in LsiModel that occurs when id2word is a Python 3 dictionary. (@cvangysel,[#1103](https://github.com/RaRe-Technologies/gensim/pull/1103)
* Fix broken link to paper in readme (@bhargavvader,[#1101](https://github.com/RaRe-Technologies/gensim/pull/1101))
* Lazy formatting in evaluate_word_pairs (@akutuzov,[#1084](https://github.com/RaRe-Technologies/gensim/pull/1084))
* Fix broken link to paper in readme (@bhargavvader,[#1101](https://github.com/RaRe-Technologies/gensim/pull/1101))
* Lazy formatting in evaluate_word_pairs (@akutuzov,[#1084](https://github.com/RaRe-Technologies/gensim/pull/1084))
* Deacc option to keywords pre-processing (@bhargavvader,[#1076](https://github.com/RaRe-Technologies/gensim/pull/1076))
* Generate Deprecated exception when using Word2Vec.load_word2vec_format (@tmylk, [#1165](https://github.com/RaRe-Technologies/gensim/pull/1165))
* Fix hdpmodel constructor docstring for print_topics (#1152) (@toliwa, [#1152](https://github.com/RaRe-Technologies/gensim/pull/1152))
* Default to per_word_topics=False in LDA get_item for performance (@menshikh-iv, [#1154](https://github.com/RaRe-Technologies/gensim/pull/1154))
* Generate Deprecated exception when using Word2Vec.load_word2vec_format (@tmylk, [#1165](https://github.com/RaRe-Technologies/gensim/pull/1165))
* Fix hdpmodel constructor docstring for print_topics (#1152) (@toliwa, [#1152](https://github.com/RaRe-Technologies/gensim/pull/1152))
* Default to per_word_topics=False in LDA get_item for performance (@menshikh-iv, [#1154](https://github.com/RaRe-Technologies/gensim/pull/1154))
* Fix bound computation in Author Topic models. (@olavurmortensen, [#1156](https://github.com/RaRe-Technologies/gensim/pull/1156))
* Write UTF-8 byte strings in tensorboard conversion (@tmylk,[#1144](https://github.com/RaRe-Technologies/gensim/pull/1144))
* Make top_topics and sparse2full compatible with numpy 1.12 strictly int idexing (@tmylk,[#1146](https://github.com/RaRe-Technologies/gensim/pull/1146))

Tutorial and doc improvements:

* Clarifying comment in is_corpus func in utils.py (@greninja,[#1109](https://github.com/RaRe-Technologies/gensim/pull/1109))
* Clarifying comment in is_corpus func in utils.py (@greninja,[#1109](https://github.com/RaRe-Technologies/gensim/pull/1109))
* Tutorial Topics_and_Transformations fix markdown and add references (@lgmoneda,[#1120](https://github.com/RaRe-Technologies/gensim/pull/1120))
* Fix doc2vec-lee.ipynb results to match previous behavior (@bahbbc,[#1119](https://github.com/RaRe-Technologies/gensim/pull/1119))
* Fix doc2vec-lee.ipynb results to match previous behavior (@bahbbc,[#1119](https://github.com/RaRe-Technologies/gensim/pull/1119))
* Remove Pattern lib dependency in News Classification tutorial (@luizcavalcanti,[#1118](https://github.com/RaRe-Technologies/gensim/pull/1118))
* Corpora_and_Vector_Spaces tutorial text clarification (@lgmoneda,[#1116](https://github.com/RaRe-Technologies/gensim/pull/1116))
* Update Transformation and Topics link from quick start notebook (@mariana393,[#1115](https://github.com/RaRe-Technologies/gensim/pull/1115))
* Quick Start Text clarification and typo correction (@luizcavalcanti,[#1114](https://github.com/RaRe-Technologies/gensim/pull/1114))
* Fix typos in Author-topic tutorial (@Fil,[#1102](https://github.com/RaRe-Technologies/gensim/pull/1102))
* Address benchmark inconsistencies in Annoy tutorial (@droudy,[#1113](https://github.com/RaRe-Technologies/gensim/pull/1113))
* Add note about Annoy speed depending on numpy BLAS setup in annoytutorial.ipynb (@greninja,[#1137](https://github.com/RaRe-Technologies/gensim/pull/1137))
* Fix dependencies description on doc2vec-IMDB notebook (@luizcavalcanti, [#1132](https://github.com/RaRe-Technologies/gensim/pull/1132))
* Add documentation for WikiCorpus metadata. (@kirit93, [#1163](https://github.com/RaRe-Technologies/gensim/pull/1163))
* Add note about Annoy speed depending on numpy BLAS setup in annoytutorial.ipynb (@greninja,[#1137](https://github.com/RaRe-Technologies/gensim/pull/1137))
* Fix dependencies description on doc2vec-IMDB notebook (@luizcavalcanti, [#1132](https://github.com/RaRe-Technologies/gensim/pull/1132))
* Add documentation for WikiCorpus metadata. (@kirit93, [#1163](https://github.com/RaRe-Technologies/gensim/pull/1163))



1.0.0RC2, 2017-02-16

* Add note about Annoy speed depending on numpy BLAS setup in annoytutorial.ipynb (@greninja,[#1137](https://github.com/RaRe-Technologies/gensim/pull/1137))
* Add note about Annoy speed depending on numpy BLAS setup in annoytutorial.ipynb (@greninja,[#1137](https://github.com/RaRe-Technologies/gensim/pull/1137))
* Remove direct access to properties moved to KeyedVectors (@tmylk,[#1147](https://github.com/RaRe-Technologies/gensim/pull/1147))
* Remove support for Python 2.6, 3.3 and 3.4 (@tmylk,[#1145](https://github.com/RaRe-Technologies/gensim/pull/1145))
* Write UTF-8 byte strings in tensorboard conversion (@tmylk,[#1144](https://github.com/RaRe-Technologies/gensim/pull/1144))
Expand All @@ -83,15 +85,15 @@ Improvements:
* Ignore DocvecsArray.doctag_syn0norm in save. Fix #789 (@accraze,[#1053](https://github.com/RaRe-Technologies/gensim/pull/1053))
* Move load and save word2vec_format out of word2vec class to KeyedVectors (@tmylk,[#1107](https://github.com/RaRe-Technologies/gensim/pull/1107))
* Fix bug in LsiModel that occurs when id2word is a Python 3 dictionary. (@cvangysel,[#1103](https://github.com/RaRe-Technologies/gensim/pull/1103)
* Fix broken link to paper in readme (@bhargavvader,[#1101](https://github.com/RaRe-Technologies/gensim/pull/1101))
* Lazy formatting in evaluate_word_pairs (@akutuzov,[#1084](https://github.com/RaRe-Technologies/gensim/pull/1084))
* Fix broken link to paper in readme (@bhargavvader,[#1101](https://github.com/RaRe-Technologies/gensim/pull/1101))
* Lazy formatting in evaluate_word_pairs (@akutuzov,[#1084](https://github.com/RaRe-Technologies/gensim/pull/1084))
* Deacc option to keywords pre-processing (@bhargavvader,[#1076](https://github.com/RaRe-Technologies/gensim/pull/1076))

Tutorial and doc improvements:

* Clarifying comment in is_corpus func in utils.py (@greninja,[#1109](https://github.com/RaRe-Technologies/gensim/pull/1109))
* Clarifying comment in is_corpus func in utils.py (@greninja,[#1109](https://github.com/RaRe-Technologies/gensim/pull/1109))
* Tutorial Topics_and_Transformations fix markdown and add references (@lgmoneda,[#1120](https://github.com/RaRe-Technologies/gensim/pull/1120))
* Fix doc2vec-lee.ipynb results to match previous behavior (@bahbbc,[#1119](https://github.com/RaRe-Technologies/gensim/pull/1119))
* Fix doc2vec-lee.ipynb results to match previous behavior (@bahbbc,[#1119](https://github.com/RaRe-Technologies/gensim/pull/1119))
* Remove Pattern lib dependency in News Classification tutorial (@luizcavalcanti,[#1118](https://github.com/RaRe-Technologies/gensim/pull/1118))
* Corpora_and_Vector_Spaces tutorial text clarification (@lgmoneda,[#1116](https://github.com/RaRe-Technologies/gensim/pull/1116))
* Update Transformation and Topics link from quick start notebook (@mariana393,[#1115](https://github.com/RaRe-Technologies/gensim/pull/1115))
Expand All @@ -103,9 +105,9 @@ Tutorial and doc improvements:
0.13.4.1, 2017-01-04

* Disable direct access warnings on save and load of Word2vec/Doc2vec (@tmylk, [#1072](https://github.com/RaRe-Technologies/gensim/pull/1072))
* Making Default hs error explicit (@accraze, [#1054](https://github.com/RaRe-Technologies/gensim/pull/1054))
* Making Default hs error explicit (@accraze, [#1054](https://github.com/RaRe-Technologies/gensim/pull/1054))
* Removed unnecessary numpy imports (@bhargavvader, [#1065](https://github.com/RaRe-Technologies/gensim/pull/1065))
* Utils and Matutils changes (@bhargavvader, [#1062](https://github.com/RaRe-Technologies/gensim/pull/1062))
* Utils and Matutils changes (@bhargavvader, [#1062](https://github.com/RaRe-Technologies/gensim/pull/1062))
* Tests for the evaluate_word_pairs function (@akutuzov, [#1061](https://github.com/RaRe-Technologies/gensim/pull/1061))

0.13.4, 2016-12-22
Expand All @@ -127,8 +129,8 @@ Tutorial and doc improvements:
* Remove warning on gensim import "pattern not installed". Fix #1009 (@shashankg7, [#1018](https://github.com/RaRe-Technologies/gensim/pull/1018))
* Add delete_temporary_training_data() function to word2vec and doc2vec models. (@deepmipt-VladZhukov, [#987](https://github.com/RaRe-Technologies/gensim/pull/987))
* Documentation improvements (@IrinaGoloshchapova, [#1010](https://github.com/RaRe-Technologies/gensim/pull/1010), [#1011](https://github.com/RaRe-Technologies/gensim/pull/1011))
* LDA tutorial by Olavur, tips and tricks (@olavurmortensen, [#779](https://github.com/RaRe-Technologies/gensim/pull/779))
* Add double quote in commmand line to run on Windows (@akarazeev, [#1005](https://github.com/RaRe-Technologies/gensim/pull/1005))
* LDA tutorial by Olavur, tips and tricks (@olavurmortensen, [#779](https://github.com/RaRe-Technologies/gensim/pull/779))
* Add double quote in commmand line to run on Windows (@akarazeev, [#1005](https://github.com/RaRe-Technologies/gensim/pull/1005))
* Fix directory names in notebooks to be OS-independent (@mamamot, [#1004](https://github.com/RaRe-Technologies/gensim/pull/1004))
* Respect clip_start, clip_end in most_similar. Fix #601. (@parulsethi, [#994](https://github.com/RaRe-Technologies/gensim/pull/994))
* Replace Python sigmoid function with scipy in word2vec & doc2vec (@markroxor, [#989](https://github.com/RaRe-Technologies/gensim/pull/989))
Expand Down
26 changes: 26 additions & 0 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,32 @@ def similarity(self, w1, w2):
def n_similarity(self, ws1, ws2):
return self.wv.n_similarity(ws1, ws2)

def predict_output_word(self, context_words_list, topn=10):
"""Report the probability distribution of the center word given the context words as input to the trained model."""
if not self.negative:
raise RuntimeError("We have currently only implemented predict_output_word "
"for the negative sampling scheme, so you need to have "
"run word2vec with negative > 0 for this to work.")

if not hasattr(self.wv, 'syn0') or not hasattr(self, 'syn1neg'):
raise RuntimeError("Parameters required for predicting the output words not found.")

word_vocabs = [self.wv.vocab[w] for w in context_words_list if w in self.wv.vocab]
if not word_vocabs:
warnings.warn("All the input context words are out-of-vocabulary for the current model.")
return None

word2_indices = [word.index for word in word_vocabs]

l1 = np_sum(self.wv.syn0[word2_indices], axis=0)
if word2_indices and self.cbow_mean:
l1 /= len(word2_indices)

prob_values = exp(dot(l1, self.syn1neg.T)) # propagate hidden -> output and take softmax to get probabilities
prob_values /= sum(prob_values)
top_indices = matutils.argsort(prob_values, topn=topn, reverse=True)
return [(self.wv.index2word[index1], prob_values[index1]) for index1 in top_indices] #returning the most probable output words with their probabilities

def init_sims(self, replace=False):
"""
init_sims() resides in KeyedVectors because it deals with syn0 mainly, but because syn1 is not an attribute
Expand Down
29 changes: 26 additions & 3 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def testLoadPreKeyedVectorModel(self):
model_file_suffix = '_py2'
else:
model_file_suffix = '_py3'

# Model stored in one file
model_file = 'word2vec_pre_kv%s' % model_file_suffix
model = word2vec.Word2Vec.load(datapath(model_file))
Expand Down Expand Up @@ -620,6 +620,29 @@ def testNormalizeAfterTrainingData(self):
norm_only_model.delete_temporary_training_data(replace_word_vectors_with_normalized=True)
self.assertFalse(np.allclose(model['human'], norm_only_model['human']))

def testPredictOutputWord(self):
'''Test word2vec predict_output_word method handling for negative sampling scheme'''
#under normal circumstances
model_with_neg = word2vec.Word2Vec(sentences, min_count=1)
predictions_with_neg = model_with_neg.predict_output_word(['system', 'human'], topn=5)
self.assertTrue(len(predictions_with_neg)==5)

#out-of-vobaculary scenario
predictions_out_of_vocab = model_with_neg.predict_output_word(['some', 'random', 'words'], topn=5)
self.assertEqual(predictions_out_of_vocab, None)

#when required model parameters have been deleted
model_with_neg.init_sims()
model_with_neg.wv.save_word2vec_format(testfile(), binary=True)
kv_model_with_neg = keyedvectors.KeyedVectors.load_word2vec_format(testfile(), binary=True)
binary_model_with_neg = word2vec.Word2Vec()
binary_model_with_neg.wv = kv_model_with_neg
self.assertRaises(RuntimeError, binary_model_with_neg.predict_output_word, ['system', 'human'])

#negative sampling scheme not used
model_without_neg = word2vec.Word2Vec(sentences, min_count=1, negative=0)
self.assertRaises(RuntimeError, model_without_neg.predict_output_word, ['system', 'human'])

@log_capture()
def testBuildVocabWarning(self, l):
"""Test if warning is raised on non-ideal input to a word2vec model"""
Expand All @@ -644,14 +667,14 @@ def testTrainWarning(self, l):
model.alpha += 0.05
warning = "Effective 'alpha' higher than previous training cycles"
self.assertTrue(warning in str(l))

def test_sentences_should_not_be_a_generator(self):
"""
Is sentences a generator object?
"""
gen = (s for s in sentences)
self.assertRaises(TypeError, word2vec.Word2Vec, (gen,))

def testLoadOnClassError(self):
"""Test if exception is raised when loading word2vec model on instance"""
self.assertRaises(AttributeError, load_on_instance)
Expand Down

0 comments on commit cc86005

Please sign in to comment.