Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Action Graph Solver Version 0.1 #1428

Merged
merged 5 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,319 changes: 1,319 additions & 0 deletions examples/ags/benchmark/data/gsm8k_main_test.jsonl

Large diffs are not rendered by default.

7,473 changes: 7,473 additions & 0 deletions examples/ags/benchmark/data/gsm8k_main_train.jsonl

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/ags/benchmark/data/hotpot.json

Large diffs are not rendered by default.

70 changes: 70 additions & 0 deletions examples/ags/benchmark/gsm8k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# @Date :
# @Author : issac
# @Desc : test on gsm8k

import json
import re
import os

# 读取原始数据集
def read_jsonl(path: str):
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
with open(path) as fh:
return [json.loads(line) for line in fh.readlines() if line]

# 和图/和基础模型直接交互得到答案
def LLM(question):
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
answer = ""
# 这里就是输入问题question返回答案answer
# answer = 根据question生成的回答
return answer

def gsm_extract_answer(completion):
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"

match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
else:
return INVALID_ANS


def gsm_is_correct(data):
INVALID_ANS = "[invalid]"

gt_answer = gsm_extract_answer(data["answer"])
assert gt_answer != INVALID_ANS
return gsm_extract_answer(data["answer_llm"]) == gt_answer
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved


# 提取数据集并得到测试答案
def get_examples(split):
path = os.path.join("", f"{split}.jsonl")
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
output_path = "gsm8k_generate.jsonl"
examples = read_jsonl(path)

processed_examples = [] # 用于存储处理后的样本

for ex in examples:
answer_llm = LLM(ex['question'])
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
ex['answer_llm'] = answer_llm
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
ex['is_correct'] = gsm_is_correct(ex)
# 将处理后的样本添加到列表中
processed_examples.append(ex)

# 将处理后的样本写入到新的 JSONL 文件
with open(output_path, 'w', encoding='utf-8') as f:
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
for example in processed_examples:
# 将字典转换为 JSON 格式的字符串,并写入新行
json_line = json.dumps(example) + '\n'
f.write(json_line)

print(f"{len(examples)} {split} examples")
return examples

if __name__ == "__main__":
example = get_examples("gsm")
print(example[:5])
155 changes: 155 additions & 0 deletions examples/ags/benchmark/hotpotQA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# -*- coding: utf-8 -*-
# @Date :
# @Author : issac
# @Desc : test on hotpotqa

import sys
import json
import re
import string
from collections import Counter
import pickle

def normalize_answer(s):

def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)

ZERO_METRIC = (0, 0, 0)

if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC

prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return ZERO_METRIC
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall


def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))

def update_answer(metrics, prediction, gold):
em = exact_match_score(prediction, gold)
f1, prec, recall = f1_score(prediction, gold)
metrics['em'] += float(em)
metrics['f1'] += f1
metrics['prec'] += prec
metrics['recall'] += recall
return em, prec, recall

def update_sp(metrics, prediction, gold):
cur_sp_pred = set(map(tuple, prediction))
gold_sp_pred = set(map(tuple, gold))
tp, fp, fn = 0, 0, 0
for e in cur_sp_pred:
if e in gold_sp_pred:
tp += 1
else:
fp += 1
for e in gold_sp_pred:
if e not in cur_sp_pred:
fn += 1
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
em = 1.0 if fp + fn == 0 else 0.0
metrics['sp_em'] += em
metrics['sp_f1'] += f1
metrics['sp_prec'] += prec
metrics['sp_recall'] += recall
return em, prec, recall

def LLM(question):
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
answer = ""
# 这里就是输入问题question返回答案answer
# answer = 根据question生成的回答
return answer

def eval(prediction_file, gold_file):
with open(prediction_file) as f:
prediction = json.load(f)
with open(gold_file) as f:
gold = json.load(f)

metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
for dp in gold:
cur_id = dp['_id']
can_eval_joint = True
if cur_id not in prediction['answer']:
print('missing answer {}'.format(cur_id))
can_eval_joint = False
else:
em, prec, recall = update_answer(
metrics, prediction['answer'][cur_id], dp['answer'])

N = len(gold)
for k in metrics.keys():
metrics[k] /= N

print(metrics)

def LLM(question):
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
answer = question
# 这里就是输入问题question返回答案answer
# answer = 根据question生成的回答
return answer

def answer(prediction_file, gold_file):
with open(gold_file) as f:
gold = json.load(f)
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved

# 初始化预测字典,包含 answer 和 sp 两个键,初始为空字典
prediction = {'answer': {}}

for dp in gold:
cur_id = dp['_id']
paragraphs = [item[1] for item in dp['context'] if isinstance(item[1], list)] # 确保 item[1] 是列表
# 将所有文本段落连接成一个字符串
context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs)
question = dp['question']

# 构建输入字符串
input_llm = f"question:{question}\n\ncontext:{context_str}"

# 假设 LLM 是一个函数,返回模型的预测答案
response = LLM(input_llm)

# 将预测答案存储在字典中,键为 cur_id
prediction['answer'][cur_id] = response

# 将预测结果写入文件
with open(prediction_file, 'w') as f:
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
json.dump(prediction, f)


if __name__ == '__main__':
answer('hotpot_pre.json', 'your path here')
eval('hotpot_pre.json', 'your path here')
171 changes: 171 additions & 0 deletions examples/ags/benchmark/humaneval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
# @Date : 7/7/2024 17:07 PM
# @Author : didi
# @Desc : test on human eval graph

import os
import json
import subprocess
import sys
import asyncio
import aiofiles
from metagpt.llm import LLM
from evalplus.data import get_human_eval_plus
from examples.ags.w_action_node.utils import jsonl_ranker
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
from examples.ags.w_action_node.graph import HumanEvalGraph
from examples.ags.w_action_node.operator import GenerateCode, GenerateCodeBlock

generate_code = GenerateCode(llm=LLM())
generate_code_block = GenerateCodeBlock(llm=LLM())
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=5)

async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"):
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
case = get_human_eval_plus()[f"{id}"]
if mode == "ags":
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
solution_result = await solver(case['prompt'],ensemble_count=5)
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
elif mode == "alpha":
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
elif mode == "llm":
solution_result = await generate_code_block(case['prompt'],case['entry_point'])
sample_dict = dict(task_id=case['task_id'], solution=solution_result['code_solution'])
print(sample_dict)
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
with open(result_path, mode='a') as f:
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
f.write(json.dumps(sample_dict) + '\n')
jsonl_ranker(result_path, result_path)

async def samples_generate(mode:str, result_path:str="samples.jsonl"):
cases = list(get_human_eval_plus().values())
file_lock = asyncio.Lock()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议就是带上时间戳等信息,不要使用 file_lock


async def solve_and_write(case, mode):
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
try:
if mode == 'llm':
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
solution_result = await generate_code_block(problem_description=case['prompt'], function_name=case['entry_point'])
# solution_result = await generate_code(case['prompt'])
sample_dict = {
'task_id': case['task_id'],
'solution': solution_result['code_solution']
}
elif mode == "ags":
solution_result = await solver(case['prompt'], ensemble_count=5)
sample_dict = {
'task_id': case['task_id'],
'solution': solution_result['final_solution']
}
elif mode == "alpha":
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
sample_dict = {
'task_id': case['task_id'],
'solution': solution_result['final_solution']
}
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
# TODO 解决 final_solution 问题之后就可以开始正式测评了
async with file_lock:
async with aiofiles.open(result_path, mode='a') as f:
await f.write(json.dumps(sample_dict) + '\n')
return None

except Exception as e:
print(e)
return case['task_id']

tasks = [solve_and_write(case, mode) for case in cases]
results = await asyncio.gather(*tasks)
failed_tasks = [task_id for task_id in results if task_id is not None]

if failed_tasks:
print(failed_tasks)
if mode == 'llm':
for task_id in failed_tasks:
case = get_human_eval_plus()[task_id]
for _ in range(3):
try:
solution_result = await generate_code_block(case['prompt'],function_name=case['entry_point'])
task_dict = {
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
'task_id': case['task_id'],
'solution': solution_result['code_solution']
}
with open(result_path, mode='a') as f:
f.write(json.dumps(task_dict) + '\n')
failed_tasks.remove(task_id)
break
except Exception as e:
print(f"{e} \n failure {task_id}")
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
elif mode == "ags" or mode == "alpha":
for task_id in failed_tasks:
try:
await sample_generate(task_id,result_path,mode)
except Exception as e:
print(f"failure {task_id}")
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved

jsonl_ranker(result_path, result_path)

if not failed_tasks:
# 自动 sanitize
# result_path = automatic_sanitize(result_path)
if automatic_evalplus(result_path):
eval_path = result_path[:-6]+"_eval_results.json"
unpassed_exapmle = extract_failure_tests(eval_path)
print(unpassed_exapmle)
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
else:
print(failed_tasks)
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved

def automatic_sanitize(result_path: str = "samples.jsonl"):
"""
在命令行中自动执行 evalplus.sanitize --samples result_path
返回result_path前缀加上"-sanitized.jsonl"
"""
command = ["evalplus.sanitize", "--samples", result_path]

try:
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
print(f"执行命令时出错: {e}")
return None
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved

# 构建sanitized文件路径
base_name = os.path.splitext(result_path)[0]
sanitized_path = f"{base_name}-sanitized.jsonl"

return sanitized_path

def automatic_evalplus(result_path:str ="samples.jsonl"):
"""
在命令行中自动执行 evalplus.evaluate --dataset humaneval --samples samples.jsonl --parallel 2 --base-only
"""
command = [
sys.executable, # 使用当前 Python 解释器
"-m",
"evalplus.evaluate",
"--dataset", "humaneval",
"--samples", result_path,
"--parallel", "2",
"--base-only"
]

try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
print("输出:", result.stdout)
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
return True
except subprocess.CalledProcessError as e:
print("错误输出:", e.stderr)
return False
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved

def extract_failure_tests(file_path:str = "samples_eval_results.json"):
with open(file_path, 'r') as f:
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved
task_results = json.load(f)

failed_tests = []

for task in task_results['eval'].values():
if task[0]["base_status"] == "fail":
failed_test = {
"task_id": task[0]["task_id"],
# "solution": task["solution"],
# "fail_tests": task["base_fail_tests"]
}
failed_tests.append(failed_test)
print(len(failed_tests))
didiforgithub marked this conversation as resolved.
Show resolved Hide resolved

return failed_tests
Loading
Loading