Skip to content

Commit

Permalink
Update Operator & Benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
didiforgithub committed Oct 21, 2024
1 parent fe3fca5 commit 2d1d7ca
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 53 deletions.
2 changes: 1 addition & 1 deletion examples/aflow/benchmark/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def calculate_score(self, expected_output: float, prediction: float) -> Tuple[fl
async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, float, float, float]:
max_retries = 5
retries = 0

while retries < max_retries:
try:
prediction, cost = await graph(problem["question"])
Expand Down
64 changes: 49 additions & 15 deletions examples/aflow/data/download_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# -*- coding: utf-8 -*-
# @Date : 2024-10-20
# @Author : MoshiQAQ & didi
# @Desc : Download and extract dataset files

import os
import requests
import tarfile
from tqdm import tqdm
from typing import List, Dict

def download_file(url, filename):
def download_file(url: str, filename: str) -> None:
"""Download a file from the given URL and show progress."""
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
Expand All @@ -15,22 +22,49 @@ def download_file(url, filename):
progress_bar.update(size)
progress_bar.close()

def extract_tar_gz(filename, extract_path):
def extract_tar_gz(filename: str, extract_path: str) -> None:
"""Extract a tar.gz file to the specified path."""
with tarfile.open(filename, 'r:gz') as tar:
tar.extractall(path=extract_path)

url = "https://drive.google.com/uc?export=download&id=1tXp5cLw89egeKRwDuood2TPqoEWd8_C0"
filename = "aflow_data.tar.gz"
extract_path = "./"

print("Downloading data file...")
download_file(url, filename)

print("Extracting data file...")
extract_tar_gz(filename, extract_path)
def process_dataset(url: str, filename: str, extract_path: str) -> None:
"""Download, extract, and clean up a dataset."""
print(f"Downloading {filename}...")
download_file(url, filename)

print(f"Extracting {filename}...")
extract_tar_gz(filename, extract_path)

print(f"{filename} download and extraction completed.")

os.remove(filename)
print(f"Removed {filename}")

print("Download and extraction completed.")
# Define the datasets to be downloaded
# Users can modify this list to choose which datasets to download
datasets_to_download: List[Dict[str, str]] = [
{
"name": "datasets",
"url": "https://drive.google.com/uc?export=download&id=1tXp5cLw89egeKRwDuood2TPqoEWd8_C0",
"filename": "aflow_data.tar.gz",
"extract_path": "examples/aflow/data"
},
{
"name": "results",
"url": "", # Please fill in the correct URL
"filename": "result.tar.gz",
"extract_path": "examples/aflow/data/results"
},
{
"name": "initial_rounds",
"url": "", # Please fill in the correct URL
"filename": "first_round.tar.gz",
"extract_path": "examples/aflow/scripts/optimized"
}
]

# Clean up the compressed file
os.remove(filename)
print(f"Removed {filename}")
def download(datasets):
"""Main function to process all selected datasets."""
for dataset_name in datasets:
dataset = datasets_to_download[dataset_name]
process_dataset(dataset['url'], dataset['filename'], dataset['extract_path'])
File renamed without changes.
12 changes: 4 additions & 8 deletions examples/aflow/scripts/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,21 @@ async def graph_evaluate(self, dataset: DatasetType, graph, params: dict, path:

data_path = self._get_data_path(dataset, is_test)
benchmark_class = self.dataset_configs[dataset]
benchmark = benchmark_class(dataset, data_path, path)
benchmark = benchmark_class(name=dataset, file_path=data_path, log_path=path)

# Use params to configure the graph and benchmark
configured_graph = await self._configure_graph(graph, params)
configured_graph = await self._configure_graph(dataset, graph, params)

va_list = [1,2,3] # Use va_list from params, or use default value if not provided
return await benchmark.run_evaluation(configured_graph, va_list)

async def _configure_graph(self, graph, params: dict):
async def _configure_graph(self, dataset, graph, params: dict):
# Here you can configure the graph based on params
# For example: set LLM configuration, dataset configuration, etc.
dataset_config = params.get("dataset", {})
llm_config = params.get("llm_config", {})
return graph(name=self.dataset_configs[dataset]["name"], llm_config=llm_config, dataset=dataset_config)
return graph(name=dataset, llm_config=llm_config, dataset=dataset_config)

def _get_data_path(self, dataset: DatasetType, test: bool) -> str:
base_path = f"examples/aflow/data/{dataset.lower()}"
return f"{base_path}_test.jsonl" if test else f"{base_path}_validate.jsonl"

# Alias methods for backward compatibility
for dataset in ["gsm8k", "math", "humaneval", "mbpp", "hotpotqa", "drop"]:
setattr(Evaluator, f"_{dataset}_eval", lambda self, *args, dataset=dataset.upper(), **kwargs: self.graph_evaluate(dataset, *args, **kwargs))
40 changes: 22 additions & 18 deletions examples/aflow/scripts/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,27 @@ def optimize(self, mode: OptimizerType = "Graph"):
retry_count = 0
max_retries = 1

while retry_count < max_retries:
try:
score = loop.run_until_complete(self._optimize_graph())
break
except Exception as e:
retry_count += 1
logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})")
if retry_count == max_retries:
logger.info("Max retries reached. Moving to next round.")
score = None

wait_time = 5 * retry_count
time.sleep(wait_time)

if retry_count < max_retries:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

score = loop.run_until_complete(self._optimize_graph())

# while retry_count < max_retries:
# try:
# score = loop.run_until_complete(self._optimize_graph())
# break
# except Exception as e:
# retry_count += 1
# logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})")
# if retry_count == max_retries:
# logger.info("Max retries reached. Moving to next round.")
# score = None

# wait_time = 5 * retry_count
# time.sleep(wait_time)

# if retry_count < max_retries:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)

self.round += 1
logger.info(f"Score for round {self.round}: {score}")

Expand All @@ -114,7 +118,7 @@ def optimize(self, mode: OptimizerType = "Graph"):
time.sleep(5)

async def _optimize_graph(self):
validation_n = 5 # validation datasets's execution number
validation_n = 2 # validation datasets's execution number
graph_path = f"{self.root_path}/workflows"
data = self.data_utils.load_results(graph_path)

Expand Down
10 changes: 4 additions & 6 deletions metagpt/actions/action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,21 +622,19 @@ async def fill(
if self.schema:
schema = self.schema

if mode == FillMode.CODE_FILL:

if mode == FillMode.CODE_FILL.value:
result = await self.code_fill(context, function_name, timeout)
self.instruct_content = self.create_class()(**result)
return self

elif mode == FillMode.CONTEXT_FILL:
"""
使用xml_compile,但是这个版本没有办法实现system message 跟 temperature
"""
elif mode == FillMode.CONTEXT_FILL.value:
context = self.xml_compile(context=self.context)
result = await self.context_fill(context)
self.instruct_content = self.create_class()(**result)
return self

elif mode == FillMode.SINGLE_FILL:
elif mode == FillMode.SINGLE_FILL.value:
result = await self.single_fill(context)
self.instruct_content = self.create_class()(**result)
return self
Expand Down
14 changes: 9 additions & 5 deletions optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# @Desc : Entrance of AFlow.

from examples.aflow.scripts.optimizer import Optimizer
from examples.aflow.data.download_data import download
from metagpt.configs.models_config import ModelsConfig
from typing import Literal

Expand All @@ -12,10 +13,13 @@
QuestionType = Literal["math", "code", "qa"]
OptimizerType = Literal["Graph", "Test"]

# When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results.
# download(["datasets", "results", "initial_rounds"])

# Crucial Parameters
dataset: DatasetType = "HotpotQA" # Ensure the type is consistent with DatasetType
dataset: DatasetType = "GSM8K" # Ensure the type is consistent with DatasetType
sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows
question_type: QuestionType = "quiz" # Ensure the type is consistent with QuestionType
question_type: QuestionType = "code" # Ensure the type is consistent with QuestionType
optimized_path: str = "examples/aflow/scripts/optimized" # Optimized Result Save Path
initial_round: int = 1 # Corrected the case from Initial_round to initial_round
max_rounds: int = 20
Expand All @@ -30,9 +34,9 @@
"Custom", # It's basic unit of a fixed node. optimizer can modify its prompt to get vairous nodes.
# "AnswerGenerate" # It's for qa
# "CustomCodeGenerate", # It's for code
# "ScEnsemble", # It's for code, math and qa
"ScEnsemble", # It's for code, math and qa
# "Test", # It's for code
# "Programmer", # It's for math
"Programmer", # It's for math
]

# Create an optimizer instance
Expand All @@ -53,4 +57,4 @@
# Optimize workflow via setting the optimizer's mode to 'Graph'
optimizer.optimize("Graph")
# Test workflow via setting the optimizer's mode to 'Test'
optimizer.optimize("Test")
# optimizer.optimize("Test")

0 comments on commit 2d1d7ca

Please sign in to comment.