Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Rouge and Meteor to InstanceMetric for faster score computation #1011

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"files": "^.secrets.baseline$",
"lines": null
},
"generated_at": "2024-07-09T07:07:12Z",
"generated_at": "2024-07-22T10:56:00Z",
"plugins_used": [
{
"name": "AWSKeyDetector"
Expand Down Expand Up @@ -82,7 +82,7 @@
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_secret": false,
"is_verified": false,
"line_number": 1531,
"line_number": 1607,
"type": "Hex High Entropy String",
"verified_result": null
}
Expand Down
61 changes: 59 additions & 2 deletions prepare/metrics/meteor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,65 @@
from unitxt import add_to_catalog
from unitxt.metrics import HuggingfaceMetric
from unitxt.metrics import HuggingfaceMetric, Meteor
from unitxt.test_utils.metrics import test_metric

metric = HuggingfaceMetric(
metric = Meteor()

predictions = [
"It is a guide to action which ensures that the military always obeys the commands of the party",
"We strive for peace",
"On the rag sat the cat",
"I caught the ball",
]
references = [
[
"It is a guide to action that ensures that the military will forever heed Party commands"
],
["We hope for peace"],
["The cat sat on the rag"],
["He threw the ball"],
]

# the floats shown here are rounded just for the test. the actually
# returned score are 15-16 digits to the right of the decimal point
instance_targets = [
{"meteor": 0.69, "score": 0.69, "score_name": "meteor"},
{"meteor": 0.64, "score": 0.64, "score_name": "meteor"},
{"meteor": 0.5, "score": 0.5, "score_name": "meteor"},
{"meteor": 0.47, "score": 0.47, "score_name": "meteor"},
]

global_target = {
"meteor": 0.58,
"meteor_ci_high": 0.59,
"meteor_ci_low": 0.58,
"score": 0.58,
"score_ci_high": 0.59,
"score_ci_low": 0.58,
"score_name": "meteor",
}

metric.n_resamples = 3
# to match the setting to occur by testing on the global version, metric2, below

outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

# compare results with the HF version of meteor
metric2 = HuggingfaceMetric(
hf_metric_name="meteor", main_score="meteor", prediction_type="str"
)

outputs = test_metric(
metric=metric2,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

add_to_catalog(metric, "metrics.meteor", overwrite=True)
34 changes: 14 additions & 20 deletions prepare/metrics/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unitxt.metrics import Rouge
from unitxt.test_utils.metrics import test_metric

metric = Rouge(n_resamples=None)
metric = Rouge()

predictions = ["hello there", "general kenobi"]
references = [["hello", "there"], ["general kenobi", "general yoda"]]
Expand All @@ -28,13 +28,22 @@

global_target = {
"rouge1": 0.83,
"rouge1_ci_high": 1.0,
"rouge1_ci_low": 0.67,
"rouge2": 0.5,
"rouge2_ci_high": 1.0,
"rouge2_ci_low": 0.0,
"rougeL": 0.83,
"rougeL_ci_high": 1.0,
"rougeL_ci_low": 0.67,
"rougeLsum": 0.83,
"rougeLsum_ci_high": 1.0,
"rougeLsum_ci_low": 0.67,
"score": 0.83,
"score_ci_high": 1.0,
"score_ci_low": 0.67,
"score_name": "rougeL",
}

outputs = test_metric(
metric=metric,
predictions=predictions,
Expand All @@ -43,27 +52,12 @@
global_target=global_target,
)
add_to_catalog(metric, "metrics.rouge", overwrite=True)

global_target_with_confidence_intervals = global_target.copy()
global_target_with_confidence_intervals.update(
{
"rougeL_ci_low": 0.83,
"rougeL_ci_high": 0.83,
"score_ci_low": 0.83,
"score_ci_high": 0.83,
}
metric = Rouge(
__description__="This is deprecated. Use 'metrics.rouge' which also generate confidence intervals"
)

metric_with_confidence_intervals = Rouge()
outputs = test_metric(
metric=metric_with_confidence_intervals,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target_with_confidence_intervals,
)
add_to_catalog(
metric_with_confidence_intervals,
metric,
"metrics.rouge_with_confidence_intervals",
overwrite=True,
)
5 changes: 1 addition & 4 deletions src/unitxt/catalog/metrics/meteor.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
{
"__type__": "huggingface_metric",
"hf_metric_name": "meteor",
"main_score": "meteor",
"prediction_type": "str"
"__type__": "meteor"
}
3 changes: 1 addition & 2 deletions src/unitxt/catalog/metrics/rouge.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{
"__type__": "rouge",
"n_resamples": null
"__type__": "rouge"
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"__type__": "rouge"
"__type__": "rouge",
"__description__": "This is deprecated. Use 'metrics.rouge' which also generate confidence intervals"
}
149 changes: 138 additions & 11 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def score_based_confidence_interval(
# otherwise, the aggregation_func needs to be applied AFTER resampling the instances;
# that is, re-form the groups, calculate the function, and take the mean of the group scores
aggregation_func = self.average_item_scores

for score_name in score_names:
# If all computed instance level scores are the same, there is no point in computing
# confidence intervals. So skip to the next score.
Expand Down Expand Up @@ -1300,6 +1301,81 @@ def compute(
return results


class HuggingfaceInstanceMetric(InstanceMetric):
hf_metric_name: str

hf_metric_fields: List[str]
hf_compute_args: dict = {}

def prepare(self):
super().prepare()
self.metric = evaluate.load(
self.hf_metric_name, experiment_id=str(uuid.uuid4())
)

def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
# invokes module.compute, which invokes, e.g., meteor's _compute

try:
score = self.metric.compute(
predictions=[prediction],
references=[references],
**self.hf_compute_args,
)
except:
score = {self.main_score: np.nan}

if self.hf_metric_fields is not None and len(self.hf_metric_fields) > 0:
to_ret = {field: score[field] for field in self.hf_metric_fields}
score = to_ret

return score


class Meteor(InstanceMetric):
main_score = "meteor"
ci_scores = ["meteor"]
reduction_map = {"mean": ["meteor"]}
prediction_type = "str"

_requirements_list: List[str] = ["nltk"]
alpha: float = 0.9
beta: int = 3
gamma: float = 0.5
# unitxt uses nltk version >= 3.8

def prepare(self):
import nltk

nltk.download("wordnet", quiet=True)
nltk.download("omw-1.4", quiet=True)
from nltk import word_tokenize
from nltk.translate import meteor_score

self.word_tokenize = word_tokenize
self.meteor_score = meteor_score

def verify(self):
import importlib.metadata as importlib_metadata

from datasets.config import version

nltk_version = version.parse(importlib_metadata.version("nltk"))
assert nltk_version >= version.Version(
"3.6.6"
), "nltk version must be at least 3.6.6"

def compute(self, references, prediction, task_data):
score = self.meteor_score.meteor_score(
[self.word_tokenize(ref) for ref in references],
self.word_tokenize(prediction),
alpha=self.alpha,
beta=self.beta,
gamma=self.gamma,
)
return {"meteor": score}


class F1(GlobalMetric):
_metric = None
main_score = "f1_macro"
Expand Down Expand Up @@ -1691,16 +1767,60 @@ class F1MacroMultiLabel(F1MultiLabel):
average = None


class Rouge(HuggingfaceMetric):
class Rouge(InstanceMetric):
main_score = "rougeL"
prediction_type = "str"
single_reference_per_prediction = False # multiple references allowed
rouge_types: List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
reduction_map = {"mean": ["rouge1", "rouge2", "rougeL", "rougeLsum"]}
ci_scores = ["rouge1", "rouge2", "rougeL", "rougeLsum"]

sent_split_newline: bool = True
_requirements_list: List[str] = ["nltk", "rouge_score"]

def prepare(self):
import nltk
from rouge_score import rouge_scorer

self.rouge_scorer = rouge_scorer

nltk.download("punkt", quiet=True)
self.sent_tokenize = nltk.sent_tokenize

def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
# for a single instance, prediction is of type str, and references: list of str
if self.sent_split_newline:
prediction = "\n".join(self.sent_tokenize(prediction.strip()))

references = [
"\n".join(self.sent_tokenize(reference.strip()))
for reference in references
]

# the following is taken from HF rouge, using the defaults:
# use_aggregator=True, use_stemmer=False, tokenizer=None
scorer = self.rouge_scorer.RougeScorer(
rouge_types=self.rouge_types, use_stemmer=False, tokenizer=None
)
# with Unitxt, references is a list
score = scorer.score_multi(references, prediction)
for key in score:
score[key] = score[key].fmeasure
return score


class RougeHF(HuggingfaceInstanceMetric):
hf_metric_name = "rouge"
main_score = "rougeL"
scale = 1.0

prediction_type = "str"
single_reference_per_prediction = False # multiple references allowed

use_aggregator: bool = True
rouge_types: List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
reduction_map = {"mean": ["rouge1", "rouge2", "rougeL", "rougeLsum"]}
hf_metric_fields = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
ci_scores = ["rouge1", "rouge2", "rougeL", "rougeLsum"]

sent_split_newline: bool = True

Expand All @@ -1709,26 +1829,33 @@ class Rouge(HuggingfaceMetric):
def prepare(self):
super().prepare()

# We don't use the aggregation, to avoid running bootstrapping by the
# internal library (which is costly) and done by Unitxt in any case.
self.hf_compute_args.update(
{"use_aggregator": self.use_aggregator, "rouge_types": self.rouge_types}
{"use_aggregator": False, "rouge_types": self.rouge_types}
)

import nltk

nltk.download("punkt")
nltk.download("punkt", quiet=True)
self.sent_tokenize = nltk.sent_tokenize

def compute(self, references, predictions, task_data: List[Dict]):
def compute(self, references, prediction, task_data: List[Dict]):
# for a single instance, prediction is of type str, and references: list of str
if self.sent_split_newline:
predictions = [
"\n".join(self.sent_tokenize(prediction.strip()))
for prediction in predictions
]
prediction = "\n".join(self.sent_tokenize(prediction.strip()))

references = [
["\n".join(self.sent_tokenize(r.strip())) for r in reference]
"\n".join(self.sent_tokenize(reference.strip()))
for reference in references
]
return super().compute(references, predictions, task_data)

hf_score = super().compute(references, prediction, task_data)
for metric_field in self.hf_metric_fields:
if isinstance(hf_score[metric_field], list):
assert len(hf_score[metric_field]) == 1
hf_score[metric_field] = hf_score[metric_field][0]
return hf_score


# Computes char edit distance, ignoring whitespace
Expand Down
10 changes: 10 additions & 0 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ def verify(self):
raise ValueError(
f"max_train_instances should not exceed loader_limit ({self.loader_limit}), Got max_train_instances={self.max_train_instances}"
)
if self.metrics is not None and not isinstance(self.metrics, List):
raise ValueError(
f"metrics must be a list of metrics. Got metrics = {self.metrics}"
)
if self.postprocessors is not None and not isinstance(
self.postprocessors, List
):
raise ValueError(
f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}"
)

def prepare_refiners(self):
self.train_refiner.max_instances = self.max_train_instances
Expand Down
Loading
Loading