-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Pivot Normalization for gensim.models.TfidfModel
. Fix #220
#1780
Changes from 50 commits
efb7e3c
b7d07d4
e8a3f16
648bf21
a6f1afb
d091138
951c549
40c0558
b35344c
634d595
0917e75
bef79cc
d3d431c
0e6f21e
7ee7560
b2def84
5b2d37a
ac4b154
0bacc08
51e0eb9
3039732
99e6a6f
7d63d9c
e5140f8
4afbadd
d2fe235
5565c78
099dbdf
ef67f63
52ee3c4
3087030
62bba1b
0a9f816
dc63ab9
035c8c5
dc4ca52
1ee449d
4ea6caa
b3cead6
044332b
1c2196c
309b4e8
3866a9c
12b42e6
0ff6ad7
65c651b
4a947ba
619bb33
f105190
6410f21
2eb6fc2
a65dccf
8717350
95cb630
5f46d2f
2c7115d
1fe46f8
63c8385
5e87229
9f2b02c
fc701a1
1868da5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"docs/notebooks/test_notebooks.py": true, | ||
"gensim/test/test_tfidfmodel.py::TestTfidfModel::testPersistence": true, | ||
"gensim/test/test_tfidfmodel.py::TestTfidfModel::testPersistenceCompressed": true, | ||
"gensim/test/test_tfidfmodel.py::TestTfidfModel::testPivotedNormalization": true, | ||
"gensim/test/test_tfidfmodel.py::TestTfidfModel::test_pivotedNormalization": true | ||
} |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,7 +70,7 @@ def resolve_weights(smartirs): | |
if w_df not in 'ntp': | ||
raise ValueError("Expected inverse document frequency weight to be one of 'ntp', except got {}".format(w_df)) | ||
|
||
if w_n not in 'ncb': | ||
if w_n not in 'nc': | ||
raise ValueError("Expected normalization weight to be one of 'ncb', except got {}".format(w_n)) | ||
|
||
return w_tf, w_df, w_n | ||
|
@@ -177,7 +177,7 @@ def updated_wglobal(docfreq, totaldocs, n_df): | |
return np.log((1.0 * totaldocs - docfreq) / docfreq) / np.log(2) | ||
|
||
|
||
def updated_normalize(x, n_n): | ||
def updated_normalize(x, n_n, return_norm=False): | ||
mpenkov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Normalizes the final tf-idf value according to the value of `n_n`. | ||
|
||
Parameters | ||
|
@@ -194,9 +194,12 @@ def updated_normalize(x, n_n): | |
|
||
""" | ||
if n_n == "n": | ||
return x | ||
if return_norm: | ||
return x, 1 | ||
else: | ||
return x | ||
elif n_n == "c": | ||
return matutils.unitvec(x) | ||
return matutils.unitvec(x, return_norm=return_norm) | ||
|
||
|
||
class TfidfModel(interfaces.TransformationABC): | ||
|
@@ -219,7 +222,8 @@ class TfidfModel(interfaces.TransformationABC): | |
""" | ||
|
||
def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.identity, | ||
wglobal=df2idf, normalize=True, smartirs=None): | ||
wglobal=df2idf, normalize=True, smartirs=None, | ||
pivot_norm=False, slope=0.65, pivot=100): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about default values here, we really can't reduce it to 2 values (join |
||
"""Compute tf-idf by multiplying a local component (term frequency) with a global component | ||
(inverse document frequency), and normalizing the resulting documents to unit length. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missed docstring for new parameters |
||
Formula for non-normalized weight of term :math:`i` in document :math:`j` in a corpus of :math:`D` documents | ||
|
@@ -273,21 +277,34 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden | |
* `c` - cosine. | ||
|
||
For more information visit [1]_. | ||
|
||
pivot_norm : bool, optional | ||
If pivot_norm is True, then pivoted document length normalization will be applied. | ||
slope : float, optional | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can confuse users, need to mention that works only if |
||
It is the parameter required by pivoted document length normalization which determines the slope to which | ||
the `old normalization` can be tilted. | ||
pivot : float, optional | ||
Pivot is the point before which we consider a document to be short and after which the document is | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. too "broad" description, next question will be "what is retrieval and relevence curves" and "how to plot it" |
||
considered long. It can be found by plotting the retrieval and relevence curves of a set of documents using | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, you can add latex formula here using http://www.sphinx-doc.org/en/master/ext/math.html |
||
a general normalization function. The point where both these curves coincide is the pivot point. | ||
""" | ||
|
||
self.id2word = id2word | ||
self.wlocal, self.wglobal, self.normalize = wlocal, wglobal, normalize | ||
self.num_docs, self.num_nnz, self.idfs = None, None, None | ||
self.smartirs = smartirs | ||
self.pivot_norm = pivot_norm | ||
self.slope = slope | ||
self.pivot = pivot | ||
self.eps = 1e-12 | ||
|
||
# If smartirs is not None, override wlocal, wglobal and normalize | ||
if smartirs is not None: | ||
n_tf, n_df, n_n = resolve_weights(smartirs) | ||
|
||
self.wlocal = partial(updated_wlocal, n_tf=n_tf) | ||
self.wglobal = partial(updated_wglobal, n_df=n_df) | ||
self.normalize = partial(updated_normalize, n_n=n_n) | ||
# also return norm factor if pivot_norm is True | ||
self.normalize = partial(updated_normalize, n_n=n_n, return_norm=self.pivot_norm) | ||
|
||
if dictionary is not None: | ||
# user supplied a Dictionary object, which already contains all the | ||
|
@@ -309,6 +326,19 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden | |
# be initialized in some other way | ||
pass | ||
|
||
@classmethod | ||
def load(cls, *args, **kwargs): | ||
""" | ||
Load a previously saved TfidfModel class. Handles backwards compatibility from | ||
older TfidfModel versions which did not use pivoted document normalization. | ||
""" | ||
model = super(TfidfModel, cls).load(*args, **kwargs) | ||
if not hasattr(model, 'pivot_norm'): | ||
logger.info('older version of %s loaded without pivot_norm arg', cls.__name__) | ||
logger.info('Setting pivot_norm to False.') | ||
model.pivot_norm = False | ||
return model | ||
|
||
def __str__(self): | ||
return "TfidfModel(num_docs=%s, num_nnz=%s)" % (self.num_docs, self.num_nnz) | ||
|
||
|
@@ -360,6 +390,7 @@ def __getitem__(self, bow, eps=1e-12): | |
TfIdf corpus, if `bow` is corpus. | ||
|
||
""" | ||
self.eps = eps | ||
# if the input vector is in fact a corpus, return a transformed corpus as a result | ||
is_corpus, bow = utils.is_corpus(bow) | ||
if is_corpus: | ||
|
@@ -377,7 +408,7 @@ def __getitem__(self, bow, eps=1e-12): | |
|
||
vector = [ | ||
(termid, tf * self.idfs.get(termid)) | ||
for termid, tf in zip(termid_array, tf_array) if abs(self.idfs.get(termid, 0.0)) > eps | ||
for termid, tf in zip(termid_array, tf_array) if abs(self.idfs.get(termid, 0.0)) > self.eps | ||
] | ||
|
||
if self.normalize is True: | ||
|
@@ -387,8 +418,13 @@ def __getitem__(self, bow, eps=1e-12): | |
|
||
# and finally, normalize the vector either to unit length, or use a | ||
# user-defined normalization function | ||
vector = self.normalize(vector) | ||
|
||
# make sure there are no explicit zeroes in the vector (must be sparse) | ||
vector = [(termid, weight) for termid, weight in vector if abs(weight) > eps] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we used |
||
return vector | ||
if self.pivot_norm: | ||
_, old_norm = self.normalize(vector, return_norm=True) | ||
pivoted_norm = (1 - self.slope) * self.pivot + self.slope * old_norm | ||
norm_vector = [(termid, weight / float(pivoted_norm)) | ||
for termid, weight in vector if abs(weight / float(pivoted_norm)) > self.eps | ||
] | ||
else: | ||
norm_vector = self.normalize(vector) | ||
norm_vector = [(termid, weight) for termid, weight in norm_vector if abs(weight) > self.eps] | ||
return norm_vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use
git add *
carefully, please remove this file