diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 11268a28fb..c9012c144f 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -9,6 +9,7 @@ Changes * Better internal handling of job batching in word2vec (#535) - up to 300% speed up when training on very short documents (~tweets) * Word2vec allows non-strict unicode error handling (ignore or replace) (Gordon Mohr, #466) +* Fix `DocvecsArray.index_to_doctag` so `most_similar()` returns string doctags (Gordon Mohr, #560) * On-demand loading of the `pattern` library in utils.lemmatize (Jan Zikes, #461) - `utils.HAS_PATTERN` flag moved to `utils.has_pattern()` * Forwards compatibility for NumPy > 1.10 (Matti Lyra, #494, #513) diff --git a/gensim/models/doc2vec.py b/gensim/models/doc2vec.py index 4007e70310..8869b3b7ed 100644 --- a/gensim/models/doc2vec.py +++ b/gensim/models/doc2vec.py @@ -319,7 +319,7 @@ def _key_index(self, i_index, missing=None): def index_to_doctag(self, i_index): """Return string key for given i_index, if available. Otherwise return raw int doctag (same int).""" - candidate_offset = self.max_rawint - i_index - 1 + candidate_offset = i_index - self.max_rawint - 1 if 0 <= candidate_offset < len(self.offset2doctag): return self.offset2doctag[candidate_offset] else: diff --git a/gensim/test/test_doc2vec.py b/gensim/test/test_doc2vec.py index b0628e9e0a..cfb04a97c1 100644 --- a/gensim/test/test_doc2vec.py +++ b/gensim/test/test_doc2vec.py @@ -103,6 +103,8 @@ def test_string_doctags(self): self.assertTrue(all(model.docvecs['_*0'] == model.docvecs[0])) self.assertTrue(max(d.offset for d in model.docvecs.doctags.values()) < len(model.docvecs.doctags)) self.assertTrue(max(model.docvecs._int_index(str_key) for str_key in model.docvecs.doctags.keys()) < len(model.docvecs.doctag_syn0)) + # verify docvecs.most_similar() returns string doctags rather than indexes + self.assertEqual(model.docvecs.offset2doctag[0], model.docvecs.most_similar([model.docvecs[0]])[0][0]) def test_empty_errors(self): # no input => "RuntimeError: you must first build vocabulary before training the model" @@ -242,8 +244,6 @@ def test_mixed_tag_types(self): model = doc2vec.Doc2Vec() model.build_vocab(mixed_tag_corpus) expected_length = len(sentences) + len(model.docvecs.doctags) # 9 sentences, 7 unique first tokens - print(model.docvecs.doctags) - print(model.docvecs.count) self.assertEquals(len(model.docvecs.doctag_syn0), expected_length) def models_equal(self, model, model2):