Skip to content

Commit

Permalink
Expose max_final_vocab parameter in FastText constructor (#2867)
Browse files Browse the repository at this point in the history
* Expose max_final_vocab parameter in FastText constructor

* Fix lint error

* respond to reviewer comments

* add unit test

Co-authored-by: Cristi Burca <mail@scribu.net>
  • Loading branch information
mpenkov and scribu authored Jun 27, 2020
1 parent 42be086 commit a74f8e3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
12 changes: 10 additions & 2 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, size=100, alpha
max_vocab_size=None, word_ngrams=1, sample=1e-3, seed=1, workers=3, min_alpha=0.0001,
negative=5, ns_exponent=0.75, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, min_n=3, max_n=6,
sorted_vocab=1, bucket=2000000, trim_rule=None, batch_words=MAX_WORDS_IN_BATCH, callbacks=(),
compatible_hash=True):
compatible_hash=True, max_final_vocab=None):
"""
Parameters
Expand Down Expand Up @@ -448,6 +448,12 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, size=100, alpha
Older versions were not 100% compatible due to a bug.
To use the older, incompatible hash function, set this to False.
max_final_vocab : int, optional
Limits the vocab to a target vocab size by automatically selecting
``min_count```. If the specified ``min_count`` is more than the
automatically calculated ``min_count``, the former will be used.
Set to ``None`` if not required.
Examples
--------
Initialize and train a `FastText` model:
Expand All @@ -472,7 +478,9 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, size=100, alpha
self.wv = FastTextKeyedVectors(size, min_n, max_n, bucket, compatible_hash)
self.vocabulary = FastTextVocab(
max_vocab_size=max_vocab_size, min_count=min_count, sample=sample,
sorted_vocab=bool(sorted_vocab), null_word=null_word, ns_exponent=ns_exponent)
sorted_vocab=bool(sorted_vocab), null_word=null_word, ns_exponent=ns_exponent,
max_final_vocab=max_final_vocab,
)
self.trainables = FastTextTrainables(vector_size=size, seed=seed, bucket=bucket, hashfxn=hashfxn)
self.trainables.prepare_weights(hs, negative, self.wv, update=False, vocabulary=self.vocabulary)
self.wv.bucket = self.trainables.bucket
Expand Down
21 changes: 21 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,27 @@ def test_sg_hs_against_wrapper(self):
self.assertFalse((orig0 == model_gensim.wv.vectors[0]).all()) # vector should vary after training
self.compare_with_wrapper(model_gensim, model_wrapper)

def test_vocab_pruning(self):
"""Does the model correctly interpret the max_final_vocab parameter?"""
sentences = [
["graph", "system"],
["graph", "system"],
["system", "eps"],
["graph", "system"],
]
model = FT_gensim(sentences, size=10, min_count=2, max_final_vocab=2)
self.assertEqual(len(model.wv.vocab), 2)
self.assertEqual(model.wv.vocab['graph'].count, 3)
self.assertEqual(model.wv.vocab['system'].count, 4)

model = FT_gensim(sentences, size=10, min_count=2, max_final_vocab=1)
self.assertEqual(len(model.wv.vocab), 1)
self.assertEqual(model.wv.vocab['system'].count, 4)

model = FT_gensim(sentences, size=10, min_count=4)
self.assertEqual(len(model.wv.vocab), 1)
self.assertEqual(model.wv.vocab['system'].count, 4)


with open(datapath('toy-data.txt')) as fin:
TOY_SENTENCES = [fin.read().strip().split(' ')]
Expand Down

0 comments on commit a74f8e3

Please sign in to comment.