Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize word2vec.predict_output_word for speed #3153

Merged
merged 9 commits into from
Jul 19, 2021
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ jobs:
#
- name: Update sbt
run: |
echo "deb https://dl.bintray.com/sbt/debian /" | sudo tee -a /etc/apt/sources.list.d/sbt.list
echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list
echo "deb https://repo.scala-sbt.org/scalasbt/debian /" | sudo tee /etc/apt/sources.list.d/sbt_old.list
curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo apt-key add
sudo apt-get update -y
sudo apt-get install -y sbt
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Changes
* [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness)
* [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128): Materialize and copy the corpus passed to SoftCosineSimilarity, by [@Witiko](https://github.com/Witiko)
* [#3131](https://github.com/RaRe-Technologies/gensim/pull/3131): Added import to Nmf docs, and to models/__init__.py, by [@properGrammar](https://github.com/properGrammar)
* [#3153](https://github.com/RaRe-Technologies/gensim/pull/3153): Vectorize word2vec.predict_output_word for speed, by [@M-Demay](https://github.com/M-Demay)
* [#3157](https://github.com/RaRe-Technologies/gensim/pull/3157): New KeyedVectors.vectors_for_all method for vectorizing all words in a dictionary, by [@Witiko](https://github.com/Witiko)
* [#3163](https://github.com/RaRe-Technologies/gensim/pull/3163): Optimize word mover distance (WMD) computation, by [@flowlight0](https://github.com/flowlight0)
* [#2965](https://github.com/RaRe-Technologies/gensim/pull/2965): Remove strip_punctuation2 alias of strip_punctuation, by [@sciatro](https://github.com/sciatro)
Expand Down
9 changes: 5 additions & 4 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,8 +1806,9 @@ def predict_output_word(self, context_words_list, topn=10):

Parameters
----------
context_words_list : list of str
List of context words.
context_words_list : list of (str and/or int)
List of context words, which may be words themselves (str)
or their index in `self.wv.vectors` (int).
topn : int, optional
Return `topn` words and their probabilities.

Expand All @@ -1825,8 +1826,8 @@ def predict_output_word(self, context_words_list, topn=10):

if not hasattr(self.wv, 'vectors') or not hasattr(self, 'syn1neg'):
raise RuntimeError("Parameters required for predicting the output words not found.")

word2_indices = [self.wv.get_index(w) for w in context_words_list if w in self.wv]

if not word2_indices:
logger.warning("All the input context words are out-of-vocabulary for the current model.")
return None
Expand All @@ -1837,7 +1838,7 @@ def predict_output_word(self, context_words_list, topn=10):

# propagate hidden -> output and take softmax to get probabilities
prob_values = np.exp(np.dot(l1, self.syn1neg.T))
prob_values /= sum(prob_values)
prob_values /= np.sum(prob_values)
top_indices = matutils.argsort(prob_values, topn=topn, reverse=True)
# returning the most probable output words with their probabilities
return [(self.wv.index_to_key[index1], prob_values[index1]) for index1 in top_indices]
Expand Down
10 changes: 10 additions & 0 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,16 @@ def test_predict_output_word(self):
model_without_neg = word2vec.Word2Vec(sentences, min_count=1, negative=0)
self.assertRaises(RuntimeError, model_without_neg.predict_output_word, ['system', 'human'])

# passing indices instead of words in context
str_context = ['system', 'human']
mixed_context = [model_with_neg.wv.get_index(str_context[0]), str_context[1]]
idx_context = [model_with_neg.wv.get_index(w) for w in str_context]
prediction_from_str = model_with_neg.predict_output_word(str_context, topn=5)
prediction_from_mixed = model_with_neg.predict_output_word(mixed_context, topn=5)
prediction_from_idx = model_with_neg.predict_output_word(idx_context, topn=5)
self.assertEqual(prediction_from_str, prediction_from_mixed)
self.assertEqual(prediction_from_str, prediction_from_idx)

def test_load_old_model(self):
"""Test loading an old word2vec model of indeterminate version"""

Expand Down