From 179a2c170414c08e9f5d176203b5d86937bd67ef Mon Sep 17 00:00:00 2001 From: horpto <__Singleton__@hackerdom.ru> Date: Mon, 28 Jan 2019 07:42:37 +0500 Subject: [PATCH] Fix infinite diff in `LdaModel.do_mstep` (#2344) * Fix #416, #2051: Infinite diff in LdaModel.do_mstep * fix build --- gensim/models/ldamodel.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 503c2b48e3..786ec41c0b 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -594,9 +594,19 @@ def __str__(self): self.num_terms, self.num_topics, self.decay, self.chunksize ) - def sync_state(self): - """Propagate the states topic probabilities to the inner object's attribute.""" - self.expElogbeta = np.exp(self.state.get_Elogbeta()) + def sync_state(self, current_Elogbeta=None): + """Propagate the states topic probabilities to the inner object's attribute. + + Parameters + ---------- + current_Elogbeta: numpy.ndarray + Posterior probabilities for each topic, optional. + If omitted, it will get Elogbeta from state. + """ + + if current_Elogbeta is None: + current_Elogbeta = self.state.get_Elogbeta() + self.expElogbeta = np.exp(current_Elogbeta) assert self.expElogbeta.dtype == self.dtype def clear(self): @@ -1027,14 +1037,16 @@ def do_mstep(self, rho, other, extra_pass=False): logger.debug("updating topics") # update self with the new blend; also keep track of how much did # the topics change through this update, to assess convergence - diff = np.log(self.expElogbeta) + previous_Elogbeta = self.state.get_Elogbeta() self.state.blend(rho, other) - diff -= self.state.get_Elogbeta() - self.sync_state() + + current_Elogbeta = self.state.get_Elogbeta() + self.sync_state(current_Elogbeta) # print out some debug info at the end of each EM iteration self.print_topics(5) - logger.info("topic diff=%f, rho=%f", np.mean(np.abs(diff)), rho) + diff = mean_absolute_difference(previous_Elogbeta.ravel(), current_Elogbeta.ravel()) + logger.info("topic diff=%f, rho=%f", diff, rho) if self.optimize_eta: self.update_eta(self.state.get_lambda(), rho)