From 9b6c4a519c33e9eba5f999bc00053f0f44333a1e Mon Sep 17 00:00:00 2001
From: Jannis Vamvas <vamvas@cl.uzh.ch>
Date: Mon, 25 Dec 2023 20:22:56 +0100
Subject: [PATCH 1/2] Revise MetricRunner: return MetricOutput dict instead of
 aggregated scores

---
 src/mbr/__init__.py         |  2 +-
 src/mbr/generation/utils.py | 18 ++++++------
 src/mbr/metrics/base.py     | 31 ++++++++++++++++-----
 src/mbr/metrics/comet.py    |  1 -
 tests/test_generate.py      | 18 +++++++-----
 tests/test_metrics.py       | 55 +++++++++++++++++++++++--------------
 6 files changed, 79 insertions(+), 46 deletions(-)

diff --git a/src/mbr/__init__.py b/src/mbr/__init__.py
index 88ba904..834addc 100644
--- a/src/mbr/__init__.py
+++ b/src/mbr/__init__.py
@@ -1,6 +1,6 @@
 from mbr.generation.configuration_utils import MBRGenerationConfig
 from mbr.generation.utils import MBROutput, MBRGenerationMixin
-from mbr.metrics.base import MetricRunner
+from mbr.metrics.base import MetricOutput, MetricRunner
 from mbr.modeling import MBR
 
 
diff --git a/src/mbr/generation/utils.py b/src/mbr/generation/utils.py
index a9686a8..dd90068 100644
--- a/src/mbr/generation/utils.py
+++ b/src/mbr/generation/utils.py
@@ -15,7 +15,7 @@
 from transformers.utils import logging, ModelOutput
 
 from mbr.generation.configuration_utils import MBRGenerationConfig
-from mbr.metrics.base import MetricRunner
+from mbr.metrics.base import MetricRunner, MetricOutput
 
 if TYPE_CHECKING:
     from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -42,16 +42,16 @@ class MBROutput(ModelOutput):
             The indices (in `all_samples`) of the selected sequences for each batch item.
         references (`tuple(ModelOutput)`), *optional*, returned when `output_all_samples=True` is passed or when
         `config.output_all_samples=True`):
-        metric_scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`), *optional*, returned when
-        `output_metric_scores=True` is passed or when `config.output_metric_scores=True`):
-            The metric score for each sample.
+        metric_scores (`MetricOutput`), *optional*, returned when `output_metric_scores=True` is passed or when
+        `config.output_metric_scores=True`):
+            The output of the metric.
     """
 
     sequences: torch.LongTensor = None
     all_samples: Optional[Tuple[ModelOutput]] = None
     selected_samples_indices: Optional[torch.LongTensor] = None
     references: Optional[Tuple[ModelOutput]] = None
-    metric_scores: Optional[torch.FloatTensor] = None
+    metric_scores: Optional[MetricOutput] = None
 
 
 class MBRGenerationMixin(GenerationMixin):
@@ -483,11 +483,11 @@ def generate(
         else:
             reference_ids = references
 
-        metric_scores = metric_runner(input_ids, sample_ids, reference_ids)
+        metric_output = metric_runner(input_ids, sample_ids, reference_ids)
         if not mbr_config.lower_is_better:
-            top_metric_scores, top_metric_indices = metric_scores.max(dim=-1)
+            top_metric_scores, top_metric_indices = metric_output.scores.max(dim=-1)
         else:
-            top_metric_scores, top_metric_indices = metric_scores.min(dim=-1)
+            top_metric_scores, top_metric_indices = metric_output.scores.min(dim=-1)
 
         # Copy top samples into a tensor of shape (batch_size, max_length)
         max_length = max(sample.shape[1] for sample in sample_ids)
@@ -496,7 +496,7 @@ def generate(
             all_samples=(tuple(samples) if mbr_config.output_all_samples else None),
             selected_samples_indices=(top_metric_indices if mbr_config.output_all_samples else None),
             references=(tuple(references) if mbr_config.output_all_samples else None),
-            metric_scores=(metric_scores if mbr_config.output_metric_scores else None),
+            metric_scores=(metric_output if mbr_config.output_metric_scores else None),
         )
         for batch_idx, sample_idx in enumerate(top_metric_indices):
             output.sequences[batch_idx][:sample_ids[sample_idx].shape[1]] = sample_ids[sample_idx][batch_idx]
diff --git a/src/mbr/metrics/base.py b/src/mbr/metrics/base.py
index 0e92358..c6cc6fb 100644
--- a/src/mbr/metrics/base.py
+++ b/src/mbr/metrics/base.py
@@ -1,17 +1,32 @@
 import functools
-from typing import Tuple, Union, List
+from dataclasses import dataclass
+from typing import Tuple, Union, List, Optional
 
 import evaluate
 import torch
 from datasets import Metric
 from evaluate import EvaluationModule
 from transformers import PreTrainedTokenizerBase
+from transformers.utils import ModelOutput
 
 from mbr import MBRGenerationConfig
 
 MetricType = Union[Metric, EvaluationModule]
 
 
+@dataclass
+class MetricOutput(ModelOutput):
+    """
+    Args:
+        scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`):
+            The metric scores for each sample (aggregated over all references).
+        scores_per_reference (`torch.FloatTensor` of shape `(batch_size, num_samples, num_references)`):
+            The pairwise metric scores for each sample and reference. `None` if the metric is computed corpus-level.
+    """
+    scores: torch.FloatTensor
+    scores_per_reference: Optional[torch.FloatTensor] = None
+
+
 class MetricRunner:
     """
     Applies the metric to samples and references (and optionally inputs) and calculates a metric score for each sample.
@@ -46,7 +61,7 @@ def __call__(self,
                  input_ids: torch.LongTensor,
                  sample_ids: Tuple[torch.LongTensor],
                  reference_ids: Tuple[torch.LongTensor],
-                 ) -> torch.FloatTensor:
+                 ) -> MetricOutput:
         r"""
         Args:
             input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -59,8 +74,7 @@ def __call__(self,
                 the reference sequences.
 
         Returns:
-            `torch.FloatTensor` of shape `(batch_size, num_samples)`:
-                The metric scores for each sample (aggregated over all references).
+            `MetricOutput` containing the metric scores.
         """
 
         # Detokenize
@@ -83,8 +97,12 @@ def __call__(self,
             raise ValueError("Number of references must match `mbr_config.num_references`")
 
         # Compute metric
-        metric_scores = self._compute_str_metric(str_samples, str_references, str_inputs)
-        return metric_scores
+        scores_per_reference = self._compute_str_metric(str_samples, str_references, str_inputs)
+
+        return MetricOutput(
+            scores=scores_per_reference.mean(dim=-1),
+            scores_per_reference=scores_per_reference,
+        )
 
     def _compute_str_metric(self,
                             samples: List[List[str]],
@@ -112,7 +130,6 @@ def _compute_str_metric(self,
                             **self.mbr_config.metric_kwargs,
                         )
                     metric_scores[i, j, k] = score
-        metric_scores = metric_scores.mean(dim=-1)  # average over references
         return metric_scores
 
     @functools.lru_cache(maxsize=(1024 ** 2))
diff --git a/src/mbr/metrics/comet.py b/src/mbr/metrics/comet.py
index eaab3fa..aa65347 100644
--- a/src/mbr/metrics/comet.py
+++ b/src/mbr/metrics/comet.py
@@ -95,5 +95,4 @@ def _compute_str_metric(self,
                 for k in range(self.mbr_config.num_references):
                     metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])]
 
-        metric_scores = metric_scores.mean(dim=-1)  # average over references
         return metric_scores
diff --git a/tests/test_generate.py b/tests/test_generate.py
index 0a311ed..7ab605b 100644
--- a/tests/test_generate.py
+++ b/tests/test_generate.py
@@ -6,7 +6,7 @@
 from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig
 from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput
 
-from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner
+from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner, MetricOutput
 
 
 class DecoderOnlyTestCase(TestCase):
@@ -90,9 +90,11 @@ def test_model_output_extended(self):
         self.assertEqual(5, len(output.references))
         self.assertIsInstance(output.references[0], SampleDecoderOnlyOutput)
         self.assertIsNotNone(output.metric_scores)
-        self.assertTrue(torch.is_floating_point(output.metric_scores))
-        self.assertEqual(1, output.metric_scores.shape[0])
-        self.assertEqual(5, output.metric_scores.shape[1])
+        self.assertIsInstance(output.metric_scores, MetricOutput)
+        self.assertTrue(torch.is_floating_point(output.metric_scores.scores))
+        self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference))
+        self.assertEqual([1, 5], list(output.metric_scores.scores.shape))
+        self.assertEqual([1, 5, 5], list(output.metric_scores.scores_per_reference.shape))
 
         # Test the model output for a selected sample
         sample = output.all_samples[output.selected_samples_indices[0]]
@@ -267,9 +269,11 @@ def test_model_output_extended(self):
         self.assertEqual(5, len(output.references))
         self.assertIsInstance(output.references[0], SampleEncoderDecoderOutput)
         self.assertIsNotNone(output.metric_scores)
-        self.assertTrue(torch.is_floating_point(output.metric_scores))
-        self.assertEqual(2, output.metric_scores.shape[0])
-        self.assertEqual(5, output.metric_scores.shape[1])
+        self.assertIsInstance(output.metric_scores, MetricOutput)
+        self.assertTrue(torch.is_floating_point(output.metric_scores.scores))
+        self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference))
+        self.assertEqual([2, 5], list(output.metric_scores.scores.shape))
+        self.assertEqual([2, 5, 5], list(output.metric_scores.scores_per_reference.shape))
 
         # Test the model output for a selected sample (batch index 0)
         sample = output.all_samples[output.selected_samples_indices[0]]
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 93f9e1c..d0e4e47 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -75,15 +75,19 @@ def test_metric_config_name(self):
         self.assertEqual(metric.scorer.encoder.__class__.__name__, "MiniLMEncoder")
 
     def test_compute_metric__chrf(self):
-        metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
-        self.assertTrue(torch.is_floating_point(metric_scores))
+        metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
+        self.assertTrue(torch.is_floating_point(metric_output.scores))
+        self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
+        torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
+        self.assertEqual(metric_output.scores.shape, (2, 3))  # batch_size x num_samples
+        self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2))  # batch_size x num_samples x num_references
         # Duplicate samples should have the same scores
-        self.assertEqual(metric_scores[0, 0], metric_scores[0, 1])
+        torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
+        torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
         # The metric scores should rank as expected, given the test strings in self.samples and self.references
-        self.assertEqual(metric_scores.shape, (2, 3))  # batch_size x num_samples
-        self.assertGreater(metric_scores[0, 0], metric_scores[0, 2])
-        self.assertLess(metric_scores[1, 0], metric_scores[1, 1])
-        self.assertLess(metric_scores[1, 0], metric_scores[1, 2])
+        self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
+        self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
+        self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
 
     @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
     def test_compute_metric__comet(self):
@@ -91,15 +95,19 @@ def test_compute_metric__comet(self):
         self.mbr_config.metric_output_field = "mean_score"
         self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
         self.assertEqual(self.metric_runner.metric.name, "comet")
-        metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
-        self.assertTrue(torch.is_floating_point(metric_scores))
+        metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
+        self.assertTrue(torch.is_floating_point(metric_output.scores))
+        self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
+        torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
+        self.assertEqual(metric_output.scores.shape, (2, 3))  # batch_size x num_samples
+        self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2))  # batch_size x num_samples x num_references
         # Duplicate samples should have the same scores
-        self.assertEqual(metric_scores[0, 0], metric_scores[0, 1])
+        torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
+        torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
         # The metric scores should rank as expected, given the test strings in self.samples and self.references
-        self.assertEqual(metric_scores.shape, (2, 3))  # batch_size x num_samples
-        self.assertGreater(metric_scores[0, 0], metric_scores[0, 2])
-        self.assertLess(metric_scores[1, 0], metric_scores[1, 1])
-        self.assertLess(metric_scores[1, 0], metric_scores[1, 2])
+        self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
+        self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
+        self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
 
     @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
     def test_compute_metric__bleurt(self):
@@ -107,15 +115,19 @@ def test_compute_metric__bleurt(self):
         self.mbr_config.metric_output_field = "scores"
         self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
         self.assertEqual(self.metric_runner.metric.name, "bleurt")
-        metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
-        self.assertTrue(torch.is_floating_point(metric_scores))
+        metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
+        self.assertTrue(torch.is_floating_point(metric_output.scores))
+        self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
+        torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
+        self.assertEqual(metric_output.scores.shape, (2, 3))  # batch_size x num_samples
+        self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2))  # batch_size x num_samples x num_references
         # Duplicate samples should have the same scores
-        self.assertEqual(metric_scores[0, 0], metric_scores[0, 1])
+        torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
+        torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
         # The metric scores should rank as expected, given the test strings in self.samples and self.references
-        self.assertEqual(metric_scores.shape, (2, 3))  # batch_size x num_samples
-        self.assertGreater(metric_scores[0, 0], metric_scores[0, 2])
-        self.assertLess(metric_scores[1, 0], metric_scores[1, 1])
-        self.assertLess(metric_scores[1, 0], metric_scores[1, 2])
+        self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
+        self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
+        self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
 
     @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
     def test_comet_metric_runner(self):
@@ -125,6 +137,7 @@ def test_comet_metric_runner(self):
         base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
         self.assertEqual(base_metric_runner.metric.name, "comet")
         comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer)
+        # Output should be the same as the base MetricRunner
         base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
         metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
         torch.testing.assert_close(base_metric_scores, metric_scores)

From 73a38485dd5c342ba5e115509941a67751d79f83 Mon Sep 17 00:00:00 2001
From: Jannis Vamvas <vamvas@cl.uzh.ch>
Date: Wed, 27 Dec 2023 11:09:12 +0100
Subject: [PATCH 2/2] Do not globally require that num_references <=
 num_samples

---
 src/mbr/generation/configuration_utils.py | 8 +-------
 src/mbr/generation/utils.py               | 5 +++++
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/src/mbr/generation/configuration_utils.py b/src/mbr/generation/configuration_utils.py
index 51d3bb6..f238532 100644
--- a/src/mbr/generation/configuration_utils.py
+++ b/src/mbr/generation/configuration_utils.py
@@ -24,7 +24,7 @@ class MBRGenerationConfig(GenerationConfig):
         num_samples (`int`, *optional*, defaults to 10):
             Number of samples generated. 1 means no MBR decoding.
         num_references (`int`, *optional*, defaults to `num_samples`):
-            Number of pseudo-references used for MBR decoding. Needs to be smaller or equal to `num_samples`.
+            Number of pseudo-references used for MBR decoding.
         metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'chrf'):
             Metric used for MBR decoding.
         metric_config_name (`str`, *optional*, defaults to None):
@@ -100,12 +100,6 @@ def __init__(self, **kwargs):
         # Validate the values of the attributes
         self.validate(is_init=True)
 
-    def validate(self, is_init=False):
-        if self.num_references > self.num_samples:
-            raise ValueError(
-                f"`num_references` ({self.num_references}) must be <= `num_samples` ({self.num_samples})."
-            )
-
     def save_pretrained(
             self,
             save_directory: Union[str, os.PathLike],
diff --git a/src/mbr/generation/utils.py b/src/mbr/generation/utils.py
index dd90068..efe66fe 100644
--- a/src/mbr/generation/utils.py
+++ b/src/mbr/generation/utils.py
@@ -443,6 +443,11 @@ def generate(
             # 14. references
             if references_config is None:
                 # Use samples as references
+                if mbr_config.num_references > mbr_config.num_samples:
+                    raise ValueError(
+                        f"`mbr_config.num_references` ({mbr_config.num_references}) must be smaller than or equal to "
+                        f"`mbr_config.num_samples` ({mbr_config.num_samples}) if samples are re-used as references."
+                    )
                 references = samples[:mbr_config.num_references]
             else:
                 # Generate references