-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Word2Vec/Doc2Vec offer model-minimization method Fix issue #446 #987
Changes from 5 commits
2e9d2a5
a2efb8c
26e6042
ba8c8c4
51a64ba
c730984
a7cd9ba
a8cb0e7
18ca26f
a258241
9acf119
66fe5e3
85891f3
4395b75
06c6028
aa3942a
5f96aa0
84f174e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,7 +465,7 @@ def __init__( | |
self.total_train_time = 0 | ||
self.sorted_vocab = sorted_vocab | ||
self.batch_words = batch_words | ||
|
||
self.training_finished = False | ||
if sentences is not None: | ||
if isinstance(sentences, GeneratorType): | ||
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.") | ||
|
@@ -757,6 +757,8 @@ def train(self, sentences, total_words=None, word_count=0, | |
sentences are the same as those that were used to initially build the vocabulary. | ||
|
||
""" | ||
if (self.training_finished): | ||
raise RuntimeError("parameters for training were discarded") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make a better message starting with the capital letter "Parameters for training were discarded using model_trimmed_post_training method" |
||
if FAST_VERSION < 0: | ||
import warnings | ||
warnings.warn("C extension not loaded for Word2Vec, training will be slow. " | ||
|
@@ -1750,6 +1752,27 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, c | |
def __str__(self): | ||
return "%s(vocab=%s, size=%s, alpha=%s)" % (self.__class__.__name__, len(self.index2word), self.vector_size, self.alpha) | ||
|
||
def _minimize_model(self, save_syn1 = False, save_syn1neg = False, save_syn0_lockf = False): | ||
self.training_finished = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flag is best set in the end of the method |
||
if hasattr(self, 'syn1') and not save_syn1: | ||
del self.syn1 | ||
if hasattr(self, 'syn1neg') and not save_syn1neg: | ||
del self.syn1neg | ||
if hasattr(self, 'syn0_lockf') and not save_syn0_lockf: | ||
del self.syn0_lockf | ||
|
||
def discard_model_parameters(self, replace=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My English language skills allows me to only agree with you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But in this question obviously yes |
||
""" | ||
Discard parameters that are used in training and score. Use if you're sure you're done training a model. | ||
If `replace` is set, forget the original vectors and only keep the normalized | ||
ones = saves lots of memory! | ||
""" | ||
if replace: | ||
for i in xrange(self.syn0.shape[0]): | ||
self.syn0[i, :] /= sqrt((self.syn0[i, :] ** 2).sum(-1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why duplicate code and not just call |
||
self.syn0norm = self.syn0 | ||
self._minimize_model() | ||
|
||
def save(self, *args, **kwargs): | ||
# don't bother storing the cached normalized vectors, recalculable table | ||
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'table', 'cum_table']) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,6 +280,24 @@ def models_equal(self, model, model2): | |
self.assertEqual(len(model.docvecs.offset2doctag), len(model2.docvecs.offset2doctag)) | ||
self.assertTrue(np.allclose(model.docvecs.doctag_syn0, model2.docvecs.doctag_syn0)) | ||
|
||
def test_discard_model_parameters(self): | ||
"""Test doc2vec model after discard_model_parameters""" | ||
for i in [0, 1]: | ||
for j in [0, 1]: | ||
model = doc2vec.Doc2Vec(sentences, size=5, min_count=1, hs=i, negative=j) | ||
model.discard_model_parameters(remove_doctags_vectors=True) | ||
self.assertTrue(len(model['human']), 10) | ||
self.assertTrue(model.vocab['graph'].count, 5) | ||
if (i == 1): | ||
self.assertTrue(hasattr(model, 'syn1')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we assert here that |
||
else: | ||
self.assertTrue(not hasattr(model, 'syn1')) | ||
if (j == 1): | ||
self.assertTrue(hasattr(model, 'syn1neg')) | ||
else: | ||
self.assertTrue(not hasattr(model, 'syn1neg')) | ||
self.assertTrue(hasattr(model, 'syn0_lockf')) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems I "sync' in git without "commit", when I added self.docvecs, 'doctag_syn0' checks :) will fix it |
||
@log_capture() | ||
def testBuildVocabWarning(self, l): | ||
"""Test if logger warning is raised on non-ideal input to a doc2vec model""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -434,7 +434,7 @@ def testSimilarities(self): | |
model = word2vec.Word2Vec(size=2, min_count=1, sg=0, hs=0, negative=2) | ||
model.build_vocab(sentences) | ||
model.train(sentences) | ||
|
||
self.assertTrue(model.n_similarity(['graph', 'trees'], ['trees', 'graph'])) | ||
self.assertTrue(model.n_similarity(['graph'], ['trees']) == model.similarity('graph', 'trees')) | ||
self.assertRaises(ZeroDivisionError, model.n_similarity, ['graph', 'trees'], []) | ||
|
@@ -482,6 +482,24 @@ def models_equal(self, model, model2): | |
most_common_word = max(model.vocab.items(), key=lambda item: item[1].count)[0] | ||
self.assertTrue(numpy.allclose(model[most_common_word], model2[most_common_word])) | ||
|
||
def testDiscardModelParameters(self): | ||
"""Test word2vec model after discard_model_parameters""" | ||
for i in [0, 1]: | ||
for j in [0, 1]: | ||
model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=i, negative=j) | ||
model.discard_model_parameters(replace=True) | ||
self.assertTrue(len(model['human']), 10) | ||
self.assertTrue(len(model.vocab), 12) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please tests that necessary attributes are indeed deleted |
||
self.assertTrue(model.vocab['graph'].count, 3) | ||
self.assertTrue(not hasattr(model, 'syn1')) | ||
self.assertTrue(not hasattr(model, 'syn1neg')) | ||
self.assertTrue(not hasattr(model, 'syn0_lockf')) | ||
model = word2vec.Word2Vec(sentences, min_count=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a separate test. |
||
model.save_word2vec_format(testfile(), binary=True) | ||
norm_only_model = word2vec.Word2Vec.load_word2vec_format(testfile(), binary=True) | ||
norm_only_model.discard_model_parameters(replace=True) | ||
self.assertFalse(numpy.allclose(model['human'], norm_only_model['human'])) | ||
|
||
@log_capture() | ||
def testBuildVocabWarning(self, l): | ||
"""Test if warning is raised on non-ideal input to a word2vec model""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better name would be
model_trimmed_post_training = False