diff --git a/examples/di/InfiAgent-DABench/DABench.py b/examples/di/InfiAgent-DABench/DABench.py new file mode 100644 index 000000000..50ec04b29 --- /dev/null +++ b/examples/di/InfiAgent-DABench/DABench.py @@ -0,0 +1,487 @@ +import asyncio +import json +import re +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import nest_asyncio + +from examples.di.requirements_prompt import DABENCH +from metagpt.const import DABENCH_PATH +from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception + + +def evaluate_accuracy_by_question(results: dict) -> float: + """ + Calculate the accuracy of results based on complete correctness of each question. + This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py + This function checks whether each result is entirely correct, meaning all sub-questions + within that result are answered correctly. It computes the proportion of correct results + by dividing the number of fully correct results by the total number of results. + + Args: + results (dict): A collection of results where each result may contain a 'correctness' field. + + Returns: + float: The proportion of correct results, rounded to four decimal places. + Returns 0 if there are no results. + """ + correct = sum("correctness" in result and all(result["correctness"].values()) for result in results) + total = len(results) + return round(correct / total, 4) if total > 0 else 0 + + +def evaluate_accuracy_by_sub_question(results: dict) -> float: + """ + Evaluate the correctness of all sub-questions across the results. + This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py + This function calculates the total number of correct sub-questions and the overall + number of sub-questions present in all results. It returns the ratio of correct + sub-questions to the total number of sub-questions. + + Args: + results (dict): A collection of results where each result may contain a 'correctness' field. + + Returns: + float: The ratio of correct sub-questions, rounded to four decimal places. + Returns 0 if there are no sub-questions. + """ + correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result) + total = sum(len(result["correctness"]) for result in results if "correctness" in result) + return round(correct / total, 4) if total > 0 else 0 + + +def evaluate_accuracy_proportional_by_sub_question_adjusted(results: dict) -> float: + """ + Adjust the score based on the number of sub-questions in each result. + This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py + This function calculates a score for each result by considering the number of sub-questions + it contains. Each sub-question is assigned a score of 1 divided by the number of sub-questions. + The total score for each result is computed as the sum of all correct sub-questions multiplied + by the score per sub-question. Finally, it returns the average score across all results. + + Args: + results (dict): A collection of results where each result may contain a 'correctness' field. + + Returns: + float: The average score across all results, rounded to four decimal places. + Returns 0 if there are no results. + """ + total_score = 0 + for result in results: + if "correctness" in result: + sub_question_count = len(result["correctness"]) + score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0 + question_score = sum(result["correctness"].values()) * score_per_sub_question + total_score += question_score + return round(total_score / len(results), 4) if results else 0 + + +async def reformat(question: str, format: str, response: str) -> str: + """ + Asynchronously reformats a given response based on specified formatting requirements. + This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py + This function constructs a prompt for the LLM (Large Language Model) to reformat + the provided response according to the specified format. It includes a system prompt + to guide the LLM's behavior and a template that outlines the expected output structure. + + Args: + question (str): The original question posed by the user. + format (str): The specific formatting requirements that the response must adhere to. + response (str): The initial response from the LLM that needs to be reformatted. + + Returns: + str: The reformatted response generated by the LLM based on the provided question + and formatting requirements. + """ + system_prompt = "You are a helpful assistant." + demons = """\Format{{ + @shapiro_wilk_statistic[test_statistic] + @shapiro_wilk_p_value[p_value] + where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places. + where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places. + }} + \Answer{{ + @shapiro_wilk_statistic[0.56] + @shapiro_wilk_p_value[0.0002] + }} + + \Format{{ + @total_votes_outliers_num[outlier_num] + where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column. + }} + \Answer{{ + @total_votes_outliers[10] + }} + """ + reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}. + Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes. + The format requirements of this question is: + {format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:""" + messages = [ + {"role": "user", "content": question}, + {"role": "assistant", "content": response}, + {"role": "user", "content": reformat_template.format(demons=demons, format=format)}, + ] + rsp = await ask(messages, system_prompt) + return rsp + + +def load_jsonl(file_path: Union[Path, str]) -> List[Dict[str, Any]]: + """ + Load data from a JSONL file into a list of dictionaries. + + Args: + file_path (Union[Path, str]): The path to the JSONL file to be loaded. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the data from the JSONL file. + """ + # Convert file_path to Path if it's a string + if isinstance(file_path, str): + file_path = Path(file_path) + + data = [] + with open(file_path, "r", encoding="utf-8") as file: + for line in file: + data.append(json.loads(line)) + return data + + +def compare_predictions(pred_dict: dict, true_label: list) -> bool: + """ + Compares each prediction against the corresponding true label. + + This function checks whether the predicted values match the true values for each + metric. It sorts the true labels to ensure the comparison is made in the correct + order. The function returns True if all predictions are accurate within a small + tolerance for numerical values, or if string values match case-insensitively. + + Args: + pred_dict (dict): A dictionary of predicted metrics and their values. + true_label (list): A list of tuples containing true metrics and their values. + + Returns: + bool: True if all predictions match the true labels, False otherwise. + """ + sorted_true_label = sorted(true_label, key=lambda x: x[0]) # Sort true labels by metric name + + for metric, true_value in sorted_true_label: + try: + true_value = float(true_value) # Attempt to convert the true value to float + except ValueError: + true_value = true_value.replace(",", "") # Clean the true value if conversion fails + + # Check if the true value is numeric and compare with the prediction + if isinstance(true_value, (int, float)) and ( + metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6 + ): + return False # Return False if the prediction is inaccurate + + # Check if the true value is a string and compare with the prediction + if isinstance(true_value, str) and ( + metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower() + ): + return False # Return False if the string prediction does not match + + return True # Return True if all predictions are accurate + + +async def ask(question: str, system_prompt: str) -> str: + """ + Asynchronously sends a question to the LLM (Large Language Model) and retrieves the response. + + This function initializes an instance of the LLM and uses it to ask a question + along with a system prompt. The response from the LLM is awaited and returned. + + Args: + question (str): The question to be asked to the LLM. + system_prompt (str): A prompt that provides context or instructions to the LLM. + + Returns: + str: The response from the LLM based on the provided question and system prompt. + """ + from metagpt.llm import LLM # Importing the LLM class from the metagpt module + + llm = LLM() # Create an instance of the LLM + rsp = await llm.aask(question, system_msgs=[system_prompt]) # Await the response from the LLM + return rsp # Return the response + + +def parse_prediction(prediction: str) -> dict: + """ + Parses a prediction string into a dictionary of metric-value pairs. + + This function takes a formatted string containing metrics and their corresponding + values, separated by the "@" symbol. Each metric may be enclosed in brackets and + may include commas. The function processes the input to extract and clean the + metrics and their values, returning them in a structured dictionary format. + + Args: + prediction (str): A string representation of metrics and their values. + + Returns: + dict: A dictionary where each key is a metric name and each value is the + corresponding value, either as a float or a string. + """ + pred_dict = {} + for pred in prediction.split("@"): + if pred == "": + continue # Skip any empty segments resulting from the split + temp = re.split(r"[\[\]]", pred.strip()) # Split the string by brackets + temp = [s.replace(",", "") for s in temp] # Remove commas from the segments + parts = [s for s in temp if s] # Filter out any empty strings + metric = parts[0].strip().replace(",", "") # Extract and clean the metric name + value = parts[-1].replace(",", "").replace(":", "") # Extract and clean the value + + try: + value = float(value) # Attempt to convert the value to a float + except ValueError: + pass # If conversion fails, retain the value as a string + + pred_dict[metric] = value # Store the metric-value pair in the dictionary + return pred_dict + + +class DABench: + def __init__( + self, + questions_file: Path = Path(DABENCH_PATH) / "da-dev-questions.jsonl", + answers_file: Path = Path(DABENCH_PATH) / "da-dev-labels.jsonl", + template: str = "", + ): + """ + Initializes the DABench instance with questions and answers. + + This constructor loads questions and answers from specified JSONL files. + It also sets a template for formatting prompts. If no template is provided, + a default template is used. + + Args: + questions_file (Path): The path to the JSONL file containing questions. + answers_file (Path): The path to the JSONL file containing answers. + template (str): A string template for formatting prompts. + """ + + self.questions = { + int(line["id"]): line for line in load_jsonl(questions_file) + } # Load questions from the specified file + self.answers = { + int(line["id"]): line for line in load_jsonl(answers_file) + } # Load answers from the specified file + self.template = template if template else DABENCH # Set the template, defaulting if necessary + + def get_question(self, question_id: str) -> dict: + """ + Retrieve the question associated with the given ID. + + This method looks up a question by its unique identifier. If the question + is found, it returns the question data; otherwise, it returns a message + indicating that the question was not found. + + Args: + question_id (str): The unique identifier for the question. + + Returns: + dict: The question data if found, otherwise a "Question not found." message. + """ + return self.questions.get(question_id, "Question not found.") # Return the question or an error message + + def generate_formatted_prompt(self, question_id: str) -> str: + """ + Generate a formatted prompt for the specified question ID. + + This method retrieves the question data and formats it using the specified + template. The formatted prompt includes the question, constraints, format, + file name, and level, allowing for a structured output. + + Args: + question_id (str): The unique identifier for the question. + + Returns: + str: A formatted prompt string based on the question data. + """ + temp = self.get_question(question_id) # Retrieve the question data + return self.template.format( + question=temp["question"], + constraints=temp["constraints"], + format=temp["format"], + file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"], + level=temp["level"], + ) # Format and return the prompt + + def get_answer(self, answer_id: str) -> list: + """ + Retrieve the answer list associated with the given ID. + + This method looks up an answer by its unique identifier. If the answer + is found, it returns the answer data; otherwise, it returns a message + indicating that the answer was not found. + + Args: + answer_id (str): The unique identifier for the answer. + + Returns: + list: The answer data if found, otherwise an "Answer not found." message. + """ + return self.answers.get(answer_id, "Answer not found.") # Return the answer or an error message + + @handle_exception(exception_msg="Error parsing cleaned prediction", default_return=(None, False)) + def parse_cleaned_prediction(self, cleaned_prediction: str, true_label: Any) -> Tuple[str, bool]: + """ + Parse the cleaned prediction and compare it with the true label. + + Args: + cleaned_prediction (str): The cleaned prediction string. + true_label (Any): The true label to compare against. + + Returns: + Tuple[str, bool]: A tuple containing the cleaned prediction and a boolean indicating + whether it matches the true label. + """ + if cleaned_prediction: # Ensure the cleaned prediction is not empty + pred_dict = parse_prediction(cleaned_prediction) # Parse the prediction + if pred_dict is not None and compare_predictions(pred_dict, true_label): + return cleaned_prediction, True # Return if the prediction matches the true label + return cleaned_prediction, False # Return the cleaned prediction with a False match + + @handle_exception(exception_msg="Error during async reformat", default_return=(None, False)) + def async_reformat_prediction(self, id: str, result: str) -> str: + """ + Reformat the prediction asynchronously and extract the answer. + + Args: + id (str): The identifier for the question. + result (str): The original prediction result. + + Returns: + str: The reformatted prediction or the original prediction if extraction fails. + """ + question = self.get_question(id)["question"] # Retrieve the question based on the ID + question_format = self.get_question(id)["format"] # Get the format of the question + prediction = asyncio.run(reformat(question, question_format, result)) # Asynchronously reformat the prediction + + # Attempt to extract the answer from the reformatted prediction + answer_part = prediction.split("Answer{{") if "Answer{{" in prediction else [] + if len(answer_part) > 1: + return answer_part[1].split("}}")[0].strip() # Return the extracted answer + + return prediction # If extraction fails, return the original prediction + + def eval(self, id: str, result: str) -> Tuple[str, bool]: + """ + Evaluate the prediction against the true label. + + Args: + id (str): The identifier for the question. + result (str): The original prediction result. + + Returns: + Tuple[str, bool]: A tuple containing the final prediction and a boolean indicating + whether it matches the true label. + """ + true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for comparison + nest_asyncio.apply() # Apply nested asyncio to allow for async calls + result = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"].strip() + cleaned_prediction = result.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string + + # Use the decorated function to handle exceptions while parsing the cleaned prediction + parsed_result = self.parse_cleaned_prediction(cleaned_prediction, true_label) + if parsed_result[1]: # If the parsed prediction is valid + return parsed_result # Return the valid prediction + + # If the cleaned prediction is not valid, attempt to asynchronously reformat it + prediction = self.async_reformat_prediction(id, result) + + pred_dict = parse_prediction(prediction) # Parse the reformatted prediction + if pred_dict is not None and compare_predictions(pred_dict, true_label): + return prediction, True # Return if the reformatted prediction matches the true label + + return prediction, False # Return the final prediction with a False match + + @handle_exception(exception_msg="Error evaluating single prediction", default_return={}) + def single_eval(self, id: str, prediction: str) -> dict: + """ + Evaluate the prediction against the true label for a single question. + just using in eval_all + + Args: + id (str): The identifier for the question. + prediction (str): The prediction string to evaluate. + + Returns: + dict: A dictionary indicating the correctness of each metric. + """ + true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for the question + prediction = prediction.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string + pred_dict = parse_prediction(prediction) # Parse the prediction into a dictionary + + # Initialize the correctness dictionary with False values for each metric + correctness = {metric: False for metric, _ in true_label} + + # Check each metric's prediction against the true label + for metric, true_value in true_label: + try: + true_value = float(true_value) # Attempt to convert the true value to float + except ValueError: + true_value = true_value.replace(",", "") # Handle non-numeric values + + if metric in pred_dict: + # Consider the prediction correct if it's within a small tolerance + if ( + isinstance(true_value, (int, float)) + and isinstance(pred_dict[metric], (int, float)) + and abs(pred_dict[metric] - true_value) < 1e-6 + ): + correctness[metric] = True # Mark as correct if within tolerance + + if isinstance(true_value, str) and ( + metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower() + ): + correctness[metric] = True # Mark as correct for string comparison + + return correctness # Return the correctness dictionary + + def eval_all(self, id_list: list, predictions: list) -> dict: + """ + Evaluate all predictions and calculate accuracy rates. + + Args: + id_list (list): A list of question identifiers. + predictions (list): A list of prediction strings corresponding to the questions. + + Returns: + dict: A dictionary containing accuracy rates by question and sub-question. + """ + results = [] # Initialize a list to store results for each question + + # Evaluate each prediction against its corresponding question ID + for id, prediction in zip(id_list, predictions): + correct = self.single_eval(id, prediction) # Evaluate the single prediction + results.append({"id": id, "correctness": correct}) # Append the result to the list + + # Calculate the three accuracy rates based on the results + accuracy_by_question = evaluate_accuracy_by_question(results) + accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results) + proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results) + + return { + "accuracy_by_question": accuracy_by_question, + "accuracy_by_sub_question": accuracy_by_sub_question, + "proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question, + } + + +if __name__ == "__main__": + bench = DABench() + id = 0 + prediction = "@mean_fare[34.65]" + logger.info(bench.eval(id, prediction)) + ids = [0, 5, 6] + predictions = [ + "@mean_fare[34.89]", + "@correlation_coefficient[0.21]", + "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]", + ] + logger.info(bench.eval_all(ids, predictions)) diff --git a/examples/di/InfiAgent-DABench/README.md b/examples/di/InfiAgent-DABench/README.md new file mode 100644 index 000000000..74783c9d1 --- /dev/null +++ b/examples/di/InfiAgent-DABench/README.md @@ -0,0 +1,45 @@ +# InfiAgent-DABench +This example is used to solve the InfiAgent-DABench using Data Interpreter (DI), and obtains 94.93% accuracy using gpt-4o. + +## Dataset download +``` +cd /examples/di/InfiAgent-DABench +git clone https://github.com/InfiAgent/InfiAgent.git +mv InfiAgent/examples/DA-Agent/data ./ +``` +## Special note: +When doing DABench testing, you need to set the ExecuteNbCode() init to: +``` +class ExecuteNbCode(Action): + """execute notebook code block, return result to llm, and display it.""" + + nb: NotebookNode + nb_client: NotebookClient + console: Console + interaction: str + timeout: int = 600 + + def __init__( + self, + nb=nbformat.v4.new_notebook(), + timeout=600, + ): + super().__init__( + nb=nbformat.v4.new_notebook(),#nb, + nb_client=NotebookClient(nb, timeout=timeout), + timeout=timeout, + console=Console(), + interaction=("ipython" if self.is_ipython() else "terminal"), + ) +``` +The path of ExecuteNbCode() is: +``` +metagpt.actions.di.execute_nb_code +``` +Instead of using the original nb initialization by default. +## How to run +``` +python run_InfiAgent-DABench_single.py --id x # run a task, x represents the id of the question you want to test +python run_InfiAgent-DABench_all.py # Run all tasks serially +python run_InfiAgent-DABench.py --k x # Run all tasks in parallel, x represents the number of parallel tasks at a time +``` \ No newline at end of file diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py new file mode 100644 index 000000000..7e1fbad8b --- /dev/null +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py @@ -0,0 +1,77 @@ +import asyncio +import json + +from DABench import DABench + +from metagpt.logs import logger +from metagpt.roles.di.data_interpreter import DataInterpreter + + +async def get_prediction(agent, requirement): + """Helper function to obtain a prediction from a new instance of the agent. + + This function runs the agent with the provided requirement and extracts the prediction + from the result. If an error occurs during processing, it logs the error and returns None. + + Args: + agent: The agent instance used to generate predictions. + requirement: The input requirement for which the prediction is to be made. + + Returns: + The predicted result if successful, otherwise None. + """ + try: + # Run the agent with the given requirement and await the result + result = await agent.run(requirement) + + # Parse the result to extract the prediction from the JSON response + prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0]) + prediction = prediction_json[-1]["result"] # Extract the last result from the parsed JSON + + return prediction # Return the extracted prediction + except Exception as e: + # Log an error message if an exception occurs during processing + logger.info(f"Error processing requirement: {requirement}. Error: {e}") + return None # Return None in case of an error + + +async def evaluate_all(agent, k): + """Evaluate all tasks in DABench using the specified baseline agent. + + Tasks are divided into groups of size k and processed in parallel. + + Args: + agent: The baseline agent used for making predictions. + k (int): The number of tasks to process in each group concurrently. + """ + bench = DABench() # Create an instance of DABench to access its methods and data + id_list, predictions = [], [] # Initialize lists to store IDs and predictions + tasks = [] # Initialize a list to hold the tasks + + # Iterate over the answers in DABench to generate tasks + for key, value in bench.answers.items(): + requirement = bench.generate_formatted_prompt(key) # Generate a formatted prompt for the current key + tasks.append(get_prediction(agent, requirement)) # Append the prediction task to the tasks list + id_list.append(key) # Append the current key to the ID list + + # Process tasks in groups of size k and execute them concurrently + for i in range(0, len(tasks), k): + # Get the current group of tasks + current_group = tasks[i : i + k] + # Execute the current group of tasks in parallel + group_predictions = await asyncio.gather(*current_group) + # Filter out any None values from the predictions and extend the predictions list + predictions.extend(pred for pred in group_predictions if pred is not None) + + # Evaluate the results using all valid predictions and logger.info the evaluation + logger.info(bench.eval_all(id_list, predictions)) + + +def main(k=5): + """Main function to run the evaluation process.""" + agent = DataInterpreter() # Create an instance of the DataInterpreter agent + asyncio.run(evaluate_all(agent, k)) # Run the evaluate_all function asynchronously + + +if __name__ == "__main__": + main() diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py new file mode 100644 index 000000000..5cd1ef4b0 --- /dev/null +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py @@ -0,0 +1,35 @@ +import fire +import pandas as pd +from DABench import DABench + +from metagpt.logs import logger +from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.utils.recovery_util import save_history + + +async def main(): + """Evaluate all""" + bench = DABench() + id_list, predictions, labels, is_true = [], [], [], [] + for key, value in bench.answers.items(): + id_list.append(key) + labels.append(str(bench.get_answer(key))) + try: + requirement = bench.generate_formatted_prompt(key) + di = DataInterpreter() + result = await di.run(requirement) + logger.info(result) + save_history(role=di) + temp_prediction, temp_istrue = bench.eval(key, str(result)) + is_true.append(str(temp_istrue)) + predictions.append(str(temp_prediction)) + except: + is_true.append(str(bench.eval(key, ""))) + predictions.append(str("")) + df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true}) + df.to_excel("DABench_output.xlsx", index=False) + logger.info(bench.eval_all(id_list, predictions)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_single.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_single.py new file mode 100644 index 000000000..470f12fc8 --- /dev/null +++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_single.py @@ -0,0 +1,22 @@ +import fire +from DABench import DABench + +from metagpt.logs import logger +from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.utils.recovery_util import save_history + + +async def main(id=0): + """Evaluate one task""" + bench = DABench() + requirement = bench.generate_formatted_prompt(id) + di = DataInterpreter() + result = await di.run(requirement) + logger.info(result) + save_history(role=di) + _, is_correct = bench.eval(id, str(result)) + logger.info(f"Prediction is {'correct' if is_correct else 'incorrect'}.") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/di/requirements_prompt.py b/examples/di/requirements_prompt.py index 04a0414b1..34102c134 100644 --- a/examples/di/requirements_prompt.py +++ b/examples/di/requirements_prompt.py @@ -1,3 +1,5 @@ +# InfiAgent-DABench requirements +DABENCH = "You are required to {question} from a CSV file named {file_name}. **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}. This task is categorized as {level}." # ML-Benchmark requirements IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot" WINES_RECOGNITION_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class with 20% as test set, and show prediction accuracy" diff --git a/metagpt/const.py b/metagpt/const.py index f33b46b68..9497fdd1e 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -43,6 +43,7 @@ def get_metagpt_root(): EXAMPLE_PATH = METAGPT_ROOT / "examples" EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data" DATA_PATH = METAGPT_ROOT / "data" +DABENCH_PATH = EXAMPLE_PATH / "di/InfiAgent-DABench/data" EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm" TEST_DATA_PATH = METAGPT_ROOT / "tests/data" RESEARCH_PATH = DATA_PATH / "research"