-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Word2Vec/Doc2Vec offer model-minimization method Fix issue #446 #987
Changes from 1 commit
2e9d2a5
a2efb8c
26e6042
ba8c8c4
51a64ba
c730984
a7cd9ba
a8cb0e7
18ca26f
a258241
9acf119
66fe5e3
85891f3
4395b75
06c6028
aa3942a
5f96aa0
84f174e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -392,6 +392,7 @@ def init_sims(self, replace=False): | |
etc., but not `train` or `infer_vector`. | ||
|
||
""" | ||
print ('HELLO DOC!!!') | ||
if getattr(self, 'doctag_syn0norm', None) is None or replace: | ||
logger.info("precomputing L2-norms of doc weight vectors") | ||
if replace: | ||
|
@@ -780,13 +781,9 @@ def __str__(self): | |
|
||
def finished_training(self): | ||
""" | ||
Discard parametrs that are used in training and score. Use if you're sure you're done training a model, | ||
Discard parametrs that are used in training and score. Use if you're sure you're done training a model. | ||
""" | ||
self.training_finished = True | ||
if hasattr(self, 'syn1') and not self.hs: | ||
del self.syn1 | ||
if hasattr(self, 'syn1neg') and not self.negative: | ||
del self.syn1neg | ||
self._minimize_model(self.hs, self.negative > 0, True) | ||
if hasattr(self, 'doctag_syn0'): | ||
del self.doctag_syn0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Many will consider the bulk-trained doctag-vectors a part of the model they want to retain. |
||
if hasattr(self, 'doctag_syn0_lockf'): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1750,16 +1750,23 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, c | |
def __str__(self): | ||
return "%s(vocab=%s, size=%s, alpha=%s)" % (self.__class__.__name__, len(self.index2word), self.vector_size, self.alpha) | ||
|
||
def _minimize_model(self, save_syn1 = False, save_syn1neg = False, save_syn0_lockf = False): | ||
if hasattr(self, 'syn1') and not save_syn1: | ||
del self.syn1 | ||
if hasattr(self, 'syn1neg') and not save_syn1neg: | ||
del self.syn1neg | ||
if hasattr(self, 'syn0_lockf') and not save_syn0_lockf: | ||
del self.syn0_lockf | ||
|
||
def finished_training(self): | ||
""" | ||
Discard parametrs that are used in training and score. Use if you're sure you're done training a model, | ||
Discard parametrs that are used in training and score. Use if you're sure you're done training a model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
||
""" | ||
self.training_finished = True | ||
self.init_sims(replace = True) | ||
if hasattr(self, 'syn1neg'): | ||
del self.syn1neg | ||
if hasattr(self, 'syn0_lockf'): | ||
del self.syn0_lockf | ||
for i in xrange(self.syn0.shape[0]): | ||
self.syn0[i, :] /= sqrt((self.syn0[i, :] ** 2).sum(-1)) | ||
self.syn0norm = self.syn0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not all post-training applications want the unit-normalized vectors! |
||
self._minimize_model() | ||
|
||
def save(self, *args, **kwargs): | ||
# don't bother storing the cached normalized vectors, recalculable table | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -284,9 +284,19 @@ def test_finished_training(self): | |
"""Test doc2vec model after finishing training""" | ||
for i in [0, 1]: | ||
for j in [0, 1]: | ||
model = doc2vec.Doc2Vec(sentences, size=5, min_count=1, negative=i, hs=j) | ||
model = doc2vec.Doc2Vec(sentences, size=5, min_count=1, hs=i, negative=j) | ||
model.finished_training() | ||
self.assertTrue(len(model.infer_vector(['graph'])), 5) | ||
self.assertTrue(len(model['human']), 10) | ||
self.assertTrue(model.vocab['graph'].count, 5) | ||
if (i == 1): | ||
self.assertTrue(hasattr(model, 'syn1')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we assert here that |
||
else: | ||
self.assertTrue(not hasattr(model, 'syn1')) | ||
if (j == 1): | ||
self.assertTrue(hasattr(model, 'syn1neg')) | ||
else: | ||
self.assertTrue(not hasattr(model, 'syn1neg')) | ||
self.assertTrue(hasattr(model, 'syn0_lockf')) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems I "sync' in git without "commit", when I added self.docvecs, 'doctag_syn0' checks :) will fix it |
||
@log_capture() | ||
def testBuildVocabWarning(self, l): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -488,8 +488,12 @@ def testFinishedTraining(self): | |
for j in [0, 1]: | ||
model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=i, negative=j) | ||
model.finished_training() | ||
self.assertTrue(len(model['human']), 10) | ||
self.assertTrue(len(model.vocab), 12) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please tests that necessary attributes are indeed deleted |
||
self.assertTrue(model.vocab['graph'].count, 3) | ||
self.assertTrue(model.vocab['graph'].count, 3) | ||
self.assertTrue(not hasattr(model, 'syn1')) | ||
self.assertTrue(not hasattr(model, 'syn1neg')) | ||
self.assertTrue(not hasattr(model, 'syn0_lockf')) | ||
model = word2vec.Word2Vec(sentences, min_count=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a separate test. |
||
model.save_word2vec_format(testfile(), binary=True) | ||
norm_only_model = word2vec.Word2Vec.load_word2vec_format(testfile(), binary=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted in next commit