Skip to content

Commit

Permalink
Fix fastText word_vec() for OOV words with use_norm=True (#2764)
Browse files Browse the repository at this point in the history
* add a test for oov similarity

* fix a test for oov similarity

* fix it once more

* prepare the real fix

* remove a redundant variable

* less accurate comparison

Co-authored-by: David Dale <ddale@yandex-team.ru>
  • Loading branch information
avidale and David Dale authored Mar 21, 2020
1 parent 0d75f2d commit cb3d87c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
11 changes: 5 additions & 6 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,10 +2114,6 @@ def word_vec(self, word, use_norm=False):
raise KeyError('cannot calculate vector for OOV word without ngrams')
else:
word_vec = np.zeros(self.vectors_ngrams.shape[1], dtype=np.float32)
if use_norm:
ngram_weights = self.vectors_ngrams_norm
else:
ngram_weights = self.vectors_ngrams
ngram_hashes = ft_ngram_hashes(word, self.min_n, self.max_n, self.bucket, self.compatible_hash)
if len(ngram_hashes) == 0:
#
Expand All @@ -2131,8 +2127,11 @@ def word_vec(self, word, use_norm=False):
logger.warning('could not extract any ngrams from %r, returning origin vector', word)
return word_vec
for nh in ngram_hashes:
word_vec += ngram_weights[nh]
return word_vec / len(ngram_hashes)
word_vec += self.vectors_ngrams[nh]
result = word_vec / len(ngram_hashes)
if use_norm:
result /= sqrt(sum(result ** 2))
return result

def init_sims(self, replace=False):
"""Precompute L2-normalized vectors.
Expand Down
9 changes: 9 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ def test_load_model_non_utf8_encoding(self):
except KeyError:
self.fail('Unable to access vector for cp-852 word')

def test_oov_similarity(self):
word = 'someoovword'
most_similar = self.test_model.wv.most_similar(word)
top_neighbor, top_similarity = most_similar[0]
v1 = self.test_model.wv[word]
v2 = self.test_model.wv[top_neighbor]
top_similarity_direct = self.test_model.wv.cosine_similarities(v1, v2.reshape(1, -1))[0]
self.assertAlmostEqual(top_similarity, top_similarity_direct, places=6)

def test_n_similarity(self):
# In vocab, sanity check
self.assertTrue(np.allclose(self.test_model.wv.n_similarity(['the', 'and'], ['and', 'the']), 1.0))
Expand Down

0 comments on commit cb3d87c

Please sign in to comment.