-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Word2Vec/Doc2Vec offer model-minimization method Fix issue #446 #987
Changes from 1 commit
2e9d2a5
a2efb8c
26e6042
ba8c8c4
51a64ba
c730984
a7cd9ba
a8cb0e7
18ca26f
a258241
9acf119
66fe5e3
85891f3
4395b75
06c6028
aa3942a
5f96aa0
84f174e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -508,8 +508,8 @@ def similarity_unseen_docs(self, model, doc_words1, doc_words2, alpha=0.1, min_a | |
d1 = model.infer_vector(doc_words=doc_words1, alpha=alpha, min_alpha=min_alpha, steps=steps) | ||
d2 = model.infer_vector(doc_words=doc_words2, alpha=alpha, min_alpha=min_alpha, steps=steps) | ||
return dot(matutils.unitvec(d1), matutils.unitvec(d2)) | ||
|
||
|
||
class Doctag(namedtuple('Doctag', 'offset, word_count, doc_count')): | ||
"""A string document tag discovered during the initial vocabulary | ||
scan. (The document-vector equivalent of a Vocab object.) | ||
|
@@ -553,7 +553,7 @@ def __init__(self, documents=None, size=300, alpha=0.025, window=8, min_count=5, | |
|
||
`alpha` is the initial learning rate (will linearly drop to zero as training progresses). | ||
|
||
`seed` = for the random number generator. | ||
`seed` = for the random number generator. | ||
Note that for a fully deterministically-reproducible run, you must also limit the model to | ||
a single worker thread, to eliminate ordering jitter from OS thread scheduling. (In Python | ||
3, reproducibility between interpreter launches also requires use of the PYTHONHASHSEED | ||
|
@@ -570,7 +570,7 @@ def __init__(self, documents=None, size=300, alpha=0.025, window=8, min_count=5, | |
|
||
`workers` = use this many worker threads to train the model (=faster training with multicore machines). | ||
|
||
`iter` = number of iterations (epochs) over the corpus. The default inherited from Word2Vec is 5, | ||
`iter` = number of iterations (epochs) over the corpus. The default inherited from Word2Vec is 5, | ||
but values of 10 or 20 are common in published 'Paragraph Vector' experiments. | ||
|
||
`hs` = if 1 (default), hierarchical sampling will be used for model training (else set to 0). | ||
|
@@ -778,6 +778,19 @@ def __str__(self): | |
segments.append('t%d' % self.workers) | ||
return '%s(%s)' % (self.__class__.__name__, ','.join(segments)) | ||
|
||
def finished_training(self): | ||
""" | ||
Discard parametrs that are used in training and score. Use if you're sure you're done training a model, | ||
""" | ||
self.training_finished = True | ||
if hasattr(self, 'syn1') and not self.hs: | ||
del self.syn1 | ||
if hasattr(self, 'syn1neg') and not self.negative: | ||
del self.syn1neg | ||
if hasattr(self, 'doctag_syn0'): | ||
del self.doctag_syn0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Many will consider the bulk-trained doctag-vectors a part of the model they want to retain. |
||
if hasattr(self, 'doctag_syn0_lockf'): | ||
del self.doctag_syn0_lockf | ||
|
||
class TaggedBrownCorpus(object): | ||
"""Iterate over documents from the Brown corpus (part of NLTK data), yielding | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1750,6 +1750,17 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, c | |
def __str__(self): | ||
return "%s(vocab=%s, size=%s, alpha=%s)" % (self.__class__.__name__, len(self.index2word), self.vector_size, self.alpha) | ||
|
||
def finished_training(self): | ||
""" | ||
Discard parametrs that are used in training and score. Use if you're sure you're done training a model, | ||
""" | ||
self.training_finished = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flag is best set in the end of the method |
||
self.init_sims(replace = True) | ||
if hasattr(self, 'syn1neg'): | ||
del self.syn1neg | ||
if hasattr(self, 'syn0_lockf'): | ||
del self.syn0_lockf | ||
|
||
def save(self, *args, **kwargs): | ||
# don't bother storing the cached normalized vectors, recalculable table | ||
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'table', 'cum_table']) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,6 +280,14 @@ def models_equal(self, model, model2): | |
self.assertEqual(len(model.docvecs.offset2doctag), len(model2.docvecs.offset2doctag)) | ||
self.assertTrue(np.allclose(model.docvecs.doctag_syn0, model2.docvecs.doctag_syn0)) | ||
|
||
def test_finished_training(self): | ||
"""Test doc2vec model after finishing training""" | ||
for i in [0, 1]: | ||
for j in [0, 1]: | ||
model = doc2vec.Doc2Vec(sentences, size=5, min_count=1, negative=i, hs=j) | ||
model.finished_training() | ||
self.assertTrue(len(model.infer_vector(['graph'])), 5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please tests that necessary attributes are indeed deleted |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems I "sync' in git without "commit", when I added self.docvecs, 'doctag_syn0' checks :) will fix it |
||
@log_capture() | ||
def testBuildVocabWarning(self, l): | ||
"""Test if logger warning is raised on non-ideal input to a doc2vec model""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -434,7 +434,7 @@ def testSimilarities(self): | |
model = word2vec.Word2Vec(size=2, min_count=1, sg=0, hs=0, negative=2) | ||
model.build_vocab(sentences) | ||
model.train(sentences) | ||
|
||
self.assertTrue(model.n_similarity(['graph', 'trees'], ['trees', 'graph'])) | ||
self.assertTrue(model.n_similarity(['graph'], ['trees']) == model.similarity('graph', 'trees')) | ||
self.assertRaises(ZeroDivisionError, model.n_similarity, ['graph', 'trees'], []) | ||
|
@@ -482,6 +482,20 @@ def models_equal(self, model, model2): | |
most_common_word = max(model.vocab.items(), key=lambda item: item[1].count)[0] | ||
self.assertTrue(numpy.allclose(model[most_common_word], model2[most_common_word])) | ||
|
||
def testFinishedTraining(self): | ||
"""Test word2vec model after finishing training""" | ||
for i in [0, 1]: | ||
for j in [0, 1]: | ||
model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=i, negative=j) | ||
model.finished_training() | ||
self.assertTrue(len(model.vocab), 12) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please tests that necessary attributes are indeed deleted |
||
self.assertTrue(model.vocab['graph'].count, 3) | ||
model = word2vec.Word2Vec(sentences, min_count=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a separate test. |
||
model.save_word2vec_format(testfile(), binary=True) | ||
norm_only_model = word2vec.Word2Vec.load_word2vec_format(testfile(), binary=True) | ||
norm_only_model.finished_training() | ||
self.assertFalse(numpy.allclose(model['human'], norm_only_model['human'])) | ||
|
||
@log_capture() | ||
def testBuildVocabWarning(self, l): | ||
"""Test if warning is raised on non-ideal input to a word2vec model""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please call the super method in word2vec explicitly