From 9b6c4a519c33e9eba5f999bc00053f0f44333a1e Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Mon, 25 Dec 2023 20:22:56 +0100 Subject: [PATCH 1/5] Revise MetricRunner: return MetricOutput dict instead of aggregated scores --- src/mbr/__init__.py | 2 +- src/mbr/generation/utils.py | 18 ++++++------ src/mbr/metrics/base.py | 31 ++++++++++++++++----- src/mbr/metrics/comet.py | 1 - tests/test_generate.py | 18 +++++++----- tests/test_metrics.py | 55 +++++++++++++++++++++++-------------- 6 files changed, 79 insertions(+), 46 deletions(-) diff --git a/src/mbr/__init__.py b/src/mbr/__init__.py index 88ba904..834addc 100644 --- a/src/mbr/__init__.py +++ b/src/mbr/__init__.py @@ -1,6 +1,6 @@ from mbr.generation.configuration_utils import MBRGenerationConfig from mbr.generation.utils import MBROutput, MBRGenerationMixin -from mbr.metrics.base import MetricRunner +from mbr.metrics.base import MetricOutput, MetricRunner from mbr.modeling import MBR diff --git a/src/mbr/generation/utils.py b/src/mbr/generation/utils.py index a9686a8..dd90068 100644 --- a/src/mbr/generation/utils.py +++ b/src/mbr/generation/utils.py @@ -15,7 +15,7 @@ from transformers.utils import logging, ModelOutput from mbr.generation.configuration_utils import MBRGenerationConfig -from mbr.metrics.base import MetricRunner +from mbr.metrics.base import MetricRunner, MetricOutput if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer @@ -42,16 +42,16 @@ class MBROutput(ModelOutput): The indices (in `all_samples`) of the selected sequences for each batch item. references (`tuple(ModelOutput)`), *optional*, returned when `output_all_samples=True` is passed or when `config.output_all_samples=True`): - metric_scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`), *optional*, returned when - `output_metric_scores=True` is passed or when `config.output_metric_scores=True`): - The metric score for each sample. + metric_scores (`MetricOutput`), *optional*, returned when `output_metric_scores=True` is passed or when + `config.output_metric_scores=True`): + The output of the metric. """ sequences: torch.LongTensor = None all_samples: Optional[Tuple[ModelOutput]] = None selected_samples_indices: Optional[torch.LongTensor] = None references: Optional[Tuple[ModelOutput]] = None - metric_scores: Optional[torch.FloatTensor] = None + metric_scores: Optional[MetricOutput] = None class MBRGenerationMixin(GenerationMixin): @@ -483,11 +483,11 @@ def generate( else: reference_ids = references - metric_scores = metric_runner(input_ids, sample_ids, reference_ids) + metric_output = metric_runner(input_ids, sample_ids, reference_ids) if not mbr_config.lower_is_better: - top_metric_scores, top_metric_indices = metric_scores.max(dim=-1) + top_metric_scores, top_metric_indices = metric_output.scores.max(dim=-1) else: - top_metric_scores, top_metric_indices = metric_scores.min(dim=-1) + top_metric_scores, top_metric_indices = metric_output.scores.min(dim=-1) # Copy top samples into a tensor of shape (batch_size, max_length) max_length = max(sample.shape[1] for sample in sample_ids) @@ -496,7 +496,7 @@ def generate( all_samples=(tuple(samples) if mbr_config.output_all_samples else None), selected_samples_indices=(top_metric_indices if mbr_config.output_all_samples else None), references=(tuple(references) if mbr_config.output_all_samples else None), - metric_scores=(metric_scores if mbr_config.output_metric_scores else None), + metric_scores=(metric_output if mbr_config.output_metric_scores else None), ) for batch_idx, sample_idx in enumerate(top_metric_indices): output.sequences[batch_idx][:sample_ids[sample_idx].shape[1]] = sample_ids[sample_idx][batch_idx] diff --git a/src/mbr/metrics/base.py b/src/mbr/metrics/base.py index 0e92358..c6cc6fb 100644 --- a/src/mbr/metrics/base.py +++ b/src/mbr/metrics/base.py @@ -1,17 +1,32 @@ import functools -from typing import Tuple, Union, List +from dataclasses import dataclass +from typing import Tuple, Union, List, Optional import evaluate import torch from datasets import Metric from evaluate import EvaluationModule from transformers import PreTrainedTokenizerBase +from transformers.utils import ModelOutput from mbr import MBRGenerationConfig MetricType = Union[Metric, EvaluationModule] +@dataclass +class MetricOutput(ModelOutput): + """ + Args: + scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`): + The metric scores for each sample (aggregated over all references). + scores_per_reference (`torch.FloatTensor` of shape `(batch_size, num_samples, num_references)`): + The pairwise metric scores for each sample and reference. `None` if the metric is computed corpus-level. + """ + scores: torch.FloatTensor + scores_per_reference: Optional[torch.FloatTensor] = None + + class MetricRunner: """ Applies the metric to samples and references (and optionally inputs) and calculates a metric score for each sample. @@ -46,7 +61,7 @@ def __call__(self, input_ids: torch.LongTensor, sample_ids: Tuple[torch.LongTensor], reference_ids: Tuple[torch.LongTensor], - ) -> torch.FloatTensor: + ) -> MetricOutput: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -59,8 +74,7 @@ def __call__(self, the reference sequences. Returns: - `torch.FloatTensor` of shape `(batch_size, num_samples)`: - The metric scores for each sample (aggregated over all references). + `MetricOutput` containing the metric scores. """ # Detokenize @@ -83,8 +97,12 @@ def __call__(self, raise ValueError("Number of references must match `mbr_config.num_references`") # Compute metric - metric_scores = self._compute_str_metric(str_samples, str_references, str_inputs) - return metric_scores + scores_per_reference = self._compute_str_metric(str_samples, str_references, str_inputs) + + return MetricOutput( + scores=scores_per_reference.mean(dim=-1), + scores_per_reference=scores_per_reference, + ) def _compute_str_metric(self, samples: List[List[str]], @@ -112,7 +130,6 @@ def _compute_str_metric(self, **self.mbr_config.metric_kwargs, ) metric_scores[i, j, k] = score - metric_scores = metric_scores.mean(dim=-1) # average over references return metric_scores @functools.lru_cache(maxsize=(1024 ** 2)) diff --git a/src/mbr/metrics/comet.py b/src/mbr/metrics/comet.py index eaab3fa..aa65347 100644 --- a/src/mbr/metrics/comet.py +++ b/src/mbr/metrics/comet.py @@ -95,5 +95,4 @@ def _compute_str_metric(self, for k in range(self.mbr_config.num_references): metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])] - metric_scores = metric_scores.mean(dim=-1) # average over references return metric_scores diff --git a/tests/test_generate.py b/tests/test_generate.py index 0a311ed..7ab605b 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput -from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner +from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner, MetricOutput class DecoderOnlyTestCase(TestCase): @@ -90,9 +90,11 @@ def test_model_output_extended(self): self.assertEqual(5, len(output.references)) self.assertIsInstance(output.references[0], SampleDecoderOnlyOutput) self.assertIsNotNone(output.metric_scores) - self.assertTrue(torch.is_floating_point(output.metric_scores)) - self.assertEqual(1, output.metric_scores.shape[0]) - self.assertEqual(5, output.metric_scores.shape[1]) + self.assertIsInstance(output.metric_scores, MetricOutput) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores)) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference)) + self.assertEqual([1, 5], list(output.metric_scores.scores.shape)) + self.assertEqual([1, 5, 5], list(output.metric_scores.scores_per_reference.shape)) # Test the model output for a selected sample sample = output.all_samples[output.selected_samples_indices[0]] @@ -267,9 +269,11 @@ def test_model_output_extended(self): self.assertEqual(5, len(output.references)) self.assertIsInstance(output.references[0], SampleEncoderDecoderOutput) self.assertIsNotNone(output.metric_scores) - self.assertTrue(torch.is_floating_point(output.metric_scores)) - self.assertEqual(2, output.metric_scores.shape[0]) - self.assertEqual(5, output.metric_scores.shape[1]) + self.assertIsInstance(output.metric_scores, MetricOutput) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores)) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference)) + self.assertEqual([2, 5], list(output.metric_scores.scores.shape)) + self.assertEqual([2, 5, 5], list(output.metric_scores.scores_per_reference.shape)) # Test the model output for a selected sample (batch index 0) sample = output.all_samples[output.selected_samples_indices[0]] diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 93f9e1c..d0e4e47 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -75,15 +75,19 @@ def test_metric_config_name(self): self.assertEqual(metric.scorer.encoder.__class__.__name__, "MiniLMEncoder") def test_compute_metric__chrf(self): - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__comet(self): @@ -91,15 +95,19 @@ def test_compute_metric__comet(self): self.mbr_config.metric_output_field = "mean_score" self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(self.metric_runner.metric.name, "comet") - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__bleurt(self): @@ -107,15 +115,19 @@ def test_compute_metric__bleurt(self): self.mbr_config.metric_output_field = "scores" self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(self.metric_runner.metric.name, "bleurt") - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_comet_metric_runner(self): @@ -125,6 +137,7 @@ def test_comet_metric_runner(self): base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(base_metric_runner.metric.name, "comet") comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer) + # Output should be the same as the base MetricRunner base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) torch.testing.assert_close(base_metric_scores, metric_scores) From 50d886de4a23073510df12ce3005527d10e0fca4 Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Mon, 25 Dec 2023 20:22:56 +0100 Subject: [PATCH 2/5] Revise MetricRunner: return MetricOutput dict instead of aggregated scores --- src/mbr/__init__.py | 2 +- src/mbr/generation/utils.py | 18 ++++++------ src/mbr/metrics/base.py | 31 ++++++++++++++++----- src/mbr/metrics/comet.py | 1 - tests/test_generate.py | 18 +++++++----- tests/test_metrics.py | 55 +++++++++++++++++++++++-------------- 6 files changed, 79 insertions(+), 46 deletions(-) diff --git a/src/mbr/__init__.py b/src/mbr/__init__.py index 88ba904..834addc 100644 --- a/src/mbr/__init__.py +++ b/src/mbr/__init__.py @@ -1,6 +1,6 @@ from mbr.generation.configuration_utils import MBRGenerationConfig from mbr.generation.utils import MBROutput, MBRGenerationMixin -from mbr.metrics.base import MetricRunner +from mbr.metrics.base import MetricOutput, MetricRunner from mbr.modeling import MBR diff --git a/src/mbr/generation/utils.py b/src/mbr/generation/utils.py index a9686a8..dd90068 100644 --- a/src/mbr/generation/utils.py +++ b/src/mbr/generation/utils.py @@ -15,7 +15,7 @@ from transformers.utils import logging, ModelOutput from mbr.generation.configuration_utils import MBRGenerationConfig -from mbr.metrics.base import MetricRunner +from mbr.metrics.base import MetricRunner, MetricOutput if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer @@ -42,16 +42,16 @@ class MBROutput(ModelOutput): The indices (in `all_samples`) of the selected sequences for each batch item. references (`tuple(ModelOutput)`), *optional*, returned when `output_all_samples=True` is passed or when `config.output_all_samples=True`): - metric_scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`), *optional*, returned when - `output_metric_scores=True` is passed or when `config.output_metric_scores=True`): - The metric score for each sample. + metric_scores (`MetricOutput`), *optional*, returned when `output_metric_scores=True` is passed or when + `config.output_metric_scores=True`): + The output of the metric. """ sequences: torch.LongTensor = None all_samples: Optional[Tuple[ModelOutput]] = None selected_samples_indices: Optional[torch.LongTensor] = None references: Optional[Tuple[ModelOutput]] = None - metric_scores: Optional[torch.FloatTensor] = None + metric_scores: Optional[MetricOutput] = None class MBRGenerationMixin(GenerationMixin): @@ -483,11 +483,11 @@ def generate( else: reference_ids = references - metric_scores = metric_runner(input_ids, sample_ids, reference_ids) + metric_output = metric_runner(input_ids, sample_ids, reference_ids) if not mbr_config.lower_is_better: - top_metric_scores, top_metric_indices = metric_scores.max(dim=-1) + top_metric_scores, top_metric_indices = metric_output.scores.max(dim=-1) else: - top_metric_scores, top_metric_indices = metric_scores.min(dim=-1) + top_metric_scores, top_metric_indices = metric_output.scores.min(dim=-1) # Copy top samples into a tensor of shape (batch_size, max_length) max_length = max(sample.shape[1] for sample in sample_ids) @@ -496,7 +496,7 @@ def generate( all_samples=(tuple(samples) if mbr_config.output_all_samples else None), selected_samples_indices=(top_metric_indices if mbr_config.output_all_samples else None), references=(tuple(references) if mbr_config.output_all_samples else None), - metric_scores=(metric_scores if mbr_config.output_metric_scores else None), + metric_scores=(metric_output if mbr_config.output_metric_scores else None), ) for batch_idx, sample_idx in enumerate(top_metric_indices): output.sequences[batch_idx][:sample_ids[sample_idx].shape[1]] = sample_ids[sample_idx][batch_idx] diff --git a/src/mbr/metrics/base.py b/src/mbr/metrics/base.py index 0e92358..c6cc6fb 100644 --- a/src/mbr/metrics/base.py +++ b/src/mbr/metrics/base.py @@ -1,17 +1,32 @@ import functools -from typing import Tuple, Union, List +from dataclasses import dataclass +from typing import Tuple, Union, List, Optional import evaluate import torch from datasets import Metric from evaluate import EvaluationModule from transformers import PreTrainedTokenizerBase +from transformers.utils import ModelOutput from mbr import MBRGenerationConfig MetricType = Union[Metric, EvaluationModule] +@dataclass +class MetricOutput(ModelOutput): + """ + Args: + scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`): + The metric scores for each sample (aggregated over all references). + scores_per_reference (`torch.FloatTensor` of shape `(batch_size, num_samples, num_references)`): + The pairwise metric scores for each sample and reference. `None` if the metric is computed corpus-level. + """ + scores: torch.FloatTensor + scores_per_reference: Optional[torch.FloatTensor] = None + + class MetricRunner: """ Applies the metric to samples and references (and optionally inputs) and calculates a metric score for each sample. @@ -46,7 +61,7 @@ def __call__(self, input_ids: torch.LongTensor, sample_ids: Tuple[torch.LongTensor], reference_ids: Tuple[torch.LongTensor], - ) -> torch.FloatTensor: + ) -> MetricOutput: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -59,8 +74,7 @@ def __call__(self, the reference sequences. Returns: - `torch.FloatTensor` of shape `(batch_size, num_samples)`: - The metric scores for each sample (aggregated over all references). + `MetricOutput` containing the metric scores. """ # Detokenize @@ -83,8 +97,12 @@ def __call__(self, raise ValueError("Number of references must match `mbr_config.num_references`") # Compute metric - metric_scores = self._compute_str_metric(str_samples, str_references, str_inputs) - return metric_scores + scores_per_reference = self._compute_str_metric(str_samples, str_references, str_inputs) + + return MetricOutput( + scores=scores_per_reference.mean(dim=-1), + scores_per_reference=scores_per_reference, + ) def _compute_str_metric(self, samples: List[List[str]], @@ -112,7 +130,6 @@ def _compute_str_metric(self, **self.mbr_config.metric_kwargs, ) metric_scores[i, j, k] = score - metric_scores = metric_scores.mean(dim=-1) # average over references return metric_scores @functools.lru_cache(maxsize=(1024 ** 2)) diff --git a/src/mbr/metrics/comet.py b/src/mbr/metrics/comet.py index eaab3fa..aa65347 100644 --- a/src/mbr/metrics/comet.py +++ b/src/mbr/metrics/comet.py @@ -95,5 +95,4 @@ def _compute_str_metric(self, for k in range(self.mbr_config.num_references): metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])] - metric_scores = metric_scores.mean(dim=-1) # average over references return metric_scores diff --git a/tests/test_generate.py b/tests/test_generate.py index 0a311ed..7ab605b 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput -from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner +from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner, MetricOutput class DecoderOnlyTestCase(TestCase): @@ -90,9 +90,11 @@ def test_model_output_extended(self): self.assertEqual(5, len(output.references)) self.assertIsInstance(output.references[0], SampleDecoderOnlyOutput) self.assertIsNotNone(output.metric_scores) - self.assertTrue(torch.is_floating_point(output.metric_scores)) - self.assertEqual(1, output.metric_scores.shape[0]) - self.assertEqual(5, output.metric_scores.shape[1]) + self.assertIsInstance(output.metric_scores, MetricOutput) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores)) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference)) + self.assertEqual([1, 5], list(output.metric_scores.scores.shape)) + self.assertEqual([1, 5, 5], list(output.metric_scores.scores_per_reference.shape)) # Test the model output for a selected sample sample = output.all_samples[output.selected_samples_indices[0]] @@ -267,9 +269,11 @@ def test_model_output_extended(self): self.assertEqual(5, len(output.references)) self.assertIsInstance(output.references[0], SampleEncoderDecoderOutput) self.assertIsNotNone(output.metric_scores) - self.assertTrue(torch.is_floating_point(output.metric_scores)) - self.assertEqual(2, output.metric_scores.shape[0]) - self.assertEqual(5, output.metric_scores.shape[1]) + self.assertIsInstance(output.metric_scores, MetricOutput) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores)) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference)) + self.assertEqual([2, 5], list(output.metric_scores.scores.shape)) + self.assertEqual([2, 5, 5], list(output.metric_scores.scores_per_reference.shape)) # Test the model output for a selected sample (batch index 0) sample = output.all_samples[output.selected_samples_indices[0]] diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 93f9e1c..d0e4e47 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -75,15 +75,19 @@ def test_metric_config_name(self): self.assertEqual(metric.scorer.encoder.__class__.__name__, "MiniLMEncoder") def test_compute_metric__chrf(self): - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__comet(self): @@ -91,15 +95,19 @@ def test_compute_metric__comet(self): self.mbr_config.metric_output_field = "mean_score" self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(self.metric_runner.metric.name, "comet") - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__bleurt(self): @@ -107,15 +115,19 @@ def test_compute_metric__bleurt(self): self.mbr_config.metric_output_field = "scores" self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(self.metric_runner.metric.name, "bleurt") - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_comet_metric_runner(self): @@ -125,6 +137,7 @@ def test_comet_metric_runner(self): base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(base_metric_runner.metric.name, "comet") comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer) + # Output should be the same as the base MetricRunner base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) torch.testing.assert_close(base_metric_scores, metric_scores) From 0a31da278467993f54e75306b7d926712c6de998 Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Tue, 26 Dec 2023 17:49:31 +0100 Subject: [PATCH 3/5] Add test case for output consistency --- tests/test_output_consistency.py | 51 ++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/test_output_consistency.py diff --git a/tests/test_output_consistency.py b/tests/test_output_consistency.py new file mode 100644 index 0000000..27e2f67 --- /dev/null +++ b/tests/test_output_consistency.py @@ -0,0 +1,51 @@ +from unittest import TestCase + +import torch.testing +from transformers import GPT2LMHeadModel, AutoTokenizer, set_seed + +from mbr import MBR, MBRGenerationConfig, MBROutput + + +class OutputConsistencyTestCase(TestCase): + """ + Test that the output of MBR remains the same across different versions of this library. + """ + + def setUp(self): + self.model = MBR(GPT2LMHeadModel).from_pretrained("distilgpt2").eval() + self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + + def test_output(self): + set_seed(42) + mbr_config = MBRGenerationConfig( + num_samples=5, + return_dict_in_generate=True, + output_hidden_states=True, + output_attentions=True, + output_all_samples=True, + output_reference_sequences=True, + output_metric_scores=True, + ) + input_sentences = [ + "Hello, my name is", + "This is another sentence because", + ] + encoding = self.tokenizer(input_sentences, return_tensors="pt") + output: MBROutput = self.model.generate( + **encoding, + mbr_config=mbr_config, + tokenizer=self.tokenizer, + do_sample=True, + progress_bar=True, + ) + torch.testing.assert_close(output.sequences[0], + torch.tensor([15496, 11, 616, 1438, 318, 3977, 11, 290, 314, 716, + 262, 1772, 286, 257, 11648, 1444, 366, 464, 7443, 286])) + torch.testing.assert_close(output.selected_samples_indices, torch.tensor([1, 1])) + torch.testing.assert_close(output.references[0].sequences, torch.tensor( + [[15496, 11, 616, 1438, 318, 449, 13, 41, 13, 53, 13, 447, 103, 290, 356, 423, 257, 1049, 6180, 13], + [1212, 318, 1194, 6827, 780, 612, 373, 2147, 2642, 351, 340, 13, 447, 237, 198, 1532, 484, 547, 284, + 423]])) + torch.testing.assert_close(output.metric_scores, torch.tensor( + [[43.1201, 46.1530, 43.5142, 43.8980, 44.0345], + [57.1227, 57.2903, 54.9877, 57.1268, 56.8152]])) From c60c08d395767e34e8dab010069de3b9a479fb23 Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Tue, 26 Dec 2023 17:53:57 +0100 Subject: [PATCH 4/5] Update output consistency test case --- tests/test_output_consistency.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_output_consistency.py b/tests/test_output_consistency.py index 27e2f67..84ea675 100644 --- a/tests/test_output_consistency.py +++ b/tests/test_output_consistency.py @@ -46,6 +46,12 @@ def test_output(self): [[15496, 11, 616, 1438, 318, 449, 13, 41, 13, 53, 13, 447, 103, 290, 356, 423, 257, 1049, 6180, 13], [1212, 318, 1194, 6827, 780, 612, 373, 2147, 2642, 351, 340, 13, 447, 237, 198, 1532, 484, 547, 284, 423]])) - torch.testing.assert_close(output.metric_scores, torch.tensor( - [[43.1201, 46.1530, 43.5142, 43.8980, 44.0345], - [57.1227, 57.2903, 54.9877, 57.1268, 56.8152]])) + torch.testing.assert_close(output.metric_scores.scores, + torch.tensor([[43.1201, 46.1530, 43.5142, 43.8980, 44.0345], + [57.1227, 57.2903, 54.9877, 57.1268, 56.8152]]), atol=1e-4, rtol=1e-4) + torch.testing.assert_close(output.metric_scores.scores_per_reference[0], + torch.tensor([[100.0000, 23.5668, 28.9961, 29.3708, 33.6670], + [31.3434, 100.0000, 33.3013, 31.0525, 35.0679], + [31.0233, 26.7555, 100.0000, 30.4199, 29.3724], + [32.2136, 25.5630, 31.1885, 100.0000, 30.5251], + [35.0868, 27.4656, 28.6148, 29.0051, 100.0000]]), atol=1e-4, rtol=1e-4) From 4e22a87cdeea472bda9cf11fb22918d6cc633e90 Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Wed, 27 Dec 2023 10:59:45 +0100 Subject: [PATCH 5/5] Delete output consistency test file again --- tests/test_output_consistency.py | 57 -------------------------------- 1 file changed, 57 deletions(-) delete mode 100644 tests/test_output_consistency.py diff --git a/tests/test_output_consistency.py b/tests/test_output_consistency.py deleted file mode 100644 index 84ea675..0000000 --- a/tests/test_output_consistency.py +++ /dev/null @@ -1,57 +0,0 @@ -from unittest import TestCase - -import torch.testing -from transformers import GPT2LMHeadModel, AutoTokenizer, set_seed - -from mbr import MBR, MBRGenerationConfig, MBROutput - - -class OutputConsistencyTestCase(TestCase): - """ - Test that the output of MBR remains the same across different versions of this library. - """ - - def setUp(self): - self.model = MBR(GPT2LMHeadModel).from_pretrained("distilgpt2").eval() - self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - - def test_output(self): - set_seed(42) - mbr_config = MBRGenerationConfig( - num_samples=5, - return_dict_in_generate=True, - output_hidden_states=True, - output_attentions=True, - output_all_samples=True, - output_reference_sequences=True, - output_metric_scores=True, - ) - input_sentences = [ - "Hello, my name is", - "This is another sentence because", - ] - encoding = self.tokenizer(input_sentences, return_tensors="pt") - output: MBROutput = self.model.generate( - **encoding, - mbr_config=mbr_config, - tokenizer=self.tokenizer, - do_sample=True, - progress_bar=True, - ) - torch.testing.assert_close(output.sequences[0], - torch.tensor([15496, 11, 616, 1438, 318, 3977, 11, 290, 314, 716, - 262, 1772, 286, 257, 11648, 1444, 366, 464, 7443, 286])) - torch.testing.assert_close(output.selected_samples_indices, torch.tensor([1, 1])) - torch.testing.assert_close(output.references[0].sequences, torch.tensor( - [[15496, 11, 616, 1438, 318, 449, 13, 41, 13, 53, 13, 447, 103, 290, 356, 423, 257, 1049, 6180, 13], - [1212, 318, 1194, 6827, 780, 612, 373, 2147, 2642, 351, 340, 13, 447, 237, 198, 1532, 484, 547, 284, - 423]])) - torch.testing.assert_close(output.metric_scores.scores, - torch.tensor([[43.1201, 46.1530, 43.5142, 43.8980, 44.0345], - [57.1227, 57.2903, 54.9877, 57.1268, 56.8152]]), atol=1e-4, rtol=1e-4) - torch.testing.assert_close(output.metric_scores.scores_per_reference[0], - torch.tensor([[100.0000, 23.5668, 28.9961, 29.3708, 33.6670], - [31.3434, 100.0000, 33.3013, 31.0525, 35.0679], - [31.0233, 26.7555, 100.0000, 30.4199, 29.3724], - [32.2136, 25.5630, 31.1885, 100.0000, 30.5251], - [35.0868, 27.4656, 28.6148, 29.0051, 100.0000]]), atol=1e-4, rtol=1e-4)