diff --git a/requirements-dev.txt b/requirements-dev.txt index aca6bbc..333ddaf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ sacrebleu==2.4.0 -unbabel-comet==2.1.1 +unbabel-comet==2.2.1 git+https://github.com/google-research/bleurt.git sentencepiece==0.1.99 # M2M100 model diff --git a/requirements-test.txt b/requirements-test.txt index 16cffc9..1181efa 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,2 @@ sacrebleu==2.4.0 - +unbabel-comet==2.2.1 diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5a83eef..f4b3eac 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -43,7 +43,6 @@ def test_is_source_based__chrf(self): chrf = evaluate.load("chrf") self.assertFalse(metric_is_source_based(chrf)) - @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_is_source_based__comet(self): comet = evaluate.load("comet", "eamt22-cometinho-da") self.assertTrue(metric_is_source_based(comet)) @@ -63,7 +62,6 @@ def test_load_metric(self): self.assertIsInstance(metric, evaluate.Metric) self.assertEqual(metric.name, "chr_f") - @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_metric_config_name(self): self.mbr_config.metric = "comet" self.mbr_config.metric_config_name = "eamt22-cometinho-da" @@ -89,7 +87,6 @@ def test_compute_metric__chrf(self): self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) - @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__comet(self): self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") self.mbr_config.metric.scorer.eval() @@ -130,7 +127,6 @@ def test_compute_metric__bleurt(self): self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) - @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_comet_metric_runner(self): from mbr.metrics.comet import CometMetricRunner self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") @@ -146,7 +142,6 @@ def test_comet_metric_runner(self): metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) torch.testing.assert_close(base_metric_scores, metric_scores) - @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_comet_metric_runner__cache(self): """Output should be identical irrespective of cache size""" from mbr.metrics.comet import CometMetricRunner @@ -161,7 +156,6 @@ def test_comet_metric_runner__cache(self): metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) torch.testing.assert_close(base_metric_scores, metric_scores) - @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_comet_metric_runner__aggregate(self): from mbr.metrics.comet import AggregateCometMetricRunner self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")