-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1428 from didiforgithub/action_graph
Update Action Graph Solver Version 0.1
- Loading branch information
Showing
15 changed files
with
10,194 additions
and
0 deletions.
There are no files selected for viewing
1,319 changes: 1,319 additions & 0 deletions
1,319
examples/ags/benchmark/data/gsm8k_main_test.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
7,473 changes: 7,473 additions & 0 deletions
7,473
examples/ags/benchmark/data/gsm8k_main_train.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : | ||
# @Author : issac | ||
# @Desc : test on gsm8k |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : | ||
# @Author : issac | ||
# @Desc : test on hotpotqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : 7/7/2024 17:07 PM | ||
# @Author : didi | ||
# @Desc : test on human eval graph | ||
|
||
import asyncio | ||
import json | ||
import os | ||
import subprocess | ||
import sys | ||
from typing import Literal, Optional | ||
|
||
import aiofiles | ||
from evalplus.data import get_human_eval_plus | ||
|
||
from examples.ags.w_action_node.graph import HumanEvalGraph | ||
from examples.ags.w_action_node.operator import GenerateCode, GenerateCodeBlock | ||
from examples.ags.w_action_node.utils import sort_json_by_key | ||
from metagpt.llm import LLM | ||
from metagpt.logs import logger | ||
from metagpt.utils.common import add_jsonl_file, read_json_file | ||
from metagpt.utils.exceptions import handle_exception | ||
|
||
generate_code = GenerateCode(llm=LLM()) | ||
generate_code_block = GenerateCodeBlock(llm=LLM()) | ||
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria="correctness, efficiency, readability", vote_count=5) | ||
|
||
ModeType = Literal["ags", "alpha_codium", "llm"] | ||
|
||
|
||
async def llm_generate(id): | ||
case = get_human_eval_plus()[f"{id}"] | ||
solution_result = await generate_code_block(case["prompt"], case["entry_point"]) | ||
sample_dict = dict(task_id=case["task_id"], solution=solution_result["code_solution"]) | ||
return sample_dict | ||
|
||
|
||
async def ags_generate(id, ensemble_count: int = 5): | ||
case = get_human_eval_plus()[f"{id}"] | ||
solution_result = await solver(case["prompt"], ensemble_count=ensemble_count) | ||
sample_dict = dict(task_id=case["task_id"], solution=solution_result["final_solution"]) | ||
return sample_dict | ||
|
||
|
||
async def alpha_codium_generate(id): | ||
case = get_human_eval_plus()[f"{id}"] | ||
solution_result = await solver.alpha_codium(case["task_id"], case["prompt"], ensemble_count=5) | ||
sample_dict = dict(task_id=case["task_id"], solution=solution_result["final_solution"]) | ||
return sample_dict | ||
|
||
|
||
async def route_generate(mode: ModeType, id: str): | ||
if mode == "ags": | ||
sample_dict = await ags_generate(id) | ||
elif mode == "alpha_codium": | ||
sample_dict = await alpha_codium_generate(id) | ||
elif mode == "llm": | ||
sample_dict = await llm_generate(id) | ||
else: | ||
raise ValueError(f"Invalid mode: {mode}") | ||
return sample_dict | ||
|
||
|
||
async def sample_generate(id, result_path: str = "samples.jsonl", mode: ModeType = "ags"): | ||
sample_dict = await route_generate(mode, id) | ||
add_jsonl_file(result_path, [sample_dict]) | ||
sort_json_by_key(result_path, result_path) | ||
|
||
|
||
async def samples_generate(mode: ModeType, result_path: str = "samples.jsonl"): | ||
ids = list(get_human_eval_plus().keys()) | ||
file_lock = asyncio.Lock() | ||
|
||
async def solve_and_write(id: str, mode: ModeType) -> Optional[str]: | ||
try: | ||
sample_dict = await route_generate(mode, id) | ||
except Exception: | ||
return id | ||
async with file_lock: | ||
async with aiofiles.open(result_path, mode="a") as f: | ||
await f.write(json.dumps(sample_dict) + "\n") | ||
return None | ||
|
||
tasks = [solve_and_write(id, mode) for id in ids] | ||
results = await asyncio.gather(*tasks) | ||
failed_tasks = [task_id for task_id in results if task_id is not None] | ||
|
||
if failed_tasks: | ||
logger.info(failed_tasks) | ||
for task_id in failed_tasks: | ||
try: | ||
await sample_generate(task_id, result_path, mode) | ||
failed_tasks.remove(task_id) | ||
except Exception: | ||
logger.error(f"{task_id} fail") | ||
|
||
sort_json_by_key(result_path, result_path) | ||
|
||
if not failed_tasks: | ||
if automatic_evalplus(result_path): | ||
eval_path = result_path[:-6] + "_eval_results.json" | ||
unpassed_exapmle = extract_failure_tests(eval_path) | ||
logger.info(unpassed_exapmle) | ||
else: | ||
logger.info(failed_tasks) | ||
|
||
|
||
@handle_exception(exception_type=subprocess.CalledProcessError, exception_msg="sanitize error", default_return=None) | ||
def automatic_sanitize(result_path: str = "samples.jsonl") -> Optional[str]: | ||
""" | ||
在命令行中自动执行 evalplus.sanitize --samples result_path | ||
返回result_path前缀加上"-sanitized.jsonl" | ||
""" | ||
command = ["evalplus.sanitize", "--samples", result_path] | ||
|
||
subprocess.run(command, check=True) | ||
|
||
base_name = os.path.splitext(result_path)[0] | ||
sanitized_path = f"{base_name}-sanitized.jsonl" | ||
|
||
return sanitized_path | ||
|
||
|
||
@handle_exception( | ||
exception_type=subprocess.CalledProcessError, | ||
exception_msg="Error in automatic_evalplus function", | ||
default_return=False, | ||
) | ||
def automatic_evalplus(result_path: str = "samples.jsonl") -> bool: | ||
""" | ||
在命令行中自动执行 evalplus.evaluate --dataset humaneval --samples samples.jsonl --parallel 2 --base-only | ||
""" | ||
command = [ | ||
sys.executable, # 使用当前 Python 解释器 | ||
"-m", | ||
"evalplus.evaluate", | ||
"--dataset", | ||
"humaneval", | ||
"--samples", | ||
result_path, | ||
"--parallel", | ||
"2", | ||
"--base-only", | ||
] | ||
|
||
result = subprocess.run(command, check=True, capture_output=True, text=True) | ||
logger.info(f"ouptput: \n {result.stdout}") | ||
return True | ||
|
||
|
||
def extract_failure_tests(file_path: str = "samples_eval_results.json"): | ||
task_results = read_json_file(file_path) | ||
|
||
failed_tests = [] | ||
for task in task_results["eval"].values(): | ||
if task[0]["base_status"] == "fail": | ||
failed_test = { | ||
"task_id": task[0]["task_id"], | ||
} | ||
failed_tests.append(failed_test) | ||
logger.info(f"length of failed tests: {len(failed_tests)}") | ||
|
||
return failed_tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Date : 6/27/2024 22:07 PM | ||
# @Author : didi | ||
# @Desc : graph & an instance - humanevalgraph | ||
|
||
from typing import List | ||
|
||
from evalplus.data import get_human_eval_plus | ||
|
||
from examples.ags.w_action_node.operator import ( | ||
FuEnsemble, | ||
Generate, | ||
GenerateCode, | ||
GenerateCodeBlock, | ||
MdEnsemble, | ||
Rephrase, | ||
Review, | ||
Revise, | ||
Test, | ||
) | ||
from examples.ags.w_action_node.utils import extract_test_cases_from_jsonl | ||
from metagpt.llm import LLM | ||
|
||
|
||
class Graph: | ||
def __init__(self, name: str, llm: LLM) -> None: | ||
self.name = name | ||
self.model = llm | ||
|
||
def __call__(): | ||
NotImplementedError("Subclasses must implement __call__ method") | ||
|
||
def optimize(dataset: List): | ||
pass | ||
|
||
|
||
class HumanEvalGraph(Graph): | ||
def __init__(self, name: str, llm: LLM, criteria: str, vote_count: int = 5) -> None: | ||
super().__init__(name, llm) | ||
self.criteria = criteria # TODO 自动构建图时,图的初始参数与图所使用的算子要求的外部参数相匹配 | ||
self.generate_code = GenerateCode(llm=llm) | ||
self.generate_code_block = GenerateCodeBlock(llm=llm) | ||
self.review = Review(llm=llm, criteria=criteria) | ||
self.revise = Revise(llm=llm) | ||
self.rephrase = Rephrase(llm=llm) | ||
self.tester = Test(llm=llm) | ||
self.fuensemble = FuEnsemble(llm=llm) | ||
self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count) | ||
|
||
async def __call__(self, problem: str, ensemble_count: int = 3): | ||
solution_list = [] | ||
for _ in range(ensemble_count): | ||
solution = await self.generate_code_block(problem) | ||
solution = solution.get("code_solution") | ||
solution_list.append(solution) | ||
solution = await self.mdensemble("code", solution_list, problem) | ||
return solution | ||
|
||
async def alpha_codium(self, problem_id: str, problem: str, ensemble_count: int = 3): | ||
""" | ||
Paper: Code Generation with AlphaCodium: From Prompt Engineering to Flow Engineering | ||
Link: https://arxiv.org/abs/2404.14963 | ||
Flow: An incomplete version of alpha codium, implementing the basic process of rephrase -> code ensemble -> tes | ||
""" | ||
test_cases = extract_test_cases_from_jsonl(problem_id) | ||
entry_point = get_human_eval_plus()[problem_id]["entry_point"] | ||
rephrase_problem = await self.rephrase(problem) # 在rephrase 中拼接原始的问题描述 | ||
solution_list = [] | ||
for _ in range(ensemble_count): | ||
solution = await self.generate_code_block.rephrase_generate( | ||
problem, rephrase_problem, function_name=entry_point | ||
) | ||
solution = solution.get("code_solution") | ||
solution_list.append(solution) | ||
solution = await self.mdensemble("code", solution_list, problem) | ||
solution = await self.tester(problem_id, problem, rephrase_problem, solution, test_cases) | ||
return solution | ||
|
||
async def review_revise_ensemble(self, problem: str, ensemble_count: int = 2, revise_round: int = 3): | ||
solution_list = [] | ||
for _ in range(ensemble_count): | ||
solution = await self.single_solve(problem, revise_round) | ||
solution_list.append(solution) | ||
solution = await self.ensemble(solution_list, problem) | ||
return solution | ||
|
||
async def simple_ensemble(self, problem: str, ensemble_count: int = 3): | ||
solution_list = [] | ||
for _ in range(ensemble_count): | ||
solution = await self.generate_code(problem) | ||
# solution = await self.generate_code_block(problem) | ||
solution = solution.get("code_solution") | ||
solution_list.append(solution) | ||
solution = await self.fuensemble(solution_list, problem) | ||
return solution | ||
|
||
async def single_solve(self, problem: str, max_loop: int): | ||
solution = await self.generate_code(problem) | ||
solution = solution.get("code_solution") | ||
for _ in range(max_loop): | ||
review_feedback = await self.review(problem, solution) | ||
if review_feedback["review_result"]: | ||
break | ||
solution = await self.revise(problem, solution, review_feedback["feedback"]) | ||
solution = solution.get("revised_solution") | ||
return solution | ||
|
||
|
||
class Gsm8kGraph(Graph): | ||
def __init__(self, name: str, llm: LLM) -> None: | ||
super().__init__(name, llm) | ||
self.generate = Generate(llm=llm) | ||
self.rephrase = Rephrase(llm=llm) | ||
|
||
async def __call__(self, problem: str): | ||
solution = self.generate(problem) | ||
return solution | ||
|
||
|
||
class HotpotQAGraph(Graph): | ||
def __init__(self, name: str, llm: LLM) -> None: | ||
super().__init__(name, llm) | ||
self.generate = Generate(llm=llm) | ||
self.rephrase = Rephrase(llm=llm) | ||
|
||
async def __call__(self, problem: str): | ||
solution = self.generate(problem) | ||
return solution |
Oops, something went wrong.