-
Notifications
You must be signed in to change notification settings - Fork 265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BHASA LINDSEA scenarios #2694
Changes from 2 commits
68665f6
2506c69
42a75e3
f787f33
119fdff
625a258
d520018
7137e1c
72705c2
cbe5f74
b414432
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
import re | ||
import string | ||
from dataclasses import replace | ||
from functools import partial | ||
from typing import Any, Callable, Dict, List, cast | ||
from collections import Counter | ||
|
||
import numpy as np | ||
from nltk.metrics.scores import f_measure | ||
from pythainlp.tokenize import word_tokenize | ||
from sacrebleu.metrics import CHRF | ||
|
||
from helm.benchmark.adaptation.adapter_spec import AdapterSpec | ||
from helm.benchmark.adaptation.request_state import RequestState | ||
from helm.benchmark.adaptation.scenario_state import ScenarioState | ||
from helm.benchmark.metrics.evaluate_reference_metrics import exact_match | ||
from helm.benchmark.metrics.evaluate_reference_metrics import rouge_score as rouge_score_fn | ||
from helm.benchmark.metrics.metric import Metric, MetricResult, MetricSpec | ||
from helm.benchmark.metrics.metric_name import MetricName | ||
from helm.benchmark.metrics.metric_service import MetricService | ||
from helm.benchmark.metrics.statistic import Stat | ||
from helm.benchmark.metrics.xlsum import rouge_scorer | ||
from helm.benchmark.metrics.xlsum.scoring import BootstrapAggregator | ||
|
||
class BhasaMachineTranslationMetric(Metric): | ||
"""Machine Translation Metrics | ||
|
||
This class computes the following standard machine translation metrics | ||
|
||
1. ChrF++ | ||
|
||
@inproceedings{popovic-2015-chrf, | ||
title = "chr{F}: character n-gram {F}-score for automatic {MT} evaluation", | ||
author = "Popovi{\'c}, Maja", | ||
editor = "Bojar, Ond{\v{r}}ej and | ||
Chatterjee, Rajan and | ||
Federmann, Christian and | ||
Haddow, Barry and | ||
Hokamp, Chris and | ||
Huck, Matthias and | ||
Logacheva, Varvara and | ||
Pecina, Pavel", | ||
booktitle = "Proceedings of the Tenth Workshop on Statistical Machine Translation", | ||
month = sep, | ||
year = "2015", | ||
address = "Lisbon, Portugal", | ||
publisher = "Association for Computational Linguistics", | ||
url = "https://aclanthology.org/W15-3049", | ||
doi = "10.18653/v1/W15-3049", | ||
pages = "392--395", | ||
github = "https://github.com/mjpost/sacrebleu", | ||
} | ||
""" | ||
|
||
def __init__(self): | ||
self.chrf_scorer = CHRF(word_order=2) | ||
|
||
def evaluate( | ||
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int | ||
) -> MetricResult: | ||
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=parallelism) | ||
|
||
def _compute_chrf(self, refs: List[str], pred: str) -> Dict[str, float]: | ||
metrics: Dict[str, float] = {} | ||
metrics['ChrF++'] = self.chrf_scorer.sentence_score(pred, refs).score | ||
return metrics | ||
|
||
def _remove_braces(self, text: str) -> str: | ||
if text.startswith("{"): | ||
text = text[1:] | ||
if text.endswith("}"): | ||
text = text[:-1] | ||
return text | ||
|
||
def evaluate_generation( | ||
self, | ||
adapter_spec: AdapterSpec, | ||
request_state: RequestState, | ||
metric_service: MetricService, | ||
eval_cache_path: str, | ||
) -> List[Stat]: | ||
refs: List[str] = [self._remove_braces(ref.output.text) for ref in request_state.instance.references] | ||
inp: str = self._remove_braces(request_state.instance.input.text) | ||
|
||
assert request_state.result is not None | ||
pred: str = self._remove_braces(request_state.result.completions[0].text.strip()) | ||
|
||
result: List[Stat] = [] | ||
|
||
# Compute ChrF++ metrics | ||
result.extend([Stat(MetricName(name)).add(float(val)) for name, val in self._compute_chrf(refs, pred).items()]) | ||
|
||
return result | ||
|
||
class BhasaQAMetric(Metric): | ||
"""Bhasa QA Metrics | ||
|
||
This class computes the following standard SQuAD v1.1 metrics | ||
|
||
1. SQuAD exact match | ||
2. SQuAD macro-averaged F1 score | ||
|
||
@inproceedings{rajpurkar-etal-2016-squad, | ||
title = "{SQ}u{AD}: 100,000+ Questions for Machine Comprehension of Text", | ||
author = "Rajpurkar, Pranav and | ||
Zhang, Jian and | ||
Lopyrev, Konstantin and | ||
Liang, Percy", | ||
editor = "Su, Jian and | ||
Duh, Kevin and | ||
Carreras, Xavier", | ||
booktitle = "Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing", | ||
month = nov, | ||
year = "2016", | ||
address = "Austin, Texas", | ||
publisher = "Association for Computational Linguistics", | ||
url = "https://aclanthology.org/D16-1264", | ||
doi = "10.18653/v1/D16-1264", | ||
pages = "2383--2392", | ||
} | ||
""" | ||
|
||
def __init__(self, language: str = 'en'): | ||
self.language: str = language | ||
self.metrics: Dict[str, Callable] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the type should be |
||
"squad_exact_match_score": self.squad_exact_match_score, | ||
"squad_f1_score": self.squad_f1_score, | ||
} | ||
|
||
def evaluate( | ||
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int | ||
) -> MetricResult: | ||
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=parallelism) | ||
|
||
def normalize_answer(self, text: str) -> List[str]: | ||
""" | ||
For Thai, this will split the text using PyThaiNLP's tokenizer. | ||
For all other languages, this will: | ||
- Lower text | ||
- Remove punctuation | ||
- Remove extra whitespace | ||
|
||
If the language is English, it will | ||
- Remove articles "a", "an", and "the" | ||
|
||
Modifies code from [SQuAD v1.1](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py). | ||
""" | ||
|
||
def remove_articles(text: str) -> str: | ||
return re.sub(r"\b(a|an|the)\b", " ", text) | ||
|
||
def white_space_fix(text: str) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional: can maybe delete this - see later comment. |
||
return " ".join(text.split()) | ||
|
||
def remove_punc(text: str) -> str: | ||
exclude = set(string.punctuation) | ||
return "".join(ch for ch in text if ch not in exclude) | ||
|
||
def lower(text: str) -> str: | ||
return text.lower() | ||
|
||
normalized_text = remove_punc(lower(text)) | ||
if self.language == "th": | ||
return word_tokenize(normalized_text, engine="newmm") | ||
elif self.language == "en": | ||
return white_space_fix(remove_articles(normalized_text)).split() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional: you can skip There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! I was testing this out with the string " abc abc abc " , and it seems like we get different results when using each method:
But besides this difference, actually this code was largely taken from the SQuAD v1.1 evaluation code, and I guess we were trying to preserve as much of the original as possible for transparency. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Optionally: it would be nice to add a comment to the code that this is to match the SQuAD v1.1 behavior. |
||
else: | ||
return white_space_fix(normalized_text).split() | ||
|
||
def squad_f1_score(self, gold: str, pred: str) -> float: | ||
prediction_tokens = self.normalize_answer(pred) | ||
ground_truth_tokens = self.normalize_answer(gold) | ||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | ||
num_same = sum(common.values()) | ||
if num_same == 0: | ||
return 0 | ||
precision = 1.0 * num_same / len(prediction_tokens) | ||
recall = 1.0 * num_same / len(ground_truth_tokens) | ||
f1 = (2 * precision * recall) / (precision + recall) | ||
return f1 | ||
|
||
def squad_exact_match_score(self, gold: str, pred: str) -> float: | ||
return self.normalize_answer(pred) == self.normalize_answer(gold) | ||
|
||
def evaluate_generation( | ||
self, | ||
adapter_spec: AdapterSpec, | ||
request_state: RequestState, | ||
metric_service: MetricService, | ||
eval_cache_path: str, | ||
) -> List[Stat]: | ||
|
||
stats: List[Stat] = [] | ||
if len(request_state.instance.references) > 0: | ||
golds = [reference for reference in request_state.instance.references if reference.is_correct] | ||
assert len(golds) > 0 | ||
|
||
assert request_state.result is not None | ||
sorted_completions = sorted(request_state.result.completions, key=lambda x: -x.logprob) | ||
preds = [completion.text.strip() for completion in sorted_completions] | ||
|
||
for name, metric in self.metrics.items(): | ||
name = MetricName(name) | ||
metric = cast(Callable[[str, str], float], metric) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't need this cast - put the type in the declaration of |
||
score_1 = max(metric(gold.output.text.strip(), preds[0]) for gold in golds) | ||
score_k = max(metric(gold.output.text.strip(), pred) for gold in golds for pred in preds) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can move this immediately after |
||
|
||
metrics = [Stat(name).add(score_1)] | ||
if adapter_spec.num_outputs != 1: | ||
metrics.append(Stat(replace(name, name=f"{name.name}@{adapter_spec.num_outputs}")).add(score_k)) | ||
stats.extend(metrics) | ||
|
||
return stats | ||
|
||
class BhasaSummarizationMetric(Metric): | ||
"""Summarization Metrics | ||
|
||
This class computes the following standard summarization metrics | ||
|
||
1. XL-Sum Rouge-L (F1 score, using the "mid" result when performing bootstrap aggregation) | ||
|
||
@inproceedings{hasan-etal-2021-xl, | ||
title = "{XL}-Sum: Large-Scale Multilingual Abstractive Summarization for 44 Languages", | ||
author = "Hasan, Tahmid and | ||
Bhattacharjee, Abhik and | ||
Islam, Md. Saiful and | ||
Mubasshir, Kazi and | ||
Li, Yuan-Fang and | ||
Kang, Yong-Bin and | ||
Rahman, M. Sohel and | ||
Shahriyar, Rifat", | ||
editor = "Zong, Chengqing and | ||
Xia, Fei and | ||
Li, Wenjie and | ||
Navigli, Roberto", | ||
booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021", | ||
month = aug, | ||
year = "2021", | ||
address = "Online", | ||
publisher = "Association for Computational Linguistics", | ||
url = "https://aclanthology.org/2021.findings-acl.413", | ||
doi = "10.18653/v1/2021.findings-acl.413", | ||
pages = "4693--4703", | ||
github = "https://github.com/csebuetnlp/xl-sum", | ||
} | ||
|
||
""" | ||
|
||
def __init__(self, language: str = 'en'): | ||
self.language: str = language | ||
self.rouge_metrics = { | ||
"rougeL": "xlsum_rouge_l", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: prefer "rouge_l" to maintain conventions. |
||
} | ||
self.rouge_scorer = self._get_bhasa_rouge_scorer(self.rouge_metrics) | ||
|
||
def _get_bhasa_rouge_scorer(self, rouge_metrics: str) -> Callable[[str, str], float]: | ||
language = "thai" if self.language == "th" else None | ||
return rouge_scorer.RougeScorer(list(rouge_metrics.keys()), use_stemmer=False, lang=language) | ||
|
||
def _compute_rouge(self, refs: List[str], pred: str) -> Dict[str, float]: | ||
metrics: Dict[str, float] = {} | ||
|
||
aggregator = BootstrapAggregator() | ||
for ref in refs: | ||
aggregator.add_scores(self.rouge_scorer.score(ref, pred)) | ||
aggregates = aggregator.aggregate() | ||
for key, value in self.rouge_metrics.items(): | ||
metrics[value] = aggregates[key].mid.fmeasure * 100 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this to make the range 0 to 100 instead of 0 to 1? We generally prefer the 0 to 1 range. |
||
return metrics | ||
|
||
def _remove_braces(self, text: str) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: Don't know if this is important, but if you want to ensure that braces are removed in a balanced way, then you should do if text.startswith("{") and text.endswith("}"):
text = text[1:-1] otherwise you might strip the brace from only the start or only the end. Likewise for the other occurrence of this function. Also, for my education, why do we need to remove braces? |
||
if text.startswith("{"): | ||
text = text[1:] | ||
if text.endswith("}"): | ||
text = text[:-1] | ||
return text | ||
|
||
def evaluate( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can omit this definition, if all you are doing is calling the super. |
||
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int | ||
) -> MetricResult: | ||
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=parallelism) | ||
|
||
def evaluate_generation( | ||
self, | ||
adapter_spec: AdapterSpec, | ||
request_state: RequestState, | ||
metric_service: MetricService, | ||
eval_cache_path: str, | ||
) -> List[Stat]: | ||
refs: List[str] = [self._remove_braces(ref.output.text) for ref in request_state.instance.references] | ||
inp: str = self._remove_braces(request_state.instance.input.text) | ||
|
||
assert request_state.result is not None | ||
pred: str = self._remove_braces(request_state.result.completions[0].text.strip()) | ||
|
||
result: List[Stat] = [] | ||
|
||
# Compute rouge metrics | ||
result.extend([Stat(MetricName(name)).add(float(val)) for name, val in self._compute_rouge(refs, pred).items()]) | ||
|
||
return result | ||
|
||
def get_bhasa_machine_translation_metric_specs() -> List[MetricSpec]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move these to a different file The reason for doing this is that we don't want the |
||
return [ | ||
MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaMachineTranslationMetric") | ||
] | ||
|
||
def get_bhasa_summarization_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]: | ||
return [ | ||
MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaSummarizationMetric", args=args) | ||
] | ||
|
||
def get_bhasa_qa_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]: | ||
return [ | ||
MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaQAMetric", args=args) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally our convention for metrics is camel case - could you make this
"chr_f_plus_plus"
instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correction - convention is snake case, not camel case.