Skip to content

Commit

Permalink
Implement metrics ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
eladven committed Jul 23, 2024
1 parent 3d1b3ea commit ceee695
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 4 deletions.
13 changes: 12 additions & 1 deletion docs/docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Evaluate the quality of an LLM as judge

Demonstrates how to evaluate an LLM as judge by checking its scores using the gold references of a dataset.
It checks if the judge consistently prefers correct outputs over clearly wrong ones.
Note that to check the the ability of the LLM as judge to discern sutble differences between
Note that to check the the ability of the LLM as judge to discern suitable differences between
partially correct answers requires more refined tests and corresponding labeled data.
The example shows an 8b llama based judge is not a good judge for a summarization task,
while the 70b model performs much better.
Expand Down Expand Up @@ -122,5 +122,16 @@ The model is evaluated on its capability to give a judgment that is in correlati

Related documentation: :ref:`Evaluate a Model on Arena Hard Benchmark <arena_hard_evaluation>`.

Evaluate using ensemble of LLM as a judge metrics
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Demonstrates how to create a metric which is an ensemble of LLM as a judge metrics.
The example shows how to ensemble two judges which uses different templates.

`Example code <https://github.com/IBM/unitxt/blob/main/examples/evaluate_using_metrics_ensemble.py>`_

Related documentation: :ref:`LLM as a Judge Metrics Guide <llm_as_judge>`.




50 changes: 50 additions & 0 deletions examples/evaluate_using_metrics_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unitxt import get_logger
from unitxt.api import evaluate, load_dataset
from unitxt.inference import (
HFPipelineBasedInferenceEngine,
)
from unitxt.metrics import MetricsEnsemble
from unitxt.text_utils import print_dict

logger = get_logger()

# define the metrics ensemble
ensemble_metric = MetricsEnsemble(
metrics=[
"metrics.llm_as_judge.rating.llama_3_70b_instruct_ibm_genai_template_generic_single_turn",
"metrics.llm_as_judge.rating.llama_3_70b_instruct_ibm_genai_template_mt_bench_single_turn",
],
weights=[0.75, 0.25],
)
# Use the HF load_dataset API, to load the squad QA dataset using the standard template in the catalog.
# We set loader_limit to 20 to reduce download time.
dataset = load_dataset(
card="cards.squad",
template="templates.qa.with_context.simple",
metrics=[ensemble_metric],
loader_limit=20,
)
test_dataset = dataset["test"]

# Infer a model to get predictions.
model_name = "google/flan-t5-base"
inference_model = HFPipelineBasedInferenceEngine(
model_name=model_name, max_new_tokens=32
)
predictions = inference_model.infer(test_dataset)

# Evaluate the predictions using the defined metric.
evaluated_dataset = evaluate(predictions=predictions, data=test_dataset)

# Print results
for instance in evaluated_dataset:
print_dict(
instance,
keys_to_print=[
"source",
"prediction",
"processed_prediction",
"references",
"score",
],
)
62 changes: 60 additions & 2 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import field
from operator import itemgetter
from statistics import mean
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

import evaluate
import numpy
Expand All @@ -19,7 +19,7 @@
from scipy.stats import bootstrap
from scipy.stats._warnings_errors import DegenerateDataWarning

from .artifact import Artifact
from .artifact import Artifact, fetch_artifact
from .dataclass import (
AbstractField,
InternalField,
Expand Down Expand Up @@ -4525,3 +4525,61 @@ def _prepare_instances_for_model(self, texts: List[str]):
)
processed_stream = self.processor.process(stream)
return processed_stream.to_dataset()["test"]


class MetricsEnsemble(InstanceMetric):
"""Metrics Ensemble class for creating ensemble of given metrics.
Attributes:
main_score (str): The main score label used for evaluation.
metrics (List[Union[Metric, str]]): List of metrics that will be ensemble.
weights (List[float]): Weight of each the metrics
InstanceMetric currently allows two reductions:
reduction_map (Dict[str, List[str]]. Parameter for specifying the redaction method of the global score.
(see it definition at InstanceMetric class). This class define its default
value to reduce by the mean of the main score.
"""

main_score = "ensemble_score"
reduction_map = {"mean": [main_score]}
metrics: List[Union[Metric, str]]
weights: List[float] = None

def get_prefix_name(self, i):
return f"ensemble_{i}_"

def prepare(self):
super().prepare()
self.metrics = [fetch_artifact(metric)[0] for metric in self.metrics]
for i, metric in enumerate(self.metrics):
metric.score_prefix = self.get_prefix_name(i)
if self.weights is None:
self.weights = [1 / len(self.metrics) for _ in range(len(self.metrics))]

def create_ensemble_scores(self, instance):
score = self.ensemble(instance)
instance[
"prediction"
] = score # We use here the prediction field to pass the score to the compute method.
return instance

def ensemble(self, instance):
score = 0
for i, (metric, weight) in enumerate(zip(self.metrics, self.weights)):
score += (
instance["score"]["instance"][
self.get_prefix_name(i) + metric.main_score
]
* weight
)
return score

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
for metric in self.metrics:
stream = list(metric.process(stream=stream, stream_name=stream_name))
stream = [self.create_ensemble_scores(g) for g in stream]
return super().process(stream=stream, stream_name=stream_name)

def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
return {self.main_score: prediction}
1 change: 1 addition & 0 deletions tests/library/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_examples(self):
"evaluate_a_judge_model_capabilities_on_arena_hard.py",
"evaluate_a_model_using_arena_hard.py",
"evaluate_llm_as_judge.py",
"evaluate_using_metrics_ensemble.py",
]
for file in all_example_files:
logger.info(
Expand Down
69 changes: 68 additions & 1 deletion tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
KendallTauMetric,
LlamaIndexCorrectness,
MaxAccuracy,
MetricsEnsemble,
NormalizedSacrebleu,
Perplexity,
PrecisionBinary,
Expand All @@ -52,7 +53,7 @@
TokenOverlap,
UnsortedListExactMatch,
)
from unitxt.test_utils.metrics import apply_metric, check_scores
from unitxt.test_utils.metrics import apply_metric, check_scores, test_metric

from tests.utils import UnitxtTestCase

Expand Down Expand Up @@ -1663,3 +1664,69 @@ def test_fin_qa_eval(self):

for i in range(len(actual_scores)):
self.assertAlmostEqual(actual_scores[i], target_scores[i])

def test_metrics_ensemble(self):
metric = MetricsEnsemble(
main_score="ensemble_score",
metrics=[
"metrics.precision_micro_multi_label",
"metrics.recall_macro_multi_label",
],
weights=None,
)

predictions = [["A"], ["B"], [""], ["A"]]
references = [[["B", "A"]], [["B"]], [["A"]], [[""]]]

instance_targets = [
{
"ensemble_score": 0.75,
"ensemble_0_precision_micro": 1.0,
"ensemble_1_recall_macro": 0.5,
"score": 0.75,
"score_name": "ensemble_score",
},
{
"ensemble_score": 1.0,
"ensemble_0_precision_micro": 1.0,
"ensemble_1_recall_macro": 1.0,
"score": 1.0,
"score_name": "ensemble_score",
},
{
"ensemble_score": 0.0,
"ensemble_0_precision_micro": 0.0,
"ensemble_1_recall_macro": 0.0,
"score": 0.0,
"score_name": "ensemble_score",
},
{
"ensemble_score": 0.0,
"ensemble_0_precision_micro": 0.0,
"ensemble_1_recall_macro": 0.0,
"score": 0.0,
"score_name": "ensemble_score",
},
]

global_target = {
"ensemble_0_precision_micro": 0.5,
"ensemble_0_precision_micro_ci_high": 1.0,
"ensemble_0_precision_micro_ci_low": 0.0,
"ensemble_1_recall_macro": 0.33,
"ensemble_1_recall_macro_ci_high": 0.56,
"ensemble_1_recall_macro_ci_low": 0.0,
"ensemble_score": 0.44,
"score": 0.44,
"score_ci_high": 0.56,
"score_ci_low": 0.0,
"score_name": "ensemble_score",
}

test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

0 comments on commit ceee695

Please sign in to comment.