diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 5ded70eddb..41d652f3a9 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -50,6 +50,12 @@ logger = logging.getLogger('gensim.models.ldamodel') +DTYPE_TO_EPS = { + np.float16: 1e-5, + np.float32: 1e-35, + np.float64: 1e-100, +} + def logsumexp(x): """Log of sum of exponentials @@ -275,6 +281,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, `callbacks` a list of metric callbacks to log/visualize evaluation metrics of topic model during training. `dtype` is data-type to use during calculations inside model. All inputs are also converted to this dtype. + Available types: `numpy.float16`, `numpy.float32`, `numpy.float64`. Example: @@ -286,6 +293,11 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, >>> lda = LdaModel(corpus, num_topics=50, alpha='auto', eval_every=5) # train asymmetric alpha from data """ + if dtype not in DTYPE_TO_EPS: + raise ValueError( + "Incorrect 'dtype', please choose one of {}".format( + ", ".join("numpy.{}".format(tp.__name__) for tp in sorted(DTYPE_TO_EPS)))) + self.dtype = dtype # store user-supplied parameters @@ -497,8 +509,9 @@ def inference(self, chunk, collect_sstats=False): # The optimal phi_{dwk} is proportional to expElogthetad_k * expElogbetad_w. # phinorm is the normalizer. - # TODO treat zeros explicitly, instead of adding 1e-100? - phinorm = np.dot(expElogthetad, expElogbetad) + 1e-100 + # TODO treat zeros explicitly, instead of adding epsilon? + eps = DTYPE_TO_EPS[self.dtype] + phinorm = np.dot(expElogthetad, expElogbetad) + eps # Iterate between gamma and phi until convergence for _ in xrange(self.iterations): @@ -509,7 +522,7 @@ def inference(self, chunk, collect_sstats=False): gammad = self.alpha + expElogthetad * np.dot(cts / phinorm, expElogbetad.T) Elogthetad = dirichlet_expectation(gammad) expElogthetad = np.exp(Elogthetad) - phinorm = np.dot(expElogthetad, expElogbetad) + 1e-100 + phinorm = np.dot(expElogthetad, expElogbetad) + eps # If gamma hasn't changed much, we're done. meanchange = np.mean(abs(gammad - lastgamma)) if meanchange < self.gamma_threshold: