-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate fastChrF and set as default metric (#6)
- Loading branch information
Showing
15 changed files
with
363 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR. | ||
|
||
## Setup | ||
* Task: Machine translation | ||
* Translation directions: en–de, de–en, en–ru, ru–en | ||
* Model: [facebook/wmt19-*](https://huggingface.co/facebook/wmt19-en-de) ([Ng et al., 2019](https://aclanthology.org/W19-5333/)). | ||
* MBR metrics: `fastchrf.pairwise_chrf` (a fast implementation of standard ChrF) and `fastchrf.aggregate_chrf` (a streamlined ChrF variant for MBR) | ||
* Number of samples: 256 | ||
* Sampling approach: epsilon sampling with ε=0.02 | ||
* Samples and references are the same | ||
* Test set: newstest2019 | ||
* Evaluation metrics: chrF ([sacreBLEU](https://github.com/mjpost/sacrebleu)) and COMET-22 ([Rei et al., 2022](https://aclanthology.org/2022.wmt-1.52/)) | ||
* Baseline: beam search with beam size 4 | ||
|
||
## Results | ||
| Language Pair | Method | ChrF | COMET | duration (s) | | ||
|---------------|--------------------------------------|---------:|----------:|-------------:| | ||
| en-de | MBR with `fastchrf.pairwise_chrf` | 67.7 | 0.867 | 7798 | | ||
| en-de | MBR with `fastchrf.aggregate_chrf` | 67.7 | 0.867 | 7480 | | ||
| en-de | Beam search | 67.7 | 0.868 | 62 | | ||
| de-en | MBR with `fastchrf.pairwise_chrf` | 65.4 | 0.851 | 6894 | | ||
| de-en | MBR with `fastchrf.aggregate_chrf` | 65.6 | 0.850 | 6849 | | ||
| de-en | Beam search | 65.1 | 0.851 | 53 | | ||
| en-ru | MBR with `fastchrf.pairwise_chrf` | 57.5 | 0.862 | 7802 | | ||
| en-ru | MBR with `fastchrf.aggregate_chrf` | 57.5 | 0.862 | 7465 | | ||
| en-ru | Beam search | 56.9 | 0.863 | 64 | | ||
| ru-en | MBR with `fastchrf.pairwise_chrf` | 64.2 | 0.847 | 7541 | | ||
| ru-en | MBR with `fastchrf.aggregate_chrf` | 64.3 | 0.848 | 6689 | | ||
| ru-en | Beam search | 63.5 | 0.847 | 61 | | ||
| **Average** | **MBR with `fastchrf.pairwise_chrf`** | **63.7** | **0.857** | **7509** | | ||
| **Average** | **MBR with `fastchrf.aggregate_chrf`** | **63.7** | **0.857** | **7121** | | ||
| **Average** | **Beam search** | **63.3** | **0.857** | **60** | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import sys | ||
import time | ||
from copy import deepcopy | ||
from pathlib import Path | ||
|
||
import evaluate | ||
import jsonlines | ||
import sacrebleu | ||
import torch | ||
from datasets import load_dataset | ||
from tqdm import tqdm | ||
from transformers import FSMTForConditionalGeneration, AutoTokenizer, pipeline, set_seed, GenerationConfig | ||
|
||
from mbr import MBR, MBRConfig | ||
|
||
language_pair = sys.argv[1] | ||
assert language_pair in ["de-en", "en-de", "en-ru", "ru-en"] | ||
|
||
batch_size = 32 | ||
|
||
results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}.jsonl", "w") | ||
|
||
model_name = f"facebook/wmt19-{language_pair}" | ||
model = MBR(FSMTForConditionalGeneration).from_pretrained(model_name) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
mt_pipeline = pipeline( | ||
"translation_" + language_pair.split("-")[0] + "_to_" + language_pair.split("-")[1], | ||
model=model, | ||
tokenizer=tokenizer, | ||
device=(0 if torch.cuda.is_available() else -1), | ||
) | ||
evaluation_metric_chrf = evaluate.load("chrf") | ||
evaluation_metric_comet = evaluate.load("comet", "Unbabel/wmt22-comet-da") | ||
|
||
src_path = sacrebleu.get_source_file("wmt19", language_pair) | ||
ref_path = sacrebleu.get_reference_files("wmt19", language_pair)[0] | ||
dataset = load_dataset("text", data_files={"test": src_path}) | ||
references = Path(ref_path).read_text().splitlines() | ||
assert len(dataset["test"]) == len(references) | ||
|
||
# MBR | ||
generation_config = GenerationConfig.from_pretrained(model_name) | ||
generation_config.do_sample = True | ||
generation_config.num_beams = 1 | ||
generation_config.early_stopping = False | ||
generation_config.epsilon_cutoff = 0.02 | ||
|
||
base_mbr_config = MBRConfig( | ||
num_samples=256, | ||
num_references=256, | ||
) | ||
base_mbr_config.metric_cache_size = batch_size * base_mbr_config.num_samples * base_mbr_config.num_references | ||
mbr_configs = {} | ||
|
||
# MBR with fastchrf.pairwise_chrf | ||
mbr_config = deepcopy(base_mbr_config) | ||
mbr_config.metric = "fastchrf-pairwise" | ||
mbr_configs["MBR with fastchrf.pairwise_chrf"] = mbr_config | ||
|
||
# MBR with fastchrf.aggregate_chrf | ||
mbr_config = deepcopy(base_mbr_config) | ||
mbr_config.metric = "fastchrf-aggregate" | ||
mbr_configs["MBR with fastchrf.aggregate_chrf"] = mbr_config | ||
|
||
for method, mbr_config in mbr_configs.items(): | ||
|
||
set_seed(42) | ||
time_start = time.time() | ||
outputs = mt_pipeline( | ||
dataset["test"]["text"], | ||
mbr_config=mbr_config, | ||
generation_config=generation_config, | ||
tokenizer=tokenizer, | ||
batch_size=batch_size, | ||
progress_bar=True | ||
) | ||
translations = [] | ||
for batch in tqdm(outputs): | ||
if isinstance(batch, dict): | ||
batch = [batch] | ||
translations += [translation["translation_text"] for translation in batch] | ||
time_end = time.time() | ||
|
||
chrf_score = evaluation_metric_chrf.compute( | ||
predictions=translations, | ||
references=references, | ||
) | ||
comet_score = evaluation_metric_comet.compute( | ||
predictions=translations, | ||
references=references, | ||
sources=dataset["test"]["text"], | ||
gpus=0, | ||
) | ||
results_file.write({ | ||
"language_pair": language_pair, | ||
"method": method, | ||
"chrf": chrf_score["score"], | ||
"comet22": comet_score["mean_score"], | ||
"duration": time_end - time_start, | ||
"translations": translations, | ||
}) | ||
|
||
# Beam search | ||
model = FSMTForConditionalGeneration.from_pretrained(model_name).half().to(mt_pipeline.device) | ||
mt_pipeline.model = model | ||
generation_config = GenerationConfig.from_pretrained(model_name) | ||
generation_config.num_beams = 4 | ||
|
||
set_seed(42) | ||
time_start = time.time() | ||
outputs = mt_pipeline( | ||
dataset["test"]["text"], | ||
generation_config=generation_config, | ||
batch_size=batch_size, | ||
) | ||
translations = [] | ||
for batch in tqdm(outputs): | ||
if isinstance(batch, dict): | ||
batch = [batch] | ||
translations += [translation["translation_text"] for translation in batch] | ||
time_end = time.time() | ||
|
||
chrf_score = evaluation_metric_chrf.compute( | ||
predictions=translations, | ||
references=references, | ||
) | ||
comet_score = evaluation_metric_comet.compute( | ||
predictions=translations, | ||
references=references, | ||
sources=dataset["test"]["text"], | ||
gpus=0, | ||
) | ||
results_file.write({ | ||
"language_pair": language_pair, | ||
"method": f"beam search (beam size {generation_config.num_beams})", | ||
"chrf": chrf_score["score"], | ||
"comet22": comet_score["mean_score"], | ||
"duration": time_end - time_start, | ||
"translations": translations, | ||
}) | ||
|
||
results_file.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
jsonlines==4.0.0 | ||
datasets==2.14.6 | ||
sacrebleu==2.3.1 | ||
sacremoses==0.0.53 # For OpusMT | ||
nltk==3.8.1 | ||
rouge_score==0.1.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
sacrebleu==2.4.0 | ||
unbabel-comet==2.1.1 | ||
git+https://github.com/google-research/bleurt.git | ||
sentencepiece==0.1.99 # M2M100 model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
sacrebleu==2.4.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,13 @@ | ||
from mbr.metrics.base import metric_is_source_based | ||
from mbr import MBRConfig | ||
from mbr.metrics.base import metric_is_source_based, MetricRunner | ||
|
||
|
||
def load_metric_runner(mbr_config: MBRConfig, tokenizer=None) -> MetricRunner: | ||
if mbr_config.metric in {"fastchrf", "aggregate_chrf", "fastchrf.aggregate_chrf"}: | ||
from mbr.metrics.fastchrf import FastChrfMetricRunner | ||
return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=False) | ||
elif mbr_config.metric in {"pairwise_chrf", "fastchrf.pairwise_chrf"}: | ||
from mbr.metrics.fastchrf import FastChrfMetricRunner | ||
return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=True) | ||
else: | ||
return MetricRunner(mbr_config, tokenizer) |
Oops, something went wrong.