Skip to content

Commit

Permalink
Integrate fastChrF and set as default metric (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvamvas authored Jan 2, 2024
1 parent 9588b26 commit 3949960
Show file tree
Hide file tree
Showing 15 changed files with 363 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
python -m pip install --upgrade pip
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install .
pip install -r requirements-test.txt
- name: Lint with flake8
run: |
pip install flake8
Expand Down
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,25 @@ model.generate(..., references_config=references_config)
```

### Choosing a metric
By default, **mbr** integrates metrics via the [Hugging Face Evaluate](https://github.com/huggingface/evaluate) library.
By default, **mbr** uses [fastChrF](https://github.com/jvamvas/fastChrF), which is optimized for efficient comparison of many samples to many references.

You can also plug in metrics from the [**Hugging Face Evaluate**](https://github.com/huggingface/evaluate) library.

A full list of metrics is found [here](https://huggingface.co/metrics). Some typical choices are:
- [ChrF](https://huggingface.co/spaces/evaluate-metric/chrf) ([Popović, 2015](https://www.aclweb.org/anthology/W15-3049/))
- [COMET](https://huggingface.co/spaces/evaluate-metric/comet) ([Rei et al., 2020](https://aclanthology.org/2020.emnlp-main.213/))
- [BLEURT](https://huggingface.co/spaces/evaluate-metric/bleurt) ([Sellam et al., 2020](https://aclanthology.org/2020.acl-main.704))

In the MBR config, you can either specify the metric's name (e.g., `"chrf"`, `"comet"`) or pass an `evaluate.Metric` object directly.
To use a metric from Hugging Face, either specify the metric's name (e.g., `"comet"`, `"bleurt"`) or pass an `evaluate.Metric` object directly.

Since different metrics output differently structured dicts, you need to specify the `metric_output_field` that should be used as the metric score.

```python
from evaluate import load

metric = load('chrf')
metric = load('bleu')
mbr_config = MBRGenerationConfig(
metric=metric,
metric_output_field="score", # the ChrF metric returns a dict with a "score" field
metric_output_field="bleu", # the BLEU metric returns a dict with a "bleu" field
...
)
```
Expand Down Expand Up @@ -188,8 +189,9 @@ model.generate(..., metric_runner=metric_runner)
### Optimizations
MBR decoding is notoriously slow. **mbr** implements some optimizations:
- Cached encoder outputs: For encoder-decoder models, the encoder outputs are computed only once and reused during sampling.
- Cached metric: The metric is computed only once for each unique sample–reference pair (since there will be duplicate samples and references).
- Optimized COMET metric: Inspired by [Amrhein & Sennrich (2022)](https://aclanthology.org/2022.aacl-main.83/), sequence embeddings are cached and reused for all pairwise comparisons.
- Optimized ChrF metric: [fastChrF](https://github.com/jvamvas/fastChrF) is used by default, which is a streamlined ChrF variant for MBR, implemented in Rust.
- Optimized COMET metric: Inspired by [Amrhein & Sennrich (2022)](https://aclanthology.org/2022.aacl-main.83/), `CometMetricRunner` caches sequence embeddings and reuses them for all pairwise comparisons.
- Cached metrics: Most metrics are computed only once for each unique sample–reference pair (since there will be duplicate samples and references).

## Example scripts

Expand All @@ -199,13 +201,20 @@ The [experiments](experiments) directory contains the code for reproductions of
- [MBR with neural metrics and epsilon sampling for machine translation](experiments/freitag-et-al-2023-epsilon) ([Freitag et al., 2023](https://arxiv.org/abs/2305.09860))
- [MBR for summarization](experiments/bertsch-et-al-2023-mbr) ([Bertsch et al., 2023](https://arxiv.org/abs/2310.01387))

### Other experiments
- Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR

## Related projects
- https://github.com/roxot/mbr-nmt: Original implementation ([demo](https://colab.research.google.com/github/probabll/demo-mbr-nmt/blob/main/German-English.ipynb))
- https://github.com/ZurichNLP/understanding-mbr: MBR with Sockeye
- https://github.com/ZurichNLP/mbr-sensitivity and https://github.com/Unbabel/COMET#minimum-bayes-risk-decoding: COMET metric for MBR
- https://github.com/rainavyas/mbr_gec: MBR for Grammatical Error Correction

## Changelog

- v0.3.0 (draft)
- Use [fastChrF](https://github.com/jvamvas/fastChrF) as default metric

- v0.2.0
- **Breaking change:** Rename `MBRGenerationConfig` to `MBRConfig`
- **Breaking change:** `MetricRunner` now returns a `MetricOutput` dict instead of the raw tensor of scores.
Expand Down
3 changes: 3 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
- It's MBR All the Way Down: Modern Generation Techniques Through the Lens of Minimum Bayes Risk (Bertsch et al., 2023)
- Epsilon Sampling Rocks: Investigating Sampling Strategies for Minimum Bayes Risk Decoding for Machine Translation (Freitag et al., 2023)
- Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation (Müller & Sennrich, ACL-IJCNLP 2021)

**Other experiments**
- Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR.
32 changes: 32 additions & 0 deletions experiments/chrf-vs-fastchrf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR.

## Setup
* Task: Machine translation
* Translation directions: en–de, de–en, en–ru, ru–en
* Model: [facebook/wmt19-*](https://huggingface.co/facebook/wmt19-en-de) ([Ng et al., 2019](https://aclanthology.org/W19-5333/)).
* MBR metrics: `fastchrf.pairwise_chrf` (a fast implementation of standard ChrF) and `fastchrf.aggregate_chrf` (a streamlined ChrF variant for MBR)
* Number of samples: 256
* Sampling approach: epsilon sampling with ε=0.02
* Samples and references are the same
* Test set: newstest2019
* Evaluation metrics: chrF ([sacreBLEU](https://github.com/mjpost/sacrebleu)) and COMET-22 ([Rei et al., 2022](https://aclanthology.org/2022.wmt-1.52/))
* Baseline: beam search with beam size 4

## Results
| Language Pair | Method | ChrF | COMET | duration (s) |
|---------------|--------------------------------------|---------:|----------:|-------------:|
| en-de | MBR with `fastchrf.pairwise_chrf` | 67.7 | 0.867 | 7798 |
| en-de | MBR with `fastchrf.aggregate_chrf` | 67.7 | 0.867 | 7480 |
| en-de | Beam search | 67.7 | 0.868 | 62 |
| de-en | MBR with `fastchrf.pairwise_chrf` | 65.4 | 0.851 | 6894 |
| de-en | MBR with `fastchrf.aggregate_chrf` | 65.6 | 0.850 | 6849 |
| de-en | Beam search | 65.1 | 0.851 | 53 |
| en-ru | MBR with `fastchrf.pairwise_chrf` | 57.5 | 0.862 | 7802 |
| en-ru | MBR with `fastchrf.aggregate_chrf` | 57.5 | 0.862 | 7465 |
| en-ru | Beam search | 56.9 | 0.863 | 64 |
| ru-en | MBR with `fastchrf.pairwise_chrf` | 64.2 | 0.847 | 7541 |
| ru-en | MBR with `fastchrf.aggregate_chrf` | 64.3 | 0.848 | 6689 |
| ru-en | Beam search | 63.5 | 0.847 | 61 |
| **Average** | **MBR with `fastchrf.pairwise_chrf`** | **63.7** | **0.857** | **7509** |
| **Average** | **MBR with `fastchrf.aggregate_chrf`** | **63.7** | **0.857** | **7121** |
| **Average** | **Beam search** | **63.3** | **0.857** | **60** |
142 changes: 142 additions & 0 deletions experiments/chrf-vs-fastchrf/run_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import sys
import time
from copy import deepcopy
from pathlib import Path

import evaluate
import jsonlines
import sacrebleu
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import FSMTForConditionalGeneration, AutoTokenizer, pipeline, set_seed, GenerationConfig

from mbr import MBR, MBRConfig

language_pair = sys.argv[1]
assert language_pair in ["de-en", "en-de", "en-ru", "ru-en"]

batch_size = 32

results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}.jsonl", "w")

model_name = f"facebook/wmt19-{language_pair}"
model = MBR(FSMTForConditionalGeneration).from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
mt_pipeline = pipeline(
"translation_" + language_pair.split("-")[0] + "_to_" + language_pair.split("-")[1],
model=model,
tokenizer=tokenizer,
device=(0 if torch.cuda.is_available() else -1),
)
evaluation_metric_chrf = evaluate.load("chrf")
evaluation_metric_comet = evaluate.load("comet", "Unbabel/wmt22-comet-da")

src_path = sacrebleu.get_source_file("wmt19", language_pair)
ref_path = sacrebleu.get_reference_files("wmt19", language_pair)[0]
dataset = load_dataset("text", data_files={"test": src_path})
references = Path(ref_path).read_text().splitlines()
assert len(dataset["test"]) == len(references)

# MBR
generation_config = GenerationConfig.from_pretrained(model_name)
generation_config.do_sample = True
generation_config.num_beams = 1
generation_config.early_stopping = False
generation_config.epsilon_cutoff = 0.02

base_mbr_config = MBRConfig(
num_samples=256,
num_references=256,
)
base_mbr_config.metric_cache_size = batch_size * base_mbr_config.num_samples * base_mbr_config.num_references
mbr_configs = {}

# MBR with fastchrf.pairwise_chrf
mbr_config = deepcopy(base_mbr_config)
mbr_config.metric = "fastchrf-pairwise"
mbr_configs["MBR with fastchrf.pairwise_chrf"] = mbr_config

# MBR with fastchrf.aggregate_chrf
mbr_config = deepcopy(base_mbr_config)
mbr_config.metric = "fastchrf-aggregate"
mbr_configs["MBR with fastchrf.aggregate_chrf"] = mbr_config

for method, mbr_config in mbr_configs.items():

set_seed(42)
time_start = time.time()
outputs = mt_pipeline(
dataset["test"]["text"],
mbr_config=mbr_config,
generation_config=generation_config,
tokenizer=tokenizer,
batch_size=batch_size,
progress_bar=True
)
translations = []
for batch in tqdm(outputs):
if isinstance(batch, dict):
batch = [batch]
translations += [translation["translation_text"] for translation in batch]
time_end = time.time()

chrf_score = evaluation_metric_chrf.compute(
predictions=translations,
references=references,
)
comet_score = evaluation_metric_comet.compute(
predictions=translations,
references=references,
sources=dataset["test"]["text"],
gpus=0,
)
results_file.write({
"language_pair": language_pair,
"method": method,
"chrf": chrf_score["score"],
"comet22": comet_score["mean_score"],
"duration": time_end - time_start,
"translations": translations,
})

# Beam search
model = FSMTForConditionalGeneration.from_pretrained(model_name).half().to(mt_pipeline.device)
mt_pipeline.model = model
generation_config = GenerationConfig.from_pretrained(model_name)
generation_config.num_beams = 4

set_seed(42)
time_start = time.time()
outputs = mt_pipeline(
dataset["test"]["text"],
generation_config=generation_config,
batch_size=batch_size,
)
translations = []
for batch in tqdm(outputs):
if isinstance(batch, dict):
batch = [batch]
translations += [translation["translation_text"] for translation in batch]
time_end = time.time()

chrf_score = evaluation_metric_chrf.compute(
predictions=translations,
references=references,
)
comet_score = evaluation_metric_comet.compute(
predictions=translations,
references=references,
sources=dataset["test"]["text"],
gpus=0,
)
results_file.write({
"language_pair": language_pair,
"method": f"beam search (beam size {generation_config.num_beams})",
"chrf": chrf_score["score"],
"comet22": comet_score["mean_score"],
"duration": time_end - time_start,
"translations": translations,
})

results_file.close()
1 change: 1 addition & 0 deletions experiments/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
jsonlines==4.0.0
datasets==2.14.6
sacrebleu==2.3.1
sacremoses==0.0.53 # For OpusMT
nltk==3.8.1
rouge_score==0.1.2
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ requires-python = ">=3.9"
dependencies = [
"transformers",
"evaluate",
"sacrebleu",
"cachetools",
"tqdm",
"fastchrf",
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
sacrebleu==2.4.0
unbabel-comet==2.1.1
git+https://github.com/google-research/bleurt.git
sentencepiece==0.1.99 # M2M100 model
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sacrebleu==2.4.0

6 changes: 3 additions & 3 deletions src/mbr/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MBRConfig:
Example:
```python
>>> config = MBRConfig(num_samples=10, num_references=10, metric="chrf")
>>> config = MBRConfig(num_samples=10, num_references=10, metric="fastchrf")
>>> model.generate(..., mbr_config=config)
```
Expand All @@ -31,7 +31,7 @@ class MBRConfig:
Number of samples generated. 1 means no MBR decoding.
num_references (`int`, *optional*, defaults to `num_samples`):
Number of pseudo-references used for MBR decoding.
metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'chrf'):
metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'fastchrf'):
Metric used for MBR decoding.
metric_config_name (`str`, *optional*, defaults to None):
Metric configuration to pass to `evaluate.load` (e.g., the model for a trained metric, such as
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, **kwargs):
# Parameters that control the generation strategy used
self.num_samples = kwargs.pop("num_samples", 10)
self.num_references = kwargs.pop("num_references", self.num_samples)
self.metric = kwargs.pop("metric", "chrf")
self.metric = kwargs.pop("metric", "fastchrf")
self.metric_config_name = kwargs.pop("metric_config_name", None)
self.metric_output_field = kwargs.pop("metric_output_field", "score")
self.metric_kwargs = kwargs.pop("metric_kwargs", {})
Expand Down
3 changes: 2 additions & 1 deletion src/mbr/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from transformers.utils import logging, ModelOutput

from mbr.generation.configuration_utils import MBRConfig
from mbr.metrics import load_metric_runner
from mbr.metrics.base import MetricRunner, MetricOutput

if TYPE_CHECKING:
Expand Down Expand Up @@ -476,7 +477,7 @@ def generate(

# 15. apply metric to samples
if metric_runner is None:
metric_runner = MetricRunner(mbr_config, tokenizer)
metric_runner = load_metric_runner(mbr_config, tokenizer)

if isinstance(samples[0], ModelOutput):
sample_ids = tuple(sample.sequences for sample in samples)
Expand Down
14 changes: 13 additions & 1 deletion src/mbr/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
from mbr.metrics.base import metric_is_source_based
from mbr import MBRConfig
from mbr.metrics.base import metric_is_source_based, MetricRunner


def load_metric_runner(mbr_config: MBRConfig, tokenizer=None) -> MetricRunner:
if mbr_config.metric in {"fastchrf", "aggregate_chrf", "fastchrf.aggregate_chrf"}:
from mbr.metrics.fastchrf import FastChrfMetricRunner
return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=False)
elif mbr_config.metric in {"pairwise_chrf", "fastchrf.pairwise_chrf"}:
from mbr.metrics.fastchrf import FastChrfMetricRunner
return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=True)
else:
return MetricRunner(mbr_config, tokenizer)
Loading

0 comments on commit 3949960

Please sign in to comment.