Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the train method of TranslationMatrix #1838

Merged
merged 22 commits into from
Jan 15, 2018
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1aa3f33
fix the compatibility between python2 & 3
robotcator Mar 17, 2017
24e6331
Merge https://github.com/RaRe-Technologies/gensim into fix-word2vec-n…
robotcator Mar 18, 2017
f6f571f
require explicit corpus size, epochs for train()
gojomo Feb 9, 2017
5e9529b
make all train() calls use explicit count, epochs
gojomo Feb 9, 2017
5c24a90
add tests to make sure that ValueError is indeed thrown
robotcator Mar 23, 2017
c89f285
update test
robotcator Mar 24, 2017
10ff8a5
fix the word2vec's reset_from()
robotcator Mar 25, 2017
a6312ca
Merge branch 'fix-word2vec' into fix-word2vec-notebook
robotcator Mar 29, 2017
be5216a
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
robotcator Mar 29, 2017
504bd09
require explicit corpus size, epochs for train()
gojomo Feb 9, 2017
43f9689
make all train() calls use explicit count, epochs
gojomo Feb 9, 2017
49e3d00
update notebooks
robotcator Mar 29, 2017
c9eab32
fix some error
robotcator Mar 29, 2017
8024eb5
fix test error
robotcator Mar 29, 2017
d3562b6
Merge branch 'test-word2vec' of https://github.com/robotcator/gensim …
robotcator Apr 9, 2017
ff93cdf
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
robotcator May 24, 2017
c11d007
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
robotcator Jan 6, 2018
155e1db
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
robotcator Jan 9, 2018
6cfc651
make tagged_docs optional
robotcator Jan 12, 2018
c47303b
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
robotcator Jan 12, 2018
fd2f753
fix the train method
robotcator Jan 15, 2018
6f05130
add comments for the translation matrix
robotcator Jan 15, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions gensim/models/translation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,23 +365,23 @@ class BackMappingTranslationMatrix(utils.SaveLoad):
>>> src_model = Doc2Vec.load(datapath("small_tag_doc_5_iter50"))
>>> dst_model = Doc2Vec.load(datapath("large_tag_doc_10_iter50"))
>>>
>>> model_trans = BackMappingTranslationMatrix(data, src_model, dst_model)
>>> model_trans = BackMappingTranslationMatrix(src_model, dst_model)
>>> trans_matrix = model_trans.train(data)
>>>
>>> result = model_trans.infer_vector(dst_model.docvecs[data[3].tags])

"""
def __init__(self, tagged_docs, source_lang_vec, target_lang_vec, random_state=None):
def __init__(self, source_lang_vec, target_lang_vec, tagged_docs=None, random_state=None):
"""

Parameters
----------
tagged_docs : list of :class:`~gensim.models.doc2vec.TaggedDocument`, optional
Documents that will be used for training
source_lang_vec : :class:`~gensim.models.doc2vec.Doc2Vec`
Source Doc2Vec model.
target_lang_vec : :class:`~gensim.models.doc2vec.Doc2Vec`
Target Doc2Vec model.
tagged_docs : list of :class:`~gensim.models.doc2vec.TaggedDocument`, optional
Documents that will be used for training
random_state : {None, int, array_like}, optional
Seed for random state.

Expand All @@ -393,22 +393,24 @@ def __init__(self, tagged_docs, source_lang_vec, target_lang_vec, random_state=N
self.random_state = utils.get_random_state(random_state)
self.translation_matrix = None

if tagged_docs is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem isn't fixed in train method (you continue to ignore passed parameter).

self.train(tagged_docs)

def train(self, tagged_docs):
"""Build the translation matrix that mapping from the source model's vector to target model's vector

Parameters
----------
tagged_docs : list of :class:`~gensim.models.doc2vec.TaggedDocument`, optional
THIS ARGUMENT WILL BE IGNORED.
tagged_docs : list of :class:`~gensim.models.doc2vec.TaggedDocument`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add a description for this parameter, what is it.


Returns
-------
numpy.ndarray
Translation matrix that mapping from the source model's vector to target model's vector.

"""
m1 = [self.source_lang_vec.docvecs[item.tags].flatten() for item in self.tagged_docs]
m2 = [self.target_lang_vec.docvecs[item.tags].flatten() for item in self.tagged_docs]
m1 = [self.source_lang_vec.docvecs[item.tags].flatten() for item in tagged_docs]
m2 = [self.target_lang_vec.docvecs[item.tags].flatten() for item in tagged_docs]

self.translation_matrix = np.linalg.lstsq(m2, m1, -1)[0]
return self.translation_matrix
Expand Down
4 changes: 2 additions & 2 deletions gensim/test/test_translation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def setUp(self):

def test_translation_matrix(self):
model = translation_matrix.BackMappingTranslationMatrix(
self.train_docs[:5], self.source_doc_vec, self.target_doc_vec
self.source_doc_vec, self.target_doc_vec, self.train_docs[:5]
)
transmat = model.train(self.train_docs[:5])
self.assertEqual(transmat.shape, (100, 100))

def test_infer_vector(self):
model = translation_matrix.BackMappingTranslationMatrix(
self.train_docs[:5], self.source_doc_vec, self.target_doc_vec
self.source_doc_vec, self.target_doc_vec, self.train_docs[:5]
)
model.train(self.train_docs[:5])
infered_vec = model.infer_vector(self.target_doc_vec.docvecs[self.train_docs[5].tags])
Expand Down