Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy up KeyedVectors.most_similar() API #3000

Merged
merged 14 commits into from
Aug 13, 2021
Merged
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Nevertheless, we describe them below.
### Improved parameter edge-case handling in KeyedVectors most_similar and most_similar_cosmul methods

We now handle both ``positive`` and ``negative`` keyword parameters consistently.
These parameters typically specify
They may now be either:

1. A string, in which case the value is reinterpreted as a list of one element (the string value)
Expand All @@ -28,7 +27,7 @@ So you can now simply do:
```python
model.most_similar(positive='war', negative='peace')
```

instead of the slightly more involved

```python
Expand Down Expand Up @@ -73,7 +72,7 @@ Plus a large number of smaller improvements and fixes, as usual.
* [#3091](https://github.com/RaRe-Technologies/gensim/pull/3091): LsiModel: Only log top words that actually exist in the dictionary, by [@kmurphy4](https://github.com/kmurphy4)
* [#2980](https://github.com/RaRe-Technologies/gensim/pull/2980): Added EnsembleLda for stable LDA topics, by [@sezanzeb](https://github.com/sezanzeb)
* [#2978](https://github.com/RaRe-Technologies/gensim/pull/2978): Optimize performance of Author-Topic model, by [@horpto](https://github.com/horpto)

* [#3000](https://github.com/RaRe-Technologies/gensim/pull/3000): Tidy up KeyedVectors.most_similar() API, by [@simonwiles](https://github.com/simonwiles)

### :books: Tutorials and docs

Expand Down
61 changes: 36 additions & 25 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,28 @@
logger = logging.getLogger(__name__)


KEY_TYPES = (str, int, np.integer)
_KEY_TYPES = (str, int, np.integer)

_EXTENDED_KEY_TYPES = (str, int, np.integer, np.ndarray)

mpenkov marked this conversation as resolved.
Show resolved Hide resolved

def _ensure_list(value):
#
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
# Ensure that the specified value is a list. Sometimes users pass a single
# value when they should really pass a list containing value.
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
#
# This is here to make invocation of e.g. most_similar method consistent
# and convenient, and to guarantee backwards compability with older
# versions. See https://github.com/RaRe-Technologies/gensim/pull/3000
# for the background.
#
if value is None:
return []

if isinstance(value, _KEY_TYPES) or (isinstance(value, ndarray) and len(value.shape) == 1):
return [value]

return value


class KeyedVectors(utils.SaveLoad):
Expand Down Expand Up @@ -377,7 +398,7 @@ def __getitem__(self, key_or_keys):
Vector representation for `key_or_keys` (1D if `key_or_keys` is single key, otherwise - 2D).

"""
if isinstance(key_or_keys, KEY_TYPES):
if isinstance(key_or_keys, _KEY_TYPES):
return self.get_vector(key_or_keys)

return vstack([self.get_vector(key) for key in key_or_keys])
Expand Down Expand Up @@ -491,7 +512,7 @@ def add_vectors(self, keys, weights, extras=None, replace=False):
if True - replace vectors, otherwise - keep old vectors.

"""
if isinstance(keys, KEY_TYPES):
if isinstance(keys, _KEY_TYPES):
keys = [keys]
weights = np.array(weights).reshape(1, -1)
elif isinstance(weights, list):
Expand Down Expand Up @@ -729,10 +750,9 @@ def most_similar(
if isinstance(topn, Integral) and topn < 1:
return []

if positive is None:
positive = []
if negative is None:
negative = []
# allow passing a single string-key or vector for the positive/negative arguments
positive = _ensure_list(positive)
negative = _ensure_list(negative)

self.fill_norms()
clip_end = clip_end or len(self.vectors)
Expand All @@ -741,18 +761,14 @@ def most_similar(
clip_start = 0
clip_end = restrict_vocab

if isinstance(positive, KEY_TYPES) and not negative:
# allow calls like most_similar('dog'), as a shorthand for most_similar(['dog'])
positive = [positive]

# add weights for each key, if not already present; default to 1.0 for positive and -1.0 for negative keys
positive = [
(item, 1.0) if isinstance(item, KEY_TYPES + (ndarray,))
else item for item in positive
(item, 1.0) if isinstance(item, _EXTENDED_KEY_TYPES) else item
for item in positive
]
negative = [
(item, -1.0) if isinstance(item, KEY_TYPES + (ndarray,))
else item for item in negative
(item, -1.0) if isinstance(item, _EXTENDED_KEY_TYPES) else item
for item in negative
]

# compute the weighted average of all keys
Expand Down Expand Up @@ -969,21 +985,16 @@ def most_similar_cosmul(self, positive=None, negative=None, topn=10):
if isinstance(topn, Integral) and topn < 1:
return []

if positive is None:
positive = []
if negative is None:
negative = []
# allow passing a single string-key or vector for the positive/negative arguments
positive = _ensure_list(positive)
negative = _ensure_list(negative)

self.fill_norms()

if isinstance(positive, str) and not negative:
# allow calls like most_similar_cosmul('dog'), as a shorthand for most_similar_cosmul(['dog'])
positive = [positive]

all_words = {
self.get_index(word) for word in positive + negative
if not isinstance(word, ndarray) and word in self.key_to_index
}
}

positive = [
self.get_vector(word, norm=True) if isinstance(word, str) else word
Expand Down Expand Up @@ -1101,7 +1112,7 @@ def distances(self, word_or_vector, other_words=()):
If either `word_or_vector` or any word in `other_words` is absent from vocab.

"""
if isinstance(word_or_vector, KEY_TYPES):
if isinstance(word_or_vector, _KEY_TYPES):
input_vector = self.get_vector(word_or_vector)
else:
input_vector = word_or_vector
Expand Down
39 changes: 39 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Automated tests for checking the poincare module from the models package.
"""

import functools
import logging
import unittest

Expand Down Expand Up @@ -39,6 +40,44 @@ def test_most_similar(self):
predicted = [result[0] for result in self.vectors.most_similar('war', topn=5)]
self.assertEqual(expected, predicted)

def test_most_similar_vector(self):
"""Can we pass vectors to most_similar directly?"""
positive = self.vectors.vectors[0:5]
most_similar = self.vectors.most_similar(positive=positive)
assert most_similar is not None

def test_most_similar_parameter_types(self):
"""Are the positive/negative parameter types are getting interpreted correctly?"""
partial = functools.partial(self.vectors.most_similar, topn=5)

position = partial('war', 'peace')
position_list = partial(['war'], ['peace'])
keyword = partial(positive='war', negative='peace')
keyword_list = partial(positive=['war'], negative=['peace'])

#
# The above calls should all yield identical results.
#
assert position == position_list
assert position == keyword
assert position == keyword_list

def test_most_similar_cosmul_parameter_types(self):
"""Are the positive/negative parameter types are getting interpreted correctly?"""
partial = functools.partial(self.vectors.most_similar_cosmul, topn=5)

position = partial('war', 'peace')
position_list = partial(['war'], ['peace'])
keyword = partial(positive='war', negative='peace')
keyword_list = partial(positive=['war'], negative=['peace'])

#
# The above calls should all yield identical results.
#
assert position == position_list
assert position == keyword
assert position == keyword_list

def test_vectors_for_all_list(self):
"""Test vectors_for_all returns expected results with a list of keys."""
words = [
Expand Down