-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Word2Vec/Doc2Vec offer model-minimization method Fix issue #446 #987
Changes from all commits
2e9d2a5
a2efb8c
26e6042
ba8c8c4
51a64ba
c730984
a7cd9ba
a8cb0e7
18ca26f
a258241
9acf119
66fe5e3
85891f3
4395b75
06c6028
aa3942a
5f96aa0
84f174e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -287,6 +287,34 @@ def models_equal(self, model, model2): | |
self.assertEqual(len(model.docvecs.offset2doctag), len(model2.docvecs.offset2doctag)) | ||
self.assertTrue(np.allclose(model.docvecs.doctag_syn0, model2.docvecs.doctag_syn0)) | ||
|
||
def test_delete_temporary_training_data(self): | ||
"""Test doc2vec model after delete_temporary_training_data""" | ||
for i in [0, 1]: | ||
for j in [0, 1]: | ||
model = doc2vec.Doc2Vec(sentences, size=5, min_count=1, window=4, hs=i, negative=j) | ||
if i: | ||
self.assertTrue(hasattr(model, 'syn1')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we assert here that |
||
if j: | ||
self.assertTrue(hasattr(model, 'syn1neg')) | ||
self.assertTrue(hasattr(model, 'syn0_lockf')) | ||
model.delete_temporary_training_data(keep_doctags_vectors=False, keep_inference=False) | ||
self.assertTrue(len(model['human']), 10) | ||
self.assertTrue(model.vocab['graph'].count, 5) | ||
self.assertTrue(not hasattr(model, 'syn1')) | ||
self.assertTrue(not hasattr(model, 'syn1neg')) | ||
self.assertTrue(not hasattr(model, 'syn0_lockf')) | ||
self.assertTrue(model.docvecs and not hasattr(model.docvecs, 'doctag_syn0')) | ||
self.assertTrue(model.docvecs and not hasattr(model.docvecs, 'doctag_syn0_lockf')) | ||
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=1, negative=0, alpha=0.05, min_count=2, iter=20) | ||
model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True) | ||
self.assertTrue(model.docvecs and hasattr(model.docvecs, 'doctag_syn0')) | ||
self.assertTrue(hasattr(model, 'syn1')) | ||
self.model_sanity(model) | ||
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=0, negative=1, alpha=0.05, min_count=2, iter=20) | ||
model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True) | ||
self.model_sanity(model) | ||
self.assertTrue(hasattr(model, 'syn1neg')) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems I "sync' in git without "commit", when I added self.docvecs, 'doctag_syn0' checks :) will fix it |
||
@log_capture() | ||
def testBuildVocabWarning(self, l): | ||
"""Test if logger warning is raised on non-ideal input to a doc2vec model""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -434,7 +434,7 @@ def testSimilarities(self): | |
model = word2vec.Word2Vec(size=2, min_count=1, sg=0, hs=0, negative=2) | ||
model.build_vocab(sentences) | ||
model.train(sentences) | ||
|
||
self.assertTrue(model.n_similarity(['graph', 'trees'], ['trees', 'graph'])) | ||
self.assertTrue(model.n_similarity(['graph'], ['trees']) == model.similarity('graph', 'trees')) | ||
self.assertRaises(ZeroDivisionError, model.n_similarity, ['graph', 'trees'], []) | ||
|
@@ -482,6 +482,31 @@ def models_equal(self, model, model2): | |
most_common_word = max(model.vocab.items(), key=lambda item: item[1].count)[0] | ||
self.assertTrue(np.allclose(model[most_common_word], model2[most_common_word])) | ||
|
||
def testDeleteTemporaryTrainingData(self): | ||
"""Test word2vec model after delete_temporary_training_data""" | ||
for i in [0, 1]: | ||
for j in [0, 1]: | ||
model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=i, negative=j) | ||
if i: | ||
self.assertTrue(hasattr(model, 'syn1')) | ||
if j: | ||
self.assertTrue(hasattr(model, 'syn1neg')) | ||
self.assertTrue(hasattr(model, 'syn0_lockf')) | ||
model.delete_temporary_training_data(replace_word_vectors_with_normalized=True) | ||
self.assertTrue(len(model['human']), 10) | ||
self.assertTrue(len(model.vocab), 12) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please tests that necessary attributes are indeed deleted |
||
self.assertTrue(model.vocab['graph'].count, 3) | ||
self.assertTrue(not hasattr(model, 'syn1')) | ||
self.assertTrue(not hasattr(model, 'syn1neg')) | ||
self.assertTrue(not hasattr(model, 'syn0_lockf')) | ||
|
||
def testNormalizeAfterTrainingData(self): | ||
model = word2vec.Word2Vec(sentences, min_count=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a separate test. |
||
model.save_word2vec_format(testfile(), binary=True) | ||
norm_only_model = word2vec.Word2Vec.load_word2vec_format(testfile(), binary=True) | ||
norm_only_model.delete_temporary_training_data(replace_word_vectors_with_normalized=True) | ||
self.assertFalse(np.allclose(model['human'], norm_only_model['human'])) | ||
|
||
@log_capture() | ||
def testBuildVocabWarning(self, l): | ||
"""Test if warning is raised on non-ideal input to a word2vec model""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add asserts that it has all the attributes that are about to be deleted