Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InfiAgent-DABench #1494

Merged
merged 16 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions examples/di/InfiAgent-DABench/DABench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import json
from pathlib import Path

from examples.di.requirements_prompt import DABENCH
from metagpt.const import DABENCH_PATH


# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
def evaluate_accuracy_by_question(results):
correct = sum("correctness" in result and all(result["correctness"].values()) for result in results)
total = len(results)
return round(correct / total, 4) if total > 0 else 0


def evaluate_accuracy_by_sub_question(results):
correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result)
total = sum(len(result["correctness"]) for result in results if "correctness" in result)
return round(correct / total, 4) if total > 0 else 0


def evaluate_accuracy_proportional_by_sub_question_adjusted(results):
kithib marked this conversation as resolved.
Show resolved Hide resolved
total_score = 0
for result in results:
if "correctness" in result:
sub_question_count = len(result["correctness"])
score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0
question_score = sum(result["correctness"].values()) * score_per_sub_question
total_score += question_score
return round(total_score / len(results), 4) if results else 0


class DABench:
stellaHSR marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
questions_file=Path(DABENCH_PATH) / "da-dev-questions.jsonl",
answers_file=Path(DABENCH_PATH) / "da-dev-labels.jsonl",
template="",
):
# Read questions from a JSONL file
with open(questions_file, "r") as file:
self.questions = {int(json.loads(line)["id"]): json.loads(line) for line in file}
geekan marked this conversation as resolved.
Show resolved Hide resolved

# Read answers from a JSONL file
with open(answers_file, "r") as file:
self.answers = {int(json.loads(line)["id"]): json.loads(line) for line in file}

self.template = template if template else DABENCH

def get_question(self, question_id):
"""Retrieve the question by its id."""
return self.questions.get(question_id, "Question not found.")

def get_prompt(self, question_id):
"""Retrieve the question by its id."""
temp = self.get_question(question_id)
geekan marked this conversation as resolved.
Show resolved Hide resolved
return self.template.format(
question=temp["question"],
constraints=temp["constraints"],
format=temp["format"],
file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
level=temp["level"],
)

def get_answer(self, answer_id):
"""Retrieve the answer list by its id."""
return self.answers.get(answer_id, "Answer not found.")

def eval(self, id, prediction):
stellaHSR marked this conversation as resolved.
Show resolved Hide resolved
"""Evaluate the prediction against the true label."""
true_label = self.get_answer(id)["common_answers"]
# Parse the prediction string into a dictionary of metric-value pairs
pred_dict = {}
for pred in prediction.split(","):
parts = pred.strip().split("[")
metric = parts[0].strip().replace("@", "")
value = float(parts[1].rstrip("]"))
pred_dict[metric] = value

# Sort the true labels to match the order of predictions
sorted_true_label = sorted(true_label, key=lambda x: x[0])

# Compare each prediction with the corresponding true label
correct = True
for metric, true_value in sorted_true_label:
if metric not in pred_dict or abs(pred_dict[metric] - float(true_value)) > 1e-6:
correct = False
break

return correct

def eval_all(self, id_list, predictions):
"""Evaluate all predictions and calculate accuracy rates."""

def sigle_eval(id, prediction):
geekan marked this conversation as resolved.
Show resolved Hide resolved
"""Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric."""
true_label = self.get_answer(id)["common_answers"]
pred_dict = {}

# Parse the prediction string into a dictionary of metric-value pairs
for pred in prediction.split(","):
parts = pred.strip().split("[")
metric = parts[0].strip().replace("@", "")
value = float(parts[1].rstrip("]"))
pred_dict[metric] = value

# Initialize the correctness dictionary with False values
correctness = {metric: False for metric, _ in true_label}

# Check each metric's prediction against the true label
for metric, true_value in true_label:
if metric in pred_dict:
# Consider the prediction correct if it's within a small tolerance
if abs(pred_dict[metric] - float(true_value)) < 1e-6:
correctness[metric] = True

return correctness

results = []
for id, prediction in zip(id_list, predictions):
correct = sigle_eval(id, prediction)
results.append({"id": id, "correctness": correct})

# Calculate the three accuracy rates
accuracy_by_question = evaluate_accuracy_by_question(results)
accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)

return {
"accuracy_by_question": accuracy_by_question,
"accuracy_by_sub_question": accuracy_by_sub_question,
"proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
}


if __name__ == "__main__":
DA = DABench()
geekan marked this conversation as resolved.
Show resolved Hide resolved
id = [0, 5, 6]
prediction = [
"@mean_fare[34.89]",
"@correlation_coefficient[0.21]",
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
]
print(DA.eval_all(id, prediction))
12 changes: 12 additions & 0 deletions examples/di/InfiAgent-DABench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# InfiAgent-DABench
This example is used to solve the InfiAgent-DABench using Data Interpreter (DI), and obtains 94.93% accuracy using gpt-4o.

## Dataset-install
```
git clone https://github.com/InfiAgent/InfiAgent.git
```
## How to run
```
python run_InfiAgent-DABench_sigle.py --id x # run a task
geekan marked this conversation as resolved.
Show resolved Hide resolved
python run_InfiAgent-DABench_all.py # run all tasks
```
24 changes: 24 additions & 0 deletions examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json

import fire
from DABench import DABench

from metagpt.roles.di.data_interpreter import DataInterpreter


async def main():
"""Evaluate all"""
DA = DABench()
geekan marked this conversation as resolved.
Show resolved Hide resolved
id_list, predictions = [], []
for key, value in DA.answers.items():
kithib marked this conversation as resolved.
Show resolved Hide resolved
requirement = DA.get_prompt(key)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
id_list.append(key)
predictions.append(prediction)
print(DA.eval_all(id_list, predictions))
geekan marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
fire.Fire(main)
20 changes: 20 additions & 0 deletions examples/di/InfiAgent-DABench/run_InfiAgent-DABench_sigle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json

import fire
geekan marked this conversation as resolved.
Show resolved Hide resolved
from DABench import DABench

from metagpt.roles.di.data_interpreter import DataInterpreter


async def main(id=5):
DA = DABench()
geekan marked this conversation as resolved.
Show resolved Hide resolved
requirement = DA.get_prompt(id)
di = DataInterpreter()
result = await di.run(requirement)
prediction = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"]
is_correct = DA.eval(id, prediction)
print(f"Prediction is {'correct' if is_correct else 'incorrect'}.")
geekan marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
fire.Fire(main)
3 changes: 3 additions & 0 deletions examples/di/requirements_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# InfiAgent-DABench requirements
DABENCH = "You are required to {question} from a CSV file named {file_name}. {constraints}. The output format should be {format}. This task is categorized as {level}."

# ML-Benchmark requirements
IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"
WINES_RECOGNITION_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class with 20% as test set, and show prediction accuracy"
Expand Down
1 change: 1 addition & 0 deletions metagpt/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_metagpt_root():
EXAMPLE_PATH = METAGPT_ROOT / "examples"
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
DATA_PATH = METAGPT_ROOT / "data"
DABENCH_PATH = EXAMPLE_PATH / "di/InfiAgent-DABench/InfiAgent/examples/DA-Agent/data"
geekan marked this conversation as resolved.
Show resolved Hide resolved
EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm"
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
RESEARCH_PATH = DATA_PATH / "research"
Expand Down
Loading