diff --git a/src/mbr/__init__.py b/src/mbr/__init__.py index 88ba904..834addc 100644 --- a/src/mbr/__init__.py +++ b/src/mbr/__init__.py @@ -1,6 +1,6 @@ from mbr.generation.configuration_utils import MBRGenerationConfig from mbr.generation.utils import MBROutput, MBRGenerationMixin -from mbr.metrics.base import MetricRunner +from mbr.metrics.base import MetricOutput, MetricRunner from mbr.modeling import MBR diff --git a/src/mbr/generation/configuration_utils.py b/src/mbr/generation/configuration_utils.py index 51d3bb6..f238532 100644 --- a/src/mbr/generation/configuration_utils.py +++ b/src/mbr/generation/configuration_utils.py @@ -24,7 +24,7 @@ class MBRGenerationConfig(GenerationConfig): num_samples (`int`, *optional*, defaults to 10): Number of samples generated. 1 means no MBR decoding. num_references (`int`, *optional*, defaults to `num_samples`): - Number of pseudo-references used for MBR decoding. Needs to be smaller or equal to `num_samples`. + Number of pseudo-references used for MBR decoding. metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'chrf'): Metric used for MBR decoding. metric_config_name (`str`, *optional*, defaults to None): @@ -100,12 +100,6 @@ def __init__(self, **kwargs): # Validate the values of the attributes self.validate(is_init=True) - def validate(self, is_init=False): - if self.num_references > self.num_samples: - raise ValueError( - f"`num_references` ({self.num_references}) must be <= `num_samples` ({self.num_samples})." - ) - def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/mbr/generation/utils.py b/src/mbr/generation/utils.py index a9686a8..efe66fe 100644 --- a/src/mbr/generation/utils.py +++ b/src/mbr/generation/utils.py @@ -15,7 +15,7 @@ from transformers.utils import logging, ModelOutput from mbr.generation.configuration_utils import MBRGenerationConfig -from mbr.metrics.base import MetricRunner +from mbr.metrics.base import MetricRunner, MetricOutput if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer @@ -42,16 +42,16 @@ class MBROutput(ModelOutput): The indices (in `all_samples`) of the selected sequences for each batch item. references (`tuple(ModelOutput)`), *optional*, returned when `output_all_samples=True` is passed or when `config.output_all_samples=True`): - metric_scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`), *optional*, returned when - `output_metric_scores=True` is passed or when `config.output_metric_scores=True`): - The metric score for each sample. + metric_scores (`MetricOutput`), *optional*, returned when `output_metric_scores=True` is passed or when + `config.output_metric_scores=True`): + The output of the metric. """ sequences: torch.LongTensor = None all_samples: Optional[Tuple[ModelOutput]] = None selected_samples_indices: Optional[torch.LongTensor] = None references: Optional[Tuple[ModelOutput]] = None - metric_scores: Optional[torch.FloatTensor] = None + metric_scores: Optional[MetricOutput] = None class MBRGenerationMixin(GenerationMixin): @@ -443,6 +443,11 @@ def generate( # 14. references if references_config is None: # Use samples as references + if mbr_config.num_references > mbr_config.num_samples: + raise ValueError( + f"`mbr_config.num_references` ({mbr_config.num_references}) must be smaller than or equal to " + f"`mbr_config.num_samples` ({mbr_config.num_samples}) if samples are re-used as references." + ) references = samples[:mbr_config.num_references] else: # Generate references @@ -483,11 +488,11 @@ def generate( else: reference_ids = references - metric_scores = metric_runner(input_ids, sample_ids, reference_ids) + metric_output = metric_runner(input_ids, sample_ids, reference_ids) if not mbr_config.lower_is_better: - top_metric_scores, top_metric_indices = metric_scores.max(dim=-1) + top_metric_scores, top_metric_indices = metric_output.scores.max(dim=-1) else: - top_metric_scores, top_metric_indices = metric_scores.min(dim=-1) + top_metric_scores, top_metric_indices = metric_output.scores.min(dim=-1) # Copy top samples into a tensor of shape (batch_size, max_length) max_length = max(sample.shape[1] for sample in sample_ids) @@ -496,7 +501,7 @@ def generate( all_samples=(tuple(samples) if mbr_config.output_all_samples else None), selected_samples_indices=(top_metric_indices if mbr_config.output_all_samples else None), references=(tuple(references) if mbr_config.output_all_samples else None), - metric_scores=(metric_scores if mbr_config.output_metric_scores else None), + metric_scores=(metric_output if mbr_config.output_metric_scores else None), ) for batch_idx, sample_idx in enumerate(top_metric_indices): output.sequences[batch_idx][:sample_ids[sample_idx].shape[1]] = sample_ids[sample_idx][batch_idx] diff --git a/src/mbr/metrics/base.py b/src/mbr/metrics/base.py index 0e92358..c6cc6fb 100644 --- a/src/mbr/metrics/base.py +++ b/src/mbr/metrics/base.py @@ -1,17 +1,32 @@ import functools -from typing import Tuple, Union, List +from dataclasses import dataclass +from typing import Tuple, Union, List, Optional import evaluate import torch from datasets import Metric from evaluate import EvaluationModule from transformers import PreTrainedTokenizerBase +from transformers.utils import ModelOutput from mbr import MBRGenerationConfig MetricType = Union[Metric, EvaluationModule] +@dataclass +class MetricOutput(ModelOutput): + """ + Args: + scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`): + The metric scores for each sample (aggregated over all references). + scores_per_reference (`torch.FloatTensor` of shape `(batch_size, num_samples, num_references)`): + The pairwise metric scores for each sample and reference. `None` if the metric is computed corpus-level. + """ + scores: torch.FloatTensor + scores_per_reference: Optional[torch.FloatTensor] = None + + class MetricRunner: """ Applies the metric to samples and references (and optionally inputs) and calculates a metric score for each sample. @@ -46,7 +61,7 @@ def __call__(self, input_ids: torch.LongTensor, sample_ids: Tuple[torch.LongTensor], reference_ids: Tuple[torch.LongTensor], - ) -> torch.FloatTensor: + ) -> MetricOutput: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -59,8 +74,7 @@ def __call__(self, the reference sequences. Returns: - `torch.FloatTensor` of shape `(batch_size, num_samples)`: - The metric scores for each sample (aggregated over all references). + `MetricOutput` containing the metric scores. """ # Detokenize @@ -83,8 +97,12 @@ def __call__(self, raise ValueError("Number of references must match `mbr_config.num_references`") # Compute metric - metric_scores = self._compute_str_metric(str_samples, str_references, str_inputs) - return metric_scores + scores_per_reference = self._compute_str_metric(str_samples, str_references, str_inputs) + + return MetricOutput( + scores=scores_per_reference.mean(dim=-1), + scores_per_reference=scores_per_reference, + ) def _compute_str_metric(self, samples: List[List[str]], @@ -112,7 +130,6 @@ def _compute_str_metric(self, **self.mbr_config.metric_kwargs, ) metric_scores[i, j, k] = score - metric_scores = metric_scores.mean(dim=-1) # average over references return metric_scores @functools.lru_cache(maxsize=(1024 ** 2)) diff --git a/src/mbr/metrics/comet.py b/src/mbr/metrics/comet.py index eaab3fa..aa65347 100644 --- a/src/mbr/metrics/comet.py +++ b/src/mbr/metrics/comet.py @@ -95,5 +95,4 @@ def _compute_str_metric(self, for k in range(self.mbr_config.num_references): metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])] - metric_scores = metric_scores.mean(dim=-1) # average over references return metric_scores diff --git a/tests/test_generate.py b/tests/test_generate.py index 0a311ed..7ab605b 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput -from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner +from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner, MetricOutput class DecoderOnlyTestCase(TestCase): @@ -90,9 +90,11 @@ def test_model_output_extended(self): self.assertEqual(5, len(output.references)) self.assertIsInstance(output.references[0], SampleDecoderOnlyOutput) self.assertIsNotNone(output.metric_scores) - self.assertTrue(torch.is_floating_point(output.metric_scores)) - self.assertEqual(1, output.metric_scores.shape[0]) - self.assertEqual(5, output.metric_scores.shape[1]) + self.assertIsInstance(output.metric_scores, MetricOutput) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores)) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference)) + self.assertEqual([1, 5], list(output.metric_scores.scores.shape)) + self.assertEqual([1, 5, 5], list(output.metric_scores.scores_per_reference.shape)) # Test the model output for a selected sample sample = output.all_samples[output.selected_samples_indices[0]] @@ -267,9 +269,11 @@ def test_model_output_extended(self): self.assertEqual(5, len(output.references)) self.assertIsInstance(output.references[0], SampleEncoderDecoderOutput) self.assertIsNotNone(output.metric_scores) - self.assertTrue(torch.is_floating_point(output.metric_scores)) - self.assertEqual(2, output.metric_scores.shape[0]) - self.assertEqual(5, output.metric_scores.shape[1]) + self.assertIsInstance(output.metric_scores, MetricOutput) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores)) + self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference)) + self.assertEqual([2, 5], list(output.metric_scores.scores.shape)) + self.assertEqual([2, 5, 5], list(output.metric_scores.scores_per_reference.shape)) # Test the model output for a selected sample (batch index 0) sample = output.all_samples[output.selected_samples_indices[0]] diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 93f9e1c..d0e4e47 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -75,15 +75,19 @@ def test_metric_config_name(self): self.assertEqual(metric.scorer.encoder.__class__.__name__, "MiniLMEncoder") def test_compute_metric__chrf(self): - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__comet(self): @@ -91,15 +95,19 @@ def test_compute_metric__comet(self): self.mbr_config.metric_output_field = "mean_score" self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(self.metric_runner.metric.name, "comet") - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__bleurt(self): @@ -107,15 +115,19 @@ def test_compute_metric__bleurt(self): self.mbr_config.metric_output_field = "scores" self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(self.metric_runner.metric.name, "bleurt") - metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) - self.assertTrue(torch.is_floating_point(metric_scores)) + metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references # Duplicate samples should have the same scores - self.assertEqual(metric_scores[0, 0], metric_scores[0, 1]) + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) # The metric scores should rank as expected, given the test strings in self.samples and self.references - self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples - self.assertGreater(metric_scores[0, 0], metric_scores[0, 2]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 1]) - self.assertLess(metric_scores[1, 0], metric_scores[1, 2]) + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_comet_metric_runner(self): @@ -125,6 +137,7 @@ def test_comet_metric_runner(self): base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(base_metric_runner.metric.name, "comet") comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer) + # Output should be the same as the base MetricRunner base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) torch.testing.assert_close(base_metric_scores, metric_scores)