Skip to content

Commit

Permalink
Merge pull request #3073 from RaRe-Technologies/wmd_norm
Browse files Browse the repository at this point in the history
[MRG] Make WMD normalization optional
  • Loading branch information
piskvorky authored Mar 14, 2021
2 parents 700d6b1 + 6a4dc0d commit 1300929
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def similar_by_vector(self, vector, topn=10, restrict_vocab=None):
"""
return self.most_similar(positive=[vector], topn=topn, restrict_vocab=restrict_vocab)

def wmdistance(self, document1, document2):
def wmdistance(self, document1, document2, norm=True):
"""Compute the Word Mover's Distance between two documents.
When using this code, please consider citing the following papers:
Expand All @@ -854,6 +854,9 @@ def wmdistance(self, document1, document2):
Input document.
document2 : list of str
Input document.
norm : boolean
Normalize all word vectors to unit length before computing the distance?
Defaults to True.
Returns
-------
Expand All @@ -873,7 +876,6 @@ def wmdistance(self, document1, document2):
If `pyemd <https://pypi.org/project/pyemd/>`_ isn't installed.
"""

# If pyemd C extension is available, import it.
# If pyemd is attempted to be used, but isn't installed, ImportError will be raised in wmdistance
from pyemd import emd
Expand All @@ -889,17 +891,14 @@ def wmdistance(self, document1, document2):
logger.info('Removed %d and %d OOV words from document 1 and 2 (respectively).', diff1, diff2)

if not document1 or not document2:
logger.info(
"At least one of the documents had no words that were in the vocabulary. "
"Aborting (returning inf)."
)
logger.warning("At least one of the documents had no words that were in the vocabulary.")
return float('inf')

dictionary = Dictionary(documents=[document1, document2])
vocab_len = len(dictionary)

if vocab_len == 1:
# Both documents are composed by a single unique token
# Both documents are composed of a single unique token => zero distance.
return 0.0

# Sets for faster look-up.
Expand All @@ -916,11 +915,11 @@ def wmdistance(self, document1, document2):
if t2 not in docset2 or distance_matrix[i, j] != 0.0:
continue

# Compute Euclidean distance between unit-normed word vectors.
# Compute Euclidean distance between (potentially unit-normed) word vectors.
distance_matrix[i, j] = distance_matrix[j, i] = np.sqrt(
np_sum((self.get_vector(t1, norm=True) - self.get_vector(t2, norm=True))**2))
np_sum((self.get_vector(t1, norm=norm) - self.get_vector(t2, norm=norm))**2))

if np_sum(distance_matrix) == 0.0:
if abs(np_sum(distance_matrix)) < 1e-8:
# `emd` gets stuck if the distance matrix contains only zeros.
logger.info('The distance matrix is all zeros. Aborting (returning inf).')
return float('inf')
Expand All @@ -933,7 +932,7 @@ def nbow(document):
d[idx] = freq / float(doc_len) # Normalized word frequencies.
return d

# Compute nBOW representation of documents.
# Compute nBOW representation of documents. This is what pyemd expects on input.
d1 = nbow(document1)
d2 = nbow(document2)

Expand Down

0 comments on commit 1300929

Please sign in to comment.