Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BHASA LINDSEA scenarios #2694

Merged
merged 11 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ mongo =
unitxt =
evaluate~=0.4.1

bhasa =
pythainlp==5.0.0
pyonmttok==1.37.0
sacrebleu~=2.2.1

# Model extras
aleph-alpha =
aleph-alpha-client~=2.14.0
Expand Down
316 changes: 316 additions & 0 deletions src/helm/benchmark/metrics/bhasa_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
import re
import string
from dataclasses import replace
from functools import partial
from typing import Any, Callable, Dict, List, cast
from collections import Counter

import numpy as np
from nltk.metrics.scores import f_measure
from pythainlp.tokenize import word_tokenize
from sacrebleu.metrics import CHRF

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.metrics.evaluate_reference_metrics import exact_match
from helm.benchmark.metrics.evaluate_reference_metrics import rouge_score as rouge_score_fn
from helm.benchmark.metrics.metric import Metric, MetricResult, MetricSpec
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat
from helm.benchmark.metrics.xlsum import rouge_scorer
from helm.benchmark.metrics.xlsum.scoring import BootstrapAggregator

class BhasaMachineTranslationMetric(Metric):
"""Machine Translation Metrics

This class computes the following standard machine translation metrics

1. ChrF++

@inproceedings{popovic-2015-chrf,
title = "chr{F}: character n-gram {F}-score for automatic {MT} evaluation",
author = "Popovi{\'c}, Maja",
editor = "Bojar, Ond{\v{r}}ej and
Chatterjee, Rajan and
Federmann, Christian and
Haddow, Barry and
Hokamp, Chris and
Huck, Matthias and
Logacheva, Varvara and
Pecina, Pavel",
booktitle = "Proceedings of the Tenth Workshop on Statistical Machine Translation",
month = sep,
year = "2015",
address = "Lisbon, Portugal",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/W15-3049",
doi = "10.18653/v1/W15-3049",
pages = "392--395",
github = "https://github.com/mjpost/sacrebleu",
}
"""

def __init__(self):
self.chrf_scorer = CHRF(word_order=2)

def evaluate(
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
) -> MetricResult:
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=parallelism)

def _compute_chrf(self, refs: List[str], pred: str) -> Dict[str, float]:
metrics: Dict[str, float] = {}
metrics['ChrF++'] = self.chrf_scorer.sentence_score(pred, refs).score
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally our convention for metrics is camel case - could you make this "chr_f_plus_plus" instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction - convention is snake case, not camel case.

return metrics

def _remove_braces(self, text: str) -> str:
if text.startswith("{"):
text = text[1:]
if text.endswith("}"):
text = text[:-1]
return text

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
refs: List[str] = [self._remove_braces(ref.output.text) for ref in request_state.instance.references]
inp: str = self._remove_braces(request_state.instance.input.text)

assert request_state.result is not None
pred: str = self._remove_braces(request_state.result.completions[0].text.strip())

result: List[Stat] = []

# Compute ChrF++ metrics
result.extend([Stat(MetricName(name)).add(float(val)) for name, val in self._compute_chrf(refs, pred).items()])

return result

class BhasaQAMetric(Metric):
"""Bhasa QA Metrics

This class computes the following standard SQuAD v1.1 metrics

1. SQuAD exact match
2. SQuAD macro-averaged F1 score

@inproceedings{rajpurkar-etal-2016-squad,
title = "{SQ}u{AD}: 100,000+ Questions for Machine Comprehension of Text",
author = "Rajpurkar, Pranav and
Zhang, Jian and
Lopyrev, Konstantin and
Liang, Percy",
editor = "Su, Jian and
Duh, Kevin and
Carreras, Xavier",
booktitle = "Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing",
month = nov,
year = "2016",
address = "Austin, Texas",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/D16-1264",
doi = "10.18653/v1/D16-1264",
pages = "2383--2392",
}
"""

def __init__(self, language: str = 'en'):
self.language: str = language
self.metrics: Dict[str, Callable] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type should be Dict[str, Callable[[str, str], float]]

"squad_exact_match_score": self.squad_exact_match_score,
"squad_f1_score": self.squad_f1_score,
}

def evaluate(
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
) -> MetricResult:
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=parallelism)

def normalize_answer(self, text: str) -> List[str]:
"""
For Thai, this will split the text using PyThaiNLP's tokenizer.
For all other languages, this will:
- Lower text
- Remove punctuation
- Remove extra whitespace

If the language is English, it will
- Remove articles "a", "an", and "the"

Modifies code from [SQuAD v1.1](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py).
"""

def remove_articles(text: str) -> str:
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text: str) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: can maybe delete this - see later comment.

return " ".join(text.split())

def remove_punc(text: str) -> str:
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text: str) -> str:
return text.lower()

normalized_text = remove_punc(lower(text))
if self.language == "th":
return word_tokenize(normalized_text, engine="newmm")
elif self.language == "en":
return white_space_fix(remove_articles(normalized_text)).split()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: you can skip white_space_fix() by doing re.split("\s+", remove_articles(normalized_text)), which would also allow you to delete the definition of white_space_fix().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I was testing this out with the string " abc abc abc " , and it seems like we get different results when using each method:

\s+: [" ", "abc", "abc", "abc", " "]

white_space_fix(): ["abc", "abc", "abc"]

But besides this difference, actually this code was largely taken from the SQuAD v1.1 evaluation code, and I guess we were trying to preserve as much of the original as possible for transparency. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

Optionally: it would be nice to add a comment to the code that this is to match the SQuAD v1.1 behavior.

else:
return white_space_fix(normalized_text).split()

def squad_f1_score(self, gold: str, pred: str) -> float:
prediction_tokens = self.normalize_answer(pred)
ground_truth_tokens = self.normalize_answer(gold)
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1

def squad_exact_match_score(self, gold: str, pred: str) -> float:
return self.normalize_answer(pred) == self.normalize_answer(gold)

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:

stats: List[Stat] = []
if len(request_state.instance.references) > 0:
golds = [reference for reference in request_state.instance.references if reference.is_correct]
assert len(golds) > 0

assert request_state.result is not None
sorted_completions = sorted(request_state.result.completions, key=lambda x: -x.logprob)
preds = [completion.text.strip() for completion in sorted_completions]

for name, metric in self.metrics.items():
name = MetricName(name)
metric = cast(Callable[[str, str], float], metric)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need this cast - put the type in the declaration of self.metrics

score_1 = max(metric(gold.output.text.strip(), preds[0]) for gold in golds)
score_k = max(metric(gold.output.text.strip(), pred) for gold in golds for pred in preds)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can move this immediately after if adapter_spec.num_outputs != 1: to avoid making an unnecessary extra call to metric() when preds is length 1.


metrics = [Stat(name).add(score_1)]
if adapter_spec.num_outputs != 1:
metrics.append(Stat(replace(name, name=f"{name.name}@{adapter_spec.num_outputs}")).add(score_k))
stats.extend(metrics)

return stats

class BhasaSummarizationMetric(Metric):
"""Summarization Metrics

This class computes the following standard summarization metrics

1. XL-Sum Rouge-L (F1 score, using the "mid" result when performing bootstrap aggregation)

@inproceedings{hasan-etal-2021-xl,
title = "{XL}-Sum: Large-Scale Multilingual Abstractive Summarization for 44 Languages",
author = "Hasan, Tahmid and
Bhattacharjee, Abhik and
Islam, Md. Saiful and
Mubasshir, Kazi and
Li, Yuan-Fang and
Kang, Yong-Bin and
Rahman, M. Sohel and
Shahriyar, Rifat",
editor = "Zong, Chengqing and
Xia, Fei and
Li, Wenjie and
Navigli, Roberto",
booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
month = aug,
year = "2021",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.findings-acl.413",
doi = "10.18653/v1/2021.findings-acl.413",
pages = "4693--4703",
github = "https://github.com/csebuetnlp/xl-sum",
}

"""

def __init__(self, language: str = 'en'):
self.language: str = language
self.rouge_metrics = {
"rougeL": "xlsum_rouge_l",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer "rouge_l" to maintain conventions.

}
self.rouge_scorer = self._get_bhasa_rouge_scorer(self.rouge_metrics)

def _get_bhasa_rouge_scorer(self, rouge_metrics: str) -> Callable[[str, str], float]:
language = "thai" if self.language == "th" else None
return rouge_scorer.RougeScorer(list(rouge_metrics.keys()), use_stemmer=False, lang=language)

def _compute_rouge(self, refs: List[str], pred: str) -> Dict[str, float]:
metrics: Dict[str, float] = {}

aggregator = BootstrapAggregator()
for ref in refs:
aggregator.add_scores(self.rouge_scorer.score(ref, pred))
aggregates = aggregator.aggregate()
for key, value in self.rouge_metrics.items():
metrics[value] = aggregates[key].mid.fmeasure * 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to make the range 0 to 100 instead of 0 to 1? We generally prefer the 0 to 1 range.

return metrics

def _remove_braces(self, text: str) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Don't know if this is important, but if you want to ensure that braces are removed in a balanced way, then you should do

        if text.startswith("{") and text.endswith("}"):
            text = text[1:-1]

otherwise you might strip the brace from only the start or only the end. Likewise for the other occurrence of this function.

Also, for my education, why do we need to remove braces?

if text.startswith("{"):
text = text[1:]
if text.endswith("}"):
text = text[:-1]
return text

def evaluate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can omit this definition, if all you are doing is calling the super.

self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
) -> MetricResult:
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=parallelism)

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
refs: List[str] = [self._remove_braces(ref.output.text) for ref in request_state.instance.references]
inp: str = self._remove_braces(request_state.instance.input.text)

assert request_state.result is not None
pred: str = self._remove_braces(request_state.result.completions[0].text.strip())

result: List[Stat] = []

# Compute rouge metrics
result.extend([Stat(MetricName(name)).add(float(val)) for name, val in self._compute_rouge(refs, pred).items()])

return result

def get_bhasa_machine_translation_metric_specs() -> List[MetricSpec]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move these to a different file bhasa_metric_specs.py (follow the conventions in common_metric_specs.py).

The reason for doing this is that we don't want the bhasa_run_specs file to import bhasa_metrics, which transitively imports optional dependencies, otherwise this would cause helm-run to fail for someone who doesn't have the optional dependencies installed.

return [
MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaMachineTranslationMetric")
]

def get_bhasa_summarization_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]:
return [
MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaSummarizationMetric", args=args)
]

def get_bhasa_qa_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]:
return [
MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaQAMetric", args=args)
]
Loading
Loading