Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Make WMD normalization optional #3073

Merged
merged 1 commit into from
Mar 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def similar_by_vector(self, vector, topn=10, restrict_vocab=None):
"""
return self.most_similar(positive=[vector], topn=topn, restrict_vocab=restrict_vocab)

def wmdistance(self, document1, document2):
def wmdistance(self, document1, document2, norm=True):
"""Compute the Word Mover's Distance between two documents.

When using this code, please consider citing the following papers:
Expand All @@ -854,6 +854,9 @@ def wmdistance(self, document1, document2):
Input document.
document2 : list of str
Input document.
norm : boolean
Normalize all word vectors to unit length before computing the distance?
Defaults to True.

Returns
-------
Expand All @@ -873,7 +876,6 @@ def wmdistance(self, document1, document2):
If `pyemd <https://pypi.org/project/pyemd/>`_ isn't installed.

"""

# If pyemd C extension is available, import it.
# If pyemd is attempted to be used, but isn't installed, ImportError will be raised in wmdistance
from pyemd import emd
Expand All @@ -889,17 +891,14 @@ def wmdistance(self, document1, document2):
logger.info('Removed %d and %d OOV words from document 1 and 2 (respectively).', diff1, diff2)

if not document1 or not document2:
logger.info(
"At least one of the documents had no words that were in the vocabulary. "
"Aborting (returning inf)."
)
logger.warning("At least one of the documents had no words that were in the vocabulary.")
return float('inf')

dictionary = Dictionary(documents=[document1, document2])
vocab_len = len(dictionary)

if vocab_len == 1:
# Both documents are composed by a single unique token
# Both documents are composed of a single unique token => zero distance.
return 0.0

# Sets for faster look-up.
Expand All @@ -916,11 +915,11 @@ def wmdistance(self, document1, document2):
if t2 not in docset2 or distance_matrix[i, j] != 0.0:
continue

# Compute Euclidean distance between unit-normed word vectors.
# Compute Euclidean distance between (potentially unit-normed) word vectors.
distance_matrix[i, j] = distance_matrix[j, i] = np.sqrt(
np_sum((self.get_vector(t1, norm=True) - self.get_vector(t2, norm=True))**2))
np_sum((self.get_vector(t1, norm=norm) - self.get_vector(t2, norm=norm))**2))

if np_sum(distance_matrix) == 0.0:
if abs(np_sum(distance_matrix)) < 1e-8:
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
# `emd` gets stuck if the distance matrix contains only zeros.
logger.info('The distance matrix is all zeros. Aborting (returning inf).')
return float('inf')
Expand All @@ -933,7 +932,7 @@ def nbow(document):
d[idx] = freq / float(doc_len) # Normalized word frequencies.
return d

# Compute nBOW representation of documents.
# Compute nBOW representation of documents. This is what pyemd expects on input.
d1 = nbow(document1)
d2 = nbow(document2)

Expand Down