Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
additional tests for python bindings
Browse files Browse the repository at this point in the history
Summary: See title.

Differential Revision: D6228155

fbshipit-source-id: a96944c6cb003449734fdc1cf2fc3fa365237619
  • Loading branch information
cpuhrsch authored and facebook-github-bot committed Nov 13, 2017
1 parent c5cb6b2 commit 1de0624
Showing 1 changed file with 66 additions and 7 deletions.
73 changes: 66 additions & 7 deletions python/fastText/test/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,72 @@ def test_subword_vector(self):
self.assertTrue(np.isclose(vec3, vec4, atol=1e-5, rtol=0).all())
self.assertTrue(np.isclose(vec4, vec1, atol=1e-5, rtol=0).all())

# TODO: Compare with .vec file
def test_get_words(self):
f = load_model(self.output + '.bin')
words1, freq1 = f.get_words(include_freq=True)
words2 = f.get_words(include_freq=False)
self.assertEqual(len(words1), len(words2))
self.assertEqual(len(words1), len(freq1))
f = load_model(self.output_sup + '.bin')
words1, freq1 = f.get_words(include_freq=True)
words2 = f.get_words(include_freq=False)
self.assertEqual(len(words1), len(words2))
self.assertEqual(len(words1), len(freq1))

# TODO: Compare with .vec file for unsup
def test_get_labels(self):
f = load_model(self.output + '.bin')
labels1, freq1 = f.get_labels(include_freq=True)
labels2 = f.get_labels(include_freq=False)
words2 = f.get_words(include_freq=False)
self.assertEqual(len(labels1), len(labels2))
self.assertEqual(len(labels1), len(freq1))
self.assertEqual(len(labels1), len(words2))
for w1, w2 in zip(labels2, words2):
self.assertEqual(w1, w2)
f = load_model(self.output_sup + '.bin')
labels1, freq1 = f.get_labels(include_freq=True)
labels2 = f.get_labels(include_freq=False)
self.assertEqual(len(labels1), len(labels2))
self.assertEqual(len(labels1), len(freq1))

def test_exercise_is_quant(self):
f = load_model(self.output + '.bin')
gotError = False
try:
f.quantize()
except ValueError:
gotError = True
self.assertTrue(gotError)
f = load_model(self.output_sup + '.bin')
self.assertTrue(not f.is_quantized())
f.quantize()
self.assertTrue(f.is_quantized())

def test_newline_predict_sentence(self):
f = load_model(self.output_sup + '.bin')
sentence = get_random_words(1, 1000, 2000)[0]
f.predict(sentence, k=5)
sentence += "\n"
gotError = False
try:
f.predict(sentence, k=5)
except ValueError:
gotError = True
self.assertTrue(gotError)

f = load_model(self.output + '.bin')
sentence = get_random_words(1, 1000, 2000)[0]
f.get_sentence_vector(sentence)
sentence += "\n"
gotError = False
try:
f.get_sentence_vector(sentence)
except ValueError:
gotError = True
self.assertTrue(gotError)


class TestFastTextPyIntegration(TestFastTextPy):
@classmethod
Expand Down Expand Up @@ -502,13 +568,6 @@ def check(

return sup_test

# TODO:
# Exercise get_words with and without include_freq
# Exercise get_labels with and without include_freq
# Compare labels to words for unsup model
# Test failure for quantizing unsupervised model
# Test isQuant
# Test failure for predict/sentence vector if text includes \n

if __name__ == "__main__":
sup_job_lr = [0.25, 0.5, 0.5, 0.1, 0.1, 0.1, 0.05, 0.05]
Expand Down

0 comments on commit 1de0624

Please sign in to comment.