diff --git a/gensim/matutils.py b/gensim/matutils.py index 93c750efd8..fbfa383a34 100644 --- a/gensim/matutils.py +++ b/gensim/matutils.py @@ -532,6 +532,10 @@ def jaccard(vec1, vec2): return 1 - float(len(intersection)) / float(len(union)) +def jaccard_set(set1, set2): + return 1. - float(len(set1 & set2)) / float(len(set1 | set2)) + + def dirichlet_expectation(alpha): """ For a vector `theta~Dir(alpha)`, compute `E[log(theta)]`. diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 67398ab099..59657fe8a0 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -33,11 +33,13 @@ import logging import numpy as np import numbers +from random import sample import os from gensim import interfaces, utils, matutils from gensim.matutils import dirichlet_expectation from gensim.models import basemodel +from gensim.matutils import kullback_leibler, hellinger, jaccard_set from itertools import chain from scipy.special import gammaln, psi # gamma function utils @@ -965,6 +967,72 @@ def get_term_topics(self, word_id, minimum_probability=None): return values + def diff(self, other, distance="kulback_leibler", num_words=100, n_ann_terms=10, normed=True): + """ + Calculate difference topic2topic between two Lda models + `other` instances of `LdaMulticore` or `LdaModel` + `distance` is function that will be applied to calculate difference between any topic pair. + Available values: `kulback_leibler`, `hellinger` and `jaccard` + `num_words` is quantity of most relevant words that used if distance == `jaccard` (also used for annotation) + `n_ann_terms` is max quantity of words in intersection/symmetric difference between topics (used for annotation) + Returns a matrix Z with shape (m1.num_topics, m2.num_topics), where Z[i][j] - difference between topic_i and topic_j + and matrix annotation with shape (m1.num_topics, m2.num_topics, 2, None), + where + annotation[i][j] = [[`int_1`, `int_2`, ...], [`diff_1`, `diff_2`, ...]] and + `int_k` is word from intersection of `topic_i` and `topic_j` and + `diff_l` is word from symmetric difference of `topic_i` and `topic_j` + `normed` is a flag. If `true`, matrix Z will be normalized + Example: + >>> m1, m2 = LdaMulticore.load(path_1), LdaMulticore.load(path_2) + >>> mdiff, annotation = m1.diff(m2) + >>> print(mdiff) # get matrix with difference for each topic pair from `m1` and `m2` + >>> print(annotation) # get array with positive/negative words for each topic pair from `m1` and `m2` + """ + + distances = {"kulback_leibler": kullback_leibler, + "hellinger": hellinger, + "jaccard": jaccard_set} + + if distance not in distances: + valid_keys = ", ".join("`{}`".format(x) for x in distances.keys()) + raise ValueError("Incorrect distance, valid only {}".format(valid_keys)) + + if not isinstance(other, self.__class__): + raise ValueError("The parameter `other` must be of type `{}`".format(self.__name__)) + + distance_func = distances[distance] + d1, d2 = self.state.get_lambda(), other.state.get_lambda() + t1_size, t2_size = d1.shape[0], d2.shape[0] + + fst_topics = [{w for (w, _) in self.show_topic(topic, topn=num_words)} for topic in xrange(t1_size)] + snd_topics = [{w for (w, _) in other.show_topic(topic, topn=num_words)} for topic in xrange(t2_size)] + + if distance == "jaccard": + d1, d2 = fst_topics, snd_topics + + z = np.zeros((t1_size, t2_size)) + for topic1 in range(t1_size): + for topic2 in range(t2_size): + z[topic1][topic2] = distance_func(d1[topic1], d2[topic2]) + + if normed: + if np.abs(np.max(z)) > 1e-8: + z /= np.max(z) + + annotation = [[None for _ in range(t1_size)] for _ in range(t2_size)] + + for topic1 in range(t1_size): + for topic2 in range(t2_size): + pos_tokens = fst_topics[topic1] & snd_topics[topic2] + neg_tokens = fst_topics[topic1].symmetric_difference(snd_topics[topic2]) + + pos_tokens = sample(pos_tokens, min(len(pos_tokens), n_ann_terms)) + neg_tokens = sample(neg_tokens, min(len(neg_tokens), n_ann_terms)) + + annotation[topic1][topic2] = [pos_tokens, neg_tokens] + + return z, annotation + def __getitem__(self, bow, eps=None): """ Return topic distribution for the given document `bow`, as a list of diff --git a/gensim/test/test_tmdiff.py b/gensim/test/test_tmdiff.py new file mode 100644 index 0000000000..2a00f81b01 --- /dev/null +++ b/gensim/test/test_tmdiff.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2016 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +import unittest +import numpy as np + +from gensim.corpora import Dictionary +from gensim.models import LdaModel + + +class TestLdaDiff(unittest.TestCase): + def setUp(self): + texts = [['human', 'interface', 'computer'], + ['survey', 'user', 'computer', 'system', 'response', 'time'], + ['eps', 'user', 'interface', 'system'], + ['system', 'human', 'system', 'eps'], + ['user', 'response', 'time'], + ['trees'], + ['graph', 'trees'], + ['graph', 'minors', 'trees'], + ['graph', 'minors', 'survey']] + self.dictionary = Dictionary(texts) + self.corpus = [self.dictionary.doc2bow(text) for text in texts] + self.num_topics = 5 + self.n_ann_terms = 10 + self.model = LdaModel(corpus=self.corpus, id2word=self.dictionary, num_topics=self.num_topics, passes=10) + + def testBasic(self): + mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms) + + self.assertEqual(mdiff.shape, (self.num_topics, self.num_topics)) + self.assertEquals(len(annotation), self.num_topics) + self.assertEquals(len(annotation[0]), self.num_topics) + + def testIdentity(self): + for dist_name in ["hellinger", "kulback_leibler", "jaccard"]: + mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms, distance=dist_name) + + for row in annotation: + for (int_tokens, diff_tokens) in row: + self.assertEquals(diff_tokens, []) + self.assertEquals(len(int_tokens), self.n_ann_terms) + + self.assertTrue(np.allclose(np.diag(mdiff), np.zeros(mdiff.shape[0], dtype=mdiff.dtype))) + + if dist_name == "jaccard": + self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype))) + + def testInput(self): + self.assertRaises(ValueError, self.model.diff, self.model, n_ann_terms=self.n_ann_terms, distance='something') + self.assertRaises(ValueError, self.model.diff, [], n_ann_terms=self.n_ann_terms, distance='something')