-
Notifications
You must be signed in to change notification settings - Fork 263
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import json | ||
from json.decoder import JSONDecodeError | ||
from typing import Any | ||
|
||
from helm.benchmark.scenarios.scenario import CORRECT_TAG | ||
from helm.benchmark.adaptation.request_state import RequestState | ||
from helm.benchmark.annotation.annotator import Annotator | ||
from helm.clients.auto_client import AutoClient | ||
from helm.common.request import Request | ||
|
||
|
||
class FinanceBenchAnnotator(Annotator): | ||
"""Annoator for FinanceBench that uses GPT-4o to determine if the model response is correct.""" | ||
|
||
name = "financebench" | ||
_PROMPT_TEMPLATE = """Classify the model's response as one of three categories: "correct_answer", "incorrect_answer", or "failure_to_answer". Additionally, provide a short, one-sentence explanation for your classification. | ||
Categories: | ||
correct_answer: Allow minor deviations, such as giving the answer in billions when the unit was given in the question as millions. | ||
incorrect_answer: This includes calculations that are off by small margins to several orders of magnitude, and from making up legal information to giving the wrong direction for an effect (e.g. reporting negative growth when it is actually positive). If a model gives the right answer but with logic or calculations that explicitly contradict the evidence in the gold standard answer, label it as incorrect_answer. | ||
failure_to_answer: If the model explicitly states that it cannot answer because it does not have access to the right information then it is a failure to answer. | ||
Question: {{QUESTION}} | ||
Gold answer: {{GOLD_ANSWER}} | ||
Model's response: {{MODEL_RESPONSE}} | ||
Respond with only a raw JSON object in the following format, without using Markdown formatting: | ||
{"explanation": "<one sentence explanation>", "label": "<category>"} | ||
""" # noqa: E501 | ||
|
||
def __init__(self, auto_client: AutoClient, file_storage_path: str): | ||
super().__init__() | ||
self._auto_client = auto_client | ||
|
||
def annotate(self, request_state: RequestState) -> Any: | ||
assert request_state.result | ||
assert len(request_state.result.completions) == 1 | ||
assert len(request_state.instance.references[0].tags) == 1 | ||
assert request_state.instance.references[0].tags[0] == CORRECT_TAG | ||
question = request_state.instance.input.text.split("\nQuestion: ")[-1].strip() | ||
gold_answer = request_state.instance.references[0].output.text.strip() | ||
model_response = request_state.result.completions[0].text.strip() | ||
if not model_response.strip(): | ||
return {"reasoning": "BLOCKED_REQUEST_OR_EMPTY_RESPONSE", "label": "failure_to_answer"} | ||
annotator_prompt = ( | ||
FinanceBenchAnnotator._PROMPT_TEMPLATE.replace("{{QUESTION}}", question) | ||
.replace("{{GOLD_ANSWER}}", gold_answer) | ||
.replace("{{MODEL_RESPONSE}}", model_response) | ||
) | ||
annotator_request = Request( | ||
model="openai/gpt-4o-2024-05-13", | ||
model_deployment="openai/gpt-4o-2024-05-13", | ||
prompt=annotator_prompt, | ||
temperature=0.0, | ||
max_tokens=64, | ||
) | ||
annotator_response = self._auto_client.make_request(annotator_request) | ||
if not annotator_response.success: | ||
raise Exception(f"Annotation request failed: {annotator_response.error}") | ||
assert len(annotator_response.completions) == 1 | ||
annotator_response_text = annotator_response.completions[0].text | ||
# OpenAI models like to surround JSON objects with ```json and ``` Markdown formatting. | ||
# This strips everything outside the outermost {} brackets. | ||
json_start_index = annotator_response_text.find("{") | ||
json_end_index = annotator_response_text.rfind("}") | ||
if json_start_index < 0 or json_end_index < 0: | ||
raise Exception(f"Malformed annotator response: {annotator_response_text}") | ||
annotator_response_json = annotator_response_text[json_start_index : json_end_index + 1] | ||
try: | ||
annotator_response_parsed = json.loads(annotator_response_json) | ||
except JSONDecodeError: | ||
raise Exception(f"Malformed annotator response: {annotator_response_text}") | ||
return annotator_response_parsed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import List | ||
|
||
from helm.benchmark.adaptation.adapter_spec import AdapterSpec | ||
from helm.benchmark.adaptation.request_state import RequestState | ||
from helm.benchmark.metrics.metric import Metric | ||
from helm.benchmark.metrics.metric_name import MetricName | ||
from helm.benchmark.metrics.metric_service import MetricService | ||
from helm.benchmark.metrics.statistic import Stat | ||
|
||
|
||
class AnnotationLabelMetric(Metric): | ||
"""Binary metric for labels produced by annotators. | ||
Expects the annotation with the given annotator name and key to be a string label. | ||
For each possible label in the list of possible labels, produces a | ||
corresponding stat with a value of 1 or 0 indicating if the actual label | ||
in the annoation.""" | ||
|
||
def __init__(self, annotator_name: str, key: str, labels: List[str]): | ||
super().__init__() | ||
self.annotator_name = annotator_name | ||
self.key = key | ||
self.labels = labels | ||
|
||
def evaluate_generation( | ||
self, | ||
adapter_spec: AdapterSpec, | ||
request_state: RequestState, | ||
metric_service: MetricService, | ||
eval_cache_path: str, | ||
) -> List[Stat]: | ||
assert request_state.annotations | ||
annotation_label = request_state.annotations[self.annotator_name][self.key] | ||
if annotation_label not in self.labels: | ||
raise ValueError( | ||
f"Unrecognized annotation label '{annotation_label}' " | ||
f"(known labels: {self.labels}) " | ||
f"in annotation {request_state.annotations[self.annotator_name]} " | ||
f"for instance id {request_state.instance.id}" | ||
) | ||
stats: List[Stat] = [] | ||
for label in self.labels: | ||
stats.append( | ||
Stat(MetricName(f"annotation_{self.annotator_name}_{self.key}_{label}")).add( | ||
1 if label == annotation_label else 0 | ||
) | ||
) | ||
return stats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import dataclasses | ||
import json | ||
import os | ||
import random | ||
from typing import List | ||
|
||
from helm.benchmark.scenarios.scenario import ( | ||
CORRECT_TAG, | ||
TRAIN_SPLIT, | ||
Scenario, | ||
Instance, | ||
Reference, | ||
TEST_SPLIT, | ||
Input, | ||
Output, | ||
) | ||
from helm.common.general import ensure_directory_exists, ensure_file_downloaded | ||
|
||
|
||
class FinanceBenchScenario(Scenario): | ||
"""FinanceBench""" | ||
|
||
name = "financebench" | ||
description = "FinanceBench" | ||
tags = ["finance"] | ||
|
||
def get_instances(self, output_path: str) -> List[Instance]: | ||
cache_dir = os.path.join(output_path, "data") | ||
ensure_directory_exists(cache_dir) | ||
target_path = os.path.join(cache_dir, "financebench_open_source.jsonl") | ||
url: str = ( | ||
"https://raw.githubusercontent.com/patronus-ai/financebench/d7beebe5e739e0b806ab4443c1b3e23f51804acf/data/financebench_open_source.jsonl" # noqa: E501 | ||
) | ||
ensure_file_downloaded(source_url=url, target_path=target_path) | ||
|
||
instances: List[Instance] = [] | ||
with open(target_path) as f: | ||
for line in f: | ||
row = json.loads(line) | ||
instance_id = row["financebench_id"] | ||
question = row["question"] | ||
answer = row["answer"] | ||
evidence = row["evidence"][0]["evidence_text_full_page"] | ||
input_text = f"Evidence: {evidence}\nQuestion: {question}" | ||
input = Input(text=input_text) | ||
references = [Reference(output=Output(text=answer), tags=[CORRECT_TAG])] | ||
instance = Instance(id=instance_id, input=input, references=references, split=TEST_SPLIT) | ||
instances.append(instance) | ||
random.seed(0) | ||
train_indexes = random.sample(list(range(len(instances))), k=10) | ||
for train_index in train_indexes: | ||
instances[train_index] = dataclasses.replace(instances[train_index], split=TRAIN_SPLIT) | ||
return instances |
26 changes: 26 additions & 0 deletions
26
src/helm/benchmark/scenarios/test_financebench_scenario.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import pytest | ||
from tempfile import TemporaryDirectory | ||
|
||
from helm.benchmark.scenarios.financebench_scenario import FinanceBenchScenario | ||
from helm.benchmark.scenarios.scenario import CORRECT_TAG, TEST_SPLIT, TRAIN_SPLIT | ||
|
||
|
||
@pytest.mark.scenarios | ||
def test_air_2024_scenario_get_instances(): | ||
scenario = FinanceBenchScenario() | ||
with TemporaryDirectory() as tmpdir: | ||
instances = scenario.get_instances(tmpdir) | ||
assert len(instances) == 150 | ||
assert len([instance for instance in instances if instance.split == TRAIN_SPLIT]) == 10 | ||
assert ( | ||
"Evidence: Table of Contents \n3M Company and Subsidiaries\nConsolidated Statement of Cash Flow s\n" # noqa: E501 | ||
in instances[0].input.text | ||
) | ||
assert ( | ||
"Question: What is the FY2018 capital expenditure amount (in USD millions) for 3M? Give a response to the question by relying on the details shown in the cash flow statement." # noqa: E501 | ||
in instances[0].input.text | ||
) | ||
assert len(instances[0].references) == 1 | ||
assert instances[0].references[0].output.text == "$1577.00" | ||
assert instances[0].references[0].tags == [CORRECT_TAG] | ||
assert instances[0].split == TEST_SPLIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters