From fb8e4e5ea0da6530cf61ef09b88666e19ad22932 Mon Sep 17 00:00:00 2001 From: Devashish Deshpande Date: Wed, 6 Jul 2016 21:45:40 +0530 Subject: [PATCH] Add malletmodel2ldamodel transformation function (#766) --- gensim/models/wrappers/ldamallet.py | 25 +++++++++++++++++++++++++ gensim/test/test_ldamallet_wrapper.py | 12 ++++++++++++ 2 files changed, 37 insertions(+) diff --git a/gensim/models/wrappers/ldamallet.py b/gensim/models/wrappers/ldamallet.py index 209ffc5f1d..c976c2c8b6 100644 --- a/gensim/models/wrappers/ldamallet.py +++ b/gensim/models/wrappers/ldamallet.py @@ -44,6 +44,7 @@ from gensim import utils, matutils from gensim.utils import check_output +from gensim.models.ldamodel import LdaModel logger = logging.getLogger(__name__) @@ -350,3 +351,27 @@ def read_doctopics(self, fname, eps=1e-6, renorm=True): if total_weight: doc = [(id_, float(weight) / total_weight) for id_, weight in doc] yield doc + + +def malletmodel2ldamodel(mallet_model, gamma_threshold=0.001, iterations=50): + """ + Function to convert mallet model to gensim LdaModel. This works by copying the + training model weights (alpha, beta...) from a trained mallet model into the + gensim model. + + Args: + ---- + mallet_model : Trained mallet model + gamma_threshold : To be used for inference in the new LdaModel. + iterations : number of iterations to be used for inference in the new LdaModel. + + Returns: + ------- + model_gensim : LdaModel instance; copied gensim LdaModel + """ + model_gensim = LdaModel( + id2word=mallet_model.id2word, num_topics=mallet_model.num_topics, + alpha=mallet_model.alpha, iterations=iterations, + gamma_threshold=gamma_threshold) + model_gensim.expElogbeta[:] = mallet_model.wordtopics + return model_gensim diff --git a/gensim/test/test_ldamallet_wrapper.py b/gensim/test/test_ldamallet_wrapper.py index 374641987e..9a84f29952 100644 --- a/gensim/test/test_ldamallet_wrapper.py +++ b/gensim/test/test_ldamallet_wrapper.py @@ -95,6 +95,18 @@ def testSparseTransform(self): (i, sorted(vec), sorted(expected))) self.assertTrue(passed) + def testMallet2Model(self): + if not self.mallet_path: + return + passed = False + tm1 = ldamallet.LdaMallet(self.mallet_path, corpus=corpus, num_topics=2, id2word=dictionary) + tm2 = ldamallet.malletmodel2ldamodel(tm1) + for document in corpus: + self.assertEqual(tm1[document][0], tm2[document][0]) + self.assertEqual(tm1[document][1], tm2[document][1]) + logging.debug('%d %d', tm1[document][0], tm2[document][0]) + logging.debug('%d %d', tm1[document][1], tm2[document][1]) + def testPersistence(self): if not self.mallet_path: