Skip to content

Commit

Permalink
Compatibility updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jvamvas authored Dec 11, 2023
2 parents ee2e8ad + a455176 commit 1a5b006
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7"]
python-version: ["3.8"]

steps:
- uses: actions/checkout@v3
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ To learn more about how these measures work, have a look at [Jannis' blog post](

## Installation

- Requires Python >= 3.7 and PyTorch
- Requires Python >= 3.8 and PyTorch
- `pip install nmtscore`
- Extra requirements for the Prism model: `pip install nmtscore[prism]`

Expand All @@ -25,7 +25,7 @@ from nmtscore import NMTScorer
scorer = NMTScorer()

scorer.score("This is a sentence.", "This is another sentence.")
# 0.5025776988808766
# 0.4677300455046415
```

#### Different similarity measures
Expand All @@ -52,7 +52,7 @@ scorer.score(
["This is a sentence.", "This is a sentence.", "This is another sentence."],
["This is another sentence.", "This sentence is completely unrelated.", "This is another sentence."],
)
# [0.5025777998113548, 0.1640727324003354, 1.0000000000000049]
# [0.46772973967003206, 0.15306852595255185, 1.0]
```

The sentences in the first list are compared element-wise to the sentences in the second list.
Expand Down Expand Up @@ -132,7 +132,7 @@ model.translate("de", ["This is a test."])
# ["Das ist ein Test."]

model.score("de", ["This is a test."], ["Das ist ein Test."])
# [0.7708902359008789]
# [0.8293135166168213]
```

## Experiments
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ classifiers =
package_dir =
= src
packages = find:
python_requires = >=3.7
python_requires = >=3.8
install_requires =
transformers
transformers<4.34 # https://github.com/ZurichNLP/nmtscore/issues/7
sentencepiece
tqdm
sqlitedict
Expand Down
2 changes: 0 additions & 2 deletions src/nmtscore/models/m2m100.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def _score(self,
batch(hypothesis_sentences, batch_size),
)
for src_sentences, tgt_sentences in batch_iterator:
# Hack: Append a second EOS token to make sure that one EOS is still there after shift_tokens_right
tgt_sentences = [f"{sentence} {self.tokenizer.eos_token}" for sentence in tgt_sentences]
inputs = self.tokenizer(
src_sentences,
text_target=tgt_sentences,
Expand Down
24 changes: 12 additions & 12 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def tearDownClass(cls) -> None:
def test_nmtscorer(self):
scorer = NMTScorer()
score = scorer.score("This is a sentence.", "This is another sentence.")
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)

def test_batch_processing(self):
scorer = NMTScorer()
Expand All @@ -33,20 +33,20 @@ def test_batch_processing(self):
["This is another sentence.", "This sentence is completely unrelated.", "This is another sentence."],
)
self.assertEqual(3, len(scores))
self.assertAlmostEqual(0.5025777998113548, scores[0], places=4)
self.assertAlmostEqual(0.1640727324003354, scores[1], places=4)
self.assertAlmostEqual(1.0000000000000049, scores[2], places=4)
self.assertAlmostEqual(0.46772973967003206, scores[0], places=4)
self.assertAlmostEqual(0.15306852595255185, scores[1], places=4)
self.assertAlmostEqual(1.0, scores[2], places=4)

def test_different_similarity_measures(self):
scorer = NMTScorer()
a = "This is a sentence."
b = "This is another sentence."
score = scorer.score_cross_likelihood(a, b, tgt_lang="en", normalize=True, both_directions=True)
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)
score = scorer.score_direct(a, b, a_lang="en", b_lang="en", normalize=True, both_directions=True)
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)
score = scorer.score_pivot(a, b, a_lang="en", b_lang="en", pivot_lang="en", normalize=True, both_directions=True)
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Slow")
def test_different_nmt_models(self):
Expand All @@ -59,18 +59,18 @@ def test_batch_size(self):
a = "This is a sentence."
b = "This is another sentence."
score = scorer.score_cross_likelihood(a, b, translate_kwargs={"batch_size": 16}, score_kwargs={"batch_size": 16})
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)
score = scorer.score_direct(a, b, a_lang="en", b_lang="en", score_kwargs={"batch_size": 16})
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)

def test_caching(self):
scorer = NMTScorer()
a = "This is a sentence."
b = "This is another sentence."
score = scorer.score_cross_likelihood(a, b, translate_kwargs={"use_cache": True}, score_kwargs={"use_cache": True})
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)
score = scorer.score_direct(a, b, a_lang="en", b_lang="en", score_kwargs={"use_cache": True})
self.assertAlmostEqual(0.5025776988808766, score, places=4)
self.assertAlmostEqual(0.4677300455046415, score, places=4)

@mock.patch('sys.stdout', new_callable=io.StringIO)
def test_version_signature(self, mock_stdout):
Expand All @@ -85,4 +85,4 @@ def test_nmt_models(self):
translations = model.translate("de", ["This is a test."], src_lang="en")
self.assertEqual(["Das ist ein Test."], translations)
scores = model.score("de", ["This is a test."], ["Das ist ein Test."], src_lang="en")
self.assertAlmostEqual(0.7708902359008789, scores[0], places=4)
self.assertAlmostEqual(0.8293135166168213, scores[0], places=4)

0 comments on commit 1a5b006

Please sign in to comment.