diff --git a/gensim/models/fasttext.py b/gensim/models/fasttext.py index d2da493ec9..2307b04468 100644 --- a/gensim/models/fasttext.py +++ b/gensim/models/fasttext.py @@ -346,7 +346,7 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, size=100, alpha max_vocab_size=None, word_ngrams=1, sample=1e-3, seed=1, workers=3, min_alpha=0.0001, negative=5, ns_exponent=0.75, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, min_n=3, max_n=6, sorted_vocab=1, bucket=2000000, trim_rule=None, batch_words=MAX_WORDS_IN_BATCH, callbacks=(), - compatible_hash=True): + compatible_hash=True, max_final_vocab=None): """ Parameters @@ -448,6 +448,12 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, size=100, alpha Older versions were not 100% compatible due to a bug. To use the older, incompatible hash function, set this to False. + max_final_vocab : int, optional + Limits the vocab to a target vocab size by automatically selecting + ``min_count```. If the specified ``min_count`` is more than the + automatically calculated ``min_count``, the former will be used. + Set to ``None`` if not required. + Examples -------- Initialize and train a `FastText` model: @@ -472,7 +478,9 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, size=100, alpha self.wv = FastTextKeyedVectors(size, min_n, max_n, bucket, compatible_hash) self.vocabulary = FastTextVocab( max_vocab_size=max_vocab_size, min_count=min_count, sample=sample, - sorted_vocab=bool(sorted_vocab), null_word=null_word, ns_exponent=ns_exponent) + sorted_vocab=bool(sorted_vocab), null_word=null_word, ns_exponent=ns_exponent, + max_final_vocab=max_final_vocab, + ) self.trainables = FastTextTrainables(vector_size=size, seed=seed, bucket=bucket, hashfxn=hashfxn) self.trainables.prepare_weights(hs, negative, self.wv, update=False, vocabulary=self.vocabulary) self.wv.bucket = self.trainables.bucket diff --git a/gensim/test/test_fasttext.py b/gensim/test/test_fasttext.py index 3517a355a9..8f691d4608 100644 --- a/gensim/test/test_fasttext.py +++ b/gensim/test/test_fasttext.py @@ -866,6 +866,27 @@ def test_sg_hs_against_wrapper(self): self.assertFalse((orig0 == model_gensim.wv.vectors[0]).all()) # vector should vary after training self.compare_with_wrapper(model_gensim, model_wrapper) + def test_vocab_pruning(self): + """Does the model correctly interpret the max_final_vocab parameter?""" + sentences = [ + ["graph", "system"], + ["graph", "system"], + ["system", "eps"], + ["graph", "system"], + ] + model = FT_gensim(sentences, size=10, min_count=2, max_final_vocab=2) + self.assertEqual(len(model.wv.vocab), 2) + self.assertEqual(model.wv.vocab['graph'].count, 3) + self.assertEqual(model.wv.vocab['system'].count, 4) + + model = FT_gensim(sentences, size=10, min_count=2, max_final_vocab=1) + self.assertEqual(len(model.wv.vocab), 1) + self.assertEqual(model.wv.vocab['system'].count, 4) + + model = FT_gensim(sentences, size=10, min_count=4) + self.assertEqual(len(model.wv.vocab), 1) + self.assertEqual(model.wv.vocab['system'].count, 4) + with open(datapath('toy-data.txt')) as fin: TOY_SENTENCES = [fin.read().strip().split(' ')]