Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gensim.models.BaseKeyedVectors.add_entity method for fill KeyedVectors in manual way. Fix #1942 #1957

64 changes: 62 additions & 2 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
PYEMD_EXT = False

from numpy import dot, zeros, float32 as REAL, empty, memmap as np_memmap, \
double, array, vstack, sqrt, newaxis, integer, \
double, array, zeros, vstack, sqrt, newaxis, integer, \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zeros imported twice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gojomo yeah, i've fixed it

ndarray, sum as np_sum, prod, argmax, divide as np_divide
import numpy as np
from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
Expand Down Expand Up @@ -109,7 +109,7 @@ def __str__(self):
class BaseKeyedVectors(utils.SaveLoad):

def __init__(self, vector_size):
self.vectors = []
self.vectors = zeros((0, vector_size))
self.vocab = {}
self.vector_size = vector_size
self.index2entity = []
Expand Down Expand Up @@ -154,6 +154,65 @@ def get_vector(self, entity):
else:
raise KeyError("'%s' not in vocabulary" % entity)

def add(self, entities, weights, replace=False):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is keeped unless `replace` flag is True.

Parameters
----------
entities : list of str
Entities specified by string tags.
weights: {list of numpy.ndarray, numpy.ndarray}
List of 1D np.array vectors or 2D np.array of vectors.
replace: bool, optional
Flag indicating whether to replace vectors for entities which are already in the vocabulary,
if True - replace vectors, otherwise - keep old vectors.

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: multiline docstring should ends with empty line, i.e.

"""
...
last text

"""

if isinstance(entities, string_types):
entities = [entities]
weights = np.array(weights).reshape(1, -1)
elif isinstance(weights, list):
weights = np.array(weights)

in_vocab_mask = np.zeros(len(entities), dtype=np.bool)
for idx, entity in enumerate(entities):
if entity in self.vocab:
in_vocab_mask[idx] = True

# add new entities to the vocab
for idx in np.nonzero(~in_vocab_mask)[0]:
entity = entities[idx]
self.vocab[entity] = Vocab(index=len(self.vocab), count=1)
self.index2entity.append(entity)

# add vectors for new entities
self.vectors = vstack((self.vectors, weights[~in_vocab_mask]))

# change vectors for in_vocab entities if `replace` flag is specified
if replace:
in_vocab_idxs = [self.vocab[entities[idx]].index for idx in np.nonzero(in_vocab_mask)[0]]
self.vectors[in_vocab_idxs] = weights[in_vocab_mask]

def __setitem__(self, entities, weights):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is replaced with the new one.
This method is alias for `add` with `replace=True`.

Parameters
----------
entities : {str, list of str}
Entities specified by string tags.
weights: {list of numpy.ndarray, numpy.ndarray}
List of 1D np.array vectors or 2D np.array of vectors.

"""
if not isinstance(entities, list):
entities = [entities]
weights = weights.reshape(1, -1)

self.add(entities, weights, replace=True)

def __getitem__(self, entities):
"""
Accept a single entity (string tag) or list of entities as input.
Expand All @@ -163,6 +222,7 @@ def __getitem__(self, entities):

If a list, return designated tags' vector representations as a
2D numpy array: #tags x #vector_size.

"""
if isinstance(entities, string_types):
# allow calls like trained_model['office'], as a shorthand for trained_model[['office']]
Expand Down
72 changes: 72 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,78 @@ def test_wv_property(self):
"""Test that the deprecated `wv` property returns `self`. To be removed in v4.0.0."""
self.assertTrue(self.vectors is self.vectors.wv)

def test_add_single(self):
"""Test that adding entity in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add` on already filled kv.
for ent, vector in zip(entities, vectors):
self.vectors.add(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
for ent, vector in zip(entities, vectors):
kv.add(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_add_multiple(self):
"""Test that adding a bulk of entities in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add` on already filled kv.
vocab_size = len(self.vectors.vocab)
self.vectors.add(entities, vectors, replace=False)
self.assertEqual(vocab_size + len(entities), len(self.vectors.vocab))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
kv[entities] = vectors
self.assertEqual(len(kv.vocab), len(entities))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_set_item(self):
"""Test that __setitem__ works correctly."""
vocab_size = len(self.vectors.vocab)

# Add new entity.
entity = '___some_new_entity___'
vector = np.random.randn(self.vectors.vector_size)
self.vectors[entity] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size + 1)
self.assertTrue(np.allclose(self.vectors[entity], vector))

# Replace vector for entity in vocab.
vocab_size = len(self.vectors.vocab)
vector = np.random.randn(self.vectors.vector_size)
self.vectors['war'] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size)
self.assertTrue(np.allclose(self.vectors['war'], vector))

# __setitem__ on several entities.
vocab_size = len(self.vectors.vocab)
entities = ['war', '___some_new_entity1___', '___some_new_entity2___', 'terrorism', 'conflict']
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(len(entities))]

self.vectors[entities] = vectors

self.assertEqual(len(self.vectors.vocab), vocab_size + 2)
for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down