Skip to content

Commit

Permalink
change old tests
Browse files Browse the repository at this point in the history
  • Loading branch information
markroxor committed Dec 15, 2017
1 parent a6f1afb commit f1646a6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
25 changes: 11 additions & 14 deletions gensim/models/tfidfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TfidfModel(interfaces.TransformationABC):
"""

def __init__(self, corpus=None, id2word=None, dictionary=None, smartirs="ntc",
wlocal=None, wglobal=None, wnormalize=None):
wlocal=None, wglobal=None, normalize=None):
"""
Compute tf-idf by multiplying a local component (term frequency) with a
global component (inverse document frequency), and normalizing
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, smartirs="ntc",
mapping (then `corpus`, if specified, is ignored).
"""
self.id2word = id2word
self.wlocal, self.wglobal, self.wnormalize = wlocal, wglobal, wnormalize
self.wlocal, self.wglobal, self.normalize = wlocal, wglobal, normalize
self.num_docs, self.num_nnz, self.idfs = None, None, None
n_tf, n_df, n_n = smartirs

Expand All @@ -106,13 +106,14 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, smartirs="ntc",
elif n_tf == "p":
self.wglobal = lambda docfreq, totaldocs: math.log((float(totaldocs) - docfreq) / docfreq)

if self.wnormalize is None:
if n_n == "n":
self.wnormalize = lambda x: x
elif n_n == "c":
self.wnormalize = matutils.unitvec
elif n_n == "t":
self.wnormalize = matutils.unitvec
if self.normalize is None or isinstance(self.normalize, bool):
if n_n == "n" or self.normalize is False:
self.normalize = lambda x: x
elif n_n == "c" or self.normalize is True:
self.normalize = matutils.unitvec
# TODO write byte-size normalisation
# elif n_n == "b":
# self.normalize = matutils.unitvec

if dictionary is not None:
# user supplied a Dictionary object, which already contains all the
Expand Down Expand Up @@ -160,10 +161,6 @@ def initialize(self, corpus):

# and finally compute the idf weights
n_features = max(dfs) if dfs else 0
logger.info(
"calculating IDF weights for %i documents and %i features (%i matrix non-zeros)",
self.num_docs, n_features, self.num_nnz
)

def __getitem__(self, bow, eps=1e-12):
"""
Expand All @@ -185,7 +182,7 @@ def __getitem__(self, bow, eps=1e-12):
# and finally, normalize the vector either to unit length, or use a
# user-defined normalization function

vector = self.wnormalize(vector)
vector = self.normalize(vector)

# make sure there are no explicit zeroes in the vector (must be sparse)
vector = [(termid, weight) for termid, weight in vector if abs(weight) > eps]
Expand Down
7 changes: 4 additions & 3 deletions gensim/sklearn_api/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ class TfIdfTransformer(TransformerMixin, BaseEstimator):
Base Tf-Idf module
"""

def __init__(self, id2word=None, dictionary=None, wlocal=gensim.utils.identity,
wglobal=gensim.models.tfidfmodel.df2idf, normalize=True):
def __init__(self, id2word=None, dictionary=None, smartirs="ntc", wlocal=None,
wglobal=None, normalize=True):
"""
Sklearn wrapper for Tf-Idf model.
"""
self.gensim_model = None
self.id2word = id2word
self.dictionary = dictionary
self.smartirs = smartirs
self.wlocal = wlocal
self.wglobal = wglobal
self.normalize = normalize
Expand All @@ -38,7 +39,7 @@ def fit(self, X, y=None):
Fit the model according to the given training data.
"""
self.gensim_model = TfidfModel(
corpus=X, id2word=self.id2word, dictionary=self.dictionary,
corpus=X, id2word=self.id2word, dictionary=self.dictionary, smartirs="ntc",
wlocal=self.wlocal, wglobal=self.wglobal, normalize=self.normalize
)
return self
Expand Down
7 changes: 3 additions & 4 deletions gensim/test/test_sklearn_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ def testPersistence(self):
original_matrix = self.model.transform(original_bow)
passed = numpy.allclose(loaded_matrix, original_matrix, atol=1e-1)
self.assertTrue(passed)

def testModelNotFitted(self):
lsi_wrapper = LsiTransformer(id2word=dictionary, num_topics=2)
texts_new = ['graph', 'eulerian']
Expand Down Expand Up @@ -973,13 +972,13 @@ def testTransform(self):

def testSetGetParams(self):
# updating only one param
self.model.set_params(normalize=False)
self.model.set_params(smartirs='nnn')
model_params = self.model.get_params()
self.assertEqual(model_params["normalize"], False)
self.assertEqual(model_params["smartirs"], 'nnn')

# verify that the attributes values are also changed for `gensim_model` after fitting
self.model.fit(self.corpus)
self.assertEqual(getattr(self.model.gensim_model, 'normalize'), False)
self.assertEqual(getattr(self.model.gensim_model, 'smartirs'), 'nnn')

def testPipeline(self):
with open(datapath('mini_newsgroup'), 'rb') as f:
Expand Down

0 comments on commit f1646a6

Please sign in to comment.