diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8e3ad48871..41a608ef90 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,7 +39,8 @@ jobs: # - name: Update sbt run: | - echo "deb https://dl.bintray.com/sbt/debian /" | sudo tee -a /etc/apt/sources.list.d/sbt.list + echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list + echo "deb https://repo.scala-sbt.org/scalasbt/debian /" | sudo tee /etc/apt/sources.list.d/sbt_old.list curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo apt-key add sudo apt-get update -y sudo apt-get install -y sbt diff --git a/CHANGELOG.md b/CHANGELOG.md index d2c72ee66c..7a05a5567a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Changes * [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness) * [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128): Materialize and copy the corpus passed to SoftCosineSimilarity, by [@Witiko](https://github.com/Witiko) * [#3131](https://github.com/RaRe-Technologies/gensim/pull/3131): Added import to Nmf docs, and to models/__init__.py, by [@properGrammar](https://github.com/properGrammar) +* [#3153](https://github.com/RaRe-Technologies/gensim/pull/3153): Vectorize word2vec.predict_output_word for speed, by [@M-Demay](https://github.com/M-Demay) * [#3157](https://github.com/RaRe-Technologies/gensim/pull/3157): New KeyedVectors.vectors_for_all method for vectorizing all words in a dictionary, by [@Witiko](https://github.com/Witiko) * [#3163](https://github.com/RaRe-Technologies/gensim/pull/3163): Optimize word mover distance (WMD) computation, by [@flowlight0](https://github.com/flowlight0) * [#2965](https://github.com/RaRe-Technologies/gensim/pull/2965): Remove strip_punctuation2 alias of strip_punctuation, by [@sciatro](https://github.com/sciatro) diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index 265364890b..1a34c367e6 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -1806,8 +1806,9 @@ def predict_output_word(self, context_words_list, topn=10): Parameters ---------- - context_words_list : list of str - List of context words. + context_words_list : list of (str and/or int) + List of context words, which may be words themselves (str) + or their index in `self.wv.vectors` (int). topn : int, optional Return `topn` words and their probabilities. @@ -1825,8 +1826,8 @@ def predict_output_word(self, context_words_list, topn=10): if not hasattr(self.wv, 'vectors') or not hasattr(self, 'syn1neg'): raise RuntimeError("Parameters required for predicting the output words not found.") - word2_indices = [self.wv.get_index(w) for w in context_words_list if w in self.wv] + if not word2_indices: logger.warning("All the input context words are out-of-vocabulary for the current model.") return None @@ -1837,7 +1838,7 @@ def predict_output_word(self, context_words_list, topn=10): # propagate hidden -> output and take softmax to get probabilities prob_values = np.exp(np.dot(l1, self.syn1neg.T)) - prob_values /= sum(prob_values) + prob_values /= np.sum(prob_values) top_indices = matutils.argsort(prob_values, topn=topn, reverse=True) # returning the most probable output words with their probabilities return [(self.wv.index_to_key[index1], prob_values[index1]) for index1 in top_indices] diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index e85cee0d5d..43505b0be2 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -875,6 +875,16 @@ def test_predict_output_word(self): model_without_neg = word2vec.Word2Vec(sentences, min_count=1, negative=0) self.assertRaises(RuntimeError, model_without_neg.predict_output_word, ['system', 'human']) + # passing indices instead of words in context + str_context = ['system', 'human'] + mixed_context = [model_with_neg.wv.get_index(str_context[0]), str_context[1]] + idx_context = [model_with_neg.wv.get_index(w) for w in str_context] + prediction_from_str = model_with_neg.predict_output_word(str_context, topn=5) + prediction_from_mixed = model_with_neg.predict_output_word(mixed_context, topn=5) + prediction_from_idx = model_with_neg.predict_output_word(idx_context, topn=5) + self.assertEqual(prediction_from_str, prediction_from_mixed) + self.assertEqual(prediction_from_str, prediction_from_idx) + def test_load_old_model(self): """Test loading an old word2vec model of indeterminate version"""