diff --git a/.gitignore b/.gitignore index aa5edd74a9..73f32f75dc 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,8 @@ examples/graph_store.json examples/image__vector_store.json examples/index_store.json .chroma +.chroma_exp_data +.role_memory_data *~$* workspace/* tmp @@ -188,3 +190,4 @@ cov.xml *-structure.json *.dot .python-version +tests/data/requirements/*.jpg diff --git a/Dockerfile b/Dockerfile index dead205375..3a2de4981c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ FROM nikolaik/python-nodejs:python3.9-nodejs20-slim # Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size RUN apt update &&\ - apt install -y libgomp1 git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ + apt install -y libgomp1 git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends file &&\ apt clean && rm -rf /var/lib/apt/lists/* # Install Mermaid CLI globally diff --git a/config/config2.example.yaml b/config/config2.example.yaml index c5454ec323..f6970dabe6 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -13,6 +13,37 @@ llm: # - gpt-4 8k: "gpt-4" # See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ +# Role's custom configuration +roles: + - role: "ProductManager" # role's className or role's role_id + llm: + api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options + base_url: "YOUR_BASE_URL" + api_key: "YOUR_API_KEY" + proxy: "YOUR_PROXY" # for LLM API requests + model: "gpt-4-turbo-1106" + - role: "Architect" + llm: + api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options + base_url: "YOUR_BASE_URL" + api_key: "YOUR_API_KEY" + proxy: "YOUR_PROXY" # for LLM API requests + model: "gpt-35-turbo" + - role: "ProjectManager" + llm: + api_type: "azure" + base_url: "YOUR_BASE_URL" + api_key: "YOUR_API_KEY" + api_version: "YOUR_API_VERSION" + model: "gpt-4-1106" + - role: "Engineer" + llm: + api_type: "azure" + base_url: "YOUR_BASE_URL" + api_key: "YOUR_API_KEY" + api_version: "YOUR_API_VERSION" + model: "gpt-35-turbo-1106" + repair_llm_output: true # when the output is not a valid json, try to repair it proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc. @@ -43,6 +74,21 @@ s3: secure: false bucket: "test" +exp_pool: + enabled: false + enable_read: false + enable_write: false + persist_path: .chroma_exp_data # The directory. + retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding. + use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result. + collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb. + +role_zero: + enable_longterm_memory: false # Whether to use long-term memory. Default is `false`. + longterm_memory_persist_path: .role_memory_data # The directory to save data. + memory_k: 200 # The capacity of short-term memory. + similarity_top_k: 5 # The number of long-term memories to retrieve. + use_llm_ranker: false # Whether to use LLM Reranker to get better result. Default is `false`. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/config/vault.example.yaml b/config/vault.example.yaml new file mode 100644 index 0000000000..0e197d2a89 --- /dev/null +++ b/config/vault.example.yaml @@ -0,0 +1,48 @@ +# Usage: +# 1. Get value. +# >>> from metagpt.tools.libs.env import get_env +# >>> access_token = await get_env(key="access_token", app_name="github") +# >>> print(access_token) +# YOUR_ACCESS_TOKEN +# +# 2. Get description for LLM understanding. +# >>> from metagpt.tools.libs.env import get_env_description +# >>> descriptions = await get_env_description +# >>> for k, desc in descriptions.items(): +# >>> print(f"{key}:{desc}") +# await get_env(key="access_token", app_name="github"):Get github access token +# await get_env(key="access_token", app_name="gitlab"):Get gitlab access token +# ... + +vault: + github: + values: + access_token: "YOUR_ACCESS_TOKEN" + descriptions: + access_token: "Get github access token" + gitlab: + values: + access_token: "YOUR_ACCESS_TOKEN" + descriptions: + access_token: "Get gitlab access token" + iflytek_tts: + values: + api_id: "YOUR_APP_ID" + api_key: "YOUR_API_KEY" + api_secret: "YOUR_API_SECRET" + descriptions: + api_id: "Get the API ID of IFlyTek Text to Speech" + api_key: "Get the API KEY of IFlyTek Text to Speech" + api_secret: "Get the API SECRET of IFlyTek Text to Speech" + azure_tts: + values: + subscription_key: "YOUR_SUBSCRIPTION_KEY" + region: "YOUR_REGION" + descriptions: + subscription_key: "Get the subscription key of Azure Text to Speech." + region: "Get the region of Azure Text to Speech." + default: # All key-value pairs whose app name is an empty string are placed below + values: + proxy: "YOUR_PROXY" + descriptions: + proxy: "Get proxy for tools like requests, playwright, selenium, etc." \ No newline at end of file diff --git a/examples/agent_creator.py b/examples/agent_creator.py index bd58840ce9..34160d3986 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -6,12 +6,13 @@ import re from metagpt.actions import Action -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.const import METAGPT_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message +config = Config.default() EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py" MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text() diff --git a/examples/cr.py b/examples/cr.py new file mode 100644 index 0000000000..295ed9fb8f --- /dev/null +++ b/examples/cr.py @@ -0,0 +1,15 @@ +import fire + +from metagpt.roles.di.engineer2 import Engineer2 +from metagpt.tools.libs.cr import CodeReview + + +async def main(msg): + role = Engineer2(tools=["Plan", "Editor:write,read", "RoleZero", "ValidateAndRewriteCode", "CodeReview"]) + cr = CodeReview() + role.tool_execution_map.update({"CodeReview.review": cr.review, "CodeReview.fix": cr.fix}) + await role.run(msg) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/data/exp_pool/engineer_exps.json b/examples/data/exp_pool/engineer_exps.json new file mode 100644 index 0000000000..6a0be36a92 --- /dev/null +++ b/examples/data/exp_pool/engineer_exps.json @@ -0,0 +1,24 @@ +[ + { + "req": [ + { + "role": "user", + "content": "\n# Current Plan\n{'goal': \"Please write a 1048 game using JavaScript and HTML code without using any frameworks, user can play with keyboard. Refer to the system design located at '/tmp/system_design.json' and the project schedule at '/tmp/project_schedule.json' for detailed information.\", 'tasks': []}\n\n# Current Task\n\n\n# Instruction\nBased on the context, write a plan or modify an existing plan to achieve the goal. A plan consists of one to 3 tasks.\nIf plan is created, you should track the progress and update the plan accordingly, such as Plan.finish_current_task, Plan.append_task, Plan.reset_task, Plan.replace_task, etc.\nWhen presented a current task, tackle the task using the available commands.\nPay close attention to new user message, review the conversation history, use RoleZero.reply_to_human to respond to new user requirement.\nNote:\n1. If you keeping encountering errors, unexpected situation, or you are not sure of proceeding, use RoleZero.ask_human to ask for help.\n2. Carefully review your progress at the current task, if your actions so far has not fulfilled the task instruction, you should continue with current task. Otherwise, finish current task.\n3. Each time you finish a task, use RoleZero.reply_to_human to report your progress.\n4. Each time you write a code in your response, write with the Editor directly without preparing a repetitive code block beforehand.\n5. Take on ONE task and write ONE code file in each response. DON'T attempt all tasks in one response.\n6. When not specified, you should write files in a folder named \"src\". If you know the project path, then write in a \"src\" folder under the project path.\n7. When provided system design or project schedule, you MUST read them first before making a plan, then adhere to them in your implementation, especially in the programming language, package, or framework. You MUST implement all code files prescribed in the system design or project schedule. You can create a plan first with each task corresponding to implementing one code file.\n8. Write at most one file per task, do your best to implement THE ONLY ONE FILE. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.\n9. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.\n10. When provided system design, YOU MUST FOLLOW \"Data structures and interfaces\". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.\n11. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.\n12. To modify code in a file, read the entire file, make changes, and update the file with the complete code, ensuring that no line numbers are included in the final write.\n13. When a system design or project schedule is provided, at the end of the plan, add a CodeRview Task for each file; for example, if there are three files, add three CodeRview Tasks. For each CodeRview Task, just call ValidateAndRewriteCode.run.\n" + } + ], + "resp": "\nExplanation: To begin the task of creating a 2048 game using JavaScript and HTML, I need to first read the system design and project schedule documents to understand the specific requirements and architecture outlined for the game development.\n\n```json\n[\n {\n \"command_name\": \"Editor.read\",\n \"args\": {\n \"path\": \"/tmp/system_design.json\"\n }\n },\n {\n \"command_name\": \"Editor.read\",\n \"args\": {\n \"path\": \"/tmp/project_schedule.json\"\n }\n }\n]\n```\n" + }, + { + "req": [ + { + "role": "user", + "content": "Command Editor.read executed: file_path='/tmp/system_design.json' block_content='001|{\"Implementation approach\":\"We will implement the 2048 game using plain JavaScript and HTML, ensuring no frameworks are used. The game logic will handle tile movements, merging, and game state updates. The UI will be simple and clean, with a responsive design to fit different screen sizes. We will use CSS for styling and ensure the game is playable with keyboard arrow keys. The game will display the current score, have a restart button, and show a game over message when no more moves are possible.\",\"File list\":[\"index.html\",\"style.css\",\"script.js\"],\"Data structures and interfaces\":\"\\\\nclassDiagram\\\\n class Game {\\\\n -grid: int[][]\\\\n -score: int\\\\n +init(): void\\\\n +move(direction: str): void\\\\n +merge(direction: str): void\\\\n +isGameOver(): bool\\\\n +restart(): void\\\\n }\\\\n class UI {\\\\n -game: Game\\\\n +init(): void\\\\n +update(): void\\\\n +showGameOver(): void\\\\n +bindEvents(): void\\\\n }\\\\n Game --> UI\\\\n\",\"Program call flow\":\"\\\\nsequenceDiagram\\\\n participant U as UI\\\\n participant G as Game\\\\n U->>G: init()\\\\n G-->>U: return\\\\n U->>U: bindEvents()\\\\n U->>G: move(direction)\\\\n G->>G: merge(direction)\\\\n G->>U: update()\\\\n U->>U: update()\\\\n U->>G: isGameOver()\\\\n G-->>U: return bool\\\\n alt Game Over\\\\n U->>U: showGameOver()\\\\n end\\\\n U->>G: restart()\\\\n G-->>U: return\\\\n\",\"Anything UNCLEAR\":\"Clarify if there are any specific design preferences or additional features required beyond the basic 2048 game functionality.\"}\\n'\n\nCommand Editor.read executed: file_path='/tmp/project_schedule.json' block_content='001|{\"Required packages\":[\"No third-party dependencies required\"],\"Required Other language third-party packages\":[\"No third-party dependencies required\"],\"Logic Analysis\":[[\"script.js\",\"Contains Game and UI classes, and their methods: init, move, merge, isGameOver, restart, update, showGameOver, bindEvents\"],[\"index.html\",\"Contains the HTML structure for the game UI\"],[\"style.css\",\"Contains the CSS styles for the game UI\"]],\"Task list\":[\"index.html\",\"style.css\",\"script.js\"],\"Full API spec\":\"\",\"Shared Knowledge\":\"The `script.js` file will contain the core game logic and UI handling. The `index.html` file will provide the structure for the game, and `style.css` will handle the styling.\",\"Anything UNCLEAR\":\"Clarify if there are any specific design preferences or additional features required beyond the basic 2048 game functionality.\"}\\n'" + }, + { + "role": "user", + "content": "\n# Current Plan\n{'goal': \"Please write a 1048 game using JavaScript and HTML code without using any frameworks, user can play with keyboard. Refer to the system design located at '/tmp/system_design.json' and the project schedule at '/tmp/project_schedule.json' for detailed information.\", 'tasks': []}\n\n# Current Task\n\n\n# Instruction\nBased on the context, write a plan or modify an existing plan to achieve the goal. A plan consists of one to 3 tasks.\nIf plan is created, you should track the progress and update the plan accordingly, such as Plan.finish_current_task, Plan.append_task, Plan.reset_task, Plan.replace_task, etc.\nWhen presented a current task, tackle the task using the available commands.\nPay close attention to new user message, review the conversation history, use RoleZero.reply_to_human to respond to new user requirement.\nNote:\n1. If you keeping encountering errors, unexpected situation, or you are not sure of proceeding, use RoleZero.ask_human to ask for help.\n2. Carefully review your progress at the current task, if your actions so far has not fulfilled the task instruction, you should continue with current task. Otherwise, finish current task.\n3. Each time you finish a task, use RoleZero.reply_to_human to report your progress.\n4. Each time you write a code in your response, write with the Editor directly without preparing a repetitive code block beforehand.\n5. Take on ONE task and write ONE code file in each response. DON'T attempt all tasks in one response.\n6. When not specified, you should write files in a folder named \"src\". If you know the project path, then write in a \"src\" folder under the project path.\n7. When provided system design or project schedule, you MUST read them first before making a plan, then adhere to them in your implementation, especially in the programming language, package, or framework. You MUST implement all code files prescribed in the system design or project schedule. You can create a plan first with each task corresponding to implementing one code file.\n8. Write at most one file per task, do your best to implement THE ONLY ONE FILE. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.\n9. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.\n10. When provided system design, YOU MUST FOLLOW \"Data structures and interfaces\". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.\n11. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.\n12. To modify code in a file, read the entire file, make changes, and update the file with the complete code, ensuring that no line numbers are included in the final write.\n13. When a system design or project schedule is provided, at the end of the plan, add a CodeRview Task for each file; for example, if there are three files, add three CodeRview Tasks. For each CodeRview Task, just call ValidateAndRewriteCode.run.\n" + } + ], + "resp": "\nExplanation: Based on the system design and project schedule, I will create a plan to implement the 2048 game. The plan will consist of tasks to create the necessary files (`index.html`, `style.css`, and `script.js`) and review each file to ensure they meet the design specifications.\n\nHere's the plan:\n\n1. **Task 1**: Create `index.html` - This file will contain the HTML structure necessary for the game's UI.\n2. **Task 2**: Create `style.css` - This file will define the CSS styles to make the game visually appealing and responsive.\n3. **Task 3**: Create `script.js` - This file will contain the JavaScript code for the game logic and UI interactions.\n4. **Code Review Tasks**: Review each file to ensure they meet the project requirements and adhere to the system design.\n\nLet's start by appending the first task to the plan.\n\n```json\n[\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"1\",\n \"dependent_task_ids\": [],\n \"instruction\": \"Create the index.html file with the basic HTML structure for the 2048 game.\",\n \"assignee\": \"Alex\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"2\",\n \"dependent_task_ids\": [\"1\"],\n \"instruction\": \"Create the style.css file with the necessary CSS to style the 2048 game.\",\n \"assignee\": \"Alex\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"3\",\n \"dependent_task_ids\": [\"1\", \"2\"],\n \"instruction\": \"Create the script.js file containing the JavaScript logic for the 2048 game.\",\n \"assignee\": \"Alex\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"4\",\n \"dependent_task_ids\": [\"1\"],\n \"instruction\": \"Use ValidateAndRewriteCode to review the code in index.html to ensure it meets the design specifications.\",\n \"assignee\": \"Alex\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"5\",\n \"dependent_task_ids\": [\"2\"],\n \"instruction\": \"Use ValidateAndRewriteCode to review the code in style.css to ensure it meets the design specifications.\",\n \"assignee\": \"Alex\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"6\",\n \"dependent_task_ids\": [\"3\"],\n \"instruction\": \"Use ValidateAndRewriteCode to review the code in script.js to ensure it meets the design specifications.\",\n \"assignee\": \"Alex\"\n }\n }\n]\n```\n" + } +] \ No newline at end of file diff --git a/examples/data/exp_pool/team_leader_exps.json b/examples/data/exp_pool/team_leader_exps.json new file mode 100644 index 0000000000..125f0a48d9 --- /dev/null +++ b/examples/data/exp_pool/team_leader_exps.json @@ -0,0 +1,22 @@ +[{ + "req": [{ + "role": "user", + "content": "\n# Current Plan\n{'goal': \"from to {''}: Write a 1024 game using JavaScript and HTML code without using any frameworks, user can play with keyboard.\", 'tasks': []}\n\n# Current Task\n\n\n# Instruction\nYou are a team leader, and you are responsible for drafting tasks and routing tasks to your team members.\nYour team member:\nTim: Team Leader, \nAlice: Product Manager, efficiently create a successful product that meets market demands and user expectations\nBob: Architect, design a concise, usable, complete software system\nEve: Project Manager, break down tasks according to PRD/technical design, generate a task list, and analyze task dependencies to start with the prerequisite modules\nAlex: Engineer, Take on game, app, and web development\nDavid: DataAnalyst, Take on any data-related tasks, such as data analysis, machine learning, deep learning, web browsing, web scraping, web searching, web deployment, terminal operation, git and github operation, etc.\n\nYou should NOT assign consecutive tasks to the same team member, instead, assign an aggregated task (or the complete requirement) and let the team member to decompose it.\nWhen creating a new plan involving multiple members, create all tasks at once.\nIf plan is created, you should track the progress based on team member feedback message, and update plan accordingly, such as Plan.finish_current_task, Plan.reset_task, Plan.replace_task, etc.\nYou should use TeamLeader.publish_team_message to team members, asking them to start their task. DONT omit any necessary info such as path, link, environment, programming language, framework, requirement, constraint from original content to team members because you are their sole info source.\nPay close attention to new user message, review the conversation history, use RoleZero.reply_to_human to respond to the user directly, DON'T ask your team members.\n\nNote:\n1. If the requirement is a pure DATA-RELATED requirement, such as bug fixes, issue reporting, environment setup, terminal operations, pip install, web browsing, web scraping, web searching, web imitation, data science, data analysis, machine learning, deep learning, text-to-image etc. DON'T decompose it, assign a single task with the original user requirement as instruction directly to Data Analyst.\n2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise, usually the sequence of Product Manager -> Architect -> Project Manager -> Engineer -> (optional: QaEngine if present) -> (optional: DataAnalyst if user requests deployment), each assigned ONE task. When publishing message to Product Manager, you should directly copy the full original user requirement.\n3. If the requirement contains both DATA-RELATED part mentioned in 1 and software development part mentioned in 2, you should decompose the software development part and assign them to different team members based on their expertise, and assign the DATA-RELATED part to Data Analyst David directly.\n4. If the requirement is a common-sense, logical, or math problem, you should respond directly without assigning any task to team members.\n5. If you think the requirement is not clear or ambiguous, you should ask the user for clarification immediately. Assign tasks only after all info is clear.\n6. It is helpful for Engineer to have both the system design and the project schedule for writing the code, so include paths of both files (if available) and remind Engineer to definitely read them when publishing message to Engineer.\n7. If the requirement is writing a TRD and software framework, you should assign it to Architect. When publishing message to Architect, you should directly copy the full original user requirement.\n" + }], + "resp": "\nExplanation: The requirement is about software development. Assign each tasks to a different team member based on their expertise. When publishing message to Product Manager, we copy original user requirement directly to ensure no information loss.\n```json\n[\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"1\",\n \"dependent_task_ids\": [],\n \"instruction\": \"Create a product requirement document (PRD) outlining the features, user interface, and user experience of the CLI python snake game.\",\n \"assignee\": \"Alice\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"2\",\n \"dependent_task_ids\": [\"1\"],\n \"instruction\": \"Design the software architecture for the CLI snake game, including the choice of programming language, libraries, and data flow.\",\n \"assignee\": \"Bob\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"3\",\n \"dependent_task_ids\": [\"2\"],\n \"instruction\": \"Break down the architecture into manageable tasks, identify task dependencies, and prepare a detailed task list for implementation.\",\n \"assignee\": \"Eve\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"4\",\n \"dependent_task_ids\": [\"3\"],\n \"instruction\": \"Implement the core game logic for the CLI snake game, including snake movement, food generation, and score tracking.\",\n \"assignee\": \"Alex\"\n }\n },\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"5\",\n \"dependent_task_ids\": [\"4\"],\n \"instruction\": \"Write comprehensive tests for the game logic and user interface to ensure functionality and reliability.\",\n \"assignee\": \"Edward\"\n }\n },\n {\n \"command_name\": \"TeamLeader.publish_message\",\n \"args\": {\n \"content\": \"Create a cli snake game using Python\",\n \"send_to\": \"Alice\"\n }\n },\n {\n \"command_name\": \"RoleZero.reply_to_human\",\n \"args\": {\n \"content\": \"I have assigned the tasks to the team members. Alice will create the PRD, Bob will design the software architecture, Eve will break down the architecture into tasks, Alex will implement the core game logic, and Edward will write comprehensive tests. The team will work on the project accordingly\"\n }\n },\n {\n \"command_name\": \"end\"\n }\n]\n```\n" +}, { + "req": [{ + "role": "user", + "content": "\n# Current Plan\n{'goal': \"from to {''}: Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\", 'tasks': []}\n\n# Current Task\n\n\n# Instruction\nYou are a team leader, and you are responsible for drafting tasks and routing tasks to your team members.\nYour team member:\nTim: Team Leader, \nAlice: Product Manager, efficiently create a successful product that meets market demands and user expectations\nBob: Architect, design a concise, usable, complete software system\nEve: Project Manager, break down tasks according to PRD/technical design, generate a task list, and analyze task dependencies to start with the prerequisite modules\nAlex: Engineer, Take on game, app, and web development\nDavid: DataAnalyst, Take on any data-related tasks, such as data analysis, machine learning, deep learning, web browsing, web scraping, web searching, web deployment, terminal operation, git and github operation, etc.\n\nYou should NOT assign consecutive tasks to the same team member, instead, assign an aggregated task (or the complete requirement) and let the team member to decompose it.\nWhen creating a new plan involving multiple members, create all tasks at once.\nIf plan is created, you should track the progress based on team member feedback message, and update plan accordingly, such as Plan.finish_current_task, Plan.reset_task, Plan.replace_task, etc.\nYou should use TeamLeader.publish_team_message to team members, asking them to start their task. DONT omit any necessary info such as path, link, environment, programming language, framework, requirement, constraint from original content to team members because you are their sole info source.\nPay close attention to new user message, review the conversation history, use RoleZero.reply_to_human to respond to the user directly, DON'T ask your team members.\n\nNote:\n1. If the requirement is a pure DATA-RELATED requirement, such as bug fixes, issue reporting, environment setup, terminal operations, pip install, web browsing, web scraping, web searching, web imitation, data science, data analysis, machine learning, deep learning, text-to-image etc. DON'T decompose it, assign a single task with the original user requirement as instruction directly to Data Analyst.\n2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise, usually the sequence of Product Manager -> Architect -> Project Manager -> Engineer -> (optional: QaEngine if present) -> (optional: DataAnalyst if user requests deployment), each assigned ONE task. When publishing message to Product Manager, you should directly copy the full original user requirement.\n3. If the requirement contains both DATA-RELATED part mentioned in 1 and software development part mentioned in 2, you should decompose the software development part and assign them to different team members based on their expertise, and assign the DATA-RELATED part to Data Analyst David directly.\n4. If the requirement is a common-sense, logical, or math problem, you should respond directly without assigning any task to team members.\n5. If you think the requirement is not clear or ambiguous, you should ask the user for clarification immediately. Assign tasks only after all info is clear.\n6. It is helpful for Engineer to have both the system design and the project schedule for writing the code, so include paths of both files (if available) and remind Engineer to definitely read them when publishing message to Engineer.\n7. If the requirement is writing a TRD and software framework, you should assign it to Architect. When publishing message to Architect, you should directly copy the full original user requirement.\n" + }], + "resp": "\nExplanation: DON'T decompose requirement if it is a DATA-RELATED task, assign a single task directly to Data Analyst David. He will manage the decomposition and implementation.\n```json\n[\n {\n \"command_name\": \"Plan.append_task\",\n \"args\": {\n \"task_id\": \"1\",\n \"dependent_task_ids\": [],\n \"instruction\": \"Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\",\n \"assignee\": \"David\"\n }\n },\n {\n \"command_name\": \"TeamLeader.publish_message\",\n \"args\": {\n \"content\": \"Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\",\n \"send_to\": \"David\"\n }\n },\n {\n \"command_name\": \"RoleZero.reply_to_human\",\n \"args\": {\n \"content\": \"I have assigned the task to David. He will break down the task further by himself and starts solving it.\"\n }\n },\n {\n \"command_name\": \"end\"\n }\n]\n```\n" +}, { + "req": "Conversation History:\n[\n ...,\n {'role': 'assistant', 'content': 'from Alice(Product Manager) to {'Bob'}: {'docs': {'20240424153821.json': {'root_path': 'docs/prd', 'filename': '20240424153821.json', 'content': '{\"Language\":\"en_us\",\"Programming Language\":\"Python\",\"Original Requirements\":\"create a cli snake game\",\"Project Name\":\"snake_game\",\"Product Goals\":[\"Develop an intuitive and addictive snake game\",...], ...}}}}},\n]\n", + "resp": "\nExplanation: You received a message from Alice, the Product Manager, that she has completed the PRD, use Plan.finish_current_task to mark her task as finished and moves the plan to the next task. Based on plan status, next task is for Bob (Architect), publish a message asking him to start. The message content should contain important path info.\n```json\n[\n {\n \"command_name\": \"Plan.finish_current_task\",\n \"args\": {}\n },\n {\n \"command_name\": \"TeamLeader.publish_message\",\n \"args\": {\n \"content\": \"Please design the software architecture for the snake game based on the PRD created by Alice. The PRD is at 'docs/prd/20240424153821.json'. Include the choice of programming language, libraries, and data flow, etc.\",\n \"send_to\": \"Bob\"\n }\n },\n {\n \"command_name\": \"RoleZero.reply_to_human\",\n \"args\": {\n \"content\": \"Alice has completed the PRD. I have marked her task as finished and sent the PRD to Bob. Bob will work on the software architecture.\"\n }\n },\n {\n \"command_name\": \"end\"\n }\n]\n```\n" +}, { + "req": [{ + "role": "user", + "content": "\n# Current Plan\n{'goal': \"from to {''}: how does the project go?\", 'tasks': []}\n\n# Current Task\n\n\n# Instruction\nYou are a team leader, and you are responsible for drafting tasks and routing tasks to your team members.\nYour team member:\nTim: Team Leader, \nAlice: Product Manager, efficiently create a successful product that meets market demands and user expectations\nBob: Architect, design a concise, usable, complete software system\nEve: Project Manager, break down tasks according to PRD/technical design, generate a task list, and analyze task dependencies to start with the prerequisite modules\nAlex: Engineer, Take on game, app, and web development\nDavid: DataAnalyst, Take on any data-related tasks, such as data analysis, machine learning, deep learning, web browsing, web scraping, web searching, web deployment, terminal operation, git and github operation, etc.\n\nYou should NOT assign consecutive tasks to the same team member, instead, assign an aggregated task (or the complete requirement) and let the team member to decompose it.\nWhen creating a new plan involving multiple members, create all tasks at once.\nIf plan is created, you should track the progress based on team member feedback message, and update plan accordingly, such as Plan.finish_current_task, Plan.reset_task, Plan.replace_task, etc.\nYou should use TeamLeader.publish_team_message to team members, asking them to start their task. DONT omit any necessary info such as path, link, environment, programming language, framework, requirement, constraint from original content to team members because you are their sole info source.\nPay close attention to new user message, review the conversation history, use RoleZero.reply_to_human to respond to the user directly, DON'T ask your team members.\n\nNote:\n1. If the requirement is a pure DATA-RELATED requirement, such as bug fixes, issue reporting, environment setup, terminal operations, pip install, web browsing, web scraping, web searching, web imitation, data science, data analysis, machine learning, deep learning, text-to-image etc. DON'T decompose it, assign a single task with the original user requirement as instruction directly to Data Analyst.\n2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise, usually the sequence of Product Manager -> Architect -> Project Manager -> Engineer -> (optional: QaEngine if present) -> (optional: DataAnalyst if user requests deployment), each assigned ONE task. When publishing message to Product Manager, you should directly copy the full original user requirement.\n3. If the requirement contains both DATA-RELATED part mentioned in 1 and software development part mentioned in 2, you should decompose the software development part and assign them to different team members based on their expertise, and assign the DATA-RELATED part to Data Analyst David directly.\n4. If the requirement is a common-sense, logical, or math problem, you should respond directly without assigning any task to team members.\n5. If you think the requirement is not clear or ambiguous, you should ask the user for clarification immediately. Assign tasks only after all info is clear.\n6. It is helpful for Engineer to have both the system design and the project schedule for writing the code, so include paths of both files (if available) and remind Engineer to definitely read them when publishing message to Engineer.\n7. If the requirement is writing a TRD and software framework, you should assign it to Architect. When publishing message to Architect, you should directly copy the full original user requirement.\n" + }], + "resp": "\nExplanation: The user is asking for a general update on the project status. Give a straight answer about the current task the team is working on and provide a summary of the completed tasks.\n```json\n[\n {\n \"command_name\": \"RoleZero.reply_to_human\",\n \"args\": {\n \"content\": \"The team is currently working on ... We have completed ...\"\n }\n },\n {\n \"command_name\": \"end\"\n }\n]\n```\n" +}] \ No newline at end of file diff --git a/examples/di/crawl_webpage.py b/examples/di/crawl_webpage.py index b8226f4f44..c4e1b6599c 100644 --- a/examples/di/crawl_webpage.py +++ b/examples/di/crawl_webpage.py @@ -6,16 +6,18 @@ """ from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.tools.libs.web_scraping import view_page_element_to_scrape PAPER_LIST_REQ = """" Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, -and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables* +and save it to a csv file. paper title must include `multiagent` or `large language model`. +**Notice: view the page element before writing scraping code** """ ECOMMERCE_REQ = """ Get products data from website https://scrapeme.live/shop/ and save it as a csv file. -**Notice: Firstly parse the web page encoding and the text HTML structure; -The first page product name, price, product URL, and image URL must be saved in the csv;** +The first page product name, price, product URL, and image URL must be saved in the csv. +**Notice: view the page element before writing scraping code** """ NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**; @@ -25,11 +27,12 @@ 3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间; 4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。 5. 将全部结果存在本地csv中 +**Notice: view the page element before writing scraping code** """ async def main(): - di = DataInterpreter(tools=["scrape_web_playwright"]) + di = DataInterpreter(tools=[view_page_element_to_scrape.__name__]) await di.run(ECOMMERCE_REQ) diff --git a/examples/di/fix_github_issue.py b/examples/di/fix_github_issue.py new file mode 100644 index 0000000000..4e26a375d8 --- /dev/null +++ b/examples/di/fix_github_issue.py @@ -0,0 +1,32 @@ +"""This example is from a real issue from MetaGPT: https://github.com/geekan/MetaGPT/issues/1067 with corresponding bugfix as https://github.com/geekan/MetaGPT/pull/1069 +We demonstrate that DataInterpreter has the capability to fix such issues. +Prerequisite: You need to manually add the bug back to your local file metagpt/utils/repair_llm_raw_output.py to test DataInterpreter's debugging ability. For detail, please check the issue and PR link above. +""" + +import asyncio + +from metagpt.roles.di.data_interpreter import DataInterpreter + +REQ = """ +# Requirement +Below is a github issue, solve it. Use Editor to search for the function, understand it, and modify the relevant code. +Write a new test file test.py with Editor and use Terminal to python the test file to ensure you have fixed the issue. +When writing test.py, you should import the function from the file you modified and test it with the given input. +Notice: Don't write all codes in one response, each time, just write code for one step. + +# Issue +>> s = "-1" +>> print(extract_state_value_from_output(s)) +>> 1 +The extract_state_value_from_output function will process -1 into 1, +resulted in an infinite loop for the react mode. +""" + + +async def main(): + di = DataInterpreter(tools=["Terminal", "Editor"], react_mode="react") + await di.run(REQ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/di/imitate_webpage.py b/examples/di/imitate_webpage.py index 60ebab3892..d181e0dfc5 100644 --- a/examples/di/imitate_webpage.py +++ b/examples/di/imitate_webpage.py @@ -11,10 +11,10 @@ async def main(): web_url = "https://pytorch.org/" prompt = f"""This is a URL of webpage: '{web_url}' . -Firstly, utilize Selenium and WebDriver for rendering. -Secondly, convert image to a webpage including HTML, CSS and JS in one go. +Firstly, open the page and take a screenshot of the page. +Secondly, convert the image to a webpage including HTML, CSS and JS in one go. Note: All required dependencies and environments have been fully installed and configured.""" - di = DataInterpreter(tools=["GPTvGenerator"]) + di = DataInterpreter(tools=["GPTvGenerator", "Browser"]) await di.run(prompt) diff --git a/examples/di/run_flask.py b/examples/di/run_flask.py new file mode 100644 index 0000000000..b57f763f3a --- /dev/null +++ b/examples/di/run_flask.py @@ -0,0 +1,19 @@ +import asyncio + +from metagpt.roles.di.data_interpreter import DataInterpreter + +USE_GOT_REPO_REQ = """ +Write a service using Flask, create a conda environment and run it, and call the service's interface for validation. +Notice: Don't write all codes in one response, each time, just write code for one step. +""" +# If you have created a conda environment, you can say: +# I have created the conda environment '{env_name}', please use this environment to execute. + + +async def main(): + di = DataInterpreter(tools=["Terminal", "Editor"]) + await di.run(USE_GOT_REPO_REQ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/di/software_company.py b/examples/di/software_company.py new file mode 100644 index 0000000000..ac9999ca93 --- /dev/null +++ b/examples/di/software_company.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import fire + +from metagpt.roles.di.data_interpreter import DataInterpreter + + +async def main(): + prompt = """ +This is a software requirement: +```text +write a snake game +``` +--- +1. Writes a PRD based on software requirements. +2. Writes a design to the project repository, based on the PRD of the project. +3. Writes a project plan to the project repository, based on the design of the project. +4. Writes codes to the project repository, based on the project plan of the project. +5. Run QA test on the project repository. +6. Stage and commit changes for the project repository using Git. +Note: All required dependencies and environments have been fully installed and configured. +""" + di = DataInterpreter( + tools=[ + "write_prd", + "write_design", + "write_project_plan", + "write_codes", + "run_qa_test", + "fix_bug", + "git_archive", + ] + ) + + await di.run(prompt) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/di/use_browser.py b/examples/di/use_browser.py new file mode 100644 index 0000000000..a3a079ccc6 --- /dev/null +++ b/examples/di/use_browser.py @@ -0,0 +1,29 @@ +import asyncio + +from metagpt.roles.di.data_interpreter import DataInterpreter + +MG_LLM_CONFIG_REQ = """ +This is a link to the doc site of MetaGPT project: https://docs.deepwisdom.ai/main/en/ +Check where you can go to on the site and try to find out the list of LLM APIs supported by MetaGPT. +Don't write all codes in one response, each time, just write code for one step. +""" + +PAPER_LIST_REQ = """" +At https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, +find the first paper whose title includes `multiagent`, open it and summarize its abstract. +Don't write all codes in one response, each time, just write code for one step. +""" + +DESCRIBE_GITHUB_ISSUE_REQ = """ +Visit https://github.com/geekan/MetaGPT, navigate to Issues page, open the first issue related to DataInterpreter, then summarize what the issue is in one sentence. +Don't write all codes in one response, each time, just write code for one step. +""" + + +async def main(): + di = DataInterpreter(tools=["Browser"], react_mode="react") + await di.run(MG_LLM_CONFIG_REQ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/di/use_github_repo.py b/examples/di/use_github_repo.py new file mode 100644 index 0000000000..7327f4597b --- /dev/null +++ b/examples/di/use_github_repo.py @@ -0,0 +1,19 @@ +import asyncio + +from metagpt.roles.di.data_interpreter import DataInterpreter + +USE_GOT_REPO_REQ = """ +This is a link to the GOT github repo: https://github.com/spcl/graph-of-thoughts.git. +Clone it, read the README to understand the usage, install it, and finally run the quick start example. +**Note the config for LLM is at `config/config_got.json`, it's outside the repo path, before using it, you need to copy it into graph-of-thoughts. +** Don't write all codes in one response, each time, just write code for one step. +""" + + +async def main(): + di = DataInterpreter(tools=["Terminal"]) + await di.run(USE_GOT_REPO_REQ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/exp_pool/README.md b/examples/exp_pool/README.md new file mode 100644 index 0000000000..37e7853f8b --- /dev/null +++ b/examples/exp_pool/README.md @@ -0,0 +1,20 @@ +# Experience Pool + +## Prerequisites +- Ensure the RAG module is installed: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html +- Set embedding: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html +- Set `enabled`、`enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml` + +## Example Files + +### 1. decorator.py +Showcases the implementation of the `@exp_cache` decorator. + +### 2. init_exp_pool.py +Demonstrates the process of initializing the experience pool. + +### 3. manager.py +Illustrates CRUD (Create, Read, Update, Delete) operations for managing experiences in the pool. + +### 4. scorer.py +Outlines methods for evaluating and scoring experiences within the pool. diff --git a/examples/exp_pool/decorator.py b/examples/exp_pool/decorator.py new file mode 100644 index 0000000000..8ee00905dd --- /dev/null +++ b/examples/exp_pool/decorator.py @@ -0,0 +1,28 @@ +""" +This script demonstrates how to automatically store experiences using @exp_cache and query the stored experiences. +""" + +import asyncio +import uuid + +from metagpt.exp_pool import exp_cache, get_exp_manager +from metagpt.logs import logger + + +@exp_cache() +async def produce(req=""): + return f"{req} {uuid.uuid4().hex}" + + +async def main(): + req = "Water" + + resp = await produce(req=req) + logger.info(f"The response of `produce({req})` is: {resp}") + + exps = await get_exp_manager().query_exps(req) + logger.info(f"Find experiences: {exps}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/exp_pool/init_exp_pool.py b/examples/exp_pool/init_exp_pool.py new file mode 100644 index 0000000000..c7412af225 --- /dev/null +++ b/examples/exp_pool/init_exp_pool.py @@ -0,0 +1,97 @@ +"""Init experience pool. + +Put some useful experiences into the experience pool. +""" + +import asyncio +import json +from pathlib import Path + +from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.exp_pool import get_exp_manager +from metagpt.exp_pool.schema import EntryType, Experience, Metric, Score +from metagpt.logs import logger +from metagpt.utils.common import aread + + +async def load_file(filepath) -> list[dict]: + """Asynchronously loads and parses a JSON file. + + Args: + filepath: Path to the JSON file. + + Returns: + A list of dictionaries parsed from the JSON file. + """ + + return json.loads(await aread(filepath)) + + +async def add_exp(req: str, resp: str, tag: str, metric: Metric = None): + """Adds a new experience to the experience pool. + + Args: + req: The request string. + resp: The response string. + tag: A tag for categorizing the experience. + metric: Optional metric for the experience. Defaults to a score of 10. + + """ + + exp = Experience( + req=req, + resp=resp, + entry_type=EntryType.MANUAL, + tag=tag, + metric=metric or Metric(score=Score(val=10, reason="Manual")), + ) + exp_manager = get_exp_manager() + exp_manager.is_writable = True + + exp_manager.create_exp(exp) + logger.info(f"New experience created for the request `{req[:10]}`.") + + +async def add_exps(exps: list, tag: str): + """Adds multiple experiences to the experience pool. + + Args: + exps: A list of experience dictionaries. + tag: A tag for categorizing the experiences. + + """ + tasks = [ + add_exp(req=exp["req"] if isinstance(exp["req"], str) else json.dumps(exp["req"]), resp=exp["resp"], tag=tag) + for exp in exps + ] + await asyncio.gather(*tasks) + + +async def add_exps_from_file(tag: str, filepath: Path): + """Loads experiences from a file and adds them to the experience pool. + + Args: + tag: A tag for categorizing the experiences. + filepath: Path to the file containing experiences. + + """ + + exps = await load_file(filepath) + await add_exps(exps, tag) + + +def query_exps_count(): + """Queries and logs the total count of experiences in the pool.""" + exp_manager = get_exp_manager() + count = exp_manager.get_exps_count() + logger.info(f"Experiences Count: {count}") + + +async def main(): + await add_exps_from_file("TeamLeader.llm_cached_aask", EXAMPLE_DATA_PATH / "exp_pool/team_leader_exps.json") + await add_exps_from_file("Engineer2.llm_cached_aask", EXAMPLE_DATA_PATH / "exp_pool/engineer_exps.json") + query_exps_count() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/exp_pool/load_exps_from_log.py b/examples/exp_pool/load_exps_from_log.py new file mode 100644 index 0000000000..77eeff6dd7 --- /dev/null +++ b/examples/exp_pool/load_exps_from_log.py @@ -0,0 +1,85 @@ +"""Load and save experiences from the log file.""" + +import json +from pathlib import Path + +from metagpt.exp_pool import get_exp_manager +from metagpt.exp_pool.schema import LOG_NEW_EXPERIENCE_PREFIX, Experience +from metagpt.logs import logger + + +def load_exps(log_file_path: str) -> list[Experience]: + """Loads experiences from a log file. + + Args: + log_file_path (str): The path to the log file. + + Returns: + list[Experience]: A list of Experience objects loaded from the log file. + """ + + if not Path(log_file_path).exists(): + logger.warning(f"`load_exps` called with a non-existent log file path: {log_file_path}") + return + + exps = [] + with open(log_file_path, "r") as log_file: + for line in log_file: + if LOG_NEW_EXPERIENCE_PREFIX in line: + json_str = line.split(LOG_NEW_EXPERIENCE_PREFIX, 1)[1].strip() + exp_data = json.loads(json_str) + + exp = Experience(**exp_data) + exps.append(exp) + + logger.info(f"Loaded {len(exps)} experiences from log file: {log_file_path}") + + return exps + + +def save_exps(exps: list[Experience]): + """Saves a list of experiences to the experience pool. + + Args: + exps (list[Experience]): The list of experiences to save. + """ + + if not exps: + logger.warning("`save_exps` called with an empty list of experiences.") + return + + manager = get_exp_manager() + manager.is_writable = True + + manager.create_exps(exps) + logger.info(f"Saved {len(exps)} experiences.") + + +def get_log_file_path() -> str: + """Retrieves the path to the log file. + + Returns: + str: The path to the log file. + + Raises: + ValueError: If the log file path cannot be found. + """ + + handlers = logger._core.handlers + + for handler in handlers.values(): + if "log" in handler._name: + return handler._name[1:-1] + + raise ValueError("Log file not found") + + +def main(): + log_file_path = get_log_file_path() + + exps = load_exps(log_file_path) + save_exps(exps) + + +if __name__ == "__main__": + main() diff --git a/examples/exp_pool/manager.py b/examples/exp_pool/manager.py new file mode 100644 index 0000000000..c9ec46da5d --- /dev/null +++ b/examples/exp_pool/manager.py @@ -0,0 +1,31 @@ +""" +Demonstrate the creation and querying of experiences. + +This script creates a new experience, logs its creation, and then queries for experiences matching the same request. +""" + +import asyncio + +from metagpt.exp_pool import get_exp_manager +from metagpt.exp_pool.schema import EntryType, Experience +from metagpt.logs import logger + + +async def main(): + # Define the simple request and response + req = "Simple req" + resp = "Simple resp" + + # Add the new experience + exp = Experience(req=req, resp=resp, entry_type=EntryType.MANUAL) + exp_manager = get_exp_manager() + exp_manager.create_exp(exp) + logger.info(f"New experience created for the request `{req}`.") + + # Query for experiences matching the request + exps = await exp_manager.query_exps(req) + logger.info(f"Got experiences: {exps}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/exp_pool/scorer.py b/examples/exp_pool/scorer.py new file mode 100644 index 0000000000..aafcee63ff --- /dev/null +++ b/examples/exp_pool/scorer.py @@ -0,0 +1,44 @@ +import asyncio + +from metagpt.exp_pool.scorers import SimpleScorer + +# Request to implement quicksort in Python +REQ = "Write a program to implement quicksort in python." + +# First response: Quicksort implementation without base case +RESP1 = """ +def quicksort(arr): + return quicksort([x for x in arr[1:] if x <= arr[0]]) + [arr[0]] + quicksort([x for x in arr[1:] if x > arr[0]]) +""" + +# Second response: Quicksort implementation with base case +RESP2 = """ +def quicksort(arr): + if len(arr) <= 1: + return arr + return quicksort([x for x in arr[1:] if x <= arr[0]]) + [arr[0]] + quicksort([x for x in arr[1:] if x > arr[0]]) +""" + + +async def simple(): + """Evaluates two quicksort implementations using SimpleScorer. + + Example: + { + "val": 3, + "reason": "The response attempts to implement quicksort but contains a critical flaw: it lacks a base case to terminate the recursion, which will lead to a maximum recursion depth exceeded error for non-empty lists. Additionally, the function does not handle empty lists properly. A correct implementation should include a base case to handle lists of length 0 or 1." + } + """ + + scorer = SimpleScorer() + + await scorer.evaluate(req=REQ, resp=RESP1) + await scorer.evaluate(req=REQ, resp=RESP2) + + +async def main(): + await simple() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mgx/__init__.py b/examples/mgx/__init__.py new file mode 100644 index 0000000000..f12b94354a --- /dev/null +++ b/examples/mgx/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : diff --git a/examples/mgx/run_mgx.py b/examples/mgx/run_mgx.py new file mode 100644 index 0000000000..86aa67ad71 --- /dev/null +++ b/examples/mgx/run_mgx.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import asyncio + +from metagpt.roles.di.mgx import MGX + +requirement = ( + # "design a game using Gym (an open source Python library), including a graphical interface and interactive gameplay" + # "帮我把pip的源设置成:https://pypi.tuna.tsinghua.edu.cn/simple" + # "This is a website url does not require login: https://demosc.chinaz.net/Files/DownLoad//moban/202404/moban7767 please write a similar web page,developed in vue language, The package.json dependency must be generated" + "I would like to imitate the website available at https://demosc.chinaz.net/Files/DownLoad//moban/202404/moban7767. Could you please browse through it?" + # "Create a 2048 Game" +) + + +async def main(requirement: str = ""): + mgx = MGX(use_intent=True, tools=[""]) + await mgx.run(requirement) + + +if __name__ == "__main__": + asyncio.run(main(requirement)) diff --git a/examples/mgx_write_project_framework.py b/examples/mgx_write_project_framework.py new file mode 100644 index 0000000000..b43d97b850 --- /dev/null +++ b/examples/mgx_write_project_framework.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : write_project_framework.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +import asyncio +import json +import uuid +from json import JSONDecodeError +from pathlib import Path +from typing import Dict, List + +import typer +from pydantic import BaseModel + +from metagpt.config2 import Config +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.context import Context +from metagpt.environment import Environment +from metagpt.environment.mgx.mgx_env import MGXEnv +from metagpt.logs import logger +from metagpt.roles import Architect +from metagpt.roles.di.team_leader import TeamLeader +from metagpt.schema import AIMessage, UserMessage +from metagpt.strategy.experience_retriever import TRDToolExpRetriever +from metagpt.utils.common import aread + +app = typer.Typer(add_completion=False) + + +class EnvBuilder(BaseModel): + context: Context + user_requirements: List[str] + actors: Dict[str, str] + technical_constraint: str + output_dir: Path + + def build(self) -> Environment: + env = MGXEnv(context=self.context) + team_leader = TeamLeader() + architect = Architect(experience_retriever=TRDToolExpRetriever()) + + # Prepare context + use_case_actors = "".join([f"- {v}: {k}\n" for k, v in self.actors.items()]) + msg = """ +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram. +## Actor, System, External System +{use_case_actors} + """ + architect.rc.memory.add(AIMessage(content=msg.format(use_case_actors=use_case_actors))) + + # Prepare technical requirements + msg = """ +"Additional Technical Requirements" specifies the additional technical requirements that the generated software framework code must meet. +## Additional Technical Requirements +{technical_requirements} +""" + architect.rc.memory.add(AIMessage(content=msg.format(technical_requirements=self.technical_constraint))) + + env.add_roles([team_leader, architect]) + return env + + +async def develop( + context: Context, + user_requirement_filename: str, + actors_filename: str, + constraint_filename: str, + output_dir: str, +): + output_dir = Path(output_dir) if output_dir else DEFAULT_WORKSPACE_ROOT / uuid.uuid4().hex + + v = await aread(filename=user_requirement_filename) + try: + user_requirements = json.loads(v) + except JSONDecodeError: + user_requirements = [v] + v = await aread(filename=actors_filename) + actors = json.loads(v) + technical_constraint = await aread(filename=constraint_filename) + env_builder = EnvBuilder( + context=context, + user_requirements=user_requirements, + actors=actors, + technical_constraint=technical_constraint, + output_dir=output_dir, + ) + env = env_builder.build() + msg = """ +Given the user requirement of "User Requirements", write out the software framework. +## User Requirements +{user_requirements} + """ + env.publish_message( + UserMessage(content=msg.format(user_requirements="\n".join(user_requirements)), send_to="Bob"), + user_defined_recipient="Bob", + ) + + while not env.is_idle: + await env.run() + + +@app.command() +def startup( + user_requirement_filename: str = typer.Argument(..., help="The filename of the user requirements."), + actors_filename: str = typer.Argument(..., help="The filename of UML use case actors description."), + llm_config: str = typer.Option(default="", help="Low-cost LLM config"), + constraint_filename: str = typer.Option(default="", help="What technical dependency constraints are."), + output_dir: str = typer.Option(default="", help="Output directory."), +): + if llm_config and Path(llm_config).exists(): + config = Config.from_yaml_file(Path(llm_config)) + else: + logger.info("GPT 4 turbo is recommended") + config = Config.default() + ctx = Context(config=config) + + asyncio.run(develop(ctx, user_requirement_filename, actors_filename, constraint_filename, output_dir)) + + +if __name__ == "__main__": + app() diff --git a/examples/search_enhanced_qa.py b/examples/search_enhanced_qa.py new file mode 100644 index 0000000000..9eb5449a49 --- /dev/null +++ b/examples/search_enhanced_qa.py @@ -0,0 +1,27 @@ +""" +This script demonstrates how to use the SearchEnhancedQA action to answer questions +by leveraging web search results. It showcases a simple example of querying about +the current weather in Beijing. + +The SearchEnhancedQA action combines web search capabilities with natural language +processing to provide informative answers to user queries. +""" + +import asyncio + +from metagpt.actions.search_enhanced_qa import SearchEnhancedQA + + +async def main(): + """Runs a sample query through SearchEnhancedQA and prints the result.""" + + action = SearchEnhancedQA() + + query = "What is the weather like in Beijing today?" + answer = await action.run(query) + + print(f"The answer to '{query}' is:\n\n{answer}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/serialize_model.py b/examples/serialize_model.py new file mode 100644 index 0000000000..2423efef8c --- /dev/null +++ b/examples/serialize_model.py @@ -0,0 +1,25 @@ +from metagpt.environment.mgx.mgx_env import MGXEnv +from metagpt.logs import logger + + +def main(): + """Demonstrates serialization and deserialization using SerializationMixin. + + This example creates an instance of MGXEnv, serializes it to a file, + and then deserializes it back to an instance. + + If executed correctly, the following log messages will be output: + MGXEnv serialization successful. File saved at: /.../workspace/storage/MGXEnv.json + MGXEnv deserialization successful. Instance created from file: /.../workspace/storage/MGXEnv.json + The instance is MGXEnv() + """ + + env = MGXEnv() + env.serialize() + + env: MGXEnv = MGXEnv.deserialize() + logger.info(f"The instance is {repr(env)}") + + +if __name__ == "__main__": + main() diff --git a/examples/write_novel.py b/examples/write_novel.py index a6e9ce05d8..f49918fbba 100644 --- a/examples/write_novel.py +++ b/examples/write_novel.py @@ -50,9 +50,9 @@ async def generate_novel(): "Fill the empty nodes with your own ideas. Be creative! Use your own words!" "I will tip you $100,000 if you write a good novel." ) - novel_node = await ActionNode.from_pydantic(Novel).fill(context=instruction, llm=LLM()) + novel_node = await ActionNode.from_pydantic(Novel).fill(req=instruction, llm=LLM()) chap_node = await ActionNode.from_pydantic(Chapters).fill( - context=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM() + req=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM() ) print(chap_node.instruct_content) diff --git a/examples/write_project_framework.py b/examples/write_project_framework.py new file mode 100644 index 0000000000..8d23695a7c --- /dev/null +++ b/examples/write_project_framework.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : write_project_framework.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +import asyncio +import json +import uuid +from pathlib import Path +from typing import Dict, List + +import typer + +from metagpt.actions.requirement_analysis.framework import ( + EvaluateFramework, + WriteFramework, + save_framework, +) +from metagpt.actions.requirement_analysis.trd import ( + CompressExternalInterfaces, + DetectInteraction, + EvaluateTRD, + WriteTRD, +) +from metagpt.config2 import Config +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.context import Context +from metagpt.logs import logger +from metagpt.utils.common import aread + +app = typer.Typer(add_completion=False) + + +async def _write_trd( + context: Context, actors: Dict[str, str], user_requirements: List[str], available_external_interfaces: str +) -> (str, str): + detect_interaction = DetectInteraction(context=context) + write_trd = WriteTRD(context=context) + evaluate_trd = EvaluateTRD(context=context) + use_case_actors = "".join([f"- {v}: {k}\n" for k, v in actors.items()]) + legacy_user_requirements = [] + legacy_user_requirements_interaction_events = [] + legacy_user_requirements_trd = "" + for ix, r in enumerate(user_requirements): + is_pass = False + evaluation_conclusion = "" + interaction_events = "" + trd = "" + while not is_pass and (context.cost_manager.total_cost < context.cost_manager.max_budget): + interaction_events = await detect_interaction.run( + user_requirements=r, + use_case_actors=use_case_actors, + legacy_interaction_events=interaction_events, + evaluation_conclusion=evaluation_conclusion, + ) + if ix == 0: + trd = await write_trd.run( + user_requirements=r, + use_case_actors=use_case_actors, + available_external_interfaces=available_external_interfaces, + evaluation_conclusion=evaluation_conclusion, + interaction_events=interaction_events, + previous_version_trd=trd, + ) + else: + trd = await write_trd.run( + user_requirements=r, + use_case_actors=use_case_actors, + available_external_interfaces=available_external_interfaces, + evaluation_conclusion=evaluation_conclusion, + interaction_events=interaction_events, + previous_version_trd=trd, + legacy_user_requirements="\n".join(legacy_user_requirements), + legacy_user_requirements_trd=legacy_user_requirements_trd, + legacy_user_requirements_interaction_events="\n".join(legacy_user_requirements_interaction_events), + ) + evaluation = await evaluate_trd.run( + user_requirements=r, + use_case_actors=use_case_actors, + trd=trd, + interaction_events=interaction_events, + legacy_user_requirements_interaction_events="\n".join(legacy_user_requirements_interaction_events), + ) + is_pass = evaluation.is_pass + evaluation_conclusion = evaluation.conclusion + legacy_user_requirements.append(r) + legacy_user_requirements_interaction_events.append(interaction_events) + legacy_user_requirements_trd = trd + + return use_case_actors, legacy_user_requirements_trd + + +async def _write_framework(context: Context, use_case_actors: str, trd: str, acknowledge: str, constraint: str) -> str: + write_framework = WriteFramework(context=context) + evaluate_framework = EvaluateFramework(context=context) + is_pass = False + framework = "" + evaluation_conclusion = "" + while not is_pass and (context.cost_manager.total_cost < context.cost_manager.max_budget): + try: + framework = await write_framework.run( + use_case_actors=use_case_actors, + trd=trd, + acknowledge=acknowledge, + legacy_output=framework, + evaluation_conclusion=evaluation_conclusion, + additional_technical_requirements=constraint, + ) + except Exception as e: + logger.info(f"{e}") + break + evaluation = await evaluate_framework.run( + use_case_actors=use_case_actors, + trd=trd, + acknowledge=acknowledge, + legacy_output=framework, + additional_technical_requirements=constraint, + ) + is_pass = evaluation.is_pass + evaluation_conclusion = evaluation.conclusion + return framework + + +async def develop( + context: Context, + user_requirement_filename: str, + actors_filename: str, + acknowledge_filename: str, + constraint_filename: str, + output_dir: str, +): + output_dir = Path(output_dir) if output_dir else DEFAULT_WORKSPACE_ROOT / uuid.uuid4().hex + + v = await aread(filename=user_requirement_filename) + user_requirements = json.loads(v) + v = await aread(filename=actors_filename) + actors = json.loads(v) + acknowledge = await aread(filename=acknowledge_filename) + technical_constraint = await aread(filename=constraint_filename) + + # Compress acknowledge + compress_acknowledge = CompressExternalInterfaces(context=context) + available_external_interfaces = await compress_acknowledge.run(acknowledge=acknowledge) + + # Write TRD + use_case_actors, trd = await _write_trd( + context=context, + actors=actors, + user_requirements=user_requirements, + available_external_interfaces=available_external_interfaces, + ) + + # Write framework + framework = await _write_framework( + context=context, + use_case_actors=use_case_actors, + trd=trd, + acknowledge=acknowledge, + constraint=technical_constraint, + ) + + # Save + file_list = await save_framework(dir_data=framework, trd=trd, output_dir=output_dir) + logger.info(f"Output:\n{file_list}") + + +@app.command() +def startup( + user_requirement_filename: str = typer.Argument(..., help="The filename of the user requirements."), + actors_filename: str = typer.Argument(..., help="The filename of UML use case actors description."), + acknowledge_filename: str = typer.Argument(..., help="External interfaces declarations."), + llm_config: str = typer.Option(default="", help="Low-cost LLM config"), + constraint_filename: str = typer.Option(default="", help="What technical dependency constraints are."), + output_dir: str = typer.Option(default="", help="Output directory."), + investment: float = typer.Option(default=15.0, help="Dollar amount to invest in the AI company."), +): + if llm_config and Path(llm_config).exists(): + config = Config.from_yaml_file(Path(llm_config)) + else: + logger.info("GPT 4 turbo is recommended") + config = Config.default() + ctx = Context(config=config) + ctx.cost_manager.max_budget = investment + + asyncio.run( + develop(ctx, user_requirement_filename, actors_filename, acknowledge_filename, constraint_filename, output_dir) + ) + + +if __name__ == "__main__": + app() diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1b93213f77..8733947f5a 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -22,7 +22,6 @@ SerializationMixin, TestingContext, ) -from metagpt.utils.project_repo import ProjectRepo class Action(SerializationMixin, ContextMixin, BaseModel): @@ -36,12 +35,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel): desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - @property - def repo(self) -> ProjectRepo: - if not self.context.repo: - self.context.repo = ProjectRepo(self.context.git_repo) - return self.context.repo - @property def prompt_schema(self): return self.config.prompt_schema @@ -97,10 +90,15 @@ async def _run_action_node(self, *args, **kwargs): msgs = args[0] context = "## History Messages\n" context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))]) - return await self.node.fill(context=context, llm=self.llm) + return await self.node.fill(req=context, llm=self.llm) async def run(self, *args, **kwargs): """Run action""" if self.node: return await self._run_action_node(*args, **kwargs) raise NotImplementedError("The run method should be implemented in a subclass.") + + def override_context(self): + """Set `private_context` and `context` to the same `Context` object.""" + if not self.private_context: + self.private_context = self.context diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 31e4cc0fc5..c1de166565 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -17,7 +17,9 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_outcls_registry import register_action_outcls -from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT +from metagpt.exp_pool import exp_cache +from metagpt.exp_pool.serializers import ActionNodeSerializer from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess @@ -113,7 +115,7 @@ class ReviseMode(Enum): """ -def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): +def dict_to_markdown(d, prefix=MARKDOWN_TITLE_PREFIX, kv_sep="\n", postfix="\n"): markdown_str = "" for key, value in d.items(): markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" @@ -465,9 +467,11 @@ async def simple_fill( return self + @exp_cache(serializer=ActionNodeSerializer()) async def fill( self, - context, + *, + req, llm, schema="json", mode="auto", @@ -478,7 +482,7 @@ async def fill( ): """Fill the node(s) with mode. - :param context: Everything we should know when filling node. + :param req: Everything we should know when filling node. :param llm: Large Language Model with pre-defined system message. :param schema: json/markdown, determine example and output format. - raw: free form text @@ -497,7 +501,7 @@ async def fill( :return: self """ self.set_llm(llm) - self.set_context(context) + self.set_context(req) if self.schema: schema = self.schema diff --git a/metagpt/actions/analyze_requirements.py b/metagpt/actions/analyze_requirements.py new file mode 100644 index 0000000000..86088d824e --- /dev/null +++ b/metagpt/actions/analyze_requirements.py @@ -0,0 +1,76 @@ +from metagpt.actions import Action + +ANALYZE_REQUIREMENTS = """ +# Example +{examples} + +---------------- + +# Requirements +{requirements} + +# Instructions +{instructions} + +# Output Format +{output_format} + +Follow the instructions and output format. Do not include any additional content. +""" + +EXAMPLES = """ +Example 1 +Requirements: +创建一个贪吃蛇,只需要给出设计文档和代码 +Outputs: +[User Restrictions] : 只需要给出设计文档和代码. +[Language Restrictions] : The response, message and instruction must be in Chinese. +[Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js) + +Example 2 +Requirements: +Create 2048 game using Python. Do not write PRD. +Outputs: +[User Restrictions] : Do not write PRD. +[Language Restrictions] : The response, message and instruction must be in English. +[Programming Language] : Python + +Example 3 +Requirements: +You must ignore create PRD and TRD. Help me write a schedule display program for the Paris Olympics. +Outputs: +[User Restrictions] : You must ignore create PRD and TRD. +[Language Restrictions] : The response, message and instruction must be in English. +[Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js) +""" + +INSTRUCTIONS = """ +You must output in the same language as the Requirements. +First, This language should be consistent with the language used in the requirement description. determine the natural language you must respond in. If the requirements specify a special language, follow those instructions. The default language for responses is English. +Second, extract the restrictions in the requirements, specifically the steps. Do not include detailed demand descriptions; focus only on the restrictions. +Third, if the requirements is a software development, extract the program language. If no specific programming language is required, Use HTML (*.html), CSS (*.css), and JavaScript (*.js) + +Note: +1. if there is not restrictions, requirements_restrictions must be "" +2. if the requirements is a not software development, programming language must be "" +""" + +OUTPUT_FORMAT = """ +[User Restrictions] : the restrictions in the requirements +[Language Restrictions] : The response, message and instruction must be in {{language}} +[Programming Language] : Your program must use ... +""" + + +class AnalyzeRequirementsRestrictions(Action): + """Write a review for the given context.""" + + name: str = "AnalyzeRequirementsRestrictions" + + async def run(self, requirements, isinstance=INSTRUCTIONS, output_format=OUTPUT_FORMAT): + """Analyze the constraints and the language used in the requirements.""" + prompt = ANALYZE_REQUIREMENTS.format( + examples=EXAMPLES, requirements=requirements, instructions=isinstance, output_format=output_format + ) + rsp = await self.llm.aask(prompt) + return rsp diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 5ed31bed8c..8f0f52266f 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -9,13 +9,15 @@ 2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. """ import re +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser +from metagpt.utils.project_repo import ProjectRepo PROMPT_TEMPLATE = """ NOTICE @@ -47,6 +49,8 @@ class DebugError(Action): i_context: RunCodeContext = Field(default_factory=RunCodeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, *args, **kwargs) -> str: output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename) @@ -59,9 +63,7 @@ async def run(self, *args, **kwargs) -> str: return "" logger.info(f"Debug and rewrite {self.i_context.test_filename}") - code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get( - filename=self.i_context.code_filename - ) + code_doc = await self.repo.srcs.get(filename=self.i_context.code_filename) if not code_doc: return "" test_doc = await self.repo.tests.get(filename=self.i_context.test_filename) @@ -70,6 +72,6 @@ async def run(self, *args, **kwargs) -> str: prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr) rsp = await self._aask(prompt) - code = CodeParser.parse_code(block="", text=rsp) + code = CodeParser.parse_code(text=rsp) return code diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index e5f038c7c8..68a66d5a49 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -8,12 +8,15 @@ 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. 2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json from pathlib import Path -from typing import Optional +from typing import List, Optional, Union -from metagpt.actions import Action, ActionOutput +from pydantic import BaseModel, Field + +from metagpt.actions import Action from metagpt.actions.design_api_an import ( DATA_STRUCTURES_AND_INTERFACES, DESIGN_API_NODE, @@ -24,8 +27,18 @@ ) from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO from metagpt.logs import logger -from metagpt.schema import Document, Documents, Message +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import ( + aread, + awrite, + rectify_pathname, + save_json_to_markdown, + to_markdown_code_block, +) from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import DocsReporter, GalleryReporter NEW_REQ_TEMPLATE = """ ### Legacy Content @@ -36,6 +49,7 @@ """ +@register_tool(include_functions=["run"]) class WriteDesign(Action): name: str = "" i_context: Optional[str] = None @@ -44,21 +58,98 @@ class WriteDesign(Action): "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." ) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + prd_filename: str = "", + legacy_design_filename: str = "", + extra_info: str = "", + output_pathname: str = "", + **kwargs, + ) -> Union[AIMessage, str]: + """ + Write a system design. + + Args: + user_requirement (str): The user's requirements for the system design. + prd_filename (str, optional): The filename of the Product Requirement Document (PRD). + legacy_design_filename (str, optional): The filename of the legacy design document. + extra_info (str, optional): Additional information to be included in the system design. + output_pathname (str, optional): The output file path of the document. + + Returns: + str: The file path of the generated system design. + + Example: + # Write a new system design and save to the path name. + >>> user_requirement = "Write system design for a snake game" + >>> extra_info = "Your extra information" + >>> output_pathname = "snake_game/docs/system_design.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/system_design.json" + + # Rewrite an existing system design and save to the path name. + >>> user_requirement = "Write system design for a snake game, include new features such as a web UI" + >>> extra_info = "Your extra information" + >>> legacy_design_filename = "/absolute/path/to/snake_game/docs/system_design.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/system_design_new.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/system_design_new.json" - async def run(self, with_messages: Message, schema: str = None): - # Use `git status` to identify which PRD documents have been modified in the `docs/prd` directory. - changed_prds = self.repo.docs.prd.changed_files - # Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone - # changes. - changed_system_designs = self.repo.docs.system_design.changed_files + # Write a new system design with the given PRD(Product Requirement Document) and save to the path name. + >>> user_requirement = "Write system design for a snake game based on the PRD at /absolute/path/to/snake_game/docs/prd.json" + >>> extra_info = "Your extra information" + >>> prd_filename = "/absolute/path/to/snake_game/docs/prd.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/sytem_design.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/sytem_design.json" + + # Rewrite an existing system design with the given PRD(Product Requirement Document) and save to the path name. + >>> user_requirement = "Write system design for a snake game, include new features such as a web UI" + >>> extra_info = "Your extra information" + >>> prd_filename = "/absolute/path/to/snake_game/docs/prd.json" + >>> legacy_design_filename = "/absolute/path/to/snake_game/docs/system_design.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/system_design_new.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/system_design_new.json" + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + prd_filename=prd_filename, + legacy_design_filename=legacy_design_filename, + extra_info=extra_info, + output_pathname=output_pathname, + ) + + self.input_args = with_messages[-1].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_prds = self.input_args.changed_prd_filenames + changed_system_designs = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] # For those PRDs and design documents that have undergone changes, regenerate the design content. changed_files = Documents() - for filename in changed_prds.keys(): + for filename in changed_prds: doc = await self._update_system_design(filename=filename) changed_files.docs[filename] = doc - for filename in changed_system_designs.keys(): + for filename in changed_system_designs: if filename in changed_files.docs: continue doc = await self._update_system_design(filename=filename) @@ -67,54 +158,122 @@ async def run(self, with_messages: Message, schema: str = None): logger.info("Nothing has changed.") # Wait until all files under `docs/system_designs/` are processed before sending the publish message, # leaving room for global optimization in subsequent steps. - return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files) + kvs = self.input_args.model_dump() + kvs["changed_system_design_filenames"] = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] + return AIMessage( + content="Designing is complete. " + + "\n".join( + list(self.repo.docs.system_design.changed_files.keys()) + + list(self.repo.resources.data_api_design.changed_files.keys()) + + list(self.repo.resources.seq_flow.changed_files.keys()) + ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput"), + cause_by=self, + ) async def _new_system_design(self, context): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm) + node = await DESIGN_API_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) return node async def _merge(self, prd_doc, system_design_doc): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm) + node = await REFINED_DESIGN_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) system_design_doc.content = node.instruct_content.model_dump_json() return system_design_doc async def _update_system_design(self, filename) -> Document: - prd = await self.repo.docs.prd.get(filename) - old_system_design_doc = await self.repo.docs.system_design.get(filename) - if not old_system_design_doc: - system_design = await self._new_system_design(context=prd.content) - doc = await self.repo.docs.system_design.save( - filename=filename, - content=system_design.instruct_content.model_dump_json(), - dependencies={prd.root_relative_path}, - ) - else: - doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc) - await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path}) - await self._save_data_api_design(doc) - await self._save_seq_flow(doc) - await self.repo.resources.system_design.save_pdf(doc=doc) + root_relative_path = Path(filename).relative_to(self.repo.workdir) + prd = await Document.load(filename=filename, project_path=self.repo.workdir) + old_system_design_doc = await self.repo.docs.system_design.get(root_relative_path.name) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "design"}, "meta") + if not old_system_design_doc: + system_design = await self._new_system_design(context=prd.content) + doc = await self.repo.docs.system_design.save( + filename=prd.filename, + content=system_design.instruct_content.model_dump_json(), + dependencies={prd.root_relative_path}, + ) + else: + doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc) + await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path}) + await self._save_data_api_design(doc) + await self._save_seq_flow(doc) + md = await self.repo.resources.system_design.save_pdf(doc=doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") return doc - async def _save_data_api_design(self, design_doc): + async def _save_data_api_design(self, design_doc, output_filename: Path = None): m = json.loads(design_doc.content) data_api_design = m.get(DATA_STRUCTURES_AND_INTERFACES.key) or m.get(REFINED_DATA_STRUCTURES_AND_INTERFACES.key) if not data_api_design: return - pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("") + pathname = output_filename or self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path( + design_doc.filename + ).with_suffix("") await self._save_mermaid_file(data_api_design, pathname) logger.info(f"Save class view to {str(pathname)}") - async def _save_seq_flow(self, design_doc): + async def _save_seq_flow(self, design_doc, output_filename: Path = None): m = json.loads(design_doc.content) seq_flow = m.get(PROGRAM_CALL_FLOW.key) or m.get(REFINED_PROGRAM_CALL_FLOW.key) if not seq_flow: return - pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("") + pathname = output_filename or self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path( + design_doc.filename + ).with_suffix("") await self._save_mermaid_file(seq_flow, pathname) logger.info(f"Saving sequence flow to {str(pathname)}") async def _save_mermaid_file(self, data: str, pathname: Path): pathname.parent.mkdir(parents=True, exist_ok=True) await mermaid_to_file(self.config.mermaid.engine, data, pathname) + image_path = pathname.parent / f"{pathname.name}.svg" + if image_path.exists(): + await GalleryReporter().async_report(image_path, "path") + + async def _execute_api( + self, + user_requirement: str = "", + prd_filename: str = "", + legacy_design_filename: str = "", + extra_info: str = "", + output_pathname: str = "", + ) -> str: + prd_content = "" + if prd_filename: + prd_filename = rectify_pathname(path=prd_filename, default_filename="prd.json") + prd_content = await aread(filename=prd_filename) + context = "### User Requirements\n{user_requirement}\n### Extra_info\n{extra_info}\n### PRD\n{prd}\n".format( + user_requirement=to_markdown_code_block(user_requirement), + extra_info=to_markdown_code_block(extra_info), + prd=to_markdown_code_block(prd_content), + ) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "design"}, "meta") + if not legacy_design_filename: + node = await self._new_system_design(context=context) + design = Document(content=node.instruct_content.model_dump_json()) + else: + old_design_content = await aread(filename=legacy_design_filename) + design = await self._merge( + prd_doc=Document(content=context), system_design_doc=Document(content=old_design_content) + ) + + if not output_pathname: + output_pathname = Path(output_pathname) / "docs" / "system_design.json" + elif not Path(output_pathname).is_absolute(): + output_pathname = self.config.workspace.path / output_pathname + output_pathname = rectify_pathname(path=output_pathname, default_filename="system_design.json") + await awrite(filename=output_pathname, data=design.content) + output_filename = output_pathname.parent / f"{output_pathname.stem}-class-diagram" + await self._save_data_api_design(design_doc=design, output_filename=output_filename) + output_filename = output_pathname.parent / f"{output_pathname.stem}-sequence-diagram" + await self._save_seq_flow(design_doc=design, output_filename=output_filename) + md_output_filename = output_pathname.with_suffix(".md") + await save_json_to_markdown(content=design.content, output_filename=md_output_filename) + await reporter.async_report(md_output_filename, "path") + return f'System Design filename: "{str(output_pathname)}". \n The System Design has been completed.' diff --git a/metagpt/actions/design_api_an.py b/metagpt/actions/design_api_an.py index 5977cbd958..0de17f32c1 100644 --- a/metagpt/actions/design_api_an.py +++ b/metagpt/actions/design_api_an.py @@ -13,7 +13,7 @@ IMPLEMENTATION_APPROACH = ActionNode( key="Implementation approach", expected_type=str, - instruction="Analyze the difficult points of the requirements, select the appropriate open-source framework", + instruction="Analyze the difficult points of the requirements, select the appropriate open-source framework.", example="We will ...", ) @@ -33,8 +33,8 @@ FILE_LIST = ActionNode( key="File list", expected_type=List[str], - instruction="Only need relative paths. ALWAYS write a main.py or app.py here", - example=["main.py", "game.py"], + instruction="Only need relative paths. Succinctly designate the correct entry file for your project based on the programming language: use main.js for JavaScript, main.py for Python, and so on for other languages.", + example=["a.js", "b.py", "c.css", "d.html"], ) REFINED_FILE_LIST = ActionNode( diff --git a/metagpt/actions/di/ask_review.py b/metagpt/actions/di/ask_review.py index 041011e80d..ecbbd992ea 100644 --- a/metagpt/actions/di/ask_review.py +++ b/metagpt/actions/di/ask_review.py @@ -3,7 +3,7 @@ from typing import Tuple from metagpt.actions import Action -from metagpt.logs import logger +from metagpt.logs import get_human_input, logger from metagpt.schema import Message, Plan @@ -50,7 +50,7 @@ async def run( "Please type your review below:\n" ) - rsp = input(prompt) + rsp = await get_human_input(prompt) if rsp.lower() in ReviewConst.EXIT_WORDS: exit() diff --git a/metagpt/actions/di/execute_nb_code.py b/metagpt/actions/di/execute_nb_code.py index 0cf16b70f6..01019b4931 100644 --- a/metagpt/actions/di/execute_nb_code.py +++ b/metagpt/actions/di/execute_nb_code.py @@ -13,9 +13,10 @@ import nbformat from nbclient import NotebookClient -from nbclient.exceptions import CellTimeoutError, DeadKernelError +from nbclient.exceptions import CellExecutionComplete, CellTimeoutError, DeadKernelError +from nbclient.util import ensure_async from nbformat import NotebookNode -from nbformat.v4 import new_code_cell, new_markdown_cell, new_output +from nbformat.v4 import new_code_cell, new_markdown_cell, new_output, output_from_msg from rich.box import MINIMAL from rich.console import Console, Group from rich.live import Live @@ -25,29 +26,79 @@ from metagpt.actions import Action from metagpt.logs import logger +from metagpt.utils.report import NotebookReporter + +INSTALL_KEEPLEN = 500 +INI_CODE = """import warnings +import logging + +root_logger = logging.getLogger() +root_logger.setLevel(logging.ERROR) +warnings.filterwarnings('ignore')""" + + +class RealtimeOutputNotebookClient(NotebookClient): + """Realtime output of Notebook execution.""" + + def __init__(self, *args, notebook_reporter=None, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.notebook_reporter = notebook_reporter or NotebookReporter() + + async def _async_poll_output_msg(self, parent_msg_id: str, cell: NotebookNode, cell_index: int) -> None: + """Implement a feature to enable sending messages.""" + assert self.kc is not None + while True: + msg = await ensure_async(self.kc.iopub_channel.get_msg(timeout=None)) + await self._send_msg(msg) + + if msg["parent_header"].get("msg_id") == parent_msg_id: + try: + # Will raise CellExecutionComplete when completed + self.process_message(msg, cell, cell_index) + except CellExecutionComplete: + return + + async def _send_msg(self, msg: dict): + msg_type = msg.get("header", {}).get("msg_type") + if msg_type not in ["stream", "error", "execute_result"]: + return + + await self.notebook_reporter.async_report(output_from_msg(msg), "content") class ExecuteNbCode(Action): """execute notebook code block, return result to llm, and display it.""" nb: NotebookNode - nb_client: NotebookClient + nb_client: RealtimeOutputNotebookClient = None console: Console interaction: str timeout: int = 600 - def __init__( - self, - nb=nbformat.v4.new_notebook(), - timeout=600, - ): + def __init__(self, nb=nbformat.v4.new_notebook(), timeout=600): super().__init__( nb=nb, - nb_client=NotebookClient(nb, timeout=timeout), timeout=timeout, console=Console(), interaction=("ipython" if self.is_ipython() else "terminal"), ) + self.reporter = NotebookReporter() + self.set_nb_client() + self.init_called = False + + async def init_code(self): + if not self.init_called: + await self.run(INI_CODE) + self.init_called = True + + def set_nb_client(self): + self.nb_client = RealtimeOutputNotebookClient( + self.nb, + timeout=self.timeout, + resources={"metadata": {"path": self.config.workspace.path}}, + notebook_reporter=self.reporter, + coalesce_streams=True, + ) async def build(self): if self.nb_client.kc is None or not await self.nb_client.kc.is_alive(): @@ -82,7 +133,7 @@ async def reset(self): # sleep 1s to wait for the kernel to be cleaned up completely await asyncio.sleep(1) await self.build() - self.nb_client = NotebookClient(self.nb, timeout=self.timeout) + self.set_nb_client() def add_code_cell(self, code: str): self.nb.cells.append(new_code_cell(source=code)) @@ -106,7 +157,7 @@ def add_output_to_cell(self, cell: NotebookNode, output: str): else: cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output))) - def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, str]: + def parse_outputs(self, outputs: list[str], keep_len: int = 5000) -> Tuple[bool, str]: """Parses the outputs received from notebook execution.""" assert isinstance(outputs, list) parsed_output, is_success = [], True @@ -135,9 +186,12 @@ def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, is_success = False output_text = remove_escape_and_color_codes(output_text) + if is_success: + output_text = remove_log_and_warning_lines(output_text) # The useful information of the exception is at the end, # the useful information of normal output is at the begining. - output_text = output_text[:keep_len] if is_success else output_text[-keep_len:] + if "" not in output_text: + output_text = output_text[:keep_len] if is_success else output_text[-keep_len:] parsed_output.append(output_text) return is_success, ",".join(parsed_output) @@ -172,6 +226,8 @@ async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str """set timeout for run code. returns the success or failure of the cell execution, and an optional error message. """ + await self.reporter.async_report(cell, "content") + try: await self.nb_client.async_execute_cell(cell, cell_index) return self.parse_outputs(self.nb.cells[-1].outputs) @@ -193,29 +249,45 @@ async def run(self, code: str, language: Literal["python", "markdown"] = "python """ self._display(code, language) - if language == "python": - # add code to the notebook - self.add_code_cell(code=code) - - # build code executor - await self.build() - - # run code - cell_index = len(self.nb.cells) - 1 - success, outputs = await self.run_cell(self.nb.cells[-1], cell_index) + async with self.reporter: + if language == "python": + # add code to the notebook + self.add_code_cell(code=code) + + # build code executor + await self.build() + + # run code + cell_index = len(self.nb.cells) - 1 + success, outputs = await self.run_cell(self.nb.cells[-1], cell_index) + + if "!pip" in code: + success = False + outputs = outputs[-INSTALL_KEEPLEN:] + elif "git clone" in code: + outputs = outputs[:INSTALL_KEEPLEN] + "..." + outputs[-INSTALL_KEEPLEN:] + + elif language == "markdown": + # add markdown content to markdown cell in a notebook. + self.add_markdown_cell(code) + # return True, beacuse there is no execution failure for markdown cell. + outputs, success = code, True + else: + raise ValueError(f"Only support for language: python, markdown, but got {language}, ") - if "!pip" in code: - success = False + file_path = self.config.workspace.path / "code.ipynb" + nbformat.write(self.nb, file_path) + await self.reporter.async_report(file_path, "path") return outputs, success - elif language == "markdown": - # add markdown content to markdown cell in a notebook. - self.add_markdown_cell(code) - # return True, beacuse there is no execution failure for markdown cell. - return code, True - else: - raise ValueError(f"Only support for language: python, markdown, but got {language}, ") + +def remove_log_and_warning_lines(input_str: str) -> str: + delete_lines = ["[warning]", "warning:", "[cv]", "[info]"] + result = "\n".join( + [line for line in input_str.split("\n") if not any(dl in line.lower() for dl in delete_lines)] + ).strip() + return result def remove_escape_and_color_codes(input_str: str): diff --git a/metagpt/actions/di/run_command.py b/metagpt/actions/di/run_command.py new file mode 100644 index 0000000000..510bb5d920 --- /dev/null +++ b/metagpt/actions/di/run_command.py @@ -0,0 +1,5 @@ +from metagpt.actions import Action + + +class RunCommand(Action): + """A dummy RunCommand action used as a symbol only""" diff --git a/metagpt/actions/di/write_analysis_code.py b/metagpt/actions/di/write_analysis_code.py index 711e56d39b..4d21b2cec3 100644 --- a/metagpt/actions/di/write_analysis_code.py +++ b/metagpt/actions/di/write_analysis_code.py @@ -30,7 +30,9 @@ async def _debug_with_reflection(self, context: list[Message], working_memory: l ) rsp = await self._aask(reflection_prompt, system_msgs=[REFLECTION_SYSTEM_MSG]) - reflection = json.loads(CodeParser.parse_code(block=None, text=rsp)) + reflection = json.loads(CodeParser.parse_code(text=rsp)) + if "```python" in reflection["improved_impl"]: + reflection["improved_impl"] = CodeParser.parse_code(text=reflection["improved_impl"], lang="python") return reflection["improved_impl"] @@ -41,6 +43,7 @@ async def run( tool_info: str = "", working_memory: list[Message] = None, use_reflection: bool = False, + memory: list[Message] = None, **kwargs, ) -> str: structual_prompt = STRUCTUAL_PROMPT.format( @@ -50,14 +53,15 @@ async def run( ) working_memory = working_memory or [] - context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory) + memory = memory or [] + context = self.llm.format_msg(memory + [Message(content=structual_prompt, role="user")] + working_memory) # LLM call if use_reflection: code = await self._debug_with_reflection(context=context, working_memory=working_memory) else: rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs) - code = CodeParser.parse_code(block=None, text=rsp) + code = CodeParser.parse_code(text=rsp, lang="python") return code @@ -69,5 +73,5 @@ async def run(self, plan: Plan) -> dict: code_written = "\n\n".join(code_written) prompt = CHECK_DATA_PROMPT.format(code_written=code_written) rsp = await self._aask(prompt) - code = CodeParser.parse_code(block=None, text=rsp) + code = CodeParser.parse_code(text=rsp) return code diff --git a/metagpt/actions/di/write_plan.py b/metagpt/actions/di/write_plan.py index 2dbe3f0e7f..efea9f526f 100644 --- a/metagpt/actions/di/write_plan.py +++ b/metagpt/actions/di/write_plan.py @@ -16,38 +16,38 @@ from metagpt.strategy.task_type import TaskType from metagpt.utils.common import CodeParser +PROMPT_TEMPLATE: str = """ +# Context: +{context} +# Available Task Types: +{task_type_desc} +# Task: +Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to {max_tasks} tasks. +If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan. +If you encounter errors on the current task, revise and output the current single task only. +Output a list of jsons following the format: +```json +[ + {{ + "task_id": str = "unique identifier for a task in plan, can be an ordinal", + "dependent_task_ids": list[str] = "ids of tasks prerequisite to this task", + "instruction": "what you should do in this task, one short phrase or sentence.", + "task_type": "type of this task, should be one of Available Task Types.", + }}, + ... +] +``` +""" -class WritePlan(Action): - PROMPT_TEMPLATE: str = """ - # Context: - {context} - # Available Task Types: - {task_type_desc} - # Task: - Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to {max_tasks} tasks. - If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan. - If you encounter errors on the current task, revise and output the current single task only. - Output a list of jsons following the format: - ```json - [ - {{ - "task_id": str = "unique identifier for a task in plan, can be an ordinal", - "dependent_task_ids": list[str] = "ids of tasks prerequisite to this task", - "instruction": "what you should do in this task, one short phrase or sentence", - "task_type": "type of this task, should be one of Available Task Types", - }}, - ... - ] - ``` - """ +class WritePlan(Action): async def run(self, context: list[Message], max_tasks: int = 5) -> str: task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]) - prompt = self.PROMPT_TEMPLATE.format( + prompt = PROMPT_TEMPLATE.format( context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc ) rsp = await self._aask(prompt) - rsp = CodeParser.parse_code(block=None, text=rsp) + rsp = CodeParser.parse_code(text=rsp) return rsp diff --git a/metagpt/actions/extract_readme.py b/metagpt/actions/extract_readme.py new file mode 100644 index 0000000000..69f5503a9a --- /dev/null +++ b/metagpt/actions/extract_readme.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Module Description: This script defines the LearnReadMe class, which is an action to learn from the contents of + a README.md file. +Author: mashenquan +Date: 2024-3-20 +""" +from pathlib import Path +from typing import Optional + +from pydantic import Field + +from metagpt.actions import Action +from metagpt.const import GRAPH_REPO_FILE_REPO +from metagpt.schema import Message +from metagpt.utils.common import aread +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository + + +class ExtractReadMe(Action): + """ + An action to extract summary, installation, configuration, usages from the contents of a README.md file. + + Attributes: + graph_db (Optional[GraphRepository]): A graph database repository. + install_to_path (Optional[str]): The path where the repository to install to. + """ + + graph_db: Optional[GraphRepository] = None + install_to_path: Optional[str] = Field(default="/TO/PATH") + _readme: Optional[str] = None + _filename: Optional[str] = None + + async def run(self, with_messages=None, **kwargs): + """ + Implementation of `Action`'s `run` method. + + Args: + with_messages (Optional[Type]): An optional argument specifying messages to react to. + """ + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + summary = await self._summarize() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_SUMMARY, object_=summary) + install = await self._extract_install() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_INSTALL, object_=install) + conf = await self._extract_configuration() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_CONFIG, object_=conf) + usage = await self._extract_usage() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_USAGE, object_=usage) + + await self.graph_db.save() + + return Message(content="", cause_by=self) + + async def _summarize(self) -> str: + readme = await self._get() + summary = await self.llm.aask( + readme, + system_msgs=[ + "You are a tool can summarize git repository README.md file.", + "Return the summary about what is the repository.", + ], + stream=False, + ) + return summary + + async def _extract_install(self) -> str: + await self._get() + install = await self.llm.aask( + self._readme, + system_msgs=[ + "You are a tool can install git repository according to README.md file.", + "Return a bash code block of markdown including:\n" + f"1. git clone the repository to the directory `{self.install_to_path}`;\n" + f"2. cd `{self.install_to_path}`;\n" + f"3. install the repository.", + ], + stream=False, + ) + return install + + async def _extract_configuration(self) -> str: + await self._get() + configuration = await self.llm.aask( + self._readme, + system_msgs=[ + "You are a tool can configure git repository according to README.md file.", + "Return a bash code block of markdown object to configure the repository if necessary, otherwise return" + " a empty bash code block of markdown object", + ], + stream=False, + ) + return configuration + + async def _extract_usage(self) -> str: + await self._get() + usage = await self.llm.aask( + self._readme, + system_msgs=[ + "You are a tool can summarize all usages of git repository according to README.md file.", + "Return a list of code block of markdown objects to demonstrates the usage of the repository.", + ], + stream=False, + ) + return usage + + async def _get(self) -> str: + if self._readme is not None: + return self._readme + root = Path(self.i_context).resolve() + filename = None + for file_path in root.iterdir(): + if file_path.is_file() and file_path.stem == "README": + filename = file_path + break + if not filename: + return "" + self._readme = await aread(filename=filename, encoding="utf-8") + self._filename = str(filename) + return self._readme diff --git a/metagpt/actions/generate_questions.py b/metagpt/actions/generate_questions.py index c96a376493..bf0ba62773 100644 --- a/metagpt/actions/generate_questions.py +++ b/metagpt/actions/generate_questions.py @@ -22,4 +22,4 @@ class GenerateQuestions(Action): name: str = "GenerateQuestions" async def run(self, context) -> ActionNode: - return await QUESTIONS.fill(context=context, llm=self.llm) + return await QUESTIONS.fill(req=context, llm=self.llm) diff --git a/metagpt/actions/import_repo.py b/metagpt/actions/import_repo.py new file mode 100644 index 0000000000..82aa916f46 --- /dev/null +++ b/metagpt/actions/import_repo.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" + +This script defines an action to import a Git repository into the MetaGPT project format, enabling incremental + appending of requirements. +The MetaGPT project format encompasses a structured representation of project data compatible with MetaGPT's + capabilities, facilitating the integration of Git repositories into MetaGPT workflows while allowing for the gradual + addition of requirements. + +""" +import json +import re +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel + +from metagpt.actions import Action +from metagpt.actions.extract_readme import ExtractReadMe +from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.actions.rebuild_sequence_view import RebuildSequenceView +from metagpt.const import GRAPH_REPO_FILE_REPO +from metagpt.logs import logger +from metagpt.schema import Message +from metagpt.tools.libs.git import git_clone +from metagpt.utils.common import ( + aread, + awrite, + list_files, + parse_json_code_block, + split_namespace, +) +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.file_repository import FileRepository +from metagpt.utils.git_repository import GitRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository +from metagpt.utils.project_repo import ProjectRepo + + +class ImportRepo(Action): + """ + An action to import a Git repository into a graph database and create related artifacts. + + Attributes: + repo_path (str): The URL of the Git repository to import. + graph_db (Optional[GraphRepository]): The output graph database of the Git repository. + rid (str): The output requirement ID. + """ + + repo_path: str # input, git repo url. + graph_db: Optional[GraphRepository] = None # output. graph db of the git repository + rid: str = "" # output, requirement ID. + + async def run(self, with_messages: List[Message] = None, **kwargs) -> Message: + """ + Runs the import process for the Git repository. + + Args: + with_messages (List[Message], optional): Additional messages to include. + **kwargs: Additional keyword arguments. + + Returns: + Message: A message indicating the completion of the import process. + """ + await self._create_repo() + await self._create_prd() + await self._create_system_design() + self.context.git_repo.archive(comments="Import") + + async def _create_repo(self): + path = await git_clone(url=self.repo_path, output_dir=self.config.workspace.path) + self.repo_path = str(path) + self.config.project_path = path + self.context.git_repo = GitRepository(local_path=path, auto_init=True) + self.context.repo = ProjectRepo(self.context.git_repo) + self.context.src_workspace = await self._guess_src_workspace() + await awrite( + filename=self.context.repo.workdir / ".src_workspace", + data=str(self.context.src_workspace.relative_to(self.context.repo.workdir)), + ) + + async def _create_prd(self): + action = ExtractReadMe(i_context=str(self.context.repo.workdir), context=self.context) + await action.run() + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SUMMARY) + prd = {"Project Name": self.context.repo.workdir.name} + for r in rows: + if Path(r.subject).stem == "README": + prd["Original Requirements"] = r.object_ + break + self.rid = FileRepository.new_filename() + await self.repo.docs.prd.save(filename=self.rid + ".json", content=json.dumps(prd)) + + async def _create_system_design(self): + action = RebuildClassView( + name="ReverseEngineering", i_context=str(self.context.src_workspace), context=self.context + ) + await action.run() + rows = await action.graph_db.select(predicate="hasMermaidClassDiagramFile") + class_view_filename = rows[0].object_ + logger.info(f"class view:{class_view_filename}") + + rows = await action.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO) + tag = "__name__:__main__" + entries = [] + src_workspace = self.context.src_workspace.relative_to(self.context.repo.workdir) + for r in rows: + if tag in r.subject: + path = split_namespace(r.subject)[0] + elif tag in r.object_: + path = split_namespace(r.object_)[0] + else: + continue + if Path(path).is_relative_to(src_workspace): + entries.append(Path(path)) + main_entry = await self._guess_main_entry(entries) + full_path = RebuildSequenceView.get_full_filename(self.context.repo.workdir, main_entry) + action = RebuildSequenceView(context=self.context, i_context=str(full_path)) + try: + await action.run() + except Exception as e: + logger.warning(f"{e}, use the last successful version.") + files = list_files(self.context.repo.resources.data_api_design.workdir) + pattern = re.compile(r"[^a-zA-Z0-9]") + name = re.sub(pattern, "_", str(main_entry)) + filename = Path(name).with_suffix(".sequence_diagram.mmd") + postfix = str(filename) + sequence_files = [i for i in files if postfix in str(i)] + content = await aread(filename=sequence_files[0]) + await self.context.repo.resources.data_api_design.save( + filename=self.repo.workdir.stem + ".sequence_diagram.mmd", content=content + ) + await self._save_system_design() + + async def _save_system_design(self): + class_view = await self.context.repo.resources.data_api_design.get( + filename=self.repo.workdir.stem + ".class_diagram.mmd" + ) + sequence_view = await self.context.repo.resources.data_api_design.get( + filename=self.repo.workdir.stem + ".sequence_diagram.mmd" + ) + file_list = self.context.git_repo.get_files(relative_path=".", root_relative_path=self.context.src_workspace) + data = { + "Data structures and interfaces": class_view.content, + "Program call flow": sequence_view.content, + "File list": [str(i) for i in file_list], + } + await self.context.repo.docs.system_design.save(filename=self.rid + ".json", content=json.dumps(data)) + + async def _guess_src_workspace(self) -> Path: + files = list_files(self.context.repo.workdir) + dirs = [i.parent for i in files if i.name == "__init__.py"] + distinct = set() + for i in dirs: + done = False + for j in distinct: + if i.is_relative_to(j): + done = True + break + if j.is_relative_to(i): + break + if not done: + distinct = {j for j in distinct if not j.is_relative_to(i)} + distinct.add(i) + if len(distinct) == 1: + return list(distinct)[0] + prompt = "\n".join([f"- {str(i)}" for i in distinct]) + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a tool to choose the source code path from a list of paths based on the directory name.", + "You should identify the source code path among paths such as unit test path, examples path, etc.", + "Return a markdown JSON object containing:\n" + '- a "src" field containing the source code path;\n' + '- a "reason" field containing explaining why other paths is not the source code path\n', + ], + ) + logger.debug(rsp) + json_blocks = parse_json_code_block(rsp) + + class Data(BaseModel): + src: str + reason: str + + data = Data.model_validate_json(json_blocks[0]) + logger.info(f"src_workspace: {data.src}") + return Path(data.src) + + async def _guess_main_entry(self, entries: List[Path]) -> Path: + if len(entries) == 1: + return entries[0] + + file_list = "## File List\n" + file_list += "\n".join([f"- {i}" for i in entries]) + + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_USAGE) + usage = "## Usage\n" + for r in rows: + if Path(r.subject).stem == "README": + usage += r.object_ + + prompt = file_list + "\n---\n" + usage + rsp = await self.llm.aask( + prompt, + system_msgs=[ + 'You are a tool to choose the source file path from "File List" which is used in "Usage".', + 'You choose the source file path based on the name of file and the class name and package name used in "Usage".', + "Return a markdown JSON object containing:\n" + '- a "filename" field containing the chosen source file path from "File List" which is used in "Usage";\n' + '- a "reason" field explaining why.', + ], + stream=False, + ) + logger.debug(rsp) + json_blocks = parse_json_code_block(rsp) + + class Data(BaseModel): + filename: str + reason: str + + data = Data.model_validate_json(json_blocks[0]) + logger.info(f"main: {data.filename}") + return Path(data.filename) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index ab069dc11a..393c483cc5 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -9,12 +9,14 @@ """ import shutil from pathlib import Path -from typing import Optional +from typing import Dict, Optional -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action, UserRequirement from metagpt.const import REQUIREMENT_FILENAME +from metagpt.logs import logger +from metagpt.schema import AIMessage +from metagpt.utils.common import any_to_str from metagpt.utils.file_repository import FileRepository -from metagpt.utils.git_repository import GitRepository from metagpt.utils.project_repo import ProjectRepo @@ -23,12 +25,19 @@ class PrepareDocuments(Action): name: str = "PrepareDocuments" i_context: Optional[str] = None + key_descriptions: Optional[Dict[str, str]] = None + send_to: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if not self.key_descriptions: + self.key_descriptions = {"project_path": 'the project path if exists in "Original Requirement"'} @property def config(self): return self.context.config - def _init_repo(self): + def _init_repo(self) -> ProjectRepo: """Initialize the Git environment.""" if not self.config.project_path: name = self.config.project_name or FileRepository.new_filename() @@ -37,16 +46,45 @@ def _init_repo(self): path = Path(self.config.project_path) if path.exists() and not self.config.inc: shutil.rmtree(path) - self.config.project_path = path - self.context.git_repo = GitRepository(local_path=path, auto_init=True) - self.context.repo = ProjectRepo(self.context.git_repo) + self.context.kwargs.project_path = path + self.context.kwargs.inc = self.config.inc + return ProjectRepo(path) async def run(self, with_messages, **kwargs): """Create and initialize the workspace folder, initialize the Git environment.""" - self._init_repo() + user_requirements = [i for i in with_messages if i.cause_by == any_to_str(UserRequirement)] + if not self.config.project_path and user_requirements and self.key_descriptions: + args = await user_requirements[0].parse_resources(llm=self.llm, key_descriptions=self.key_descriptions) + for k, v in args.items(): + if not v or k in ["resources", "reason"]: + continue + self.context.kwargs.set(k, v) + logger.info(f"{k}={v}") + if self.context.kwargs.project_path: + self.config.update_via_cli( + project_path=self.context.kwargs.project_path, + project_name="", + inc=False, + reqa_file=self.context.kwargs.reqa_file or "", + max_auto_summarize_code=0, + ) + + repo = self._init_repo() # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. - doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) # Send a Message notification to the WritePRD action, instructing it to process requirements using # `docs/requirement.txt` and `docs/prd/`. - return ActionOutput(content=doc.content, instruct_content=doc) + return AIMessage( + content="", + instruct_content=AIMessage.create_instruct_value( + kvs={ + "project_path": str(repo.workdir), + "requirements_filename": str(repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(repo.docs.prd.workdir / i) for i in repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ), + cause_by=self, + send_to=self.send_to, + ) diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py index 04cc954d28..0a7eb6581e 100644 --- a/metagpt/actions/prepare_interview.py +++ b/metagpt/actions/prepare_interview.py @@ -22,4 +22,4 @@ class PrepareInterview(Action): name: str = "PrepareInterview" async def run(self, context): - return await QUESTIONS.fill(context=context, llm=self.llm) + return await QUESTIONS.fill(req=context, llm=self.llm) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 67a614d6f6..1ce94cd993 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -8,17 +8,30 @@ 1. Divide the context into three components: legacy code, unit test code, and console log. 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json -from typing import Optional +from pathlib import Path +from typing import List, Optional, Union + +from pydantic import BaseModel, Field from metagpt.actions.action import Action -from metagpt.actions.action_output import ActionOutput from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME from metagpt.logs import logger -from metagpt.schema import Document, Documents +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import ( + aread, + awrite, + rectify_pathname, + save_json_to_markdown, + to_markdown_code_block, +) +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import DocsReporter NEW_REQ_TEMPLATE = """ ### Legacy Content @@ -29,19 +42,67 @@ """ +@register_tool(include_functions=["run"]) class WriteTasks(Action): name: str = "CreateTasks" i_context: Optional[str] = None + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + design_filename: str = "", + output_pathname: str = "", + **kwargs, + ) -> Union[AIMessage, str]: + """ + Write a project schedule given a project system design file. + + Args: + user_requirement (str, optional): A string specifying the user's requirements. Defaults to an empty string. + design_filename (str): The output file path of the document. Defaults to an empty string. + output_pathname (str, optional): The output path name of file that the project schedule should be saved to. + **kwargs: Additional keyword arguments. + + Returns: + str: Path to the generated project schedule. + + Example: + # Write a project schedule with a given system design. + >>> design_filename = "/absolute/path/to/snake_game/docs/system_design.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/project_schedule.json" + >>> user_requirement = "Write project schedule for a snake game following these requirements:..." + >>> action = WriteTasks() + >>> result = await action.run(user_requirement=user_requirement, design_filename=design_filename, output_pathname=output_pathname) + >>> print(result) + The project schedule is at /absolute/path/to/snake_game/docs/project_schedule.json + + # Write a project schedule with a user requirement. + >>> user_requirement = "Write project schedule for a snake game following these requirements: ..." + >>> output_pathname = "/absolute/path/to/snake_game/docs/project_schedule.json" + >>> action = WriteTasks() + >>> result = await action.run(user_requirement=user_requirement, output_pathname=output_pathname) + >>> print(result) + The project schedule is at /absolute/path/to/snake_game/docs/project_schedule.json + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, design_filename=design_filename, output_pathname=output_pathname + ) - async def run(self, with_messages): - changed_system_designs = self.repo.docs.system_design.changed_files - changed_tasks = self.repo.docs.task.changed_files + self.input_args = with_messages[-1].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_system_designs = self.input_args.changed_system_design_filenames + changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())] change_files = Documents() # Rewrite the system designs that have undergone changes based on the git head diff under # `docs/system_designs/`. for filename in changed_system_designs: task_doc = await self._update_tasks(filename=filename) - change_files.docs[filename] = task_doc + change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc # Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`. for filename in changed_tasks: @@ -54,31 +115,50 @@ async def run(self, with_messages): logger.info("Nothing has changed.") # Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for # global optimization in subsequent steps. - return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files) + kvs = self.input_args.model_dump() + kvs["changed_task_filenames"] = [ + str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys()) + ] + kvs["python_package_dependency_filename"] = str(self.repo.workdir / PACKAGE_REQUIREMENTS_FILENAME) + return AIMessage( + content="WBS is completed. " + + "\n".join( + [PACKAGE_REQUIREMENTS_FILENAME] + + list(self.repo.docs.task.changed_files.keys()) + + list(self.repo.resources.api_spec_and_task.changed_files.keys()) + ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTaskOutput"), + cause_by=self, + ) async def _update_tasks(self, filename): - system_design_doc = await self.repo.docs.system_design.get(filename) - task_doc = await self.repo.docs.task.get(filename) - if task_doc: - task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc) - await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path}) - else: - rsp = await self._run_new_tasks(context=system_design_doc.content) - task_doc = await self.repo.docs.task.save( - filename=filename, - content=rsp.instruct_content.model_dump_json(), - dependencies={system_design_doc.root_relative_path}, - ) - await self._update_requirements(task_doc) + root_relative_path = Path(filename).relative_to(self.repo.workdir) + system_design_doc = await Document.load(filename=filename, project_path=self.repo.workdir) + task_doc = await self.repo.docs.task.get(root_relative_path.name) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "task"}, "meta") + if task_doc: + task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc) + await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path}) + else: + rsp = await self._run_new_tasks(context=system_design_doc.content) + task_doc = await self.repo.docs.task.save( + filename=system_design_doc.filename, + content=rsp.instruct_content.model_dump_json(), + dependencies={system_design_doc.root_relative_path}, + ) + await self._update_requirements(task_doc) + md = await self.repo.resources.api_spec_and_task.save_pdf(doc=task_doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") return task_doc - async def _run_new_tasks(self, context): - node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema) + async def _run_new_tasks(self, context: str): + node = await PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) return node async def _merge(self, system_design_doc, task_doc) -> Document: context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content) - node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema) + node = await REFINED_PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) task_doc.content = node.instruct_content.model_dump_json() return task_doc @@ -94,3 +174,28 @@ async def _update_requirements(self, doc): continue packages.add(pkg) await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages)) + + async def _execute_api( + self, user_requirement: str = "", design_filename: str = "", output_pathname: str = "" + ) -> str: + context = to_markdown_code_block(user_requirement) + if design_filename: + design_filename = rectify_pathname(path=design_filename, default_filename="system_design.md") + content = await aread(filename=design_filename) + context += to_markdown_code_block(content) + + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "task"}, "meta") + node = await self._run_new_tasks(context) + file_content = node.instruct_content.model_dump_json() + + if not output_pathname: + output_pathname = Path(output_pathname) / "docs" / "project_schedule.json" + elif not Path(output_pathname).is_absolute(): + output_pathname = self.config.workspace.path / output_pathname + output_pathname = rectify_pathname(path=output_pathname, default_filename="project_schedule.json") + await awrite(filename=output_pathname, data=file_content) + md_output_filename = output_pathname.with_suffix(".md") + await save_json_to_markdown(content=file_content, output_filename=md_output_filename) + await reporter.async_report(md_output_filename, "path") + return f'Project Schedule filename: "{str(output_pathname)}"' diff --git a/metagpt/actions/project_management_an.py b/metagpt/actions/project_management_an.py index 0417c0ce4a..7131a6c99b 100644 --- a/metagpt/actions/project_management_an.py +++ b/metagpt/actions/project_management_an.py @@ -10,9 +10,9 @@ from metagpt.actions.action_node import ActionNode REQUIRED_PYTHON_PACKAGES = ActionNode( - key="Required Python packages", + key="Required packages", expected_type=List[str], - instruction="Provide required Python packages in requirements.txt format.", + instruction="Provide required packages The response language should correspond to the context and requirements.", example=["flask==1.1.2", "bcrypt==3.2.0"], ) @@ -27,7 +27,9 @@ key="Logic Analysis", expected_type=List[List[str]], instruction="Provide a list of files with the classes/methods/functions to be implemented, " - "including dependency analysis and imports.", + "including dependency analysis and imports." + "Ensure consistency between System Design and Logic Analysis; the files must match exactly. " + "If the file is written in Vue or React, use Tailwind CSS for styling.", example=[ ["game.py", "Contains Game class and ... functions"], ["main.py", "Contains main function, from game import Game"], diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index ff030ec878..64f003f919 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -14,7 +14,6 @@ import aiofiles from metagpt.actions import Action -from metagpt.config2 import config from metagpt.const import ( AGGREGATION, COMPOSITION, @@ -40,7 +39,7 @@ class RebuildClassView(Action): graph_db: Optional[GraphRepository] = None - async def run(self, with_messages=None, format=config.prompt_schema): + async def run(self, with_messages=None, format=None): """ Implementation of `Action`'s `run` method. @@ -48,6 +47,7 @@ async def run(self, with_messages=None, format=config.prompt_schema): with_messages (Optional[Type]): An optional argument specifying messages to react to. format (str): The format for the prompt schema. """ + format = format if format else self.config.prompt_schema graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) repo_parser = RepoParser(base_directory=Path(self.i_context)) diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py index 0e67de9086..e23487511b 100644 --- a/metagpt/actions/rebuild_sequence_view.py +++ b/metagpt/actions/rebuild_sequence_view.py @@ -18,7 +18,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import Action -from metagpt.config2 import config from metagpt.const import GRAPH_REPO_FILE_REPO from metagpt.logs import logger from metagpt.repo_parser import CodeBlockInfo, DotClassInfo @@ -84,7 +83,7 @@ class RebuildSequenceView(Action): graph_db: Optional[GraphRepository] = None - async def run(self, with_messages=None, format=config.prompt_schema): + async def run(self, with_messages=None, format=None): """ Implementation of `Action`'s `run` method. @@ -92,6 +91,7 @@ async def run(self, with_messages=None, format=config.prompt_schema): with_messages (Optional[Type]): An optional argument specifying messages to react to. format (str): The format for the prompt schema. """ + format = format if format else self.config.prompt_schema graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) if not self.i_context: @@ -244,15 +244,6 @@ async def _rebuild_use_case(self, ns_class_name: str): class_view = await self._get_uml_class_view(ns_class_name) source_code = await self._get_source_code(ns_class_name) - # prompt_blocks = [ - # "## Instruction\n" - # "You are a python code to UML 2.0 Use Case translator.\n" - # 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n' - # "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not " - # 'conflict with the information in "Mermaid Class Views".\n' - # 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external ' - # "system interactions with the internal system.\n" - # ] prompt_blocks = [] block = "## Participants\n" for p in participants: @@ -340,6 +331,7 @@ async def _rebuild_sequence_view(self, ns_class_name: str): system_msgs=[ "You are a Mermaid Sequence Diagram translator in function detail.", "Translate the markdown text to a Mermaid Sequence Diagram.", + "Response must be concise.", "Return a markdown mermaid code block.", ], stream=False, @@ -440,7 +432,7 @@ async def _get_source_code(self, ns_class_name: str) -> str: rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO) filename = split_namespace(ns_class_name=ns_class_name)[0] if not rows: - src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename) + src_filename = RebuildSequenceView.get_full_filename(root=self.i_context, pathname=filename) if not src_filename: return "" return await aread(filename=src_filename, encoding="utf-8") @@ -450,7 +442,7 @@ async def _get_source_code(self, ns_class_name: str) -> str: ) @staticmethod - def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: + def get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: """ Convert package name to the full path of the module. @@ -466,7 +458,7 @@ def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: "metagpt/management/skill_manager.py", then the returned value will be "/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py" """ - if re.match(r"^/.+", pathname): + if re.match(r"^/.+", str(pathname)): return pathname files = list_files(root=root) postfix = "/" + str(pathname) diff --git a/metagpt/actions/requirement_analysis/__init__.py b/metagpt/actions/requirement_analysis/__init__.py new file mode 100644 index 0000000000..d196bafeeb --- /dev/null +++ b/metagpt/actions/requirement_analysis/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : __init__.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from metagpt.actions.requirement_analysis.evaluate_action import EvaluationData, EvaluateAction + +__all__ = [EvaluationData, EvaluateAction] diff --git a/metagpt/actions/requirement_analysis/evaluate_action.py b/metagpt/actions/requirement_analysis/evaluate_action.py new file mode 100644 index 0000000000..376c73f2c9 --- /dev/null +++ b/metagpt/actions/requirement_analysis/evaluate_action.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : evaluate_action.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from typing import Optional + +from pydantic import BaseModel +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.utils.common import CodeParser, general_after_log, to_markdown_code_block + + +class EvaluationData(BaseModel): + """Model to represent evaluation data. + + Attributes: + is_pass (bool): Indicates if the evaluation passed or failed. + conclusion (Optional[str]): Conclusion or remarks about the evaluation. + """ + + is_pass: bool + conclusion: Optional[str] = None + + +class EvaluateAction(Action): + """The base class for an evaluation action. + + This class provides methods to evaluate prompts using a specified language model. + """ + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _evaluate(self, prompt: str) -> (bool, str): + """Evaluates a given prompt. + + Args: + prompt (str): The prompt to be evaluated. + + Returns: + tuple: A tuple containing: + - bool: Indicates if the evaluation passed. + - str: The JSON string containing the evaluation data. + """ + rsp = await self.llm.aask(prompt) + json_data = CodeParser.parse_code(text=rsp, lang="json") + data = EvaluationData.model_validate_json(json_data) + return data.is_pass, to_markdown_code_block(val=json_data, type_="json") + + async def _vote(self, prompt: str) -> EvaluationData: + """Evaluates a prompt multiple times and returns the consensus. + + Args: + prompt (str): The prompt to be evaluated. + + Returns: + EvaluationData: An object containing the evaluation result and a summary of evaluations. + """ + evaluations = {} + for i in range(3): + vote, evaluation = await self._evaluate(prompt) + val = evaluations.get(vote, []) + val.append(evaluation) + if len(val) > 1: + return EvaluationData(is_pass=vote, conclusion="\n".join(val)) + evaluations[vote] = val diff --git a/metagpt/actions/requirement_analysis/framework/__init__.py b/metagpt/actions/requirement_analysis/framework/__init__.py new file mode 100644 index 0000000000..5e06530887 --- /dev/null +++ b/metagpt/actions/requirement_analysis/framework/__init__.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : __init__.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +import json +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional, Union, List + +from pydantic import BaseModel + +from metagpt.actions.requirement_analysis.framework.evaluate_framework import EvaluateFramework +from metagpt.actions.requirement_analysis.framework.write_framework import WriteFramework +from metagpt.config2 import Config +from metagpt.utils.common import awrite + + +async def save_framework( + dir_data: str, trd: Optional[str] = None, output_dir: Optional[Union[str, Path]] = None +) -> List[str]: + """ + Saves framework data to files based on input JSON data and optionally saves a TRD (technical requirements document). + + Args: + dir_data (str): JSON data in string format enclosed in triple backticks ("```json" "...data..." "```"). + trd (str, optional): Technical requirements document content to be saved. Defaults to None. + output_dir (Union[str, Path], optional): Output directory path where files will be saved. If not provided, + a default directory is created based on the current timestamp and a random UUID suffix. + + Returns: + List[str]: List of file paths where data was saved. + + Raises: + Any exceptions raised during file writing operations. + + Notes: + - JSON data should be provided in the format "```json ...data... ```". + - The function ensures that paths and filenames are correctly formatted and creates necessary directories. + + Example: + ```python + dir_data = "```json\n[{\"path\": \"/folder\", \"filename\": \"file1.txt\", \"content\": \"Some content\"}]\n```" + trd = "Technical requirements document content." + output_dir = '/path/to/output/dir' + saved_files = await save_framework(dir_data, trd, output_dir) + print(saved_files) + ``` + """ + output_dir = ( + Path(output_dir) + if output_dir + else Config.default().workspace.path / (datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:8]) + ) + output_dir.mkdir(parents=True, exist_ok=True) + + json_data = dir_data.removeprefix("```json").removesuffix("```") + items = json.loads(json_data) + + class Data(BaseModel): + path: str + filename: str + content: str + + if trd: + pathname = output_dir / "TRD.md" + await awrite(filename=pathname, data=trd) + + files = [] + for i in items: + v = Data.model_validate(i) + if v.path and v.path[0] == "/": + v.path = "." + v.path + pathname = output_dir / v.path + pathname.mkdir(parents=True, exist_ok=True) + pathname = pathname / v.filename + await awrite(filename=pathname, data=v.content) + files.append(str(pathname)) + return files + + +__all__ = [WriteFramework, EvaluateFramework] diff --git a/metagpt/actions/requirement_analysis/framework/evaluate_framework.py b/metagpt/actions/requirement_analysis/framework/evaluate_framework.py new file mode 100644 index 0000000000..2f92396583 --- /dev/null +++ b/metagpt/actions/requirement_analysis/framework/evaluate_framework.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : evaluate_framework.py +@Desc : The implementation of Chapter 2.1.8 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" + +from metagpt.actions.requirement_analysis import EvaluateAction, EvaluationData +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class EvaluateFramework(EvaluateAction): + """WriteFramework deal with the following situations: + 1. Given a TRD and the software framework based on the TRD, evaluate the quality of the software framework. + """ + + async def run( + self, + *, + use_case_actors: str, + trd: str, + acknowledge: str, + legacy_output: str, + additional_technical_requirements: str, + ) -> EvaluationData: + """ + Run the evaluation of the software framework based on the provided TRD and related parameters. + + Args: + use_case_actors (str): A description of the actors involved in the use case. + trd (str): The Technical Requirements Document (TRD) that outlines the requirements for the software framework. + acknowledge (str): External acknowledgments or acknowledgments information related to the framework. + legacy_output (str): The previous versions of software framework returned by `WriteFramework`. + additional_technical_requirements (str): Additional technical requirements that need to be considered during evaluation. + + Returns: + EvaluationData: An object containing the results of the evaluation. + + Example: + >>> evaluate_framework = EvaluateFramework() + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> trd = "## TRD\\n..." + >>> acknowledge = "## Interfaces\\n..." + >>> framework = '{"path":"balabala", "filename":"...", ...' + >>> constraint = "Using Java language, ..." + >>> evaluation = await evaluate_framework.run( + >>> use_case_actors=use_case_actors, + >>> trd=trd, + >>> acknowledge=acknowledge, + >>> legacy_output=framework, + >>> additional_technical_requirements=constraint, + >>> ) + >>> is_pass = evaluation.is_pass + >>> print(is_pass) + True + >>> evaluation_conclusion = evaluation.conclusion + >>> print(evaluation_conclusion) + Balabala... + """ + prompt = PROMPT.format( + use_case_actors=use_case_actors, + trd=to_markdown_code_block(val=trd), + acknowledge=to_markdown_code_block(val=acknowledge), + legacy_output=to_markdown_code_block(val=legacy_output), + additional_technical_requirements=to_markdown_code_block(val=additional_technical_requirements), + ) + return await self._vote(prompt) + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## Legacy TRD +{trd} + +## Acknowledge +{acknowledge} + +## Legacy Outputs +{legacy_output} + +## Additional Technical Requirements +{additional_technical_requirements} + +--- +You are a tool that evaluates the quality of framework code based on the TRD content; +You need to refer to the content of the "Legacy TRD" section to check for any errors or omissions in the framework code found in "Legacy Outputs"; +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +Information about the external system missing from the "Legacy TRD" can be found in the "Acknowledge" section; +Which interfaces defined in "Acknowledge" are used in the "Legacy TRD"? +Do not implement the interface in "Acknowledge" section until it is used in "Legacy TRD", you can check whether they are the same interface by looking at its ID or url; +Parts not mentioned in the "Legacy TRD" will be handled by other TRDs, therefore, processes not present in the "Legacy TRD" are considered ready; +"Additional Technical Requirements" specifies the additional technical requirements that the generated software framework code must meet; +Do the parameters of the interface of the external system used in the code comply with it's specifications in 'Acknowledge'? +Is there a lack of necessary configuration files? +Return a markdown JSON object with: +- an "issues" key containing a string list of natural text about the issues that need to addressed, found in the "Legacy Outputs" if any exits, each issue found must provide a detailed description and include reasons; +- a "conclusion" key containing the evaluation conclusion; +- a "misalignment" key containing the judgement detail of the natural text string list about the misalignment with "Legacy TRD"; +- a "is_pass" key containing a true boolean value if there is not any issue in the "Legacy Outputs"; +""" diff --git a/metagpt/actions/requirement_analysis/framework/write_framework.py b/metagpt/actions/requirement_analysis/framework/write_framework.py new file mode 100644 index 0000000000..2aa03f4473 --- /dev/null +++ b/metagpt/actions/requirement_analysis/framework/write_framework.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : write_framework.py +@Desc : The implementation of Chapter 2.1.8 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +import json + +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class WriteFramework(Action): + """WriteFramework deal with the following situations: + 1. Given a TRD, write out the software framework. + """ + + async def run( + self, + *, + use_case_actors: str, + trd: str, + acknowledge: str, + legacy_output: str, + evaluation_conclusion: str, + additional_technical_requirements: str, + ) -> str: + """ + Run the action to generate a software framework based on the provided TRD and related information. + + Args: + use_case_actors (str): Description of the use case actors involved. + trd (str): Technical Requirements Document detailing the requirements. + acknowledge (str): External acknowledgements or acknowledgements required. + legacy_output (str): Previous version of the software framework returned by `WriteFramework.run`. + evaluation_conclusion (str): Conclusion from the evaluation of the requirements. + additional_technical_requirements (str): Any additional technical requirements. + + Returns: + str: The generated software framework as a string. + + Example: + >>> write_framework = WriteFramework() + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> trd = "## TRD\\n..." + >>> acknowledge = "## Interfaces\\n..." + >>> legacy_output = '{"path":"balabala", "filename":"...", ...' + >>> evaluation_conclusion = "Balabala..." + >>> constraint = "Using Java language, ..." + >>> framework = await write_framework.run( + >>> use_case_actors=use_case_actors, + >>> trd=trd, + >>> acknowledge=acknowledge, + >>> legacy_output=framework, + >>> evaluation_conclusion=evaluation_conclusion, + >>> additional_technical_requirements=constraint, + >>> ) + >>> print(framework) + {"path":"balabala", "filename":"...", ... + + """ + acknowledge = await self._extract_external_interfaces(trd=trd, knowledge=acknowledge) + prompt = PROMPT.format( + use_case_actors=use_case_actors, + trd=to_markdown_code_block(val=trd), + acknowledge=to_markdown_code_block(val=acknowledge), + legacy_output=to_markdown_code_block(val=legacy_output), + evaluation_conclusion=evaluation_conclusion, + additional_technical_requirements=to_markdown_code_block(val=additional_technical_requirements), + ) + return await self._write(prompt) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write(self, prompt: str) -> str: + rsp = await self.llm.aask(prompt) + # Do not use `CodeParser` here. + tags = ["```json", "```"] + bix = rsp.find(tags[0]) + eix = rsp.rfind(tags[1]) + if bix >= 0: + rsp = rsp[bix : eix + len(tags[1])] + json_data = rsp.removeprefix("```json").removesuffix("```") + json.loads(json_data) # validate + return json_data + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _extract_external_interfaces(self, trd: str, knowledge: str) -> str: + prompt = f"## TRD\n{to_markdown_code_block(val=trd)}\n\n## Knowledge\n{to_markdown_code_block(val=knowledge)}\n" + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a tool that removes impurities from articles; you can remove irrelevant content from articles.", + 'Identify which interfaces are used in "TRD"? Remove the relevant content of the interfaces NOT used in "TRD" from "Knowledge" and return the simplified content of "Knowledge".', + ], + ) + return rsp + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## TRD +{trd} + +## Acknowledge +{acknowledge} + +## Legacy Outputs +{legacy_output} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Additional Technical Requirements +{additional_technical_requirements} + +--- +You are a tool that generates software framework code based on TRD. +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +The descriptions of the interfaces of the external system used in the "TRD" can be found in the "Acknowledge" section; Do not implement the interface of the external system in "Acknowledge" section until it is used in "TRD"; +"Legacy Outputs" contains the software framework code generated by you last time, which you can improve by addressing the issues raised in "Evaluation Conclusion"; +"Additional Technical Requirements" specifies the additional technical requirements that the generated software framework code must meet; +Develop the software framework based on the "TRD", the output files should include: +- The `README.md` file should include: + - The folder structure diagram of the entire project; + - Correspondence between classes, interfaces, and functions with the content in the "TRD" section; + - Prerequisites if necessary; + - Installation if necessary; + - Configuration if necessary; + - Usage if necessary; +- The `CLASS.md` file should include the class diagram in PlantUML format based on the "TRD"; +- The `SEQUENCE.md` file should include the sequence diagram in PlantUML format based on the "TRD"; +- The source code files that implement the "TRD" and "Additional Technical Requirements"; do not add comments to source code files; +- The configuration files that required by the source code files, "TRD" and "Additional Technical Requirements"; + +Return a markdown JSON object list, each object containing: +- a "path" key with a value specifying its path; +- a "filename" key with a value specifying its file name; +- a "content" key with a value containing its file content; +""" diff --git a/metagpt/actions/requirement_analysis/requirement/__init__.py b/metagpt/actions/requirement_analysis/requirement/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metagpt/actions/requirement_analysis/requirement/pic2txt.py b/metagpt/actions/requirement_analysis/requirement/pic2txt.py new file mode 100644 index 0000000000..b8f236dacb --- /dev/null +++ b/metagpt/actions/requirement_analysis/requirement/pic2txt.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/27 +@Author : mashenquan +@File : pic2txt.py +""" +import json +from pathlib import Path +from typing import List + +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import encode_image, general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class Pic2Txt(Action): + """Pic2Txt deal with the following situations: + Given some pictures depicting user requirements alongside contextual description, write out the intact textual user requirements. + """ + + async def run( + self, + *, + image_paths: List[str], + textual_user_requirement: str = "", + legacy_output: str = "", + evaluation_conclusion: str = "", + additional_technical_requirements: str = "", + ) -> str: + """ + Given some pictures depicting user requirements alongside contextual description, write out the intact textual user requirements + + Args: + image_paths (List[str]): A list of file paths to the input image(s) depicting user requirements. + textual_user_requirement (str, optional): Textual user requirement that alongside the given images, if any. + legacy_output (str, optional): The intact textual user requirements generated by you last time, if any. + evaluation_conclusion (str, optional): Conclusion or evaluation based on the processed requirements. + additional_technical_requirements (str, optional): Any supplementary technical details relevant to the process. + + Returns: + str: Textual representation of user requirements extracted from the provided image(s). + + Raises: + ValueError: If image_paths list is empty. + OSError: If there is an issue accessing or reading the image files. + + Example: + >>> images = ["requirements/pic/1.png", "requirements/pic/2.png", "requirements/pic/3.png"] + >>> textual_user_requirements = "User requirement paragraph 1 ..., ![](1.png). paragraph 2...![](2.png)..." + >>> action = Pic2Txt() + >>> intact_textual_user_requirements = await action.run(image_paths=images, textual_user_requirement=textual_user_requirements) + >>> print(intact_textual_user_requirements) + "User requirement paragraph 1 ..., ![...](1.png) This picture describes... paragraph 2...![...](2.png)..." + + """ + descriptions = {} + for i in image_paths: + filename = Path(i) + base64_image = encode_image(filename) + rsp = await self._pic2txt( + "Generate a paragraph of text based on the content of the image, the language of the text is consistent with the language in the image.", + base64_image=base64_image, + ) + descriptions[filename.name] = rsp + + prompt = PROMPT.format( + textual_user_requirement=textual_user_requirement, + acknowledge=to_markdown_code_block(val=json.dumps(descriptions), type_="json"), + legacy_output=to_markdown_code_block(val=legacy_output), + evaluation_conclusion=evaluation_conclusion, + additional_technical_requirements=to_markdown_code_block(val=additional_technical_requirements), + ) + return await self._write(prompt) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write(self, prompt: str) -> str: + rsp = await self.llm.aask(prompt) + return rsp + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _pic2txt(self, prompt: str, base64_image: str) -> str: + rsp = await self.llm.aask(prompt, images=base64_image) + return rsp + + +PROMPT = """ +## Textual User Requirements +{textual_user_requirement} + +## Acknowledge +{acknowledge} + +## Legacy Outputs +{legacy_output} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Additional Technical Requirements +{additional_technical_requirements} + +--- +You are a tool that generates an intact textual user requirements given a few of textual fragments of user requirements and some fragments of UI pictures. +The content of "Textual User Requirements" provides a few of textual fragments of user requirements; +The content of "Acknowledge" provides the descriptions of pictures used in "Textual User Requirements"; +"Legacy Outputs" contains the intact textual user requirements generated by you last time, which you can improve by addressing the issues raised in "Evaluation Conclusion"; +"Additional Technical Requirements" specifies the additional technical requirements that the generated textual user requirements must meet; +You need to merge the text content of the corresponding image in the "Acknowledge" into the "Textual User Requirements" to generate a complete, natural and coherent description of the user requirements; +Return the intact textual user requirements according to the given fragments of the user requirement of "Textual User Requirements" and the UI pictures; +""" diff --git a/metagpt/actions/requirement_analysis/trd/__init__.py b/metagpt/actions/requirement_analysis/trd/__init__.py new file mode 100644 index 0000000000..4603532c42 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : __init__.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" + + +from metagpt.actions.requirement_analysis.trd.detect_interaction import DetectInteraction +from metagpt.actions.requirement_analysis.trd.evaluate_trd import EvaluateTRD +from metagpt.actions.requirement_analysis.trd.write_trd import WriteTRD +from metagpt.actions.requirement_analysis.trd.compress_external_interfaces import CompressExternalInterfaces + +__all__ = [CompressExternalInterfaces, DetectInteraction, WriteTRD, EvaluateTRD] diff --git a/metagpt/actions/requirement_analysis/trd/compress_external_interfaces.py b/metagpt/actions/requirement_analysis/trd/compress_external_interfaces.py new file mode 100644 index 0000000000..abaf6fc307 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/compress_external_interfaces.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : compress_external_interfaces.py +@Desc : The implementation of Chapter 2.1.5 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log + + +@register_tool(include_functions=["run"]) +class CompressExternalInterfaces(Action): + """CompressExternalInterfaces deal with the following situations: + 1. Given a natural text of acknowledgement, it extracts and compresses the information about external system interfaces. + """ + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def run( + self, + *, + acknowledge: str, + ) -> str: + """ + Extracts and compresses information about external system interfaces from a given acknowledgement text. + + Args: + acknowledge (str): A natural text of acknowledgement containing details about external system interfaces. + + Returns: + str: A compressed version of the information about external system interfaces. + + Example: + >>> compress_acknowledge = CompressExternalInterfaces() + >>> acknowledge = "## Interfaces\\n..." + >>> available_external_interfaces = await compress_acknowledge.run(acknowledge=acknowledge) + >>> print(available_external_interfaces) + ```json\n[\n{\n"id": 1,\n"inputs": {... + """ + return await self.llm.aask( + msg=acknowledge, + system_msgs=[ + "Extracts and compresses the information about external system interfaces.", + "Return a markdown JSON list of objects, each object containing:\n" + '- an "id" key containing the interface id;\n' + '- an "inputs" key containing a dict of input parameters that consist of name and description pairs;\n' + '- an "outputs" key containing a dict of returns that consist of name and description pairs;\n', + ], + ) diff --git a/metagpt/actions/requirement_analysis/trd/detect_interaction.py b/metagpt/actions/requirement_analysis/trd/detect_interaction.py new file mode 100644 index 0000000000..b771931941 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/detect_interaction.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : detect_interaction.py +@Desc : The implementation of Chapter 2.1.6 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class DetectInteraction(Action): + """DetectInteraction deal with the following situations: + 1. Given a natural text of user requirements, it identifies the interaction events and the participants of those interactions from the original text. + """ + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def run( + self, + *, + user_requirements: str, + use_case_actors: str, + legacy_interaction_events: str, + evaluation_conclusion: str, + ) -> str: + """ + Identifies interaction events and participants from the user requirements. + + Args: + user_requirements (str): A natural language text detailing the user's requirements. + use_case_actors (str): A description of the actors involved in the use case. + legacy_interaction_events (str): The previous version of the interaction events identified by you. + evaluation_conclusion (str): The external evaluation conclusions regarding the interactions events identified by you. + + Returns: + str: A string summarizing the identified interaction events and their participants. + + Example: + >>> detect_interaction = DetectInteraction() + >>> user_requirements = "User requirements 1. ..." + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> previous_version_interaction_events = "['interaction ...', ...]" + >>> evaluation_conclusion = "Issues: ..." + >>> interaction_events = await detect_interaction.run( + >>> user_requirements=user_requirements, + >>> use_case_actors=use_case_actors, + >>> legacy_interaction_events=previous_version_interaction_events, + >>> evaluation_conclusion=evaluation_conclusion, + >>> ) + >>> print(interaction_events) + "['interaction ...', ...]" + """ + msg = PROMPT.format( + use_case_actors=use_case_actors, + original_user_requirements=to_markdown_code_block(val=user_requirements), + previous_version_of_interaction_events=legacy_interaction_events, + the_evaluation_conclusion_of_previous_version_of_trd=evaluation_conclusion, + ) + return await self.llm.aask(msg=msg) + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## User Requirements +{original_user_requirements} + +## Legacy Interaction Events +{previous_version_of_interaction_events} + +## Evaluation Conclusion +{the_evaluation_conclusion_of_previous_version_of_trd} + +--- +You are a tool for capturing interaction events. +"Actor, System, External System" provides the possible participants of the interaction event; +"Legacy Interaction Events" is the contents of the interaction events that you output earlier; +Some descriptions in the "Evaluation Conclusion" relate to the content of "User Requirements", and these descriptions in the "Evaluation Conclusion" address some issues regarding the content of "Legacy Interaction Events"; +You need to capture the interaction events occurring in the description within the content of "User Requirements" word-for-word, including: +1. Who is interacting with whom. An interaction event has a maximum of 2 participants. If there are multiple participants, it indicates that multiple events are combined into one event and should be further split; +2. When an interaction event occurs, who is the initiator? What data did the initiator enter? +3. What data does the interaction event ultimately return according to the "User Requirements"? + +You can check the data flow described in the "User Requirements" to see if there are any missing interaction events; +Return a markdown JSON object list, each object of the list containing: +- a "name" key containing the name of the interaction event; +- a "participants" key containing a string list of the names of the two participants; +- a "initiator" key containing the name of the participant who initiate the interaction; +- a "input" key containing a natural text description about the input data; +""" diff --git a/metagpt/actions/requirement_analysis/trd/evaluate_trd.py b/metagpt/actions/requirement_analysis/trd/evaluate_trd.py new file mode 100644 index 0000000000..5c256ed075 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/evaluate_trd.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : evaluate_trd.py +@Desc : The implementation of Chapter 2.1.6~2.1.7 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" + +from metagpt.actions.requirement_analysis import EvaluateAction, EvaluationData +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class EvaluateTRD(EvaluateAction): + """EvaluateTRD deal with the following situations: + 1. Given a TRD, evaluates the quality and returns a conclusion. + """ + + async def run( + self, + *, + user_requirements: str, + use_case_actors: str, + trd: str, + interaction_events: str, + legacy_user_requirements_interaction_events: str = "", + ) -> EvaluationData: + """ + Evaluates the given TRD based on user requirements, use case actors, interaction events, and optionally external legacy interaction events. + + Args: + user_requirements (str): The requirements provided by the user. + use_case_actors (str): The actors involved in the use case. + trd (str): The TRD (Technical Requirements Document) to be evaluated. + interaction_events (str): The interaction events related to the user requirements and the TRD. + legacy_user_requirements_interaction_events (str, optional): External legacy interaction events tied to the user requirements. Defaults to an empty string. + + Returns: + EvaluationData: The conclusion of the TRD evaluation. + + Example: + >>> evaluate_trd = EvaluateTRD() + >>> user_requirements = "User requirements 1. ..." + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> trd = "## TRD\\n..." + >>> interaction_events = "['interaction ...', ...]" + >>> evaluation_conclusion = "Issues: ..." + >>> legacy_user_requirements_interaction_events = ["user requirements 1. ...", ...] + >>> evaluation = await evaluate_trd.run( + >>> user_requirements=user_requirements, + >>> use_case_actors=use_case_actors, + >>> trd=trd, + >>> interaction_events=interaction_events, + >>> legacy_user_requirements_interaction_events=str(legacy_user_requirements_interaction_events), + >>> ) + >>> is_pass = evaluation.is_pass + >>> print(is_pass) + True + >>> evaluation_conclusion = evaluation.conclusion + >>> print(evaluation_conclusion) + ## Conclustion\n balabalabala... + + """ + prompt = PROMPT.format( + use_case_actors=use_case_actors, + user_requirements=to_markdown_code_block(val=user_requirements), + trd=to_markdown_code_block(val=trd), + legacy_user_requirements_interaction_events=legacy_user_requirements_interaction_events, + interaction_events=interaction_events, + ) + return await self._vote(prompt) + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## User Requirements +{user_requirements} + +## TRD Design +{trd} + +## External Interaction Events +{legacy_user_requirements_interaction_events} + +## Interaction Events +{legacy_user_requirements_interaction_events} +{interaction_events} + +--- +You are a tool to evaluate the TRD design. +"Actor, System, External System" provides the all possible participants in interaction events; +"User Requirements" provides the original requirements description, any parts not mentioned in this description will be handled by other modules, so do not fabricate requirements; +"External Interaction Events" is provided by an external module for your use, its content is also referred to "Interaction Events" section; The content in "External Interaction Events" can be determined to be problem-free; +"External Interaction Events" provides some identified interaction events and the interacting participants based on the part of the content of the "User Requirements"; +"Interaction Events" provides some identified interaction events and the interacting participants based on the content of the "User Requirements"; +"TRD Design" provides a comprehensive design of the implementation steps for the original requirements, incorporating the interaction events from "Interaction Events" and adding additional steps to connect the complete upstream and downstream data flows; +In order to integrate the full upstream and downstream data flow, the "TRD Design" allows for the inclusion of steps that do not appear in the original requirements description, but do not conflict with those explicitly described in the "User Requirements"; +Which interactions from "Interaction Events" correspond to which steps in "TRD Design"? Please provide reasons. +Which aspects of "TRD Design" and "Interaction Events" do not align with the descriptions in "User Requirements"? Please provide detailed descriptions and reasons. +If the descriptions in "User Requirements" are divided into multiple steps in "TRD Design" and "Interaction Events," it can be considered compliant with the descriptions in "User Requirements" as long as it does not conflict with them; +There is a possibility of missing details in the descriptions of "User Requirements". Any additional steps in "TRD Design" and "Interaction Events" are considered compliant with "User Requirements" as long as they do not conflict with the descriptions provided in "User Requirements"; +If there are interaction events with external systems in "TRD Design", you must explicitly specify the ID of the external interface to use for the interaction events, the input and output parameters of the used external interface must explictly match the input and output of the interaction event; +Does the sequence of steps in "Interaction Events" cause performance or cost issues? Please provide detailed descriptions and reasons; +If each step of "TRD Design" has input data, its input data is provided either by the output of the previous steps or by participants of "Actor, System, External System", and there should be no passive data; +Return a markdown JSON object with: +- an "issues" key containing a string list of natural text about the issues that need to be addressed, found in the "TRD Design" if any exist, each issue found must provide a detailed description and include reasons; +- a "conclusion" key containing the evaluation conclusion; +- a "correspondence_between" key containing the judgement detail of the natural text string list about the correspondence between "Interaction Events" and "TRD Design" steps; +- a "misalignment" key containing the judgement detail of the natural text string list about the misalignment with "User Requirements"; +- a "is_pass" key containing a true boolean value if there is not any issue in the "TRD Design"; +""" diff --git a/metagpt/actions/requirement_analysis/trd/write_trd.py b/metagpt/actions/requirement_analysis/trd/write_trd.py new file mode 100644 index 0000000000..bad93ea766 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/write_trd.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : write_trd.py +@Desc : The implementation of Chapter 2.1.6~2.1.7 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class WriteTRD(Action): + """WriteTRD deal with the following situations: + 1. Given some new user requirements, write out a new TRD(Technical Requirements Document). + 2. Given some incremental user requirements, update the legacy TRD. + """ + + async def run( + self, + *, + user_requirements: str = "", + use_case_actors: str, + available_external_interfaces: str, + evaluation_conclusion: str = "", + interaction_events: str, + previous_version_trd: str = "", + legacy_user_requirements: str = "", + legacy_user_requirements_trd: str = "", + legacy_user_requirements_interaction_events: str = "", + ) -> str: + """ + Handles the writing or updating of a Technical Requirements Document (TRD) based on user requirements. + + Args: + user_requirements (str): The new/incremental user requirements. + use_case_actors (str): Description of the actors involved in the use case. + available_external_interfaces (str): List of available external interfaces. + evaluation_conclusion (str, optional): The conclusion of the evaluation of the TRD written by you. Defaults to an empty string. + interaction_events (str): The interaction events related to the user requirements that you are handling. + previous_version_trd (str, optional): The previous version of the TRD written by you, for updating. + legacy_user_requirements (str, optional): Existing user requirements handled by an external object for your use. Defaults to an empty string. + legacy_user_requirements_trd (str, optional): The TRD associated with the existing user requirements handled by an external object for your use. Defaults to an empty string. + legacy_user_requirements_interaction_events (str, optional): Interaction events related to the existing user requirements handled by an external object for your use. Defaults to an empty string. + + Returns: + str: The newly created or updated TRD written by you. + + Example: + >>> # Given a new user requirements, write out a new TRD. + >>> user_requirements = "Write a 'snake game' TRD." + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> available_external_interfaces = "The available external interfaces returned by `CompressExternalInterfaces.run` are ..." + >>> previous_version_trd = "TRD ..." # The last version of the TRD written out if there is. + >>> evaluation_conclusion = "Conclusion ..." # The conclusion returned by `EvaluateTRD.run` if there is. + >>> interaction_events = "Interaction ..." # The interaction events returned by `DetectInteraction.run`. + >>> write_trd = WriteTRD() + >>> new_version_trd = await write_trd.run( + >>> user_requirements=user_requirements, + >>> use_case_actors=use_case_actors, + >>> available_external_interfaces=available_external_interfaces, + >>> evaluation_conclusion=evaluation_conclusion, + >>> interaction_events=interaction_events, + >>> previous_version_trd=previous_version_trd, + >>> ) + >>> print(new_version_trd) + ## Technical Requirements Document\n ... + + >>> # Given an incremental requirements, update the legacy TRD. + >>> legacy_user_requirements = ["User requirements 1. ...", "User requirements 2. ...", ...] + >>> legacy_user_requirements_trd = "## Technical Requirements Document\\n ..." # The TRD before integrating more user requirements. + >>> legacy_user_requirements_interaction_events = ["The interaction events list of user requirements 1 ...", "The interaction events list of user requiremnts 2 ...", ...] + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> available_external_interfaces = "The available external interfaces returned by `CompressExternalInterfaces.run` are ..." + >>> increment_requirements = "The incremental user requirements are ..." + >>> evaluation_conclusion = "Conclusion ..." # The conclusion returned by `EvaluateTRD.run` if there is. + >>> previous_version_trd = "TRD ..." # The last version of the TRD written out if there is. + >>> write_trd = WriteTRD() + >>> new_version_trd = await write_trd.run( + >>> user_requirements=increment_requirements, + >>> use_case_actors=use_case_actors, + >>> available_external_interfaces=available_external_interfaces, + >>> evaluation_conclusion=evaluation_conclusion, + >>> interaction_events=interaction_events, + >>> previous_version_trd=previous_version_trd, + >>> legacy_user_requirements=str(legacy_user_requirements), + >>> legacy_user_requirements_trd=legacy_user_requirements_trd, + >>> legacy_user_requirements_interaction_events=str(legacy_user_requirements_interaction_events), + >>> ) + >>> print(new_version_trd) + ## Technical Requirements Document\n ... + """ + if legacy_user_requirements: + return await self._write_incremental_trd( + use_case_actors=use_case_actors, + legacy_user_requirements=legacy_user_requirements, + available_external_interfaces=available_external_interfaces, + legacy_user_requirements_trd=legacy_user_requirements_trd, + legacy_user_requirements_interaction_events=legacy_user_requirements_interaction_events, + incremental_user_requirements=user_requirements, + previous_version_trd=previous_version_trd, + evaluation_conclusion=evaluation_conclusion, + incremental_user_requirements_interaction_events=interaction_events, + ) + return await self._write_new_trd( + use_case_actors=use_case_actors, + original_user_requirement=user_requirements, + available_external_interfaces=available_external_interfaces, + legacy_trd=previous_version_trd, + evaluation_conclusion=evaluation_conclusion, + interaction_events=interaction_events, + ) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write_new_trd( + self, + *, + use_case_actors: str, + original_user_requirement: str, + available_external_interfaces: str, + legacy_trd: str, + evaluation_conclusion: str, + interaction_events: str, + ) -> str: + prompt = NEW_PROMPT.format( + use_case_actors=use_case_actors, + original_user_requirement=to_markdown_code_block(val=original_user_requirement), + available_external_interfaces=available_external_interfaces, + legacy_trd=to_markdown_code_block(val=legacy_trd), + evaluation_conclusion=evaluation_conclusion, + interaction_events=interaction_events, + ) + return await self.llm.aask(prompt) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write_incremental_trd( + self, + *, + use_case_actors: str, + legacy_user_requirements: str, + available_external_interfaces: str, + legacy_user_requirements_trd: str, + legacy_user_requirements_interaction_events: str, + incremental_user_requirements: str, + previous_version_trd: str, + evaluation_conclusion: str, + incremental_user_requirements_interaction_events: str, + ): + prompt = INCREMENTAL_PROMPT.format( + use_case_actors=use_case_actors, + legacy_user_requirements=to_markdown_code_block(val=legacy_user_requirements), + available_external_interfaces=available_external_interfaces, + legacy_user_requirements_trd=to_markdown_code_block(val=legacy_user_requirements_trd), + legacy_user_requirements_interaction_events=legacy_user_requirements_interaction_events, + incremental_user_requirements=to_markdown_code_block(val=incremental_user_requirements), + previous_version_trd=to_markdown_code_block(val=previous_version_trd), + evaluation_conclusion=evaluation_conclusion, + incremental_user_requirements_interaction_events=incremental_user_requirements_interaction_events, + ) + return await self.llm.aask(prompt) + + +NEW_PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## User Requirements +{original_user_requirement} + +## Available External Interfaces +{available_external_interfaces} + +## Legacy TRD +{legacy_trd} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Interaction Events +{interaction_events} + +--- +You are a TRD generator. +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +The content of "Available External Interfaces" provides the candidate steps, along with the inputs and outputs of each step; +"User Requirements" provides the original requirements description, any parts not mentioned in this description will be handled by other modules, so do not fabricate requirements; +"Legacy TRD" provides the old version of the TRD based on the "User Requirements" and can serve as a reference for the new TRD; +"Evaluation Conclusion" provides a summary of the evaluation of the old TRD in the "Legacy TRD" and can serve as a reference for the new TRD; +"Interaction Events" provides some identified interaction events and the interacting participants based on the content of the "User Requirements"; +1. What inputs and outputs are described in the "User Requirements"? +2. How many steps are needed to achieve the inputs and outputs described in the "User Requirements"? Which actors from the "Actor, System, External System" section are involved in each step? What are the inputs and outputs of each step? Where is this output used, for example, as input for which interface or where it is required in the requirements, etc.? +3. Output a complete Technical Requirements Document (TRD): + 3.1. In the description, use the actor and system names defined in the "Actor, System, External System" section to describe the interactors; + 3.2. The content should include the original text of the requirements from "User Requirements"; + 3.3. In the TRD, each step can involve a maximum of two participants. If there are more than two participants, the step needs to be further split; + 3.4. In the TRD, each step must include detailed descriptions, inputs, outputs, participants, initiator, and the rationale for the step's existence. The rationale should reference the original text to justify it, such as specifying which interface requires the output of this step as parameters or where in the requirements this step is mandated, etc.; + 3.5. In the TRD, if you need to call interfaces of external systems, you must explicitly specify the interface IDs of the external systems you want to call; +""" + +INCREMENTAL_PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## Legacy User Requirements +{legacy_user_requirements} + +## Available External Interfaces +{available_external_interfaces} + +## The TRD of Legacy User Requirements +{legacy_user_requirements_trd} + + +## The Interaction Events of Legacy User Requirements +{legacy_user_requirements_interaction_events} + +## Incremental Requirements +{incremental_user_requirements} + +## Legacy TRD +{previous_version_trd} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Interaction Events +{incremental_user_requirements_interaction_events} + +--- +You are a TRD generator. +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +The content of "Available External Interfaces" provides the candidate steps, along with the inputs and outputs of each step; +"Legacy User Requirements" provides the original requirements description handled by other modules for your use; +"The TRD of Legacy User Requirements" is the TRD generated by other modules based on the "Legacy User Requirements" for your use; +"The Interaction Events of Legacy User Requirements" is the interaction events list generated by other modules based on the "Legacy User Requirements" for your use; +"Incremental Requirements" provides the original requirements description that you need to address, any parts not mentioned in this description will be handled by other modules, so do not fabricate requirements; +The requirements in "Legacy User Requirements" combined with the "Incremental Requirements" form a complete set of requirements, therefore, you need to add the TRD portion of the "Incremental Requirements" to "The TRD of Legacy User Requirements", the added content must not conflict with the original content of "The TRD of Legacy User Requirements"; +"Legacy TRD" provides the old version of the TRD you previously wrote based on the "Incremental Requirements" and can serve as a reference for the new TRD; +"Evaluation Conclusion" provides a summary of the evaluation of the old TRD you generated in the "Legacy TRD", and the identified issues can serve as a reference for the new TRD you create; +"Interaction Events" provides some identified interaction events and the interacting participants based on the content of the "Incremental Requirements"; +1. What inputs and outputs are described in the "Incremental Requirements"? +2. How many steps are needed to achieve the inputs and outputs described in the "Incremental Requirements"? Which actors from the "Actor, System, External System" section are involved in each step? What are the inputs and outputs of each step? Where is this output used, for example, as input for which interface or where it is required in the requirements, etc.? +3. Output a complete Technical Requirements Document (TRD): + 3.1. In the description, use the actor and system names defined in the "Actor, System, External System" section to describe the interactors; + 3.2. The content should include the original text of the requirements from "User Requirements"; + 3.3. In the TRD, each step can involve a maximum of two participants. If there are more than two participants, the step needs to be further split; + 3.4. In the TRD, each step must include detailed descriptions, inputs, outputs, participants, initiator, and the rationale for the step's existence. The rationale should reference the original text to justify it, such as specifying which interface requires the output of this step as parameters or where in the requirements this step is mandated, etc. + """ diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 2a99a8d99e..99f72b076e 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -3,16 +3,17 @@ from __future__ import annotations import asyncio -from typing import Any, Callable, Optional, Union +from datetime import datetime +from typing import Any, Callable, Coroutine, Optional, Union from pydantic import TypeAdapter, model_validator from metagpt.actions import Action -from metagpt.config2 import config from metagpt.logs import logger from metagpt.tools.search_engine import SearchEngine from metagpt.tools.web_browser_engine import WebBrowserEngine from metagpt.utils.common import OutputParser +from metagpt.utils.parse_html import WebPage from metagpt.utils.text import generate_prompt_chunk, reduce_message_length LANG_PROMPT = "Please respond in {language}." @@ -43,9 +44,10 @@ {results} ### Requirements -Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \ -based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. Provide the -ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words. +Please remove irrelevant search results that are not related to the query or topic. +If the query is time-sensitive or specifies a certain time frame, please also remove search results that are outdated or outside the specified time frame. Notice that the current time is {time_stamp}. +Then, sort the remaining search results based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. +Provide the ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words. """ WEB_BROWSE_AND_SUMMARIZE_PROMPT = """### Requirements @@ -133,8 +135,8 @@ def gen_msg(): if len(remove) == 0: break - model_name = config.llm.model - prompt = reduce_message_length(gen_msg(), model_name, system_text, config.llm.max_token) + model_name = self.config.llm.model + prompt = reduce_message_length(gen_msg(), model_name, system_text, self.config.llm.max_token) logger.debug(prompt) queries = await self._aask(prompt, [system_text]) try: @@ -148,21 +150,25 @@ def gen_msg(): ret[query] = await self._search_and_rank_urls(topic, query, url_per_query) return ret - async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]: + async def _search_and_rank_urls( + self, topic: str, query: str, num_results: int = 4, max_num_results: int = None + ) -> list[str]: """Search and rank URLs based on a query. Args: topic: The research topic. query: The search query. num_results: The number of URLs to collect. + max_num_results: The max number of URLs to collect. Returns: A list of ranked URLs. """ - max_results = max(num_results * 2, 6) - results = await self.search_engine.run(query, max_results=max_results, as_string=False) + max_results = max_num_results or max(num_results * 2, 6) + results = await self._search_urls(query, max_results=max_results) _results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results)) - prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results) + time_stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results, time_stamp=time_stamp) logger.debug(prompt) indices = await self._aask(prompt) try: @@ -176,6 +182,15 @@ async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = results = self.rank_func(results) return [i["link"] for i in results[:num_results]] + async def _search_urls(self, query: str, max_results: int) -> list[dict[str, str]]: + """Use search_engine to get urls. + + Returns: + e.g. [{"title": "...", "link": "...", "snippet", "..."}] + """ + + return await self.search_engine.run(query, max_results=max_results, as_string=False) + class WebBrowseAndSummarize(Action): """Action class to explore the web and provide summaries of articles and webpages.""" @@ -202,6 +217,8 @@ async def run( *urls: str, query: str, system_text: str = RESEARCH_BASE_SYSTEM, + use_concurrent_summarization: bool = False, + per_page_timeout: Optional[float] = None, ) -> dict[str, str]: """Run the action to browse the web and provide summaries. @@ -210,18 +227,41 @@ async def run( urls: Additional URLs to browse. query: The research question. system_text: The system text. + use_concurrent_summarization: Whether to concurrently summarize the content of the webpage by LLM. + per_page_timeout: The maximum time for fetching a single page in seconds. Returns: A dictionary containing the URLs as keys and their summaries as values. """ - contents = await self.web_browser_engine.run(url, *urls) - if not urls: - contents = [contents] - - summaries = {} - prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}") - for u, content in zip([url, *urls], contents): - content = content.inner_text + contents = await self._fetch_web_contents(url, *urls, per_page_timeout=per_page_timeout) + + all_urls = [url] + list(urls) + summarize_tasks = [self._summarize_content(content, query, system_text) for content in contents] + summaries = await self._execute_summarize_tasks(summarize_tasks, use_concurrent_summarization) + result = {url: summary for url, summary in zip(all_urls, summaries) if summary} + + return result + + async def _fetch_web_contents( + self, url: str, *urls: str, per_page_timeout: Optional[float] = None + ) -> list[WebPage]: + """Fetch web contents from given URLs.""" + + contents = await self.web_browser_engine.run(url, *urls, per_page_timeout=per_page_timeout) + + return [contents] if not urls else contents + + async def _summarize_content(self, page: WebPage, query: str, system_text: str) -> str: + """Summarize web content.""" + try: + prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}") + + content = page.inner_text + + if self._is_content_invalid(content): + logger.warning(f"Invalid content detected for URL {page.url}: {content[:10]}...") + return None + chunk_summaries = [] for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096): logger.debug(prompt) @@ -231,18 +271,33 @@ async def run( chunk_summaries.append(summary) if not chunk_summaries: - summaries[u] = None - continue + return None if len(chunk_summaries) == 1: - summaries[u] = chunk_summaries[0] - continue + return chunk_summaries[0] content = "\n".join(chunk_summaries) prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content) summary = await self._aask(prompt, [system_text]) - summaries[u] = summary - return summaries + return summary + except Exception as e: + logger.error(f"Error summarizing content: {e}") + return None + + def _is_content_invalid(self, content: str) -> bool: + """Check if the content is invalid based on specific starting phrases.""" + + invalid_starts = ["Fail to load page", "Access Denied"] + + return any(content.strip().startswith(phrase) for phrase in invalid_starts) + + async def _execute_summarize_tasks(self, tasks: list[Coroutine[Any, Any, str]], use_concurrent: bool) -> list[str]: + """Execute summarize tasks either concurrently or sequentially.""" + + if use_concurrent: + return await asyncio.gather(*tasks) + + return [await task for task in tasks] class ConductResearch(Action): diff --git a/metagpt/actions/search_enhanced_qa.py b/metagpt/actions/search_enhanced_qa.py new file mode 100644 index 0000000000..1427f9b195 --- /dev/null +++ b/metagpt/actions/search_enhanced_qa.py @@ -0,0 +1,292 @@ +"""Enhancing question-answering capabilities through search engine augmentation.""" + +from __future__ import annotations + +import json + +from pydantic import Field, PrivateAttr, model_validator + +from metagpt.actions import Action +from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.tools.web_browser_engine import WebBrowserEngine +from metagpt.utils.common import CodeParser +from metagpt.utils.parse_html import WebPage +from metagpt.utils.report import ThoughtReporter + +REWRITE_QUERY_PROMPT = """ +Role: You are a highly efficient assistant that provide a better search query for web search engine to answer the given question. + +I will provide you with a question. Your task is to provide a better search query for web search engine. + +## Context +### Question +{q} + +## Format Example +```json +{{ + "query": "the better search query for web search engine.", +}} +``` + +## Instructions +- Understand the question given by the user. +- Provide a better search query for web search engine to answer the given question, your answer must be written in the same language as the question. +- When rewriting, if you are unsure of the specific time, do not include the time. + +## Constraint +Format: Just print the result in json format like **Format Example**. + +## Action +Follow **Instructions**, generate output and make sure it follows the **Constraint**. +""" + +SEARCH_ENHANCED_QA_SYSTEM_PROMPT = """ +You are a large language AI assistant built by MGX. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context. + +Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information. + +Do not include [citation:x] in your anwser, where x is a number. Other than code and specific names and citations, your answer must be written in the same language as the question. + +Here are the set of contexts: + +{context} + +Remember, don't blindly repeat the contexts verbatim. And here is the user question: +""" + + +@register_tool(include_functions=["run"]) +class SearchEnhancedQA(Action): + """Question answering and info searching through search engine.""" + + name: str = "SearchEnhancedQA" + desc: str = "Integrating search engine results to anwser the question." + + collect_links_action: CollectLinks = Field( + default_factory=CollectLinks, description="Action to collect relevant links from a search engine." + ) + web_browse_and_summarize_action: WebBrowseAndSummarize = Field( + default=None, + description="Action to explore the web and provide summaries of articles and webpages.", + ) + per_page_timeout: float = Field( + default=20, description="The maximum time for fetching a single page is in seconds. Defaults to 20s." + ) + java_script_enabled: bool = Field( + default=False, description="Whether or not to enable JavaScript in the web browser context. Defaults to False." + ) + user_agent: str = Field( + default="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36 Edg/116.0.1938.81", + description="Specific user agent to use in browser.", + ) + extra_http_headers: dict = Field( + default={"sec-ch-ua": 'Chromium";v="125", "Not.A/Brand";v="24'}, + description="An object containing additional HTTP headers to be sent with every request.", + ) + max_chars_per_webpage_summary: int = Field( + default=4000, description="Maximum summary length for each web page content." + ) + max_search_results: int = Field( + default=10, + description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.", + ) + + _reporter: ThoughtReporter = PrivateAttr(ThoughtReporter()) + + @model_validator(mode="after") + def initialize(self): + if self.web_browse_and_summarize_action is None: + web_browser_engine = WebBrowserEngine.from_browser_config( + self.config.browser, + proxy=self.config.proxy, + java_script_enabled=self.java_script_enabled, + extra_http_headers=self.extra_http_headers, + user_agent=self.user_agent, + ) + + self.web_browse_and_summarize_action = WebBrowseAndSummarize(web_browser_engine=web_browser_engine) + + return self + + async def run(self, query: str, rewrite_query: bool = True) -> str: + """Answer a query by leveraging web search results. + + Args: + query (str): The original user query. + rewrite_query (bool): Whether to rewrite the query for better web search results. Defaults to True. + + Returns: + str: A detailed answer based on web search results. + + Raises: + ValueError: If the query is invalid. + """ + async with self._reporter: + await self._reporter.async_report({"type": "search", "stage": "init"}) + self._validate_query(query) + + processed_query = await self._process_query(query, rewrite_query) + context = await self._build_context(processed_query) + + return await self._generate_answer(processed_query, context) + + def _validate_query(self, query: str) -> None: + """Validate the input query. + + Args: + query (str): The query to validate. + + Raises: + ValueError: If the query is invalid. + """ + + if not query.strip(): + raise ValueError("Query cannot be empty or contain only whitespace.") + + async def _process_query(self, query: str, should_rewrite: bool) -> str: + """Process the query, optionally rewriting it.""" + + if should_rewrite: + return await self._rewrite_query(query) + + return query + + async def _rewrite_query(self, query: str) -> str: + """Write a better search query for web search engine. + + If the rewrite process fails, the original query is returned. + + Args: + query (str): The original search query. + + Returns: + str: The rewritten query if successful, otherwise the original query. + """ + + prompt = REWRITE_QUERY_PROMPT.format(q=query) + + try: + resp = await self._aask(prompt) + rewritten_query = self._extract_rewritten_query(resp) + + logger.info(f"Query rewritten: '{query}' -> '{rewritten_query}'") + return rewritten_query + except Exception as e: + logger.warning(f"Query rewrite failed. Returning original query. Error: {e}") + return query + + def _extract_rewritten_query(self, response: str) -> str: + """Extract the rewritten query from the LLM's JSON response.""" + + resp_json = json.loads(CodeParser.parse_code(response, lang="json")) + return resp_json["query"] + + async def _build_context(self, query: str) -> str: + """Construct a context string from web search citations. + + Args: + query (str): The search query. + + Returns: + str: Formatted context with numbered citations. + """ + + citations = await self._search_citations(query) + context = "\n\n".join([f"[[citation:{i+1}]] {c}" for i, c in enumerate(citations)]) + + return context + + async def _search_citations(self, query: str) -> list[str]: + """Perform web search and summarize relevant content. + + Args: + query (str): The search query. + + Returns: + list[str]: Summaries of relevant web content. + """ + + relevant_urls = await self._collect_relevant_links(query) + await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls}) + if not relevant_urls: + logger.warning(f"No relevant URLs found for query: {query}") + return [] + + logger.info(f"The Relevant links are: {relevant_urls}") + + web_summaries = await self._summarize_web_content(relevant_urls) + if not web_summaries: + logger.warning(f"No summaries generated for query: {query}") + return [] + + citations = list(web_summaries.values()) + + return citations + + async def _collect_relevant_links(self, query: str) -> list[str]: + """Search and rank URLs relevant to the query. + + Args: + query (str): The search query. + + Returns: + list[str]: Ranked list of relevant URLs. + """ + + return await self.collect_links_action._search_and_rank_urls( + topic=query, query=query, max_num_results=self.max_search_results + ) + + async def _summarize_web_content(self, urls: list[str]) -> dict[str, str]: + """Fetch and summarize content from given URLs. + + Args: + urls (list[str]): List of URLs to summarize. + + Returns: + dict[str, str]: Mapping of URLs to their summaries. + """ + + contents = await self._fetch_web_contents(urls) + + summaries = {} + await self._reporter.async_report( + {"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]} + ) + for content in contents: + url = content.url + inner_text = content.inner_text.replace("\n", "") + if self.web_browse_and_summarize_action._is_content_invalid(inner_text): + logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...") + continue + + summary = inner_text[: self.max_chars_per_webpage_summary] + summaries[url] = summary + + return summaries + + async def _fetch_web_contents(self, urls: list[str]) -> list[WebPage]: + return await self.web_browse_and_summarize_action._fetch_web_contents( + *urls, per_page_timeout=self.per_page_timeout + ) + + async def _generate_answer(self, query: str, context: str) -> str: + """Generate an answer using the query and context. + + Args: + query (str): The user's question. + context (str): Relevant information from web search. + + Returns: + str: Generated answer based on the context. + """ + + system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context) + + async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "search", "stage": "answer"}) + rsp = await self._aask(query, [system_prompt]) + return rsp diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index d21b62f83b..e3556caa7b 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -6,13 +6,16 @@ @Modified By: mashenquan, 2023/12/5. Archive the summarization content of issue discovery for use in WriteCode. """ from pathlib import Path +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo PROMPT_TEMPLATE = """ NOTICE @@ -90,6 +93,8 @@ class SummarizeCode(Action): name: str = "SummarizeCode" i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): @@ -101,11 +106,10 @@ async def run(self): design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name) task_pathname = Path(self.i_context.task_filename) task_doc = await self.repo.docs.task.get(filename=task_pathname.name) - src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs code_blocks = [] for filename in self.i_context.codes_filenames: - code_doc = await src_file_repo.get(filename) - code_block = f"```python\n{code_doc.content}\n```\n-----" + code_doc = await self.repo.srcs.get(filename) + code_block = f"```{get_markdown_code_block_type(filename)}\n{code_doc.content}\n```\n---\n" code_blocks.append(code_block) format_example = FORMAT_EXAMPLE prompt = PROMPT_TEMPLATE.format( diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py index 81f66f9a14..3fec327838 100644 --- a/metagpt/actions/talk_action.py +++ b/metagpt/actions/talk_action.py @@ -9,7 +9,6 @@ from typing import Optional from metagpt.actions import Action -from metagpt.config2 import config from metagpt.logs import logger from metagpt.schema import Message @@ -26,7 +25,7 @@ def agent_description(self): @property def language(self): - return self.context.kwargs.language or config.language + return self.context.kwargs.language or self.config.language @property def prompt(self): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index feb15657d7..da25fe621c 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -16,18 +16,20 @@ """ import json +from pathlib import Path +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE -from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult -from metagpt.utils.common import CodeParser +from metagpt.utils.common import CodeParser, get_markdown_code_block_type from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import EditorReporter PROMPT_TEMPLATE = """ NOTICE @@ -43,9 +45,7 @@ {task} ## Legacy Code -```Code {code} -``` ## Debug logs ```text @@ -60,9 +60,14 @@ ``` # Format example -## Code: {filename} +## Code: {demo_filename}.py ```python -## {filename} +## {demo_filename}.py +... +``` +## Code: {demo_filename}.js +```javascript +// {demo_filename}.js ... ``` @@ -83,18 +88,26 @@ class WriteCode(Action): name: str = "WriteCode" i_context: Document = Field(default_factory=Document) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: code_rsp = await self._aask(prompt) - code = CodeParser.parse_code(block="", text=code_rsp) + code = CodeParser.parse_code(text=code_rsp) return code async def run(self, *args, **kwargs) -> CodingContext: - bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME) + bug_feedback = None + if self.input_args and hasattr(self.input_args, "issue_filename"): + bug_feedback = await Document.load(self.input_args.issue_filename) coding_context = CodingContext.loads(self.i_context.content) + if not coding_context.code_plan_and_change_doc: + coding_context.code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get( + filename=coding_context.task_doc.filename + ) test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json") - requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + requirement_doc = await Document.load(self.input_args.requirements_filename) summary_doc = None if coding_context.design_doc and coding_context.design_doc.filename: summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename) @@ -103,29 +116,28 @@ async def run(self, *args, **kwargs) -> CodingContext: test_detail = RunCodeResult.loads(test_doc.content) logs = test_detail.stderr - if bug_feedback: - code_context = coding_context.code_doc.content - elif self.config.inc: + if self.config.inc or bug_feedback: code_context = await self.get_codes( coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True ) else: code_context = await self.get_codes( - coding_context.task_doc, - exclude=self.i_context.filename, - project_repo=self.repo.with_src_path(self.context.src_workspace), + coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo ) if self.config.inc: prompt = REFINED_TEMPLATE.format( user_requirement=requirement_doc.content if requirement_doc else "", - code_plan_and_change=str(coding_context.code_plan_and_change_doc), + code_plan_and_change=coding_context.code_plan_and_change_doc.content + if coding_context.code_plan_and_change_doc + else "", design=coding_context.design_doc.content if coding_context.design_doc else "", task=coding_context.task_doc.content if coding_context.task_doc else "", code=code_context, logs=logs, feedback=bug_feedback.content if bug_feedback else "", filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, summary_log=summary_doc.content if summary_doc else "", ) else: @@ -136,15 +148,20 @@ async def run(self, *args, **kwargs) -> CodingContext: logs=logs, feedback=bug_feedback.content if bug_feedback else "", filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, summary_log=summary_doc.content if summary_doc else "", ) logger.info(f"Writing {coding_context.filename}..") - code = await self.write_code(prompt) - if not coding_context.code_doc: - # avoid root_path pydantic ValidationError if use WriteCode alone - root_path = self.context.src_workspace if self.context.src_workspace else "" - coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) - coding_context.code_doc.content = code + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "code", "filename": coding_context.filename}, "meta") + code = await self.write_code(prompt) + if not coding_context.code_doc: + # avoid root_path pydantic ValidationError if use WriteCode alone + coding_context.code_doc = Document( + filename=coding_context.filename, root_path=str(self.repo.src_relative_path) + ) + coding_context.code_doc.content = code + await reporter.async_report(coding_context.code_doc, "document") return coding_context @staticmethod @@ -166,38 +183,35 @@ async def get_codes(task_doc: Document, exclude: str, project_repo: ProjectRepo, if not task_doc.content: task_doc = project_repo.docs.task.get(filename=task_doc.filename) m = json.loads(task_doc.content) - code_filenames = m.get(TASK_LIST.key, []) if use_inc else m.get(REFINED_TASK_LIST.key, []) + code_filenames = m.get(TASK_LIST.key, []) if not use_inc else m.get(REFINED_TASK_LIST.key, []) codes = [] src_file_repo = project_repo.srcs - # Incremental development scenario if use_inc: - src_files = src_file_repo.all_files - # Get the old workspace contained the old codes and old workspace are created in previous CodePlanAndChange - old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace) - old_files = old_file_repo.all_files - # Get the union of the files in the src and old workspaces - union_files_list = list(set(src_files) | set(old_files)) - for filename in union_files_list: + for filename in src_file_repo.all_files: + code_block_type = get_markdown_code_block_type(filename) # Exclude the current file from the all code snippets if filename == exclude: # If the file is in the old workspace, use the old code # Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and # essential functionality is included for the project’s requirements - if filename in old_files and filename != "main.py": + if filename != "main.py": # Use old code - doc = await old_file_repo.get(filename=filename) + doc = await src_file_repo.get(filename=filename) # If the file is in the src workspace, skip it else: continue - codes.insert(0, f"-----Now, {filename} to be rewritten\n```{doc.content}```\n=====") + codes.insert( + 0, f"### The name of file to rewrite: `{filename}`\n```{code_block_type}\n{doc.content}```\n" + ) + logger.info(f"Prepare to rewrite `{filename}`") # The code snippets are generated from the src workspace else: doc = await src_file_repo.get(filename=filename) # If the file does not exist in the src workspace, skip it if not doc: continue - codes.append(f"----- {filename}\n```{doc.content}```") + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") # Normal scenario else: @@ -208,6 +222,7 @@ async def get_codes(task_doc: Document, exclude: str, project_repo: ProjectRepo, doc = await src_file_repo.get(filename=filename) if not doc: continue - codes.append(f"----- {filename}\n```{doc.content}```") + code_block_type = get_markdown_code_block_type(filename) + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") return "\n".join(codes) diff --git a/metagpt/actions/write_code_an_draft.py b/metagpt/actions/write_code_an_draft.py index ce030b0e9b..4c3fd4c190 100644 --- a/metagpt/actions/write_code_an_draft.py +++ b/metagpt/actions/write_code_an_draft.py @@ -578,7 +578,7 @@ class WriteCodeAN(Action): async def run(self, context): self.llm.system_prompt = "You are an outstanding engineer and can implement any code" - return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json") + return await WRITE_MOVE_NODE.fill(req=context, llm=self.llm, schema="json") async def main(): diff --git a/metagpt/actions/write_code_plan_and_change_an.py b/metagpt/actions/write_code_plan_and_change_an.py index a909469816..989df52f22 100644 --- a/metagpt/actions/write_code_plan_and_change_an.py +++ b/metagpt/actions/write_code_plan_and_change_an.py @@ -5,15 +5,16 @@ @Author : mannaandpoem @File : write_code_plan_and_change_an.py """ -import os -from typing import List +from typing import List, Optional -from pydantic import Field +from pydantic import BaseModel, Field from metagpt.actions.action import Action from metagpt.actions.action_node import ActionNode from metagpt.logs import logger -from metagpt.schema import CodePlanAndChangeContext +from metagpt.schema import CodePlanAndChangeContext, Document +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo DEVELOPMENT_PLAN = ActionNode( key="Development Plan", @@ -162,9 +163,8 @@ def add_numbers(): {task} ## Legacy Code -```Code {code} -``` + ## Debug logs ```text @@ -179,9 +179,14 @@ def add_numbers(): ``` # Format example -## Code: {filename} +## Code: {demo_filename}.py ```python -## {filename} +## {demo_filename}.py +... +``` +## Code: {demo_filename}.js +```javascript +// {demo_filename}.js ... ``` @@ -206,13 +211,15 @@ def add_numbers(): class WriteCodePlanAndChange(Action): name: str = "WriteCodePlanAndChange" i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, *args, **kwargs): self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to " "meticulously craft comprehensive incremental development plan and deliver detailed incremental change" - prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename) - design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename) - task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename) + prd_doc = await Document.load(filename=self.i_context.prd_filename) + design_doc = await Document.load(filename=self.i_context.design_filename) + task_doc = await Document.load(filename=self.i_context.task_filename) context = CODE_PLAN_AND_CHANGE_CONTEXT.format( requirement=f"```text\n{self.i_context.requirement}\n```", issue=f"```text\n{self.i_context.issue}\n```", @@ -222,11 +229,12 @@ async def run(self, *args, **kwargs): code=await self.get_old_codes(), ) logger.info("Writing code plan and change..") - return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json") + return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(req=context, llm=self.llm, schema="json") async def get_old_codes(self) -> str: - self.repo.old_workspace = self.repo.git_repo.workdir / os.path.basename(self.config.project_path) - old_file_repo = self.repo.git_repo.new_file_repository(relative_path=self.repo.old_workspace) - old_codes = await old_file_repo.get_all() - codes = [f"----- {code.filename}\n```{code.content}```" for code in old_codes] + old_codes = await self.repo.srcs.get_all() + codes = [ + f"### File Name: `{code.filename}`\n```{get_markdown_code_block_type(code.filename)}\n{code.content}```\n" + for code in old_codes + ] return "\n".join(codes) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index ac6fe7045c..6a283f812e 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -7,16 +7,22 @@ @Modified By: mashenquan, 2023/11/27. Following the think-act principle, solidify the task parameters when creating the WriteCode object, rather than passing them in when calling the run function. """ +import asyncio +import os +from pathlib import Path +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action -from metagpt.const import REQUIREMENT_FILENAME from metagpt.logs import logger -from metagpt.schema import CodingContext -from metagpt.utils.common import CodeParser +from metagpt.schema import CodingContext, Document +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import CodeParser, aread, awrite +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import EditorReporter PROMPT_TEMPLATE = """ # System @@ -110,34 +116,48 @@ def handle_events(self): REWRITE_CODE_TEMPLATE = """ # Instruction: rewrite code based on the Code Review and Actions -## Rewrite Code: CodeBlock. If it still has some bugs, rewrite {filename} with triple quotes. Do your utmost to optimize THIS SINGLE FILE. Return all completed codes and prohibit the return of unfinished codes. -```Code +## Rewrite Code: CodeBlock. If it still has some bugs, rewrite {filename} using a Markdown code block, with the filename docstring preceding the code block. Do your utmost to optimize THIS SINGLE FILE. Return all completed codes and prohibit the return of unfinished codes. +```python ## {filename} ... ``` +or +```javascript +// {filename} +... +``` """ class WriteCodeReview(Action): name: str = "WriteCodeReview" i_context: CodingContext = Field(default_factory=CodingContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) - async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): + async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, doc): + filename = doc.filename cr_rsp = await self._aask(context_prompt + cr_prompt) result = CodeParser.parse_block("Code Review Result", cr_rsp) if "LGTM" in result: return result, None # if LBTM, rewrite code - rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" - code_rsp = await self._aask(rewrite_prompt) - code = CodeParser.parse_code(block="", text=code_rsp) + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report( + {"type": "code", "filename": filename, "src_path": doc.root_relative_path}, "meta" + ) + rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" + code_rsp = await self._aask(rewrite_prompt) + code = CodeParser.parse_code(text=code_rsp) + doc.content = code + await reporter.async_report(doc, "document") return result, code async def run(self, *args, **kwargs) -> CodingContext: iterative_code = self.i_context.code_doc.content - k = self.context.config.code_review_k_times or 1 + k = self.context.config.code_validate_k_times or 1 for i in range(k): format_example = FORMAT_EXAMPLE.format(filename=self.i_context.code_doc.filename) @@ -145,7 +165,7 @@ async def run(self, *args, **kwargs) -> CodingContext: code_context = await WriteCode.get_codes( self.i_context.task_doc, exclude=self.i_context.filename, - project_repo=self.repo.with_src_path(self.context.src_workspace), + project_repo=self.repo, use_inc=self.config.inc, ) @@ -155,7 +175,7 @@ async def run(self, *args, **kwargs) -> CodingContext: "## Code Files\n" + code_context + "\n", ] if self.config.inc: - requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + requirement_doc = await Document.load(filename=self.input_args.requirements_filename) insert_ctx_list = [ "## User New Requirements\n" + str(requirement_doc) + "\n", "## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n", @@ -177,7 +197,7 @@ async def run(self, *args, **kwargs) -> CodingContext: f"len(self.i_context.code_doc.content)={len2}" ) result, rewrited_code = await self.write_code_review_and_rewrite( - context_prompt, cr_prompt, self.i_context.code_doc.filename + context_prompt, cr_prompt, self.i_context.code_doc ) if "LBTM" in result: iterative_code = rewrited_code @@ -189,3 +209,97 @@ async def run(self, *args, **kwargs) -> CodingContext: # 如果rewrited_code是None(原code perfect),那么直接返回code self.i_context.code_doc.content = iterative_code return self.i_context + + +@register_tool(include_functions=["run"]) +class ValidateAndRewriteCode(Action): + """According to the design and task documents, validate the code to ensure it is complete and correct.""" + + name: str = "ValidateAndRewriteCode" + + async def run( + self, + code_path: str, + system_design_input: str = "", + project_schedule_input: str = "", + code_validate_k_times: int = 2, + ) -> str: + """Validates the provided code based on the accompanying system design and project schedule documentation, return the complete and correct code. + + Read the code from code_path, and write the final code to code_path. + If both system_design_input and project_schedule_input are absent, it will return and do nothing. + + Args: + code_path (str): The file path of the code snippet to be validated. This should be a string containing the path to the source code file. + system_design_input (str): Content or file path of the design document associated with the code. This should describe the system architecture, used in the code. It helps provide context for the validation process. + project_schedule_input (str): Content or file path of the task document describing what the code is intended to accomplish. This should outline the functional requirements or objectives of the code. + code_validate_k_times (int, optional): The number of iterations for validating and potentially rewriting the code. Defaults to 2. + + Returns: + str: The potentially corrected or approved code after validation. + + Example Usage: + # Example of how to call the run method with a code snippet and documentation + await ValidateAndRewriteCode().run( + code_path="/tmp/game.js", + system_design_input="/tmp/system_design.json", + project_schedule_input="/tmp/project_task_list.json" + ) + """ + if not system_design_input and not project_schedule_input: + logger.info( + "Both `system_design_input` and `project_schedule_input` are absent, ValidateAndRewriteCode will do nothing." + ) + return + + code, design_doc, task_doc = await asyncio.gather( + aread(code_path), self._try_aread(system_design_input), self._try_aread(project_schedule_input) + ) + code_doc = self._create_code_doc(code_path=code_path, code=code) + review_action = WriteCodeReview(i_context=CodingContext(filename=code_doc.filename)) + + context = "\n".join( + [ + "## System Design\n" + design_doc + "\n", + "## Task\n" + task_doc + "\n", + ] + ) + + for i in range(code_validate_k_times): + context_prompt = PROMPT_TEMPLATE.format(context=context, code=code, filename=code_path) + cr_prompt = EXAMPLE_AND_INSTRUCTION.format( + format_example=FORMAT_EXAMPLE.format(filename=code_path), + ) + logger.info(f"The {i+1}th time to CodeReview: {code_path}.") + result, rewrited_code = await review_action.write_code_review_and_rewrite( + context_prompt, cr_prompt, doc=code_doc + ) + + if "LBTM" in result: + code = rewrited_code + elif "LGTM" in result: + break + + await awrite(filename=code_path, data=code) + + return ( + f"The review and rewriting of the code in the file '{os.path.basename(code_path)}' has been completed." + + code + ) + + @staticmethod + async def _try_aread(input: str) -> str: + """Try to read from the path if it's a file; return input directly if not.""" + + if os.path.exists(input): + return await aread(input) + + return input + + @staticmethod + def _create_code_doc(code_path: str, code: str) -> Document: + """Create a Document to represent the code doc.""" + + path = Path(code_path) + + return Document(root_path=str(path.parent), filename=path.name, content=code) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index b66887164e..7a04520d6e 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -9,12 +9,16 @@ 2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality. 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ from __future__ import annotations import json from pathlib import Path +from typing import List, Optional, Union + +from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode @@ -33,10 +37,20 @@ REQUIREMENT_FILENAME, ) from metagpt.logs import logger -from metagpt.schema import BugFixContext, Document, Documents, Message -from metagpt.utils.common import CodeParser +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import ( + CodeParser, + aread, + awrite, + rectify_pathname, + save_json_to_markdown, + to_markdown_code_block, +) from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import DocsReporter, GalleryReporter CONTEXT_TEMPLATE = """ ### Project Name @@ -58,6 +72,7 @@ """ +@register_tool(include_functions=["run"]) class WritePRD(Action): """WritePRD deal with the following situations: 1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated. @@ -65,10 +80,79 @@ class WritePRD(Action): 3. Requirement update: If the requirement is an update, the PRD document will be updated. """ - async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message: - """Run the action.""" - req: Document = await self.repo.requirement - docs: list[Document] = await self.repo.docs.prd.get_all() + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + output_pathname: str = "", + legacy_prd_filename: str = "", + extra_info: str = "", + **kwargs, + ) -> Union[AIMessage, str]: + """ + Write a Product Requirement Document. + + Args: + user_requirement (str): A string detailing the user's requirements. + output_pathname (str, optional): The output file path of the document. Defaults to "". + legacy_prd_filename (str, optional): The file path of the legacy Product Requirement Document to use as a reference. Defaults to "". + extra_info (str, optional): Additional information to include in the document. Defaults to "". + **kwargs: Additional keyword arguments. + + Returns: + str: The file path of the generated Product Requirement Document. + + Example: + # Write a new PRD (Product Requirement Document) + >>> user_requirement = "Write a snake game" + >>> output_pathname = "snake_game/docs/prd.json" + >>> extra_info = "YOUR EXTRA INFO, if any" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, output_pathname=output_pathname, extra_info=extra_info) + >>> print(result) + PRD filename: "/absolute/path/to/snake_game/docs/prd.json" + + # Rewrite an existing PRD (Product Requirement Document) and save to a new path. + >>> user_requirement = "Write PRD for a snake game, include new features such as a web UI" + >>> legacy_prd_filename = "/absolute/path/to/snake_game/docs/prd.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/prd_new.json" + >>> extra_info = "YOUR EXTRA INFO, if any" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, legacy_prd_filename=legacy_prd_filename, extra_info=extra_info) + >>> print(result) + PRD filename: "/absolute/path/to/snake_game/docs/prd_new.json" + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + output_pathname=output_pathname, + legacy_prd_filename=legacy_prd_filename, + extra_info=extra_info, + ) + + self.input_args = with_messages[-1].instruct_content + if not self.input_args: + self.repo = ProjectRepo(self.context.kwargs.project_path) + await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content) + self.input_args = AIMessage.create_instruct_value( + kvs={ + "project_path": self.context.kwargs.project_path, + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ) + else: + self.repo = ProjectRepo(self.input_args.project_path) + req = await Document.load(filename=self.input_args.requirements_filename) + docs: list[Document] = [ + await Document.load(filename=i, project_path=self.repo.workdir) for i in self.input_args.prd_filenames + ] + if not req: raise FileNotFoundError("No requirement document found.") @@ -81,49 +165,80 @@ async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message: # if requirement is related to other documents, update them, otherwise create a new one if related_docs := await self.get_related_docs(req, docs): logger.info(f"Requirement update detected: {req.content}") - return await self._handle_requirement_update(req, related_docs) + await self._handle_requirement_update(req=req, related_docs=related_docs) else: logger.info(f"New requirement detected: {req.content}") - return await self._handle_new_requirement(req) + await self._handle_new_requirement(req) + + kvs = self.input_args.model_dump() + kvs["changed_prd_filenames"] = [ + str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys()) + ] + kvs["project_path"] = str(self.repo.workdir) + kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME) + self.context.kwargs.project_path = str(self.repo.workdir) + return AIMessage( + content="PRD is completed. " + + "\n".join( + list(self.repo.docs.prd.changed_files.keys()) + + list(self.repo.resources.prd.changed_files.keys()) + + list(self.repo.resources.competitive_analysis.changed_files.keys()) + ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput"), + cause_by=self, + ) - async def _handle_bugfix(self, req: Document) -> Message: + async def _handle_bugfix(self, req: Document) -> AIMessage: # ... bugfix logic ... await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content) await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="") - bug_fix = BugFixContext(filename=BUGFIX_FILENAME) - return Message( - content=bug_fix.model_dump_json(), - instruct_content=bug_fix, - role="", + return AIMessage( + content=f"A new issue is received: {BUGFIX_FILENAME}", cause_by=FixBug, - sent_from=self, + instruct_content=AIMessage.create_instruct_value( + { + "project_path": str(self.repo.workdir), + "issue_filename": str(self.repo.docs.workdir / BUGFIX_FILENAME), + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + }, + class_name="IssueDetail", + ), send_to="Alex", # the name of Engineer ) - async def _handle_new_requirement(self, req: Document) -> ActionOutput: - """handle new requirement""" + async def _new_prd(self, requirement: str) -> ActionNode: project_name = self.project_name - context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name) + context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name) exclude = [PROJECT_NAME.key] if project_name else [] - node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema - await self._rename_workspace(node) - new_prd_doc = await self.repo.docs.prd.save( - filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json() - ) - await self._save_competitive_analysis(new_prd_doc) - await self.repo.resources.prd.save_pdf(doc=new_prd_doc) - return Documents.from_iterable(documents=[new_prd_doc]).to_action_output() + node = await WRITE_PRD_NODE.fill( + req=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema + ) # schema=schema + return node + + async def _handle_new_requirement(self, req: Document) -> ActionOutput: + """handle new requirement""" + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "prd"}, "meta") + node = await self._new_prd(req.content) + await self._rename_workspace(node) + new_prd_doc = await self.repo.docs.prd.save( + filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json() + ) + await self._save_competitive_analysis(new_prd_doc) + md = await self.repo.resources.prd.save_pdf(doc=new_prd_doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") + return Documents.from_iterable(documents=[new_prd_doc]).to_action_output() async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput: # ... requirement update logic ... for doc in related_docs: - await self._update_prd(req, doc) + await self._update_prd(req=req, prd_doc=doc) return Documents.from_iterable(documents=related_docs).to_action_output() async def _is_bugfix(self, context: str) -> bool: if not self.repo.code_files_exists(): return False - node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm) + node = await WP_ISSUE_TYPE_NODE.fill(req=context, llm=self.llm) return node.get("issue_type") == "BUG" async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]: @@ -133,33 +248,39 @@ async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Do async def _is_related(self, req: Document, old_prd: Document) -> bool: context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content) - node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) + node = await WP_IS_RELATIVE_NODE.fill(req=context, llm=self.llm) return node.get("is_relative") == "YES" async def _merge(self, req: Document, related_doc: Document) -> Document: if not self.project_name: self.project_name = Path(self.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content) - node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema) + node = await REFINED_PRD_NODE.fill(req=prompt, llm=self.llm, schema=self.prompt_schema) related_doc.content = node.instruct_content.model_dump_json() await self._rename_workspace(node) return related_doc async def _update_prd(self, req: Document, prd_doc: Document) -> Document: - new_prd_doc: Document = await self._merge(req, prd_doc) - await self.repo.docs.prd.save_doc(doc=new_prd_doc) - await self._save_competitive_analysis(new_prd_doc) - await self.repo.resources.prd.save_pdf(doc=new_prd_doc) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "prd"}, "meta") + new_prd_doc: Document = await self._merge(req=req, related_doc=prd_doc) + await self.repo.docs.prd.save_doc(doc=new_prd_doc) + await self._save_competitive_analysis(new_prd_doc) + md = await self.repo.resources.prd.save_pdf(doc=new_prd_doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") return new_prd_doc - async def _save_competitive_analysis(self, prd_doc: Document): + async def _save_competitive_analysis(self, prd_doc: Document, output_filename: Path = None): m = json.loads(prd_doc.content) quadrant_chart = m.get(COMPETITIVE_QUADRANT_CHART.key) if not quadrant_chart: return - pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem + pathname = output_filename or self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem pathname.parent.mkdir(parents=True, exist_ok=True) await mermaid_to_file(self.config.mermaid.engine, quadrant_chart, pathname) + image_path = pathname.parent / f"{pathname.name}.svg" + if image_path.exists(): + await GalleryReporter().async_report(image_path, "path") async def _rename_workspace(self, prd): if not self.project_name: @@ -169,4 +290,36 @@ async def _rename_workspace(self, prd): ws_name = CodeParser.parse_str(block="Project Name", text=prd) if ws_name: self.project_name = ws_name - self.repo.git_repo.rename_root(self.project_name) + if self.repo: + self.repo.git_repo.rename_root(self.project_name) + + async def _execute_api( + self, user_requirement: str, output_pathname: str, legacy_prd_filename: str, extra_info: str + ) -> str: + content = "#### User Requirements\n{user_requirement}\n#### Extra Info\n{extra_info}\n".format( + user_requirement=to_markdown_code_block(val=user_requirement), + extra_info=to_markdown_code_block(val=extra_info), + ) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "prd"}, "meta") + req = Document(content=content) + if not legacy_prd_filename: + node = await self._new_prd(requirement=req.content) + new_prd = Document(content=node.instruct_content.model_dump_json()) + else: + content = await aread(filename=legacy_prd_filename) + old_prd = Document(content=content) + new_prd = await self._merge(req=req, related_doc=old_prd) + + if not output_pathname: + output_pathname = self.config.workspace.path / "docs" / "prd.json" + elif not Path(output_pathname).is_absolute(): + output_pathname = self.config.workspace.path / output_pathname + output_pathname = rectify_pathname(path=output_pathname, default_filename="prd.json") + await awrite(filename=output_pathname, data=new_prd.content) + competitive_analysis_filename = output_pathname.parent / f"{output_pathname.stem}-competitive-analysis" + await self._save_competitive_analysis(prd_doc=new_prd, output_filename=Path(competitive_analysis_filename)) + md_output_filename = output_pathname.with_suffix(".md") + await save_json_to_markdown(content=new_prd.content, output_filename=md_output_filename) + await reporter.async_report(md_output_filename, "path") + return f'PRD filename: "{str(output_pathname)}". The product requirement document (PRD) has been completed.' diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index 6a995e1840..81e16bcfa3 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : write_prd_an.py """ -from typing import List +from typing import List, Union from metagpt.actions.action_node import ActionNode @@ -19,8 +19,8 @@ PROGRAMMING_LANGUAGE = ActionNode( key="Programming Language", expected_type=str, - instruction="Python/JavaScript or other mainstream programming language.", - example="Python", + instruction="Mainstream programming language. If not specified in the requirements, use Vite, React, MUI, Tailwind CSS.", + example="Vite, React, MUI, Tailwind CSS", ) ORIGINAL_REQUIREMENTS = ActionNode( @@ -132,7 +132,7 @@ REFINED_REQUIREMENT_ANALYSIS = ActionNode( key="Refined Requirement Analysis", - expected_type=List[str], + expected_type=Union[List[str], str], instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project " "due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements " "required for the refined project scope.", @@ -165,7 +165,7 @@ key="Anything UNCLEAR", expected_type=str, instruction="Mention any aspects of the project that are unclear and try to clarify them.", - example="", + example="Currently, all aspects of the project are clear.", ) ISSUE_TYPE = ActionNode( diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index db85129462..907a1e9901 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -36,4 +36,4 @@ class WriteReview(Action): name: str = "WriteReview" async def run(self, context): - return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") + return await WRITE_REVIEW_NODE.fill(req=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 978fa20a6f..286d3ea135 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -45,7 +45,7 @@ async def write_code(self, prompt): code_rsp = await self._aask(prompt) try: - code = CodeParser.parse_code(block="", text=code_rsp) + code = CodeParser.parse_code(text=code_rsp) except Exception: # Handle the exception if needed logger.error(f"Can't parse the code: {code_rsp}") diff --git a/metagpt/base/__init__.py b/metagpt/base/__init__.py new file mode 100644 index 0000000000..a2fbe8eaff --- /dev/null +++ b/metagpt/base/__init__.py @@ -0,0 +1,8 @@ +from metagpt.base.base_env import BaseEnvironment +from metagpt.base.base_role import BaseRole + + +__all__ = [ + "BaseEnvironment", + "BaseRole", +] diff --git a/metagpt/base/base_env.py b/metagpt/base/base_env.py new file mode 100644 index 0000000000..361b8b58f2 --- /dev/null +++ b/metagpt/base/base_env.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base environment + +import typing +from abc import abstractmethod +from typing import Any, Optional + +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt.base.base_serialization import BaseSerialization + +if typing.TYPE_CHECKING: + from metagpt.schema import Message + + +class BaseEnvironment(BaseSerialization): + """Base environment""" + + @abstractmethod + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Implement this to get init observation""" + + @abstractmethod + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """Implement this if you want to get partial observation from the env""" + + @abstractmethod + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """Implement this to feed a action and then get new observation from the env""" + + @abstractmethod + def publish_message(self, message: "Message", peekable: bool = True) -> bool: + """Distribute the message to the recipients.""" + + @abstractmethod + async def run(self, k=1): + """Process all task at once""" diff --git a/metagpt/environment/base_env_space.py b/metagpt/base/base_env_space.py similarity index 100% rename from metagpt/environment/base_env_space.py rename to metagpt/base/base_env_space.py diff --git a/metagpt/base/base_role.py b/metagpt/base/base_role.py new file mode 100644 index 0000000000..1f7f00fa23 --- /dev/null +++ b/metagpt/base/base_role.py @@ -0,0 +1,36 @@ +from abc import abstractmethod +from typing import Optional, Union + +from metagpt.base.base_serialization import BaseSerialization + + +class BaseRole(BaseSerialization): + """Abstract base class for all roles.""" + + name: str + + @property + def is_idle(self) -> bool: + raise NotImplementedError + + @abstractmethod + def think(self): + """Consider what to do and decide on the next course of action.""" + raise NotImplementedError + + @abstractmethod + def act(self): + """Perform the current action.""" + raise NotImplementedError + + @abstractmethod + async def react(self) -> "Message": + """Entry to one of three strategies by which Role reacts to the observed Message.""" + + @abstractmethod + async def run(self, with_message: Optional[Union[str, "Message", list[str]]] = None) -> Optional["Message"]: + """Observe, and think and act based on the results of the observation.""" + + @abstractmethod + def get_memories(self, k: int = 0) -> list["Message"]: + """Return the most recent k memories of this role.""" diff --git a/metagpt/base/base_serialization.py b/metagpt/base/base_serialization.py new file mode 100644 index 0000000000..8aff7f39e3 --- /dev/null +++ b/metagpt/base/base_serialization.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, model_serializer, model_validator + + +class BaseSerialization(BaseModel, extra="forbid"): + """ + PolyMorphic subclasses Serialization / Deserialization Mixin + - First of all, we need to know that pydantic is not designed for polymorphism. + - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need + to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. + + More details: + - https://docs.pydantic.dev/latest/concepts/serialization/ + - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` + """ + + __is_polymorphic_base = False + __subclasses_map__ = {} + + @model_serializer(mode="wrap") + def __serialize_with_class_type__(self, default_serializer) -> Any: + # default serializer, then append the `__module_class_name` field and return + ret = default_serializer(self) + ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + return ret + + @model_validator(mode="wrap") + @classmethod + def __convert_to_real_type__(cls, value: Any, handler): + if isinstance(value, dict) is False: + return handler(value) + + # it is a dict so make sure to remove the __module_class_name + # because we don't allow extra keywords but want to ensure + # e.g Cat.model_validate(cat.model_dump()) works + class_full_name = value.pop("__module_class_name", None) + + # if it's not the polymorphic base we construct via default handler + if not cls.__is_polymorphic_base: + if class_full_name is None: + return handler(value) + elif str(cls) == f"": + return handler(value) + else: + # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") + pass + + # otherwise we lookup the correct polymorphic type and construct that + # instead + if class_full_name is None: + raise ValueError("Missing __module_class_name field") + + class_type = cls.__subclasses_map__.get(class_full_name, None) + + if class_type is None: + # TODO could try dynamic import + raise TypeError(f"Trying to instantiate {class_full_name}, which has not yet been defined!") + + return class_type(**value) + + def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): + cls.__is_polymorphic_base = is_polymorphic_base + cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls + super().__init_subclass__(**kwargs) diff --git a/metagpt/config2.py b/metagpt/config2.py index f3273419f2..fd0cb09486 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -9,12 +9,17 @@ from pathlib import Path from typing import Dict, Iterable, List, Literal, Optional -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator from metagpt.configs.browser_config import BrowserConfig +from metagpt.configs.embedding_config import EmbeddingConfig +from metagpt.configs.exp_pool_config import ExperiencePoolConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig +from metagpt.configs.omniparse_config import OmniParseConfig from metagpt.configs.redis_config import RedisConfig +from metagpt.configs.role_custom_config import RoleCustomConfig +from metagpt.configs.role_zero_config import RoleZeroConfig from metagpt.configs.s3_config import S3Config from metagpt.configs.search_config import SearchConfig from metagpt.configs.workspace_config import WorkspaceConfig @@ -47,7 +52,10 @@ class Config(CLIParams, YamlModel): # Key Parameters llm: LLMConfig - # Global Proxy. Will be used if llm.proxy is not set + # RAG Embedding + embedding: EmbeddingConfig = EmbeddingConfig() + + # Global Proxy. Not used by LLM, but by other tools such as browsers. proxy: str = "" # Tool Parameters @@ -62,9 +70,12 @@ class Config(CLIParams, YamlModel): # Misc Parameters repair_llm_output: bool = False prompt_schema: Literal["json", "markdown", "raw"] = "json" - workspace: WorkspaceConfig = WorkspaceConfig() + workspace: WorkspaceConfig = Field(default_factory=WorkspaceConfig) enable_longterm_memory: bool = False - code_review_k_times: int = 2 + code_validate_k_times: int = 2 + + # Experience Pool Parameters + exp_pool: ExperiencePoolConfig = Field(default_factory=ExperiencePoolConfig) # Will be removed in the future metagpt_tti_url: str = "" @@ -76,6 +87,14 @@ class Config(CLIParams, YamlModel): azure_tts_subscription_key: str = "" azure_tts_region: str = "" + # Role's custom configuration + roles: Optional[List[RoleCustomConfig]] = None + + # RoleZero's configuration + role_zero: RoleZeroConfig = Field(default_factory=RoleZeroConfig) + + omniparse: Optional[OmniParseConfig] = None + @classmethod def from_home(cls, path): """Load config from ~/.metagpt/config2.yaml""" @@ -85,20 +104,20 @@ def from_home(cls, path): return Config.from_yaml_file(pathname) @classmethod - def default(cls): + def default(cls, reload: bool = False, **kwargs): """Load default config - Priority: env < default_config_paths - Inside default_config_paths, the latter one overwrites the former one """ - default_config_paths: List[Path] = [ + default_config_paths = ( METAGPT_ROOT / "config/config2.yaml", CONFIG_ROOT / "config2.yaml", - ] - - dicts = [dict(os.environ)] - dicts += [Config.read_yaml(path) for path in default_config_paths] - final = merge_dict(dicts) - return Config(**final) + ) + if reload or default_config_paths not in _CONFIG_CACHE: + dicts = [dict(os.environ), *(Config.read_yaml(path) for path in default_config_paths), kwargs] + final = merge_dict(dicts) + _CONFIG_CACHE[default_config_paths] = Config(**final) + return _CONFIG_CACHE[default_config_paths] @classmethod def from_llm_config(cls, llm_config: dict): @@ -148,4 +167,4 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict: return result -config = Config.default() +_CONFIG_CACHE = {} diff --git a/metagpt/configs/browser_config.py b/metagpt/configs/browser_config.py index 2f8024f44d..fafbaeeb85 100644 --- a/metagpt/configs/browser_config.py +++ b/metagpt/configs/browser_config.py @@ -5,12 +5,23 @@ @Author : alexanderwu @File : browser_config.py """ +from enum import Enum from typing import Literal -from metagpt.tools import WebBrowserEngineType from metagpt.utils.yaml_model import YamlModel +class WebBrowserEngineType(Enum): + PLAYWRIGHT = "playwright" + SELENIUM = "selenium" + CUSTOM = "custom" + + @classmethod + def __missing__(cls, key): + """Default type conversion""" + return cls.CUSTOM + + class BrowserConfig(YamlModel): """Config for Browser""" diff --git a/metagpt/configs/compress_msg_config.py b/metagpt/configs/compress_msg_config.py new file mode 100644 index 0000000000..c46334c125 --- /dev/null +++ b/metagpt/configs/compress_msg_config.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class CompressType(Enum): + """ + Compression Type for messages. Used to compress messages under token limit. + - "": No compression. Default value. + - "post_cut_by_msg": Keep as many latest messages as possible. + - "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message. + - "pre_cut_by_msg": Keep as many earliest messages as possible. + - "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message. + """ + + NO_COMPRESS = "" + POST_CUT_BY_MSG = "post_cut_by_msg" + POST_CUT_BY_TOKEN = "post_cut_by_token" + PRE_CUT_BY_MSG = "pre_cut_by_msg" + PRE_CUT_BY_TOKEN = "pre_cut_by_token" + + def __missing__(self, key): + return self.NO_COMPRESS + + @classmethod + def get_type(cls, type_name): + for member in cls: + if member.value == type_name: + return member + return cls.NO_COMPRESS + + @classmethod + def cut_types(cls): + return [member for member in cls if "cut" in member.value] diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py new file mode 100644 index 0000000000..f9b41b9dc9 --- /dev/null +++ b/metagpt/configs/embedding_config.py @@ -0,0 +1,54 @@ +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.utils.yaml_model import YamlModel + + +class EmbeddingType(Enum): + OPENAI = "openai" + AZURE = "azure" + GEMINI = "gemini" + OLLAMA = "ollama" + + +class EmbeddingConfig(YamlModel): + """Config for Embedding. + + Examples: + --------- + api_type: "openai" + api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" + + api_type: "azure" + api_key: "YOU_API_KEY" + base_url: "YOU_BASE_URL" + api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" + + api_type: "gemini" + api_key: "YOU_API_KEY" + + api_type: "ollama" + base_url: "YOU_BASE_URL" + model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" + """ + + api_type: Optional[EmbeddingType] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + + model: Optional[str] = None + embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model + + @field_validator("api_type", mode="before") + @classmethod + def check_api_type(cls, v): + if v == "": + return None + return v diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py new file mode 100644 index 0000000000..a4a2d5d417 --- /dev/null +++ b/metagpt/configs/exp_pool_config.py @@ -0,0 +1,25 @@ +from enum import Enum + +from pydantic import Field + +from metagpt.utils.yaml_model import YamlModel + + +class ExperiencePoolRetrievalType(Enum): + BM25 = "bm25" + CHROMA = "chroma" + + +class ExperiencePoolConfig(YamlModel): + enabled: bool = Field( + default=False, + description="Flag to enable or disable the experience pool. When disabled, both reading and writing are ineffective.", + ) + enable_read: bool = Field(default=False, description="Enable to read from experience pool.") + enable_write: bool = Field(default=False, description="Enable to write to experience pool.") + persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.") + retrieval_type: ExperiencePoolRetrievalType = Field( + default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool." + ) + use_llm_ranker: bool = Field(default=True, description="Use LLM Reranker to get better result.") + collection_name: str = Field(default="experience_pool", description="The collection name in chromadb") diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index af8f56372f..57913956ce 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -10,6 +10,7 @@ from pydantic import field_validator +from metagpt.configs.compress_msg_config import CompressType from metagpt.const import LLM_API_TIMEOUT from metagpt.utils.yaml_model import YamlModel @@ -31,6 +32,9 @@ class LLMType(Enum): MOONSHOT = "moonshot" MISTRAL = "mistral" YI = "yi" # lingyiwanwu + OPEN_ROUTER = "open_router" + DEEPSEEK = "deepseek" + SILICONFLOW = "siliconflow" def __missing__(self, key): return self.OPENAI @@ -83,6 +87,9 @@ class LLMConfig(YamlModel): # Cost Control calc_usage: bool = True + # Compress request messages under token limit + compress_type: CompressType = CompressType.NO_COMPRESS + @field_validator("api_key") @classmethod def check_llm_key(cls, v): diff --git a/metagpt/configs/mermaid_config.py b/metagpt/configs/mermaid_config.py index 50c8a18475..47f14f4cd0 100644 --- a/metagpt/configs/mermaid_config.py +++ b/metagpt/configs/mermaid_config.py @@ -13,7 +13,7 @@ class MermaidConfig(YamlModel): """Config for Mermaid""" - engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs" + engine: Literal["nodejs", "ink", "playwright", "pyppeteer", "none"] = "nodejs" path: str = "mmdc" # mmdc puppeteer_config: str = "" pyppeteer_path: str = "/usr/bin/google-chrome-stable" diff --git a/metagpt/configs/omniparse_config.py b/metagpt/configs/omniparse_config.py new file mode 100644 index 0000000000..ecae786970 --- /dev/null +++ b/metagpt/configs/omniparse_config.py @@ -0,0 +1,6 @@ +from metagpt.utils.yaml_model import YamlModel + + +class OmniParseConfig(YamlModel): + url: str = "" + timeout: int = 600 diff --git a/metagpt/configs/role_custom_config.py b/metagpt/configs/role_custom_config.py new file mode 100644 index 0000000000..581de605e6 --- /dev/null +++ b/metagpt/configs/role_custom_config.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/4/22 16:33 +@Author : Justin +@File : role_custom_config.py +""" +from metagpt.configs.llm_config import LLMConfig +from metagpt.utils.yaml_model import YamlModel + + +class RoleCustomConfig(YamlModel): + """custom config for roles + role: role's className or role's role_id + To be expanded + """ + + role: str = "" + llm: LLMConfig diff --git a/metagpt/configs/role_zero_config.py b/metagpt/configs/role_zero_config.py new file mode 100644 index 0000000000..91d554b2f4 --- /dev/null +++ b/metagpt/configs/role_zero_config.py @@ -0,0 +1,11 @@ +from pydantic import Field + +from metagpt.utils.yaml_model import YamlModel + + +class RoleZeroConfig(YamlModel): + enable_longterm_memory: bool = Field(default=False, description="Whether to use long-term memory.") + longterm_memory_persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") + memory_k: int = Field(default=200, description="The capacity of short-term memory.") + similarity_top_k: int = Field(default=5, description="The number of long-term memories to retrieve.") + use_llm_ranker: bool = Field(default=False, description="Whether to use LLM Reranker to get better result.") diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py index e28b14c994..2c773b685b 100644 --- a/metagpt/configs/search_config.py +++ b/metagpt/configs/search_config.py @@ -5,17 +5,28 @@ @Author : alexanderwu @File : search_config.py """ +from enum import Enum from typing import Callable, Optional -from pydantic import Field +from pydantic import ConfigDict, Field -from metagpt.tools import SearchEngineType from metagpt.utils.yaml_model import YamlModel +class SearchEngineType(Enum): + SERPAPI_GOOGLE = "serpapi" + SERPER_GOOGLE = "serper" + DIRECT_GOOGLE = "google" + DUCK_DUCK_GO = "ddg" + CUSTOM_ENGINE = "custom" + BING = "bing" + + class SearchConfig(YamlModel): """Config for Search""" + model_config = ConfigDict(extra="allow") + api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO api_key: str = "" cse_id: str = "" # for google diff --git a/metagpt/const.py b/metagpt/const.py index e4cebfd96c..9dc08b8d40 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -20,12 +20,6 @@ def get_metagpt_package_root(): """Get the root directory of the installed package.""" package_root = Path(metagpt.__file__).parent.parent - for i in (".git", ".project_root", ".gitignore"): - if (package_root / i).exists(): - break - else: - package_root = Path.cwd() - logger.info(f"Package root set to {str(package_root)}") return package_root @@ -40,6 +34,12 @@ def get_metagpt_root(): else: # Fallback to package root if no environment variable is set project_root = get_metagpt_package_root() + for i in (".git", ".project_root", ".gitignore"): + if (project_root / i).exists(): + break + else: + project_root = Path.cwd() + return project_root @@ -71,6 +71,11 @@ def get_metagpt_root(): TOOL_SCHEMA_PATH = METAGPT_ROOT / "metagpt/tools/schemas" TOOL_LIBS_PATH = METAGPT_ROOT / "metagpt/tools/libs" +# TEMPLATE PATH +TEMPLATE_FOLDER_PATH = METAGPT_ROOT / "template" +VUE_TEMPLATE_PATH = TEMPLATE_FOLDER_PATH / "vue_template" +REACT_TEMPLATE_PATH = TEMPLATE_FOLDER_PATH / "react_template" + # REAL CONSTS MEM_TTL = 24 * 30 * 3600 @@ -81,6 +86,8 @@ def get_metagpt_root(): MESSAGE_META_ROLE = "role" MESSAGE_ROUTE_TO_ALL = "" MESSAGE_ROUTE_TO_NONE = "" +MESSAGE_ROUTE_TO_SELF = "" # Add this tag to replace `ActionOutput` + REQUIREMENT_FILENAME = "requirement.txt" BUGFIX_FILENAME = "bugfix.txt" @@ -103,12 +110,13 @@ def get_metagpt_root(): CODE_SUMMARIES_FILE_REPO = "docs/code_summary" CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary" RESOURCES_FILE_REPO = "resources" -SD_OUTPUT_FILE_REPO = "resources/sd_output" +SD_OUTPUT_FILE_REPO = DEFAULT_WORKSPACE_ROOT GRAPH_REPO_FILE_REPO = "docs/graph_repo" VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db" CLASS_VIEW_FILE_REPO = "docs/class_view" YAPI_URL = "http://yapi.deepwisdomai.com/" +SD_URL = "http://172.31.0.51:49094" DEFAULT_LANGUAGE = "English" DEFAULT_MAX_TOKENS = 1500 @@ -135,3 +143,25 @@ def get_metagpt_root(): # Timeout USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration. LLM_API_TIMEOUT = 300 + +# Assistant alias +ASSISTANT_ALIAS = "response" + +# Markdown +MARKDOWN_TITLE_PREFIX = "## " + +# Reporter +METAGPT_REPORTER_DEFAULT_URL = os.environ.get("METAGPT_REPORTER_URL", "") + +# Metadata defines +AGENT = "agent" +IMAGES = "images" + +# SWE agent +SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/setup_default.sh" + +# experience pool +EXPERIENCE_MASK = "" + +# TeamLeader's name +TEAMLEADER_NAME = "Mike" diff --git a/metagpt/context.py b/metagpt/context.py index 2bd5412026..0769f78eb6 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -5,11 +5,12 @@ @Author : alexanderwu @File : context.py """ +from __future__ import annotations + import os -from pathlib import Path from typing import Any, Dict, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from metagpt.config2 import Config from metagpt.configs.llm_config import LLMConfig, LLMType @@ -20,8 +21,6 @@ FireworksCostManager, TokenCostManager, ) -from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo class AttrDict(BaseModel): @@ -62,11 +61,8 @@ class Context(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) kwargs: AttrDict = AttrDict() - config: Config = Config.default() + config: Config = Field(default_factory=Config.default) - repo: Optional[ProjectRepo] = None - git_repo: Optional[GitRepository] = None - src_workspace: Optional[Path] = None cost_manager: CostManager = CostManager() _llm: Optional[BaseLLM] = None @@ -110,7 +106,6 @@ def serialize(self) -> Dict[str, Any]: Dict[str, Any]: A dictionary containing serialized data. """ return { - "workdir": str(self.repo.workdir) if self.repo else "", "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, "cost_manager": self.cost_manager.model_dump_json(), } @@ -123,13 +118,6 @@ def deserialize(self, serialized_data: Dict[str, Any]): """ if not serialized_data: return - workdir = serialized_data.get("workdir") - if workdir: - self.git_repo = GitRepository(local_path=workdir, auto_init=True) - self.repo = ProjectRepo(self.git_repo) - src_workspace = self.git_repo.workdir / self.git_repo.workdir.name - if src_workspace.exists(): - self.src_workspace = src_workspace kwargs = serialized_data.get("kwargs") if kwargs: for k, v in kwargs.items(): diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py index d2344fa1f5..75e2e79ef8 100644 --- a/metagpt/environment/android/android_ext_env.py +++ b/metagpt/environment/android/android_ext_env.py @@ -8,9 +8,9 @@ from pydantic import Field +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable -from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams class AndroidExtEnv(ExtEnv): diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index 024c468776..03a4760c91 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -5,25 +5,25 @@ import asyncio from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union +from typing import Any, Dict, Iterable, Optional, Set, Union from gymnasium import spaces from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator +from metagpt.base import BaseEnvironment, BaseRole +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.context import Context from metagpt.environment.api.env_api import ( EnvAPIAbstract, ReadAPIRegistry, WriteAPIRegistry, ) -from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.logs import logger +from metagpt.memory import Memory from metagpt.schema import Message from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to - -if TYPE_CHECKING: - from metagpt.roles.role import Role # noqa: F401 +from metagpt.utils.git_repository import GitRepository class EnvType(Enum): @@ -50,7 +50,7 @@ def mark_as_writeable(func): return func -class ExtEnv(BaseModel): +class ExtEnv(BaseEnvironment, BaseModel): """External Env to integrate actual game environment""" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -129,9 +129,9 @@ class Environment(ExtEnv): model_config = ConfigDict(arbitrary_types_allowed=True) desc: str = Field(default="") # 环境描述 - roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True) - member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True) - history: str = "" # For debug + roles: dict[str, SerializeAsAny[BaseRole]] = Field(default_factory=dict, validate_default=True) + member_addrs: Dict[BaseRole, Set] = Field(default_factory=dict, exclude=True) + history: Memory = Field(default_factory=Memory) # For debug context: Context = Field(default_factory=Context, exclude=True) def reset( @@ -153,20 +153,20 @@ def init_roles(self): self.add_roles(self.roles.values()) return self - def add_role(self, role: "Role"): + def add_role(self, role: BaseRole): """增加一个在当前环境的角色 Add a role in the current environment """ - self.roles[role.profile] = role + self.roles[role.name] = role role.set_env(self) role.context = self.context - def add_roles(self, roles: Iterable["Role"]): + def add_roles(self, roles: Iterable[BaseRole]): """增加一批在当前环境的角色 Add a batch of characters in the current environment """ for role in roles: - self.roles[role.profile] = role + self.roles[role.name] = role for role in roles: # setup system message with roles role.context = self.context @@ -190,7 +190,7 @@ def publish_message(self, message: Message, peekable: bool = True) -> bool: found = True if not found: logger.warning(f"Message no recipients: {message.dump()}") - self.history += f"\n{message}" # For debug + self.history.add(message) # For debug return True @@ -201,19 +201,22 @@ async def run(self, k=1): for _ in range(k): futures = [] for role in self.roles.values(): + if role.is_idle: + continue future = role.run() futures.append(future) - await asyncio.gather(*futures) + if futures: + await asyncio.gather(*futures) logger.debug(f"is idle: {self.is_idle}") - def get_roles(self) -> dict[str, "Role"]: + def get_roles(self) -> dict[str, BaseRole]: """获得环境内的所有角色 Process all Role runs at once """ return self.roles - def get_role(self, name: str) -> "Role": + def get_role(self, name: str) -> BaseRole: """获得环境内的指定角色 get all the environment roles """ @@ -239,14 +242,6 @@ def set_addresses(self, obj, addresses): self.member_addrs[obj] = addresses def archive(self, auto_archive=True): - if auto_archive and self.context.git_repo: - self.context.git_repo.archive() - - @classmethod - def model_rebuild(cls, **kwargs): - from metagpt.roles.role import Role # noqa: F401 - - super().model_rebuild(**kwargs) - - -Environment.model_rebuild() + if auto_archive and self.context.kwargs.get("project_path"): + git_repo = GitRepository(self.context.kwargs.project_path) + git_repo.archive() diff --git a/metagpt/environment/mgx/__init__.py b/metagpt/environment/mgx/__init__.py new file mode 100644 index 0000000000..2bcf8efd09 --- /dev/null +++ b/metagpt/environment/mgx/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py new file mode 100644 index 0000000000..a8fc0df9f4 --- /dev/null +++ b/metagpt/environment/mgx/mgx_env.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from metagpt.const import AGENT, IMAGES, MESSAGE_ROUTE_TO_ALL, TEAMLEADER_NAME +from metagpt.environment.base_env import Environment +from metagpt.logs import get_human_input +from metagpt.roles import Role +from metagpt.schema import Message, SerializationMixin +from metagpt.utils.common import extract_and_encode_images + + +class MGXEnv(Environment, SerializationMixin): + """MGX Environment""" + + direct_chat_roles: set[str] = set() # record direct chat: @role_name + + is_public_chat: bool = True + + def _publish_message(self, message: Message, peekable: bool = True) -> bool: + if self.is_public_chat: + message.send_to.add(MESSAGE_ROUTE_TO_ALL) + message = self.move_message_info_to_content(message) + return super().publish_message(message, peekable) + + def publish_message(self, message: Message, user_defined_recipient: str = "", publicer: str = "") -> bool: + """let the team leader take over message publishing""" + message = self.attach_images(message) # for multi-modal message + + tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike + + if user_defined_recipient: + # human user's direct chat message to a certain role + for role_name in message.send_to: + if self.get_role(role_name).is_idle: + # User starts a new direct chat with a certain role, expecting a direct chat response from the role; Other roles including TL should not be involved. + # If the role is not idle, it means the user helps the role with its current work, in this case, we handle the role's response message as usual. + self.direct_chat_roles.add(role_name) + + self._publish_message(message) + # # bypass team leader, team leader only needs to know but not to react (commented out because TL doesn't understand the message well in actual experiments) + # tl.rc.memory.add(self.move_message_info_to_content(message)) + + elif message.sent_from in self.direct_chat_roles: + # if chat is not public, direct chat response from a certain role to human user, team leader and other roles in the env should not be involved, no need to publish + self.direct_chat_roles.remove(message.sent_from) + if self.is_public_chat: + self._publish_message(message) + + elif publicer == tl.profile: + if message.send_to == {"no one"}: + # skip the dummy message from team leader + return True + # message processed by team leader can be published now + self._publish_message(message) + + else: + # every regular message goes through team leader + message.send_to.add(tl.name) + self._publish_message(message) + + self.history.add(message) + + return True + + async def ask_human(self, question: str, sent_from: Role = None) -> str: + # NOTE: Can be overwritten in remote setting + rsp = await get_human_input(question) + return "Human response: " + rsp + + async def reply_to_human(self, content: str, sent_from: Role = None) -> str: + # NOTE: Can be overwritten in remote setting + return "SUCCESS, human has received your reply. Refrain from resending duplicate messages. If you no longer need to take action, use the command ‘end’ to stop." + + def move_message_info_to_content(self, message: Message) -> Message: + """Two things here: + 1. Convert role, since role field must be reserved for LLM API, and is limited to, for example, one of ["user", "assistant", "system"] + 2. Add sender and recipient info to content, making TL aware, since LLM API only takes content as input + """ + converted_msg = message.model_copy(deep=True) + if converted_msg.role not in ["system", "user", "assistant"]: + converted_msg.role = "assistant" + sent_from = converted_msg.metadata[AGENT] if AGENT in converted_msg.metadata else converted_msg.sent_from + # When displaying send_to, change it to those who need to react and exclude those who only need to be aware, e.g.: + # send_to={} -> Mike; send_to={Alice} -> Alice; send_to={Alice, } -> Alice. + if converted_msg.send_to == {MESSAGE_ROUTE_TO_ALL}: + send_to = TEAMLEADER_NAME + else: + send_to = ", ".join({role for role in converted_msg.send_to if role != MESSAGE_ROUTE_TO_ALL}) + converted_msg.content = f"[Message] from {sent_from or 'User'} to {send_to}: {converted_msg.content}" + return converted_msg + + def attach_images(self, message: Message) -> Message: + if message.role == "user": + images = extract_and_encode_images(message.content) + if images: + message.add_metadata(IMAGES, images) + return message + + def __repr__(self): + return "MGXEnv()" diff --git a/metagpt/environment/minecraft/minecraft_env.py b/metagpt/environment/minecraft/minecraft_env.py index 0f39c9ccd9..2bf39095c6 100644 --- a/metagpt/environment/minecraft/minecraft_env.py +++ b/metagpt/environment/minecraft/minecraft_env.py @@ -11,7 +11,7 @@ from llama_index.vector_stores.chroma import ChromaVectorStore from pydantic import ConfigDict, Field -from metagpt.config2 import config as CONFIG +from metagpt.config2 import Config from metagpt.environment.base_env import Environment from metagpt.environment.minecraft.const import MC_CKPT_DIR from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv @@ -82,7 +82,7 @@ def set_mc_resume(self): persist_dir=f"{MC_CKPT_DIR}/skill/vectordb", ) - if CONFIG.resume: + if Config.default().resume: logger.info(f"Loading Action Developer from {MC_CKPT_DIR}/action") self.chest_memory = read_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json") diff --git a/metagpt/environment/minecraft/minecraft_ext_env.py b/metagpt/environment/minecraft/minecraft_ext_env.py index 0436bc3aa0..fb43e97c9e 100644 --- a/metagpt/environment/minecraft/minecraft_ext_env.py +++ b/metagpt/environment/minecraft/minecraft_ext_env.py @@ -10,8 +10,8 @@ import requests from pydantic import ConfigDict, Field, model_validator +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.environment.base_env import ExtEnv, mark_as_writeable -from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.environment.minecraft.const import ( MC_CKPT_DIR, MC_CORE_INVENTORY_ITEMS, diff --git a/metagpt/environment/stanford_town/env_space.py b/metagpt/environment/stanford_town/env_space.py index e100a29527..1741cccfe8 100644 --- a/metagpt/environment/stanford_town/env_space.py +++ b/metagpt/environment/stanford_town/env_space.py @@ -9,7 +9,7 @@ from gymnasium import spaces from pydantic import ConfigDict, Field, field_validator -from metagpt.environment.base_env_space import ( +from metagpt.base.base_env_space import ( BaseEnvAction, BaseEnvActionType, BaseEnvObsParams, diff --git a/metagpt/environment/werewolf/werewolf_ext_env.py b/metagpt/environment/werewolf/werewolf_ext_env.py index 3f2508b069..d9644eb9b7 100644 --- a/metagpt/environment/werewolf/werewolf_ext_env.py +++ b/metagpt/environment/werewolf/werewolf_ext_env.py @@ -9,8 +9,8 @@ from pydantic import ConfigDict, Field +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable -from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.logs import logger diff --git a/metagpt/exp_pool/__init__.py b/metagpt/exp_pool/__init__.py new file mode 100644 index 0000000000..97d45a278b --- /dev/null +++ b/metagpt/exp_pool/__init__.py @@ -0,0 +1,6 @@ +"""Experience pool init.""" + +from metagpt.exp_pool.manager import get_exp_manager +from metagpt.exp_pool.decorator import exp_cache + +__all__ = ["get_exp_manager", "exp_cache"] diff --git a/metagpt/exp_pool/context_builders/__init__.py b/metagpt/exp_pool/context_builders/__init__.py new file mode 100644 index 0000000000..047558be03 --- /dev/null +++ b/metagpt/exp_pool/context_builders/__init__.py @@ -0,0 +1,7 @@ +"""Context builders init.""" + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder +from metagpt.exp_pool.context_builders.simple import SimpleContextBuilder +from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder + +__all__ = ["BaseContextBuilder", "SimpleContextBuilder", "RoleZeroContextBuilder"] diff --git a/metagpt/exp_pool/context_builders/action_node.py b/metagpt/exp_pool/context_builders/action_node.py new file mode 100644 index 0000000000..891b898be3 --- /dev/null +++ b/metagpt/exp_pool/context_builders/action_node.py @@ -0,0 +1,30 @@ +"""Action Node context builder.""" + +from typing import Any + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + +ACTION_NODE_CONTEXT_TEMPLATE = """ +{req} + +### Experiences +----- +{exps} +----- + +## Instruction +Consider **Experiences** to generate a better answer. +""" + + +class ActionNodeContextBuilder(BaseContextBuilder): + async def build(self, req: Any) -> str: + """Builds the action node context string. + + If there are no experiences, returns the original `req`; + otherwise returns context with `req` and formatted experiences. + """ + + exps = self.format_exps() + + return ACTION_NODE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req diff --git a/metagpt/exp_pool/context_builders/base.py b/metagpt/exp_pool/context_builders/base.py new file mode 100644 index 0000000000..691d51c8c5 --- /dev/null +++ b/metagpt/exp_pool/context_builders/base.py @@ -0,0 +1,41 @@ +"""Base context builder.""" + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Experience + +EXP_TEMPLATE = """Given the request: {req}, We can get the response: {resp}, which scored: {score}.""" + + +class BaseContextBuilder(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + exps: list[Experience] = [] + + @abstractmethod + async def build(self, req: Any) -> Any: + """Build context from req. + + Do not modify `req`. If modification is necessary, use copy.deepcopy to create a copy first. + """ + + def format_exps(self) -> str: + """Format experiences into a numbered list of strings. + + Example: + 1. Given the request: req1, We can get the response: resp1, which scored: 8. + 2. Given the request: req2, We can get the response: resp2, which scored: 9. + + Returns: + str: The formatted experiences as a string. + """ + + result = [] + for i, exp in enumerate(self.exps, start=1): + score_val = exp.metric.score.val if exp.metric and exp.metric.score else "N/A" + result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=score_val)) + + return "\n".join(result) diff --git a/metagpt/exp_pool/context_builders/role_zero.py b/metagpt/exp_pool/context_builders/role_zero.py new file mode 100644 index 0000000000..cbda72fc58 --- /dev/null +++ b/metagpt/exp_pool/context_builders/role_zero.py @@ -0,0 +1,39 @@ +"""RoleZero context builder.""" + +import copy +from typing import Any + +from metagpt.const import EXPERIENCE_MASK +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + + +class RoleZeroContextBuilder(BaseContextBuilder): + async def build(self, req: Any) -> list[dict]: + """Builds the role zero context string. + + Note: + 1. The expected format for `req`, e.g., [{...}, {"role": "user", "content": "context"}]. + 2. Returns the original `req` if it is empty. + 3. Creates a copy of req and replaces the example content in the copied req with actual experiences. + """ + + if not req: + return req + + exps = self.format_exps() + if not exps: + return req + + req_copy = copy.deepcopy(req) + + req_copy[-1]["content"] = self.replace_example_content(req_copy[-1].get("content", ""), exps) + + return req_copy + + def replace_example_content(self, text: str, new_example_content: str) -> str: + return self.fill_experience(text, new_example_content) + + @staticmethod + def fill_experience(text: str, new_example_content: str) -> str: + replaced_text = text.replace(EXPERIENCE_MASK, new_example_content) + return replaced_text diff --git a/metagpt/exp_pool/context_builders/simple.py b/metagpt/exp_pool/context_builders/simple.py new file mode 100644 index 0000000000..d7b8d0be9a --- /dev/null +++ b/metagpt/exp_pool/context_builders/simple.py @@ -0,0 +1,26 @@ +"""Simple context builder.""" + + +from typing import Any + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + +SIMPLE_CONTEXT_TEMPLATE = """ +## Context + +### Experiences +----- +{exps} +----- + +## User Requirement +{req} + +## Instruction +Consider **Experiences** to generate a better answer. +""" + + +class SimpleContextBuilder(BaseContextBuilder): + async def build(self, req: Any) -> str: + return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=self.format_exps()) diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py new file mode 100644 index 0000000000..d49c13e95d --- /dev/null +++ b/metagpt/exp_pool/decorator.py @@ -0,0 +1,229 @@ +"""Experience Decorator.""" + +import asyncio +import functools +from typing import Any, Callable, Optional, TypeVar + +from pydantic import BaseModel, ConfigDict, model_validator + +from metagpt.config2 import Config +from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder +from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager +from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge +from metagpt.exp_pool.schema import ( + LOG_NEW_EXPERIENCE_PREFIX, + Experience, + Metric, + QueryType, + Score, +) +from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer +from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer +from metagpt.logs import logger +from metagpt.utils.async_helper import NestAsyncio +from metagpt.utils.exceptions import handle_exception + +ReturnType = TypeVar("ReturnType") + + +def exp_cache( + _func: Optional[Callable[..., ReturnType]] = None, + query_type: QueryType = QueryType.SEMANTIC, + manager: Optional[ExperienceManager] = None, + scorer: Optional[BaseScorer] = None, + perfect_judge: Optional[BasePerfectJudge] = None, + context_builder: Optional[BaseContextBuilder] = None, + serializer: Optional[BaseSerializer] = None, + tag: Optional[str] = None, +): + """Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience. + + Note: + 1. This can be applied to both synchronous and asynchronous functions. + 2. The function must have a `req` parameter, and it must be provided as a keyword argument. + 3. If `config.exp_pool.enabled` is False, the decorator will just directly execute the function. + 4. If `config.exp_pool.enable_write` is False, the decorator will skip evaluating and saving the experience. + 5. If `config.exp_pool.enable_read` is False, the decorator will skip reading from the experience pool. + + + Args: + _func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache(). + query_type: The type of query to be used when fetching experiences. + manager: How to fetch, evaluate and save experience, etc. Default to `exp_manager`. + scorer: Evaluate experience. Default to `SimpleScorer()`. + perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`. + context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`. + serializer: Serializes the request and the function's return value for storage, deserializes the stored response back to the function's return value. Defaults to `SimpleSerializer()`. + tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`. + """ + + def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + @functools.wraps(func) + async def get_or_create(args: Any, kwargs: Any) -> ReturnType: + config = Config.default() + + if not config.exp_pool.enabled: + rsp = func(*args, **kwargs) + return await rsp if asyncio.iscoroutine(rsp) else rsp + + handler = ExpCacheHandler( + func=func, + args=args, + kwargs=kwargs, + query_type=query_type, + exp_manager=manager, + exp_scorer=scorer, + exp_perfect_judge=perfect_judge, + context_builder=context_builder, + serializer=serializer, + tag=tag, + ) + + await handler.fetch_experiences() + + if exp := await handler.get_one_perfect_exp(): + return exp + + await handler.execute_function() + + if config.exp_pool.enable_write: + await handler.process_experience() + + return handler._raw_resp + + return ExpCacheHandler.choose_wrapper(func, get_or_create) + + return decorator(_func) if _func else decorator + + +class ExpCacheHandler(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + func: Callable + args: Any + kwargs: Any + query_type: QueryType = QueryType.SEMANTIC + exp_manager: Optional[ExperienceManager] = None + exp_scorer: Optional[BaseScorer] = None + exp_perfect_judge: Optional[BasePerfectJudge] = None + context_builder: Optional[BaseContextBuilder] = None + serializer: Optional[BaseSerializer] = None + tag: Optional[str] = None + + _exps: list[Experience] = None + _req: str = "" + _resp: str = "" + _raw_resp: Any = None + _score: Score = None + + @model_validator(mode="after") + def initialize(self): + """Initialize default values for optional parameters if they are None. + + This is necessary because the decorator might pass None, which would override the default values set by Field. + """ + + self._validate_params() + + self.exp_manager = self.exp_manager or get_exp_manager() + self.exp_scorer = self.exp_scorer or SimpleScorer() + self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge() + self.context_builder = self.context_builder or SimpleContextBuilder() + self.serializer = self.serializer or SimpleSerializer() + self.tag = self.tag or self._generate_tag() + + self._req = self.serializer.serialize_req(**self.kwargs) + + return self + + async def fetch_experiences(self): + """Fetch experiences by query_type.""" + + self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag) + logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'") + + async def get_one_perfect_exp(self) -> Optional[Any]: + """Get a potentially perfect experience, and resolve resp.""" + + for exp in self._exps: + if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs): + logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'") + return self.serializer.deserialize_resp(exp.resp) + + return None + + async def execute_function(self): + """Execute the function, and save resp.""" + + self._raw_resp = await self._execute_function() + self._resp = self.serializer.serialize_resp(self._raw_resp) + + @handle_exception + async def process_experience(self): + """Process experience. + + Evaluates and saves experience. + Use `handle_exception` to ensure robustness, do not stop subsequent operations. + """ + + await self.evaluate_experience() + self.save_experience() + + async def evaluate_experience(self): + """Evaluate the experience, and save the score.""" + + self._score = await self.exp_scorer.evaluate(self._req, self._resp) + + def save_experience(self): + """Save the new experience.""" + + exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score)) + self.exp_manager.create_exp(exp) + self._log_exp(exp) + + @staticmethod + def choose_wrapper(func, wrapped_func): + """Choose how to run wrapped_func based on whether the function is asynchronous.""" + + async def async_wrapper(*args, **kwargs): + return await wrapped_func(args, kwargs) + + def sync_wrapper(*args, **kwargs): + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(wrapped_func(args, kwargs)) + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + def _validate_params(self): + if "req" not in self.kwargs: + raise ValueError("`req` must be provided as a keyword argument.") + + def _generate_tag(self) -> str: + """Generates a tag for the self.func. + + "ClassName.method_name" if the first argument is a class instance, otherwise just "function_name". + """ + + if self.args and hasattr(self.args[0], "__class__"): + cls_name = type(self.args[0]).__name__ + return f"{cls_name}.{self.func.__name__}" + + return self.func.__name__ + + async def _build_context(self) -> str: + self.context_builder.exps = self._exps + + return await self.context_builder.build(self.kwargs["req"]) + + async def _execute_function(self): + self.kwargs["req"] = await self._build_context() + + if asyncio.iscoroutinefunction(self.func): + return await self.func(*self.args, **self.kwargs) + + return self.func(*self.args, **self.kwargs) + + def _log_exp(self, exp: Experience): + log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"}) + + logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}") diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py new file mode 100644 index 0000000000..35de17079b --- /dev/null +++ b/metagpt/exp_pool/manager.py @@ -0,0 +1,242 @@ +"""Experience Manager.""" + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, ConfigDict, Field + +from metagpt.config2 import Config +from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType +from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, Experience, QueryType +from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception + +if TYPE_CHECKING: + from metagpt.rag.engines import SimpleEngine + + +class ExperienceManager(BaseModel): + """ExperienceManager manages the lifecycle of experiences, including CRUD and optimization. + + Args: + config (Config): Configuration for managing experiences. + _storage (SimpleEngine): Engine to handle the storage and retrieval of experiences. + _vector_store (ChromaVectorStore): The actual place where vectors are stored. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + config: Config = Field(default_factory=Config.default) + + _storage: Any = None + + @property + def storage(self) -> "SimpleEngine": + if self._storage is None: + logger.info(f"exp_pool config: {self.config.exp_pool}") + + self._storage = self._resolve_storage() + + return self._storage + + @storage.setter + def storage(self, value): + self._storage = value + + @property + def is_readable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + + @is_readable.setter + def is_readable(self, value: bool): + self.config.exp_pool.enable_read = value + + # If set to True, ensure that enabled is also True. + if value: + self.config.exp_pool.enabled = True + + @property + def is_writable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + + @is_writable.setter + def is_writable(self, value: bool): + self.config.exp_pool.enable_write = value + + # If set to True, ensure that enabled is also True. + if value: + self.config.exp_pool.enabled = True + + @handle_exception + def create_exp(self, exp: Experience): + """Adds an experience to the storage if writing is enabled. + + Args: + exp (Experience): The experience to add. + """ + + self.create_exps([exp]) + + @handle_exception + def create_exps(self, exps: list[Experience]): + """Adds multiple experiences to the storage if writing is enabled. + + Args: + exps (list[Experience]): A list of experiences to add. + """ + if not self.is_writable: + return + + self.storage.add_objs(exps) + self.storage.persist(self.config.exp_pool.persist_path) + + @handle_exception(default_return=[]) + async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]: + """Retrieves and filters experiences. + + Args: + req (str): The query string to retrieve experiences. + tag (str): Optional tag to filter the experiences by. + query_type (QueryType): Default semantic to vector matching. exact to same matching. + + Returns: + list[Experience]: A list of experiences that match the args. + """ + + if not self.is_readable: + return [] + + nodes = await self.storage.aretrieve(req) + exps: list[Experience] = [node.metadata["obj"] for node in nodes] + + # TODO: filter by metadata + if tag: + exps = [exp for exp in exps if exp.tag == tag] + + if query_type == QueryType.EXACT: + exps = [exp for exp in exps if exp.req == req] + + return exps + + @handle_exception + def delete_all_exps(self): + """Delete the all experiences.""" + + if not self.is_writable: + return + + self.storage.clear(persist_dir=self.config.exp_pool.persist_path) + + def get_exps_count(self) -> int: + """Get the total number of experiences.""" + + return self.storage.count() + + def _resolve_storage(self) -> "SimpleEngine": + """Selects the appropriate storage creation method based on the configured retrieval type.""" + + storage_creators = { + ExperiencePoolRetrievalType.BM25: self._create_bm25_storage, + ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage, + } + + return storage_creators[self.config.exp_pool.retrieval_type]() + + def _create_bm25_storage(self) -> "SimpleEngine": + """Creates or loads BM25 storage. + + This function attempts to create a new BM25 storage if the specified + document store path does not exist. If the path exists, it loads the + existing BM25 storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with BM25 storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import BM25IndexConfig, BM25RetrieverConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + persist_path = Path(self.config.exp_pool.persist_path) + docstore_path = persist_path / "docstore.json" + + ranker_configs = self._get_ranker_configs() + + if not docstore_path.exists(): + logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.") + exps = [Experience(req="req", resp="resp")] + + retriever_configs = [BM25RetrieverConfig(create_index=True, similarity_top_k=DEFAULT_SIMILARITY_TOP_K)] + + storage = SimpleEngine.from_objs( + objs=exps, retriever_configs=retriever_configs, ranker_configs=ranker_configs + ) + return storage + + logger.debug(f"Path `{docstore_path}` exists, try to load bm25 storage.") + retriever_configs = [BM25RetrieverConfig(similarity_top_k=DEFAULT_SIMILARITY_TOP_K)] + storage = SimpleEngine.from_index( + BM25IndexConfig(persist_path=persist_path), + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, + ) + + return storage + + def _create_chroma_storage(self) -> "SimpleEngine": + """Creates Chroma storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with Chroma storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig( + persist_path=self.config.exp_pool.persist_path, + collection_name=self.config.exp_pool.collection_name, + similarity_top_k=DEFAULT_SIMILARITY_TOP_K, + ) + ] + ranker_configs = self._get_ranker_configs() + + storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return storage + + def _get_ranker_configs(self): + """Returns ranker configurations based on the configuration. + + If `use_llm_ranker` is True, returns a list with one `LLMRankerConfig` + instance. Otherwise, returns an empty list. + + Returns: + list: A list of `LLMRankerConfig` instances or an empty list. + """ + + from metagpt.rag.schema import LLMRankerConfig + + return [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] if self.config.exp_pool.use_llm_ranker else [] + + +_exp_manager = None + + +def get_exp_manager() -> ExperienceManager: + global _exp_manager + if _exp_manager is None: + _exp_manager = ExperienceManager() + return _exp_manager diff --git a/metagpt/exp_pool/perfect_judges/__init__.py b/metagpt/exp_pool/perfect_judges/__init__.py new file mode 100644 index 0000000000..d8796c7c85 --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/__init__.py @@ -0,0 +1,6 @@ +"""Perfect judges init.""" + +from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge +from metagpt.exp_pool.perfect_judges.simple import SimplePerfectJudge + +__all__ = ["BasePerfectJudge", "SimplePerfectJudge"] diff --git a/metagpt/exp_pool/perfect_judges/base.py b/metagpt/exp_pool/perfect_judges/base.py new file mode 100644 index 0000000000..2935229931 --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/base.py @@ -0,0 +1,20 @@ +"""Base perfect judge.""" + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Experience + + +class BasePerfectJudge(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool: + """Determine whether the experience is perfect. + + Args: + exp (Experience): The experience to evaluate. + serialized_req (str): The serialized request to compare against the experience's request. + """ diff --git a/metagpt/exp_pool/perfect_judges/simple.py b/metagpt/exp_pool/perfect_judges/simple.py new file mode 100644 index 0000000000..37ede95c39 --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/simple.py @@ -0,0 +1,27 @@ +"""Simple perfect judge.""" + + +from pydantic import ConfigDict + +from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge +from metagpt.exp_pool.schema import MAX_SCORE, Experience + + +class SimplePerfectJudge(BasePerfectJudge): + model_config = ConfigDict(arbitrary_types_allowed=True) + + async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool: + """Determine whether the experience is perfect. + + Args: + exp (Experience): The experience to evaluate. + serialized_req (str): The serialized request to compare against the experience's request. + + Returns: + bool: True if the serialized request matches the experience's request and the experience's score is perfect, False otherwise. + """ + + if not exp.metric or not exp.metric.score: + return False + + return serialized_req == exp.req and exp.metric.score.val == MAX_SCORE diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py new file mode 100644 index 0000000000..fea48a7f7d --- /dev/null +++ b/metagpt/exp_pool/schema.py @@ -0,0 +1,76 @@ +"""Experience schema.""" +import time +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + +MAX_SCORE = 10 + +DEFAULT_SIMILARITY_TOP_K = 2 + +LOG_NEW_EXPERIENCE_PREFIX = "New experience: " + + +class QueryType(str, Enum): + """Type of query experiences.""" + + EXACT = "exact" + SEMANTIC = "semantic" + + +class ExperienceType(str, Enum): + """Experience Type.""" + + SUCCESS = "success" + FAILURE = "failure" + INSIGHT = "insight" + + +class EntryType(Enum): + """Experience Entry Type.""" + + AUTOMATIC = "Automatic" + MANUAL = "Manual" + + +class Score(BaseModel): + """Score in Metric.""" + + val: int = Field(default=1, description="Value of the score, Between 1 and 10, higher is better.") + reason: str = Field(default="", description="Reason for the value.") + + +class Metric(BaseModel): + """Experience Metric.""" + + time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.") + money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.") + score: Score = Field(default=None, description="Score, with value and reason.") + + +class Trajectory(BaseModel): + """Experience Trajectory.""" + + plan: str = Field(default="", description="The plan.") + action: str = Field(default="", description="Action for the plan.") + observation: str = Field(default="", description="Output of the action.") + reward: int = Field(default=0, description="Measure the action.") + + +class Experience(BaseModel): + """Experience.""" + + req: str = Field(..., description="") + resp: str = Field(..., description="The type is string/json/code.") + metric: Optional[Metric] = Field(default=None, description="Metric.") + exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.") + entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.") + tag: str = Field(default="", description="Tagging experience.") + traj: Optional[Trajectory] = Field(default=None, description="Trajectory.") + timestamp: Optional[float] = Field(default_factory=time.time) + uuid: Optional[UUID] = Field(default_factory=uuid4) + + def rag_key(self): + return self.req diff --git a/metagpt/exp_pool/scorers/__init__.py b/metagpt/exp_pool/scorers/__init__.py new file mode 100644 index 0000000000..caa845b143 --- /dev/null +++ b/metagpt/exp_pool/scorers/__init__.py @@ -0,0 +1,6 @@ +"""Scorers init.""" + +from metagpt.exp_pool.scorers.base import BaseScorer +from metagpt.exp_pool.scorers.simple import SimpleScorer + +__all__ = ["BaseScorer", "SimpleScorer"] diff --git a/metagpt/exp_pool/scorers/base.py b/metagpt/exp_pool/scorers/base.py new file mode 100644 index 0000000000..97cac49925 --- /dev/null +++ b/metagpt/exp_pool/scorers/base.py @@ -0,0 +1,15 @@ +"""Base scorer.""" + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Score + + +class BaseScorer(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + async def evaluate(self, req: str, resp: str) -> Score: + """Evaluates the quality of a response relative to a given request.""" diff --git a/metagpt/exp_pool/scorers/simple.py b/metagpt/exp_pool/scorers/simple.py new file mode 100644 index 0000000000..4b060aac4f --- /dev/null +++ b/metagpt/exp_pool/scorers/simple.py @@ -0,0 +1,65 @@ +"""Simple scorer.""" + +import json + +from pydantic import Field + +from metagpt.exp_pool.schema import Score +from metagpt.exp_pool.scorers.base import BaseScorer +from metagpt.llm import LLM +from metagpt.provider.base_llm import BaseLLM +from metagpt.utils.common import CodeParser + +SIMPLE_SCORER_TEMPLATE = """ +Role: You are a highly efficient assistant, tasked with evaluating a response to a given request. The response is generated by a large language model (LLM). + +I will provide you with a request and a corresponding response. Your task is to assess this response and provide a score from a human perspective. + +## Context +### Request +{req} + +### Response +{resp} + +## Format Example +```json +{{ + "val": "the value of the score, int from 1 to 10, higher is better.", + "reason": "an explanation supporting the score." +}} +``` + +## Instructions +- Understand the request and response given by the user. +- Evaluate the response based on its quality relative to the given request. +- Provide a score from 1 to 10, where 10 is the best. +- Provide a reason supporting your score. + +## Constraint +Format: Just print the result in json format like **Format Example**. + +## Action +Follow instructions, generate output and make sure it follows the **Constraint**. +""" + + +class SimpleScorer(BaseScorer): + llm: BaseLLM = Field(default_factory=LLM) + + async def evaluate(self, req: str, resp: str) -> Score: + """Evaluates the quality of a response relative to a given request, as scored by an LLM. + + Args: + req (str): The request. + resp (str): The response. + + Returns: + Score: An object containing the score (1-10) and the reasoning. + """ + + prompt = SIMPLE_SCORER_TEMPLATE.format(req=req, resp=resp) + resp = await self.llm.aask(prompt) + resp_json = json.loads(CodeParser.parse_code(resp, lang="json")) + + return Score(**resp_json) diff --git a/metagpt/exp_pool/serializers/__init__.py b/metagpt/exp_pool/serializers/__init__.py new file mode 100644 index 0000000000..8e1045588e --- /dev/null +++ b/metagpt/exp_pool/serializers/__init__.py @@ -0,0 +1,9 @@ +"""Serializers init.""" + +from metagpt.exp_pool.serializers.base import BaseSerializer +from metagpt.exp_pool.serializers.simple import SimpleSerializer +from metagpt.exp_pool.serializers.action_node import ActionNodeSerializer +from metagpt.exp_pool.serializers.role_zero import RoleZeroSerializer + + +__all__ = ["BaseSerializer", "SimpleSerializer", "ActionNodeSerializer", "RoleZeroSerializer"] diff --git a/metagpt/exp_pool/serializers/action_node.py b/metagpt/exp_pool/serializers/action_node.py new file mode 100644 index 0000000000..7746d6be47 --- /dev/null +++ b/metagpt/exp_pool/serializers/action_node.py @@ -0,0 +1,36 @@ +"""ActionNode Serializer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Type + +# Import ActionNode only for type checking to avoid circular imports +if TYPE_CHECKING: + from metagpt.actions.action_node import ActionNode + +from metagpt.exp_pool.serializers.simple import SimpleSerializer + + +class ActionNodeSerializer(SimpleSerializer): + def serialize_resp(self, resp: ActionNode) -> str: + return resp.instruct_content.model_dump_json() + + def deserialize_resp(self, resp: str) -> ActionNode: + """Customized deserialization, it will be triggered when a perfect experience is found. + + ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'. + """ + + class InstructContent: + def __init__(self, json_data): + self.json_data = json_data + + def model_dump_json(self): + return self.json_data + + from metagpt.actions.action_node import ActionNode + + action_node = ActionNode(key="", expected_type=Type[str], instruction="", example="") + action_node.instruct_content = InstructContent(resp) + + return action_node diff --git a/metagpt/exp_pool/serializers/base.py b/metagpt/exp_pool/serializers/base.py new file mode 100644 index 0000000000..c09488e121 --- /dev/null +++ b/metagpt/exp_pool/serializers/base.py @@ -0,0 +1,29 @@ +"""Base serializer.""" + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class BaseSerializer(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + def serialize_req(self, **kwargs) -> str: + """Serializes the request for storage. + + Do not modify kwargs. If modification is necessary, use copy.deepcopy to create a copy first. + Note that copy.deepcopy may raise errors, such as TypeError: cannot pickle '_thread.RLock' object. + """ + + @abstractmethod + def serialize_resp(self, resp: Any) -> str: + """Serializes the function's return value for storage. + + Do not modify resp. The rest is the same as `serialize_req`. + """ + + @abstractmethod + def deserialize_resp(self, resp: str) -> Any: + """Deserializes the stored response back to the function's return value""" diff --git a/metagpt/exp_pool/serializers/role_zero.py b/metagpt/exp_pool/serializers/role_zero.py new file mode 100644 index 0000000000..89dd73f391 --- /dev/null +++ b/metagpt/exp_pool/serializers/role_zero.py @@ -0,0 +1,58 @@ +"""RoleZero Serializer.""" + +import copy +import json + +from metagpt.exp_pool.serializers.simple import SimpleSerializer + + +class RoleZeroSerializer(SimpleSerializer): + def serialize_req(self, **kwargs) -> str: + """Serialize the request for database storage, ensuring it is a string. + + Only extracts the necessary content from `req` because `req` may be very lengthy and could cause embedding errors. + + Args: + req (list[dict]): The request to be serialized. Example: + [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."}, + {"role": "user", "content": "context"}, + ] + + Returns: + str: The serialized request as a JSON string. + """ + req = kwargs.get("req", []) + + if not req: + return "" + + filtered_req = self._filter_req(req) + + if state_data := kwargs.get("state_data"): + filtered_req.append({"role": "user", "content": state_data}) + + return json.dumps(filtered_req) + + def _filter_req(self, req: list[dict]) -> list[dict]: + """Filter the `req` to include only necessary items. + + Args: + req (list[dict]): The original request. + + Returns: + list[dict]: The filtered request. + """ + + filtered_req = [copy.deepcopy(item) for item in req if self._is_useful_content(item["content"])] + + return filtered_req + + def _is_useful_content(self, content: str) -> bool: + """Currently, only the content of the file is considered, and more judgments can be added later.""" + + if "Command Editor.read executed: file_path" in content: + return True + + return False diff --git a/metagpt/exp_pool/serializers/simple.py b/metagpt/exp_pool/serializers/simple.py new file mode 100644 index 0000000000..ebd06e0e0c --- /dev/null +++ b/metagpt/exp_pool/serializers/simple.py @@ -0,0 +1,22 @@ +"""Simple Serializer.""" + +from typing import Any + +from metagpt.exp_pool.serializers.base import BaseSerializer + + +class SimpleSerializer(BaseSerializer): + def serialize_req(self, **kwargs) -> str: + """Just use `str` to convert the request object into a string.""" + + return str(kwargs.get("req", "")) + + def serialize_resp(self, resp: Any) -> str: + """Just use `str` to convert the response object into a string.""" + + return str(resp) + + def deserialize_resp(self, resp: str) -> Any: + """Just return the string response as it is.""" + + return resp diff --git a/metagpt/ext/cr/__init__.py b/metagpt/ext/cr/__init__.py new file mode 100644 index 0000000000..2bcf8efd09 --- /dev/null +++ b/metagpt/ext/cr/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/cr/actions/code_review.py b/metagpt/ext/cr/actions/code_review.py new file mode 100644 index 0000000000..0235dc2c60 --- /dev/null +++ b/metagpt/ext/cr/actions/code_review.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import json +import re +from pathlib import Path + +import aiofiles +from unidiff import PatchSet + +from metagpt.actions.action import Action +from metagpt.ext.cr.utils.cleaner import ( + add_line_num_on_patch, + get_code_block_from_patch, + rm_patch_useless_part, +) +from metagpt.ext.cr.utils.schema import Point +from metagpt.logs import logger +from metagpt.utils.common import parse_json_code_block +from metagpt.utils.report import EditorReporter + +CODE_REVIEW_PROMPT_TEMPLATE = """ +NOTICE +Let's think and work step by step. +With the given pull-request(PR) Patch, and referenced Points(Code Standards), you should compare each point with the code one-by-one within 4000 tokens. + +The Patch code has added line number at the first character each line for reading, but the review should focus on new added code inside the `Patch` (lines starting with line number and '+'). +Each point is start with a line number and follows with the point description. + +## Patch +``` +{patch} +``` + +## Points +{points} + +## Output Format +```json +[ + {{ + "commented_file": "The file path which you give a comment from the patch", + "comment": "The chinese comment of code which do not meet point description and give modify suggestions", + "code_start_line": "the code start line number like `10` in the Patch of current comment,", + "code_end_line": "the code end line number like `15` in the Patch of current comment", + "point_id": "The point id which the `comment` references to" + }} +] +``` + +CodeReview guidelines: +- Generate code `comment` that do not meet the point description. +- Each `comment` should be restricted inside the `commented_file`. +- Try to provide diverse and insightful comments across different `commented_file`. +- Don't suggest to add docstring unless it's necessary indeed. +- If the same code error occurs multiple times, it cannot be omitted, and all places need to be identified.But Don't duplicate at the same place with the same comment! +- Every line of code in the patch needs to be carefully checked, and laziness cannot be omitted. It is necessary to find out all the places. +- The `comment` and `point_id` in the Output must correspond to and belong to the same one `Point`. + +Strictly Observe: +Just print the PR Patch comments in json format like **Output Format**. +And the output JSON must be able to be parsed by json.loads() without any errors. +""" + +CODE_REVIEW_COMFIRM_SYSTEM_PROMPT = """ +You are a professional engineer with {code_language} stack, and good at code review comment result judgement.Let's think and work step by step. +""" + +CODE_REVIEW_COMFIRM_TEMPLATE = """ +## Code +``` +{code} +``` +## Code Review Comments +{comment} + +## Description of Defects +{desc} + +## Reference Example for Judgment +{example} + +## Your Task: +1. First, check if the code meets the requirements and does not violate any defects. If it meets the requirements and does not violate any defects, print `False` and do not proceed with further judgment. +2. Based on the `Reference Example for Judgment` provided, determine if the `Code` and `Code Review Comments` match. If they match, print `True`; otherwise, print `False`. + +Note: Your output should only be `True` or `False` without any explanations. +""" + + +class CodeReview(Action): + name: str = "CodeReview" + + def format_comments(self, comments: list[dict], points: list[Point], patch: PatchSet): + new_comments = [] + logger.debug(f"original comments: {comments}") + for cmt in comments: + try: + if cmt.get("commented_file").endswith(".py"): + points = [p for p in points if p.language == "Python"] + elif cmt.get("commented_file").endswith(".java"): + points = [p for p in points if p.language == "Java"] + else: + continue + for p in points: + point_id = int(cmt.get("point_id", -1)) + if point_id == p.id: + code_start_line = cmt.get("code_start_line") + code_end_line = cmt.get("code_end_line") + code = get_code_block_from_patch(patch, code_start_line, code_end_line) + + new_comments.append( + { + "commented_file": cmt.get("commented_file"), + "code": code, + "code_start_line": code_start_line, + "code_end_line": code_end_line, + "comment": cmt.get("comment"), + "point_id": p.id, + "point": p.text, + "point_detail": p.detail, + } + ) + break + except Exception: + pass + + logger.debug(f"new_comments: {new_comments}") + return new_comments + + async def confirm_comments(self, patch: PatchSet, comments: list[dict], points: list[Point]) -> list[dict]: + points_dict = {point.id: point for point in points} + new_comments = [] + for cmt in comments: + try: + point = points_dict[cmt.get("point_id")] + + code_start_line = cmt.get("code_start_line") + code_end_line = cmt.get("code_end_line") + # 如果代码位置为空的话,那么就将这条记录丢弃掉 + if not code_start_line or not code_end_line: + logger.info("False") + continue + + # 代码增加上下文,提升confirm的准确率 + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 3)), str(int(code_end_line) + 3) + ) + pattern = r"^[ \t\n\r(){}[\];,]*$" + if re.match(pattern, code): + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 5)), str(int(code_end_line) + 5) + ) + code_language = "Java" + code_file_ext = cmt.get("commented_file", ".java").split(".")[-1] + if code_file_ext == ".java": + code_language = "Java" + elif code_file_ext == ".py": + code_language = "Python" + prompt = CODE_REVIEW_COMFIRM_TEMPLATE.format( + code=code, + comment=cmt.get("comment"), + desc=point.text, + example=point.yes_example + "\n" + point.no_example, + ) + system_prompt = [CODE_REVIEW_COMFIRM_SYSTEM_PROMPT.format(code_language=code_language)] + resp = await self.llm.aask(prompt, system_msgs=system_prompt) + if "True" in resp or "true" in resp: + new_comments.append(cmt) + except Exception: + logger.info("False") + logger.info(f"original comments num: {len(comments)}, confirmed comments num: {len(new_comments)}") + return new_comments + + async def cr_by_points(self, patch: PatchSet, points: list[Point]): + comments = [] + valid_patch_count = 0 + for patched_file in patch: + if not patched_file: + continue + if patched_file.path.endswith(".py"): + points = [p for p in points if p.language == "Python"] + valid_patch_count += 1 + elif patched_file.path.endswith(".java"): + points = [p for p in points if p.language == "Java"] + valid_patch_count += 1 + else: + continue + group_points = [points[i : i + 3] for i in range(0, len(points), 3)] + for group_point in group_points: + points_str = "id description\n" + points_str += "\n".join([f"{p.id} {p.text}" for p in group_point]) + prompt = CODE_REVIEW_PROMPT_TEMPLATE.format(patch=str(patched_file), points=points_str) + resp = await self.llm.aask(prompt) + json_str = parse_json_code_block(resp)[0] + comments_batch = json.loads(json_str) + if comments_batch: + patched_file_path = patched_file.path + for c in comments_batch: + c["commented_file"] = patched_file_path + comments.extend(comments_batch) + + if valid_patch_count == 0: + raise ValueError("Only code reviews for Python and Java languages are supported.") + + return comments + + async def run(self, patch: PatchSet, points: list[Point], output_file: str): + patch: PatchSet = rm_patch_useless_part(patch) + patch: PatchSet = add_line_num_on_patch(patch) + + result = [] + async with EditorReporter(enable_llm_stream=True) as reporter: + log_cr_output_path = Path(output_file).with_suffix(".log") + await reporter.async_report( + {"src_path": str(log_cr_output_path), "filename": log_cr_output_path.name}, "meta" + ) + comments = await self.cr_by_points(patch=patch, points=points) + log_cr_output_path.parent.mkdir(exist_ok=True, parents=True) + async with aiofiles.open(log_cr_output_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(comments, ensure_ascii=False, indent=2)) + await reporter.async_report(log_cr_output_path) + + if len(comments) != 0: + comments = self.format_comments(comments, points, patch) + comments = await self.confirm_comments(patch=patch, comments=comments, points=points) + for comment in comments: + if comment["code"]: + if not (comment["code"].isspace()): + result.append(comment) + + async with EditorReporter() as reporter: + src_path = output_file + cr_output_path = Path(output_file) + await reporter.async_report( + {"type": "CodeReview", "src_path": src_path, "filename": cr_output_path.name}, "meta" + ) + async with aiofiles.open(cr_output_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(comments, ensure_ascii=False, indent=2)) + await reporter.async_report(cr_output_path) + return result diff --git a/metagpt/ext/cr/actions/modify_code.py b/metagpt/ext/cr/actions/modify_code.py new file mode 100644 index 0000000000..820bdae4a1 --- /dev/null +++ b/metagpt/ext/cr/actions/modify_code.py @@ -0,0 +1,112 @@ +import datetime +import itertools +import re +from pathlib import Path +from typing import Optional + +from unidiff import PatchSet + +from metagpt.actions.action import Action +from metagpt.ext.cr.utils.cleaner import ( + add_line_num_on_patch, + get_code_block_from_patch, + rm_patch_useless_part, +) +from metagpt.utils.common import CodeParser +from metagpt.utils.report import EditorReporter + +SYSTEM_MSGS_PROMPT = """ +You're an adaptive software developer who excels at refining code based on user inputs. You're proficient in creating Git patches to represent code modifications. +""" + +MODIFY_CODE_PROMPT = """ +NOTICE +With the given pull-request(PR) Patch, and referenced Comments(Code Standards), you should modify the code according the Comments. + +The Patch code has added line no at the first character each line for reading, but the modification should focus on new added code inside the `Patch` (lines starting with line no and '+'). + +## Patch +``` +{patch} +``` + +## Comments +{comments} + +## Output Format + + + +Code Modification guidelines: +- Look at `point_detail`, modify the code by `point_detail`, use `code_start_line` and `code_end_line` to locate the problematic code, fix the problematic code by `point_detail` in Comments.Strictly,must handle the fix plan given by `point_detail` in every comment. +- Create a patch that satifies the git patch standard and your fixes need to be marked with '+' and '-',but notice:don't change the hunk header! +- Do not print line no in the new patch code. + +Just print the Patch in the format like **Output Format**. +""" + + +class ModifyCode(Action): + name: str = "Modify Code" + pr: str + + async def run(self, patch: PatchSet, comments: list[dict], output_dir: Optional[str] = None) -> str: + patch: PatchSet = rm_patch_useless_part(patch) + patch: PatchSet = add_line_num_on_patch(patch) + + # + for comment in comments: + code_start_line = comment.get("code_start_line") + code_end_line = comment.get("code_end_line") + # 如果代码位置为空的话,那么就将这条记录丢弃掉 + if code_start_line and code_end_line: + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 3)), str(int(code_end_line) + 3) + ) + pattern = r"^[ \t\n\r(){}[\];,]*$" + if re.match(pattern, code): + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 5)), str(int(code_end_line) + 5) + ) + # 代码增加上下文,提升代码修复的准确率 + comment["code"] = code + # 去掉CR时LLM给的comment的影响,应该使用既定的修复方案 + comment.pop("comment") + + # 按照 commented_file 字段进行分组 + comments.sort(key=lambda x: x["commented_file"]) + grouped_comments = { + key: list(group) for key, group in itertools.groupby(comments, key=lambda x: x["commented_file"]) + } + resp = None + for patched_file in patch: + patch_target_file_name = str(patched_file.path).split("/")[-1] + if patched_file.path not in grouped_comments: + continue + comments_prompt = "" + index = 1 + for grouped_comment in grouped_comments[patched_file.path]: + comments_prompt += f""" + + {grouped_comment} + \n + """ + index += 1 + prompt = MODIFY_CODE_PROMPT.format(patch=patched_file, comments=comments_prompt) + output_dir = ( + Path(output_dir) + if output_dir + else self.config.workspace.path / "modify_code" / str(datetime.date.today()) / self.pr + ) + patch_file = output_dir / f"{patch_target_file_name}.patch" + patch_file.parent.mkdir(exist_ok=True, parents=True) + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report( + {"type": "Patch", "src_path": str(patch_file), "filename": patch_file.name}, "meta" + ) + resp = await self.llm.aask(msg=prompt, system_msgs=[SYSTEM_MSGS_PROMPT]) + resp = CodeParser.parse_code(resp, "diff") + with open(patch_file, "w", encoding="utf-8") as file: + file.writelines(resp) + await reporter.async_report(patch_file) + return resp diff --git a/metagpt/ext/cr/points.json b/metagpt/ext/cr/points.json new file mode 100644 index 0000000000..f0920caccf --- /dev/null +++ b/metagpt/ext/cr/points.json @@ -0,0 +1,656 @@ +[ + { + "id": 1, + "text": "Avoid unused temporary variables", + "language": "Java", + "detail": "Defect type: Avoid unused temporary variables; Corresponding Fixer: UnusedLocalVariableFixer; Fix solution: Delete unused temporary variables", + "yes_example": "Examples of being judged as 'avoid unused temporary variables'", + "no_example": "Examples that cannot be judged as 'avoiding unused temporary variables'\n\npublic void setTransientVariablesLocal(Map transientVariables) {\n throw new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}\nThis code's 'transientVariables' is a function parameter rather than a temporary variable. Although 'transientVariables' is not used or referenced, this cannot be judged as 'avoiding unused temporary variables'\n\n\n\npublic class TriggerCmd extends NeedsActiveExecutionCmd {\n protected Map transientVariables;\n public TriggerCmd(Map transientVariables) {\n this.transientVariables = transientVariables;\n }\n}\nIn the above code, 'transientVariables' is not a temporary variable; it is a class attribute and is used in the constructor, so this cannot be judged as 'avoiding unused temporary variables'\n" + }, + { + "id": 2, + "text": "Do not use System.out.println to print", + "language": "Java", + "detail": "Defect type: Do not use System.out.println to print; Corresponding Fixer: SystemPrintlnFixer; Fixing solution: Comment out the System.out.println code", + "yes_example": "Example of being judged as 'Do not use System.out.println for printing'", + "no_example": "Examples that cannot be judged as 'Do not use System.out.println to print'\n\nthrow new IllegalStateException(\"There is no authenticated user, we need a user authenticated to find tasks\");\nThe above code is throwing an exception, not using 'System.out.print', so this cannot be judged as 'Do not use System.out.println to print'\n" + }, + { + "id": 3, + "text": "Avoid unused formal parameters in functions", + "language": "Java", + "detail": "Defect type: Avoid unused formal parameters in functions; Fix solution: Ignore", + "yes_example": "Examples of being judged as 'avoiding unused formal parameters' in functions\n\n\npublic void setTransientVariablesLocal(Map transientVariables) {\n throw new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}In this code, the formal parameter \"transientVariables\" does not appear in the function body, so this is judged as 'avoiding unused formal parameters'\n\n\n\nprotected void modifyFetchPersistencePackageRequest(PersistencePackageRequest ppr, Map pathVars) {}\nIn this code, the formal parameters \"ppr\" and \"pathVars\" do not appear in the function body, so this is judged as 'avoiding unused formal parameters'\n", + "no_example": "Examples that cannot be judged as 'avoiding unused parameters in functions'\n\npublic String processFindForm(@RequestParam(value = \"pageNo\", defaultValue = \"1\") int pageNo) {\n\tlastName = owner.getLastName();\n\treturn addPaginationModel(pageNo, paginationModel, lastName, ownersResults);\n}In this code, the parameter 'pageNo' is used within the current function 'processFindForm' in the statement 'return addPaginationModel(pageNo, paginationModel, lastName, ownersResults);', although pageNo is not used for logical calculations, it is used as a parameter in a function call to another function, so this cannot be judged as 'avoiding unused parameters in functions'\n\n\npublic void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}In this code, the parameter 'date' is referenced in the statement 'System.out.println(\"Formatted date: \" + sdf.format(date))', so this cannot be judged as 'avoiding unused parameters in functions'\n" + }, + { + "id": 4, + "text": "if statement block cannot be empty", + "language": "Java", + "detail": "Defect type: if statement block cannot be empty; Corresponding Fixer: EmptyIfStmtFixer; Fixing solution: delete the if statement block or handle the logic appropriately or comment to explain why it is empty", + "yes_example": "Examples of being judged as 'if statement block cannot be empty'\n\npublic void emptyIfStatement() {\n\tif (getSpecialties().isEmpty()) {\n\t}\n}\nThis code's if statement block is empty, so it is judged as 'if statement block cannot be empty'\n\n\n\npublic void judgePersion() {\n\tif (persion != null) {\n\t\t// judge persion if not null\n\t}\n}\nAlthough this code's if statement block has content, the '// judge persion if not null' is just a code comment, and there is no actual logic code inside the if statement block, so it is judged as 'if statement block cannot be empty'\n", + "no_example": "Example that cannot be judged as 'if statement block cannot be empty'" + }, + { + "id": 5, + "text": "Loop body cannot be empty", + "language": "Java", + "detail": "Defect type: loop body cannot be empty; Corresponding Fixer: EmptyStatementNotInLoopFixer; Repair solution: delete the corresponding while, for, foreach loop body or add appropriate logical processing or comment explaining why it is empty", + "yes_example": "Examples of being judged as 'Loop body cannot be empty'\n\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t}\n}\nThis code's for loop body is empty, so it is judged as 'Loop body cannot be empty'\n\n\n\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t// this is a code example\n\t}\n}\nThe while loop body in this code is not empty, but the content is just a code comment with no logical content, so it is judged as 'Loop body cannot be empty'\n\n\n\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t\n\t}\n}\nThe while loop body in this code is empty, so it is judged as 'Loop body cannot be empty'\n", + "no_example": "Example that cannot be judged as 'loop body cannot be empty'\n\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t\ta = 1;\n\t\tif (a == 1) {\n\t\t\tretrun a;\n\t\t}\n\t}\n}\nThe content of the for loop in the above code is not empty, and the content is not entirely code comments, so this cannot be judged as 'loop body cannot be empty'\n" + }, + { + "id": 6, + "text": "Avoid using printStackTrace(), and instead use logging to record.", + "language": "Java", + "detail": "Defect type: Avoid using printStackTrace(), should use logging to record; Repair solution: Use logging to record", + "yes_example": "Example of being judged as 'Avoid using printStackTrace(), should use logging to record'", + "no_example": "### Example that cannot be judged as 'avoid using printStackTrace(), should use logging to record'\n\npublic void usePrintStackTrace() {\n\ttry {\n\t\tthrow new Exception(\"Fake exception\");\n\t} catch (Exception e) {\n\t\tlogging.info(\"info\");\n\t}\n}\nThis code uses logging in the catch statement, so it cannot be judged as 'avoid using printStackTrace(), should use logging to record'\n" + }, + { + "id": 7, + "text": "The catch block cannot be empty", + "language": "Java", + "detail": "Defect type: catch block cannot be empty; Corresponding Fixer: EmptyCatchBlockFixer; Fix solution: Add a comment inside the catch block", + "yes_example": "Examples of being judged as 'catch block cannot be empty'\n\n\n\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n \n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\n\ntry {\n String str = null;\n str.length();\n} catch (NullPointerException e) {\n \n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\npublic class EmptyCatchExample {\n public static void main(String[] args) {\n try {\n // Attempt to divide by zero to trigger an exception\n int result = 10 / 0;\n } catch (ArithmeticException e) {\n \n }\n }\n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\n\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n \n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\n\ntry {\n Object obj = \"string\";\n Integer num = (Integer) obj;\n} catch (ClassCastException e) {\n\t\n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n", + "no_example": "Examples that cannot be judged as 'catch block cannot be empty'\n\npersionNum = 1\ntry {\n\treturn True;\n} catch (Exception e) {\n\t// If the number of people is 1, return false\n\tif (persionNum == 1){\n\t\treturn False;\n\t}\n}This catch statement is not empty, so it cannot be judged as 'catch block cannot be empty'\n\n\n\ntry {\n\tthrow new Exception(\"Fake exception\");\n} catch (Exception e) {\n\te.printStackTrace();\n}Although this catch statement only has 'e.printStackTrace();', it is indeed not empty, so it cannot be judged as 'catch block cannot be empty'\n" + }, + { + "id": 8, + "text": "Avoid unnecessary tautologies/contradictions", + "language": "Java", + "detail": "Defect type: Avoid unnecessary true/false judgments; Corresponding Fixer: UnconditionalIfStatement Fixer; Fixing solution: Delete true/false judgment logic", + "yes_example": "Examples of being judged as 'avoiding unnecessary always true/always false judgments'", + "no_example": "Examples that cannot be judged as 'avoiding unnecessary always true/always false judgments'" + }, + { + "id": 9, + "text": "In a switch statement, default must be placed at the end", + "language": "Java", + "detail": "Defect type: The default in switch must be placed at the end; Corresponding Fixer: DefaultLabelNotLastInSwitchStmtFixer; Fixing solution: Place default at the end in switch", + "yes_example": "Example of being judged as 'default in switch must be placed at the end'", + "no_example": "Example that cannot be judged as 'the default in switch must be placed at the end'" + }, + { + "id": 10, + "text": "Comparison of String without using equals() function", + "language": "Java", + "detail": "Defect type: Not using the equals() function to compare Strings; Corresponding Fixer: UnSynStaticDateFormatter Fixer; Fix solution: Use the equals() function to compare Strings", + "yes_example": "Examples of being judged as 'not using the equals() function to compare Strings'\n\n\nif (existingPet != null && existingPet.getName() == petName) {\n result.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}\nIn this code, both existingPet.getName() and petName are strings, but the comparison in the if statement uses == instead of equals() to compare the strings, so this is judged as 'not using the equals() function to compare Strings'.\n\n\n\nString isOk = \"ok\";\nif (\"ok\" == isOk) {\n result.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}\nIn this code, isOk is a string, but in the if statement, it is compared with \"ok\" using ==, not using equals() to compare the strings, it should use \"ok\".equals(isOk), so this is judged as 'not using the equals() function to compare Strings'.\n\n\n\nString str1 = \"Hello\";\nString str2 = \"Hello\";\nif (str1 == str2) {\n System.out.println(\"str1 and str2 reference the same object\");\n} else {\n System.out.println(\"str1 and str2 reference different objects\");\n}\nIn this code, if (str1 == str2) uses == to compare str1 and str2, not using equals() to compare the strings, it should use str1.equals(str2), so this is judged as 'not using the equals() function to compare Strings'.\n\n\n\nString str = \"This is string\";\nif (str == \"This is not str\") {\n return str;\n}\nIn this code, if (str == \"This is not str\") uses == to compare the strings, not using equals() to compare the strings, it should use \"This is not str\".equals(str), so this is judged as 'not using the equals() function to compare Strings'.\n", + "no_example": "Examples that cannot be judged as 'not using the equals() function to compare Strings'\n\n\nif (PROPERTY_VALUE_YES.equalsIgnoreCase(readWriteReqNode))\n formProperty.setRequired(true);\nIn this code, both PROPERTY_VALUE_YES and readWriteReqNode are strings. The comparison between PROPERTY_VALUE_YES and readWriteReqNode in the if statement uses equalsIgnoreCase (case-insensitive string comparison), which is also in line with using the equals() function to compare Strings. Therefore, this cannot be judged as 'not using the equals() function to compare Strings'\n\n\n\nString isOk = \"ok\";\nif (\"ok\".equals(isOk)) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}In this code, isOk is a string. In the if statement, the comparison with \"ok\" uses the equals() function to compare Strings, so this cannot be judged as 'not using the equals() function to compare Strings'\n" + }, + { + "id": 11, + "text": "Prohibit the direct use of string output for exceptions in logs, please use placeholders to pass the exception object", + "language": "Java", + "detail": "Defect type: Do not directly output exceptions as strings in logs, use placeholders to pass the exception object; Corresponding Fixer: ConcatExceptionFixer; Fix solution: Use placeholders to pass the exception object", + "yes_example": "Example of being judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'\n\ntry {\n listenersNode = objectMapper.readTree(listenersNode.asText());\n} catch (Exception e) {\n LOGGER.info(\"Listeners node can not be read\", e);\n}In this code, the log output content is directly concatenated using the string \"Listeners node can not be read\". When outputting exceptions in logs, placeholders should be used to output exception information, rather than directly concatenating strings. Therefore, this is judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'.\n", + "no_example": "Examples that cannot be judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects':\n\n\nPerson person = personService.getPerson(1);\nif (person == null) {\n LOGGER.error(PERSION_NOT_EXIT);\n}\nIn this code, PERSION_NOT_EXIT is a user-defined exception constant representing that the person does not exist, and it does not directly use the string 'person not exit' for concatenation, so this cannot be judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'.\n\n\n\ntry {\n a = a + 1;\n} catch (Exception e) {\n Person person = personService.getPerson(1);\n LOGGER.info(person);\n}\nIn this code, the log output does not directly use string concatenation, but rather uses the Person object for output, so this cannot be judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'.\n" + }, + { + "id": 12, + "text": "The finally block cannot be empty", + "language": "Java", + "detail": "Defect type: finally block cannot be empty; Corresponding Fixer: EmptyFinallyBlockFixer; Fix solution: Delete the empty finally block", + "yes_example": "Examples of being judged as 'finally block cannot be empty'\n\n\n\ntry {\n Persion persion = persionService.getPersion(1);\n return persion;\n} finally {\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n System.out.println(\"Inside try block\");\n} finally {\n // Empty finally block with no statements, this is a defect\n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n int result = 10 / 0;\n} catch (ArithmeticException e) {\n e.printStackTrace();\n} finally {\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n String str = null;\n System.out.println(str.length());\n} catch (NullPointerException e) {\n e.printStackTrace();\n} finally {\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n e.printStackTrace();\n} finally {\n // Finally block with only comments\n // This is an empty finally block\n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n e.printStackTrace();\n} finally {\n // Finally block with only empty lines\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n", + "no_example": "Example that cannot be judged as 'finally block cannot be empty'\n\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){\n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\treturn null;\n\t}\n}\nThis code's finally block contains non-comment content 'return null;', so this cannot be judged as 'finally block cannot be empty'\n" + }, + { + "id": 13, + "text": "try block cannot be empty", + "language": "Java", + "detail": "Defect type: try block cannot be empty; Corresponding Fixer: EmptyTryBlockFixer; Fix solution: Delete the entire try statement", + "yes_example": "Examples of being judged as 'try block cannot be empty'\n\npublic void getPersion() {\n\ttry {\n\n\t}\n\treturn null;\n}This code's try block is empty, so it is judged as 'try block cannot be empty'\n\n\n\npublic void demoFinallyBlock() {\n\ttry {\n\n\t} finally {\n\t\treturn null;\n\t}\n}This code's try block is empty, so it is judged as 'try block cannot be empty'\n\n\n\ntry {\n \n} catch (Exception e) {\n e.printStackTrace();\n}This code's try block is empty, so it is judged as 'try block cannot be empty'\n\n\n\ntry {\n // try block with only comments\n\t\n} catch (Exception e) {\n e.printStackTrace();\n}This code's try block contains only comments and blank lines, which can also be considered as having no content in the try block, so it is judged as 'try block cannot be empty'\n", + "no_example": "### Example that cannot be judged as 'try block cannot be empty'\n\ntry {\n\ta = a + 1;\n} catch (Exception e) {\n\te.printStackTrace();\n}\nThis code snippet contains non-comment content 'return null;' in the try block, so it cannot be judged as 'try block cannot be empty'\n" + }, + { + "id": 14, + "text": "Avoid unnecessary NULL or null checks on objects", + "language": "Java", + "detail": "Defect type: Avoid unnecessary NULL or null checks on objects; Corresponding Fixer: LogicalOpNpeFixer; Fix solution: Remove the logic of unnecessary NULL checks on objects", + "yes_example": "Examples of being judged as 'avoiding unnecessary NULL or null checks':", + "no_example": "Example that cannot be judged as 'avoiding unnecessary NULL or null checks'\n\nCat cat = catService.get(1);\nif (cat != null){\n\tretrun cat;\n}In this code, the object 'cat' is obtained through the service and it is uncertain whether it is null or not, so the condition 'cat != null' in the if statement is necessary, therefore this cannot be judged as 'avoiding unnecessary NULL or null checks'\n" + }, + { + "id": 15, + "text": "Avoid return in finally block", + "language": "Java", + "detail": "Defect type: Avoid return in finally block; Repair solution: No need for repair", + "yes_example": "Example judged as 'avoid return in finally block'", + "no_example": "Example that cannot be judged as 'avoiding return in finally block'\n\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\tLOGGER.info(PERSION_NOT_EXIT);\n\t}\n}\nThis code's finally block does not contain 'return', so it cannot be judged as 'avoiding return in finally block'\n" + }, + { + "id": 16, + "text": "Avoid empty static initialization", + "language": "Java", + "detail": "Defect type: Avoid empty static initialization; Corresponding Fixer: EmptyInitializerFixer; Fix solution: Delete the entire empty initialization block", + "yes_example": "Examples of being judged as 'Avoid empty static initialization'", + "no_example": "Example that cannot be judged as 'avoiding empty static initialization'\n\npublic class Cat {\n\tstatic {\n\t\t// Static initialization block\n\t\tcat = null;\n\t}\n}\nThis code has a static block with content, not empty, and the static initialization block contains non-commented code with actual logic, so this cannot be judged as 'avoiding empty static initialization'\n" + }, + { + "id": 17, + "text": "Avoid risks of improper use of calendar", + "language": "Java", + "detail": "Defect type: Avoid improper usage risks of calendar classes; Fix solution: Use LocalDate from the java.time package in Java 8 and above", + "yes_example": "Examples of being judged as 'avoiding improper use of calendar class risks'\n\nprivate static final Calendar calendar = new GregorianCalendar(2020, Calendar.JANUARY, 1);\nThe Calendar and GregorianCalendar in this code are not thread-safe, so this is judged as 'avoiding improper use of calendar class risks'\n", + "no_example": "Examples that cannot be judged as 'avoiding improper use of calendar class risks'" + }, + { + "id": 18, + "text": "To convert a collection to an array, you must use the toArray(T[] array) method of the collection, passing in an array of the exact same type, with a size equal to list.size()", + "language": "Java", + "detail": "Defect type: When converting a collection to an array, you must use the toArray(T[] array) method of the collection, passing an array of the exact same type, with a size equal to list.size(); Corresponding Fixer: ClassCastExpWithToArrayFixer; Repair solution: Use the toArray(T[] array) method of the collection, and pass an array of the exact same type", + "yes_example": "Example judged as 'When converting a collection to an array, you must use the collection's toArray(T[] array) method, passing an array of exactly the same type, with the size being list.size()'", + "no_example": "Example that cannot be judged as 'using the method of converting a collection to an array, you must use the toArray(T[] array) of the collection, passing in an array of exactly the same type, and the size is list.size()':" + }, + { + "id": 19, + "text": "Prohibit the use of NULL or null for comparison in equals()", + "language": "Java", + "detail": "Defect type: Prohibit using NULL or null for comparison in equals(); Corresponding Fixer: EqualsNullFixer; Fixing solution: Use Object's null check function for comparison", + "yes_example": "Examples of being judged as 'Prohibited to use NULL or null for comparison in equals()'", + "no_example": "Examples that cannot be judged as 'prohibiting the use of NULL or null for comparison in equals()'" + }, + { + "id": 20, + "text": "switch statement block cannot be empty", + "language": "Java", + "detail": "Defect type: switch statement block cannot be empty; Corresponding Fixer: EmptySwitchStatementsFix; Fix solution: Delete the entire empty switch statement block", + "yes_example": "Examples of being judged as 'switch statement block cannot be empty'\n\nswitch (number) {\n \n}This code is a switch statement block, but it contains no content, so it is judged as 'switch statement block cannot be empty'\n\n\n\nswitch (number) {\n // This is a switch statement block\n}This code is a switch statement block, which contains content, but the content is only comments without actual logic, so it is judged as 'switch statement block cannot be empty'\n", + "no_example": "Example that cannot be judged as 'switch statement block cannot be empty'\n\nswitch (number) {\n\tcase 1:\n\t\tSystem.out.println(\"Number one\");\n\t\tbreak;\n\tdefault:\n\t\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\t\tbreak;\n}\nThis code is a switch statement block that contains content, and the content includes non-commented code with actual logic, so it cannot be judged as 'switch statement block cannot be empty'.\n" + }, + { + "id": 21, + "text": "When performing type coercion, no spaces are needed between the right parenthesis and the coercion value.", + "detail": "Defect type: When performing type coercion, no space is required between the right parenthesis and the coercion value; Fix solution: When performing type coercion, no space is required between the right parenthesis and the coercion value.", + "language": "Java", + "yes_example": "Examples judged as 'When performing type casting, no space is needed between the closing parenthesis and the cast value'", + "no_example": "Examples that cannot be judged as 'When performing type coercion, no spaces are required between the right parenthesis and the coercion value'" + }, + { + "id": 22, + "text": "Method parameters must have a space after the comma when defined and passed", + "detail": "Defect type: In the definition and passing of method parameters, a space must be added after the comma for multiple parameters; Repair solution: In the definition and passing of method parameters, a space must be added after the comma for multiple parameters.", + "language": "Java", + "yes_example": "Example of being judged as 'Method parameters must have a space after the comma when defined and passed'", + "no_example": "Examples that cannot be judged as 'Method parameters must have a space after the comma both in definition and when passed'" + }, + { + "id": 23, + "text": "Prohibit the use of the BigDecimal(double) constructor to convert a double value to a BigDecimal object", + "detail": "Defect type: Do not use the constructor BigDecimal(double) to convert a double value to a BigDecimal object; Repair solution: It is recommended to use the valueOf method of BigDecimal.", + "language": "Java", + "yes_example": "Example of being judged as 'Prohibited to use the constructor BigDecimal(double) to convert a double value to a BigDecimal object'", + "no_example": "Examples that cannot be considered as 'prohibiting the use of the BigDecimal(double) constructor to convert a double value to a BigDecimal object'" + }, + { + "id": 24, + "text": "No extra semicolons allowed", + "detail": "Defect type: extra semicolon; Fix solution: remove extra semicolon", + "yes_example": "Example of being judged as 'cannot have extra semicolons'", + "no_example": "Examples that cannot be judged as 'cannot have extra semicolons'\n\nwhile (True) {\n\ta = a + 1;\n\tbreak;\n}This code requires every semicolon, so it can be judged as 'cannot have extra semicolons'\n" + }, + { + "id": 25, + "text": "Non-thread-safe SimpleDateFormat usage must be synchronized at the function or code block level", + "detail": "Defect type: Non-thread-safe SimpleDateFormat usage; Fix solution: Add synchronized modifier at the function or code block level or use other thread-safe methods", + "yes_example": "Example of 'Non-thread-safe SimpleDateFormat usage, must be used with synchronized at the function or block level'", + "no_example": "Example that cannot be judged as 'Unsafe use of SimpleDateFormat, which must be used at the function or code block level with synchronized':\n\npublic synchronized void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}\nThis code is protected by a synchronized block on the function 'formatDate', ensuring thread safety, so it cannot be judged as 'Unsafe use of SimpleDateFormat, which must be used at the function or code block level with synchronized'.\n" + }, + { + "id": 26, + "text": "Naming does not follow the camel case specification. Class names should use UpperCamelCase style, while method names, parameter names, member variables, and local variables should all use lowerCamelCase style.", + "detail": "Defect type: Not following camel case naming convention; Fix solution: Class names should use UpperCamelCase style, method names, parameter names, member variables, and local variables should use lowerCamelCase style.", + "language": "Java", + "yes_example": "Examples of being judged as 'not following the camel case naming convention'\n\npublic class myClass {\n private int MyVariable;\n public void MyMethod() {}\n}\nThis code does not follow the camel case naming convention for class names, member variables, and method names, so it is judged as a naming convention issue.\n", + "no_example": "Examples that cannot be judged as 'not following the camel case naming convention'\n\npublic class MyClass {\n private int myVariable;\n public void myMethod() {}\n}\nThe class name, member variable, and method name in this code all follow the camel case naming convention, so it cannot be judged as a naming convention issue.\n" + }, + { + "id": 27, + "text": "Abstract class names start with Abstract or Base; exception class names end with Exception; test class names begin with the name of the class they are testing and end with Test", + "detail": "Defect type: Naming convention; Solution: Abstract class names should start with Abstract or Base, exception class names should end with Exception, and test class names should start with the name of the class they are testing and end with Test.", + "language": "Java", + "yes_example": "Examples of being judged as 'naming conventions'\n\npublic class MyAbstractClass {}\npublic class MyExceptionClass {}\npublic class TestMyClass {}\nThe naming of the abstract class, exception class, and test class in this code does not conform to the conventions, so it is judged as a naming convention issue.\n", + "no_example": "Examples that cannot be judged as 'naming conventions'" + }, + { + "id": 28, + "text": "Avoid adding the 'is' prefix to any boolean type variables in POJO classes", + "detail": "Defect type: Naming convention; Fix solution: Do not prefix boolean variables in POJO classes with 'is'.", + "language": "Java", + "yes_example": "Examples of being judged as 'naming convention' issues\n\npublic class User {\n private boolean isActive;\n}\nIn this code, the boolean type variable has the 'is' prefix, so it is judged as a naming convention issue.\n", + "no_example": "Examples that cannot be judged as 'naming conventions'" + }, + { + "id": 29, + "text": "Eliminate completely non-standard English abbreviations to avoid confusion when interpreting them.", + "detail": "Defect type: Naming conventions; Solution: Avoid using non-standard English abbreviations to ensure code readability.", + "language": "Java", + "yes_example": "Examples of being judged as 'naming conventions'\n\npublic class CfgMgr {\n private int cnt;\n}\nIn this code, the class name and variable name use non-standard English abbreviations, so they are judged as naming convention issues.\n", + "no_example": "Examples that cannot be judged as 'naming conventions'" + }, + { + "id": 30, + "text": "Avoid using magic characters and numbers, they should be declared as constants", + "detail": "Defect type: Avoid using magic characters and numbers, they should be declared as constants; Fix solution: Define magic values as constants.", + "language": "Java", + "yes_example": "Examples of being judged as 'avoiding magic characters and numbers, should be declared as constants'", + "no_example": "Examples that cannot be judged as 'avoiding magic characters and numbers, should be declared as constants'" + }, + { + "id": 31, + "text": "When assigning values to long or Long, use uppercase L after the number, not lowercase l. The suffix for floating-point numbers should be uppercase D or F.", + "detail": "Defect type: Code specification; Repair solution: Use uppercase L when assigning values to long or Long, and use uppercase D or F as suffixes for floating-point type values.", + "language": "Java", + "yes_example": "Examples of being judged as 'code specification'", + "no_example": "Examples that cannot be judged as 'code specification'" + }, + { + "id": 32, + "text": "If the curly braces are empty, simply write {} without line breaks or spaces inside the braces; if it is a non-empty code block, then: 1) Do not line break before the left curly brace. 2) Line break after the left curly brace. 3) Line break before the right curly brace. 4) Do not line break after the right curly brace if there is code like 'else' following it; the right curly brace indicating termination must be followed by a line break.", + "detail": "Defect type: code formatting; Fix solution: follow the curly brace usage standard.", + "language": "Java", + "yes_example": "Example of being judged as 'code format'", + "no_example": "Examples that cannot be judged as 'code format' issues\n\npublic class BracketExample {\n public void method() {\n if (true) {\n // do something\n }\n }\n}\nThe use of curly braces in this code is in accordance with the standards, so it cannot be judged as a code format issue.\n" + }, + { + "id": 33, + "text": "No space is needed between the left parenthesis and the adjacent character; no space is needed between the right parenthesis and the adjacent character; and a space is required before the left brace.", + "detail": "Defect type: code formatting; Fix solution: follow the usage rules for brackets and spaces.", + "language": "Java", + "yes_example": "Example of being judged as 'code format'\n\npublic class SpaceExample {\n public void method (){\n }\n}\nThe use of brackets and spaces in this code does not conform to the standard, so it is judged as a code format issue.\n", + "no_example": "Examples that cannot be judged as 'code specification'\n\npublic class SpaceExample {\n public void method() {}\n}\nThis code uses brackets and spaces in accordance with the specification, so it cannot be judged as a code format issue.\n" + }, + { + "id": 34, + "text": "Reserved words such as if / for / while / switch / do must be separated from the parentheses on both sides by spaces.", + "detail": "Defect type: code format; Fix solution: add spaces between reserved words and parentheses.", + "language": "Java", + "yes_example": "Example of being judged as 'code specification'\n\npublic class KeywordExample {\n public void method() {\n if(true) {\n }\n }\n}\nIn this code, there is no space between the if keyword and the parentheses, so it is judged as a code formatting issue.\n", + "no_example": "Examples that cannot be judged as 'code specification'" + }, + { + "id": 35, + "text": "All value comparisons between integer wrapper class objects should be done using the equals method", + "detail": "Defect type: Code specification; Repair solution: Use the equals method for value comparison between integer wrapper class objects.", + "language": "Java", + "yes_example": "Examples of being judged as 'code specification'", + "no_example": "### Example that cannot be judged as 'code specification'\n\npublic class IntegerComparison {\n public void compare() {\n Integer a = 100;\n Integer b = 100;\n if (a.equals(b)) {\n }\n }\n}\nIn this code, the equals method is used to compare integer wrapper class objects, so it cannot be judged as a code specification issue.\n" + }, + { + "id": 36, + "text": "For comparing BigDecimal values, the compareTo() method should be used instead of the equals() method.", + "detail": "Defect type: The equality comparison of BigDecimal should use the compareTo() method instead of the equals() method; Fix solution: Use the compareTo() method for comparison.", + "language": "Java", + "yes_example": "Example of being judged as 'For BigDecimal equality comparison, the compareTo() method should be used instead of the equals() method'\n\nBigDecimal a = new BigDecimal(\"1.0\");\nBigDecimal b = new BigDecimal(\"1.00\");\nif (a.equals(b)) {\n // This code will return false because the equals() method compares precision\n}\n", + "no_example": "Examples that cannot be judged as 'For BigDecimal equality comparison, the compareTo() method should be used instead of the equals() method'" + }, + { + "id": 37, + "text": "Prohibit having both isXxx() and getXxx() methods for the same attribute xxx in a POJO class.", + "detail": "Defect type: Duplicate getter methods in POJO class; Fix solution: Ensure only one getter method exists.", + "language": "Java", + "yes_example": "Example of being judged as 'Prohibited to have both isXxx() and getXxx() methods for the corresponding attribute xxx in a POJO class'", + "no_example": "Examples that cannot be judged as 'Prohibiting the existence of both isXxx() and getXxx() methods for the corresponding attribute xxx in a POJO class'" + }, + { + "id": 38, + "text": "When formatting dates, use the lowercase 'y' uniformly to represent the year in the pattern.", + "detail": "Defect type: date formatting error; Fix solution: use lowercase y to represent the year.", + "language": "Java", + "yes_example": "Example judged as 'When formatting dates, use lowercase y for the year in the pattern'", + "no_example": "Examples that cannot be judged as 'When formatting dates, use lowercase y for the year in the pattern'" + }, + { + "id": 39, + "text": "Prohibited from using in any part of the program: 1) java.sql.Date 2) java.sql.Time 3) java.sql.Timestamp.", + "detail": "Defect type: used date classes from the java.sql package; Fix solution: use date classes from the java.time package.", + "language": "Java", + "yes_example": "Examples of being judged as \"Prohibited from using in any part of the program: 1) java.sql.Date 2) java.sql.Time 3) java.sql.Timestamp\"", + "no_example": "Examples that cannot be judged as 'Prohibited to use in any part of the program: 1) java.sql.Date 2) java.sql.Time 3) java.sql.Timestamp'" + }, + { + "id": 40, + "text": "Determine if all elements within a collection are empty using the isEmpty() method, rather than using the size() == 0 approach.", + "detail": "Defect type: Incorrect method for checking empty collection; Fix solution: Use isEmpty() method.", + "language": "Java", + "yes_example": "Example of being judged as 'To determine if all elements within a collection are empty, use the isEmpty() method instead of the size() == 0 approach'\n\nList list = new ArrayList<>();\nif (list.size() == 0) {\n // Empty logic\n}\n", + "no_example": "Examples that cannot be considered as 'judging whether all elements within a set are empty using the isEmpty() method instead of the size() == 0 approach'" + }, + { + "id": 41, + "text": "Whenever you override equals, you must also override hashCode.", + "detail": "Defect type: hashCode method not overridden; Fix solution: Override both equals and hashCode methods.", + "language": "Java", + "yes_example": "An example where it is judged that 'if you override equals, you must also override hashCode'", + "no_example": "An example where it cannot be judged as 'Whenever you override equals, you must also override hashCode'" + }, + { + "id": 42, + "text": "When using the Map methods keySet() / values() / entrySet() to return a collection object, you cannot perform element addition operations on it, otherwise a UnsupportedOperationException will be thrown.", + "detail": "Defect type: Adding operations to the collections returned by keySet() / values() / entrySet() of a Map; Repair solution: Avoid adding operations to these collections.", + "language": "Java", + "yes_example": "Example of being judged as 'When using the Map methods keySet() / values() / entrySet() to return a collection object, you cannot perform element addition operations on it, otherwise a UnsupportedOperationException exception will be thrown'", + "no_example": "Example that cannot be judged as 'When using the methods keySet() / values() / entrySet() of Map to return a collection object, it is not allowed to perform element addition operations on it, otherwise a UnsupportedOperationException will be thrown'" + }, + { + "id": 43, + "text": "Do not perform element removal / addition operations within a foreach loop. Use the iterator method for removing elements. If concurrent operations are required, the iterator must be synchronized.", + "detail": "Defect type: performing remove / add operations on elements within a foreach loop; Repair solution: use iterator to perform remove operations on elements.", + "language": "Java", + "yes_example": "Example of being judged as 'Do not perform element remove / add operations within a foreach loop. Use the iterator method for removing elements; if concurrent operations are required, the iterator must be synchronized.'", + "no_example": "Example that cannot be judged as 'Do not perform element remove / add operations inside a foreach loop. Use the iterator method for removing elements. If concurrent operations are required, the iterator should be synchronized.'\n\nList list = new ArrayList<>(Arrays.asList(\"a\", \"b\", \"c\"));\nIterator iterator = list.iterator();\nwhile (iterator.hasNext()) {\n String s = iterator.next();\n if (s.equals(\"a\")) {\n iterator.remove();\n }\n}\n" + }, + { + "id": 44, + "text": "Class, class attributes, and class methods must use Javadoc specifications for comments, using the format /** content */, and must not use the // xxx format.", + "detail": "Defect type: Comments do not conform to Javadoc standards; Solution: Use Javadoc-compliant comment format.", + "language": "Java", + "yes_example": "Examples of being judged as 'class, class attribute, class method annotations must use Javadoc specification, using the format /** content */, not using the // xxx method'", + "no_example": "Examples that cannot be judged as 'Class, class attribute, and class method comments must follow the Javadoc specification, using the /** content */ format, not the // xxx format'" + }, + { + "id": 45, + "text": "All abstract methods (including methods in interfaces) must be annotated with Javadoc comments", + "detail": "Defect type: All abstract methods (including methods in interfaces) must be annotated with Javadoc; Repair solution: Add Javadoc comments to all abstract methods (including methods in interfaces), in addition to the return value, parameter exception description, it must also indicate what the method does and what function it implements.", + "language": "Java", + "yes_example": "Example of being judged as 'All abstract methods (including methods in interfaces) must be annotated with Javadoc'", + "no_example": "Example that cannot be judged as 'all abstract methods (including methods in interfaces) must be annotated with Javadoc comments'" + }, + { + "id": 46, + "text": "Usage guidelines for single-line and multi-line comments within methods", + "detail": "Defect type: Improper use of comments; Repair solution: Single-line comments inside the method, start a new line above the commented statement, use // for comments. Multi-line comments inside the method use /* */ comments, and pay attention to aligning with the code.", + "language": "Java", + "yes_example": "### Examples of being judged as 'Improper Use of Comments'\n\npublic void exampleMethod() {\n int a = 1; // Initialize variable a\n int b = 2; /* Initialize variable b */\n}\nThe single-line and multi-line comments in this code are not used according to the standard, so they are judged as improper use of comments.\n", + "no_example": "Examples that cannot be judged as 'improper use of comments'\n\npublic void exampleMethod() {\n // Initialize variable a\n int a = 1;\n /*\n * Initialize variable b\n */\n int b = 2;\n}\nThis code uses single-line and multi-line comments according to the standard, so it cannot be judged as improper use of comments.\n" + }, + { + "id": 47, + "text": "All enumeration type fields must have comments", + "detail": "Defect type: Enumeration type field lacks comments; Fix plan: Add comments to all enumeration type fields to explain the purpose of each data item.", + "language": "Java", + "yes_example": "Example of being judged as 'Enumeration type field lacks comments'\n\npublic enum Status {\n ACTIVE,\n INACTIVE\n}\nThe enumeration type fields in this code are not commented, so they are judged as lacking comments for enumeration type fields.\n", + "no_example": "Examples that cannot be judged as 'missing comments for enum fields'\n\npublic enum Status {\n /**\n * Active status\n */\n ACTIVE,\n /**\n * Inactive status\n */\n INACTIVE\n}\nThis code has comments for the enum fields, so it cannot be judged as missing comments for enum fields.\n" + }, + { + "id": 48, + "text": "The finally block must close resource objects and stream objects.", + "detail": "Defect type: resource objects and stream objects are not closed in the finally block; Fix solution: Close resource objects and stream objects in the finally block, and use try-catch for exceptions.", + "language": "Java", + "yes_example": "Example of being judged as 'resource object, stream object not closed in finally block'", + "no_example": "Examples that cannot be judged as 'resource objects, stream objects not closed in the finally block'" + }, + { + "id": 49, + "text": "Constant names should be in all uppercase, with words separated by underscores.", + "detail": "Defect type: Constant naming is not standardized; Fix solution: Constant names should be all uppercase, words separated by underscores, and strive for complete and clear semantic expression, do not be afraid of long names.", + "language": "Java", + "yes_example": "Examples of being judged as 'Constant names should be in all uppercase, with words separated by underscores'", + "no_example": "Examples that cannot be judged as 'constant names should be all uppercase, with words separated by underscores'" + }, + { + "id": 50, + "text": "Spaces are required on both sides of any binary or ternary operator.", + "detail": "Defect type: Lack of space around operators; Fix solution: Any binary or ternary operator should have a space on both sides.", + "language": "Java", + "yes_example": "Examples of being judged as 'Any binary or ternary operator must have spaces on both sides'", + "no_example": "Examples that cannot be judged as 'any binary, ternary operator needs a space on both sides'" + }, + { + "id": 51, + "text": "Avoid using from import *", + "detail": "Defect type: Avoid using 'from import *', importing everything can cause naming conflicts; Solution: Each sub-dependency used should be imported separately.", + "language": "Python", + "yes_example": "Example of being judged as 'avoid using from import *'", + "no_example": "Examples that cannot be judged as 'avoid using from import *'" + }, + { + "id": 52, + "text": "Avoid using the __import__() function to dynamically import modules", + "detail": "Defect type: Avoid using __import__() function to dynamically import modules; Repair solution: Use standard import statements.", + "language": "Python", + "yes_example": "Example of being judged as 'dynamically importing modules using the __import__() function'", + "no_example": "Examples that cannot be judged as 'dynamically importing modules using the __import__() function'" + }, + { + "id": 53, + "text": "Import statements are not grouped in the order of standard library imports, related third-party imports, and local application/library specific imports.", + "detail": "Defect type: Import statements are not grouped in the order of standard library imports, related third-party imports, and local application/library specific imports; Solution: Group import statements in order.", + "language": "Python", + "yes_example": "Examples of being judged as 'import statements not grouped in the order of standard library imports, related third-party imports, and local application/library specific imports'", + "no_example": "Example that cannot be judged as 'import statements not grouped in the order of standard library imports, related third-party imports, local application/library specific imports'" + }, + { + "id": 54, + "text": "Avoid unused function parameters", + "detail": "Defect type: Avoid unused function parameters; Fix solution: Remove unused function parameters.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid unused function parameters'", + "no_example": "Examples that cannot be judged as 'avoiding unused function parameters'" + }, + { + "id": 55, + "text": "Use is not None to check if a variable is not None", + "detail": "Defect type: Not using 'is not None' to check if a variable is not None; Fix solution: Use 'is not None' to check.", + "language": "Python", + "yes_example": "Example of being judged as 'not using is not None to check if a variable is not None'", + "no_example": "Examples that cannot be judged as 'not using is not None to check if a variable is not None'" + }, + { + "id": 56, + "text": "Avoid using == or != to compare the equivalence of object instances", + "detail": "Defect type: Using == or != to compare object instances for equivalence; Fix solution: Should use equals for comparison.", + "language": "Python", + "yes_example": "Example of being judged as 'using == or != to compare the equivalence of object instances'", + "no_example": "Examples that cannot be judged as 'using == or != to compare the equivalence of object instances'" + }, + { + "id": 57, + "text": "Avoid using single-letter variable names, use descriptive variable names", + "detail": "Defect type: Avoid using single-letter variable names, use descriptive variable names; Fix solution: Use descriptive variable names.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid using single-letter variable names, use descriptive variable names'", + "no_example": "Examples that cannot be judged as 'avoid using single-letter variable names, use descriptive variable names'" + }, + { + "id": 58, + "text": "Constant names use all uppercase letters and separate words with underscores", + "detail": "Defect type: Constant naming does not use all uppercase letters or does not use underscores to separate; Repair solution: Use all uppercase letters for constant naming and separate with underscores.", + "language": "Python", + "yes_example": "Example of being judged as 'Constant naming not using all uppercase letters and separated by underscores'", + "no_example": "Examples that cannot be judged as 'constant naming not using all uppercase letters and separated by underscores'" + }, + { + "id": 59, + "text": "Class names should use camel case (CamelCase)", + "detail": "Defect type: Class name not using camel case; Repair solution: Use camel case for class names.", + "language": "Python", + "yes_example": "Examples of being judged as 'class name not using CamelCase'", + "no_example": "Examples that cannot be judged as 'class name not using CamelCase'" + }, + { + "id": 60, + "text": "Try to use the with statement to manage resources as much as possible", + "detail": "Defect type: Not using the with statement to manage resources; Fix solution: Use the with statement to manage resources.", + "language": "Python", + "yes_example": "Example of being judged as 'not using the with statement to manage resources'", + "no_example": "Examples that cannot be judged as 'not using the with statement to manage resources'" + }, + { + "id": 61, + "text": "Avoid using except or generic Exception to catch all exceptions, specify the exception type instead.", + "detail": "Defect type: catch all exceptions; Fix solution: specify specific exception types.", + "language": "Python", + "yes_example": "Examples judged as 'catching all exceptions using except:' and 'throwing a generic Exception exception'", + "no_example": "Example that cannot be judged as 'using except: to catch all exceptions'" + }, + { + "id": 62, + "text": "Avoid manual string concatenation whenever possible", + "detail": "Defect type: manual string concatenation; Fix solution: use formatted strings or join method.", + "language": "Python", + "yes_example": "Examples of being judged as 'manual string concatenation'", + "no_example": "Examples that cannot be judged as 'manual string concatenation'" + }, + { + "id": 63, + "text": "Avoid using magic characters and numbers, should be declared as constants", + "detail": "Defect type: Using magic characters and numbers; Fix solution: Declare them as constants.", + "language": "Python", + "yes_example": "Examples of being judged as 'having magic characters and numbers'", + "no_example": "Examples that cannot be judged as 'containing magic characters and numbers'" + }, + { + "id": 64, + "text": "Boolean variable judgment does not require explicit comparison", + "detail": "Defect type: explicit comparison of boolean variables; fix solution: directly use boolean variables for judgment.", + "language": "Python", + "yes_example": "Examples of being judged as 'explicit comparison of boolean variables'", + "no_example": "Examples that cannot be judged as 'explicit comparison of boolean variables'" + }, + { + "id": 65, + "text": "Avoid using type() to check object types", + "detail": "Defect type: Avoid using type() to check object type; Fix solution: Use isinstance() function.", + "language": "Python", + "yes_example": "Example of being judged as 'avoid using type() to check object type'", + "no_example": "Examples that cannot be judged as 'avoid using type() to check object type'" + }, + { + "id": 66, + "text": "Avoid using os.system() to call external commands", + "detail": "Defect type: Using os.system() to call external commands; Fix solution: Use the subprocess module.", + "language": "Python", + "yes_example": "Examples of being judged as 'using os.system() to call external commands'\nos.system('ls -l')\nos.system('ls -l')", + "no_example": "Examples that cannot be judged as 'using os.system() to call external commands'" + }, + { + "id": 67, + "text": "Create read-only properties using the @property decorator instead of modifying properties", + "detail": "Defect type: Creating modifiable properties using the @property decorator; Fix solution: Only use the @property decorator to create read-only properties.", + "language": "Python", + "yes_example": "Examples of being judged as 'using the @property decorator to create modifiable attributes'", + "no_example": "Examples that cannot be judged as 'using the @property decorator to create a modifiable attribute'" + }, + { + "id": 68, + "text": "When using indexing or slicing, do not add spaces inside the brackets or colons.", + "detail": "Defect type: adding spaces inside brackets or colons for indexing or slicing; Repair solution: remove spaces inside brackets or colons.", + "language": "Python", + "yes_example": "Examples judged as 'using spaces inside brackets or colons when using indexing or slicing'", + "no_example": "Examples that cannot be judged as 'adding spaces inside brackets or colons when using indexes or slices'" + }, + { + "id": 69, + "text": "Do not add a space before a comma, semicolon, or colon, but add a space after them", + "detail": "Defect type: adding a space before a comma, semicolon, or colon, or not adding a space after them; Fix solution: do not add a space before a comma, semicolon, or colon, but add a space after them.", + "language": "Python", + "yes_example": "Examples judged as 'adding a space before a comma, semicolon, or colon, or not adding a space after them'", + "no_example": "Examples that cannot be judged as 'adding a space before a comma, semicolon, or colon, or not adding a space after them'" + }, + { + "id": 70, + "text": "For binary operators, there should be spaces on both sides", + "detail": "Defect type: no spaces around binary operators; Fix solution: add spaces around binary operators", + "language": "Python", + "yes_example": "Example of being judged as 'no space around binary operator'", + "no_example": "Examples that cannot be judged as 'no space on both sides of the binary operator'" + }, + { + "id": 71, + "text": "Avoid using Python keywords as variable or function names", + "detail": "Defect type: Using Python keywords as variable names or function names; Repair solution: Use non-keyword names.", + "language": "Python", + "yes_example": "Examples of being judged as 'using Python keywords as variable names or function names'", + "no_example": "Examples that cannot be judged as 'using Python keywords as variable names or function names'\ndef my_function():\n pass\nnumber = 5" + }, + { + "id": 72, + "text": "Avoid using special characters as variable names/method names/class names, such as $ or @", + "detail": "Defect type: Using special characters as variable names/method names/class names; Repair solution: Use legal variable names.", + "language": "Python", + "yes_example": "Examples of being judged as 'using special characters as variable names/method names/class names, such as $ or @'", + "no_example": "Examples that cannot be judged as 'using special characters as variable names/method names/class names, such as $ or @'" + }, + { + "id": 73, + "text": "Avoid using raise to rethrow the current exception, as it will lose the original stack trace.", + "detail": "Defect type: Re-raise the current exception using raise; Fix solution: Use the raise ... from ... syntax.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid using raise to rethrow the current exception, as it will lose the original stack trace'", + "no_example": "Examples that cannot be judged as 'avoid using raise to rethrow the current exception, as it will lose the original stack trace'" + }, + { + "id": 74, + "text": "Avoid using pass in except block, as it will catch and ignore the exception", + "detail": "Defect type: using pass in except block; Fix solution: handle the exception or log the error.", + "language": "Python", + "yes_example": "Examples of being judged as 'using pass in except block'", + "no_example": "Examples that cannot be judged as 'using pass in an except block'" + }, + { + "id": 75, + "text": "Avoid using assert statements to perform important runtime checks", + "detail": "Defect type: Using assert statements for important runtime checks; Fix solution: Use explicit condition checks and exception handling.", + "language": "Python", + "yes_example": "Example of being judged as 'using assert statements to perform important runtime checks'", + "no_example": "Examples that cannot be judged as 'using assert statements to perform important runtime checks'" + }, + { + "id": 76, + "text": "Avoid using eval() and exec(), these functions may bring security risks", + "detail": "Defect type: Use of eval() and exec() functions; Repair solution: Use secure alternatives.", + "language": "Python", + "yes_example": "Examples of being judged as 'using eval() and exec()'\n\n eval('print(1)') \n\n \n exec('a = 1') \n", + "no_example": "Examples that cannot be judged as 'using eval() and exec()'\n\ncompiled_code = compile('print(1)', '', 'exec')\nexec(compiled_code)\n" + }, + { + "id": 77, + "text": "Avoid using sys.exit(), use exceptions to control program exit instead.", + "detail": "Defect type: Avoid using sys.exit(), should use exceptions to control program exit; Repair solution: Use exceptions to control program exit.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid using sys.exit(), should use exceptions to control program exit'", + "no_example": "Examples that cannot be judged as 'avoid using sys.exit(), should use exceptions to control program exit'" + }, + { + "id": 78, + "text": "Avoid using time.sleep() for thread synchronization, and instead use synchronization primitives such as locks or events.", + "detail": "Defect type: Using time.sleep() for thread synchronization; Fix solution: Use synchronization primitives.", + "language": "Python", + "yes_example": "Examples of being judged as 'using time.sleep() for thread synchronization'", + "no_example": "Examples that cannot be judged as 'using time.sleep() for thread synchronization'" + }, + { + "id": 79, + "text": "Avoid exceeding 79 characters per line of code", + "detail": "Defect type: Avoid exceeding 79 characters per line of code; Fix solution: Format long lines of code into multiple lines.", + "language": "Python", + "yes_example": "Example of being judged as 'avoiding more than 79 characters per line of code'", + "no_example": "Examples that cannot be judged as 'each line of code should not exceed 79 characters'" + }, + { + "id": 80, + "text": "Functions and class definitions at the module level are separated by two blank lines, and method definitions within a class are separated by one blank line", + "detail": "Defect type: There is no separation of two blank lines between function and class definitions at the module level, and no separation of one blank line between method definitions within the class; Solution: Add blank lines according to the specification.", + "language": "Python", + "yes_example": "Example of being judged as 'Functions at the module level are not separated by two blank lines, and method definitions within a class are not separated by one blank line'", + "no_example": "Examples that cannot be judged as 'There is no two blank lines between module-level function and class definitions, and no one blank line between method definitions inside a class'" + }, + { + "id": 81, + "text": "Use lowercase letters and underscores to separate variable and function names", + "detail": "Defect type: Variable and function naming do not conform to the lowercase letters and underscore separation method; Repair solution: Use lowercase letters and underscore separation method for naming.", + "language": "Python", + "yes_example": "Examples of being judged as 'not using lowercase letters and underscores to separate variable and function names'", + "no_example": "Examples that cannot be judged as 'naming variables and functions without using lowercase letters and underscores to separate them'" + }, + { + "id": 82, + "text": "It is not allowed to use the print() function to record logs, use the logging module, etc. to record logs", + "detail": "Defect type: Using the print() function to log; Fix solution: Use the logging module to log.", + "language": "Python", + "yes_example": "Examples of being judged as 'using the print() function to log'", + "no_example": "Examples that cannot be considered as 'using the print() function to log'" + } +] \ No newline at end of file diff --git a/metagpt/ext/cr/points_cn.json b/metagpt/ext/cr/points_cn.json new file mode 100644 index 0000000000..10fc951c07 --- /dev/null +++ b/metagpt/ext/cr/points_cn.json @@ -0,0 +1,656 @@ +[ + { + "id": 1, + "text": "避免未使用的临时变量", + "language": "Java", + "detail": "缺陷类型:避免未使用的临时变量;对应Fixer:UnusedLocalVariableFixer;修复方案:删除未使用的临时变量", + "yes_example": "### 被判定为\"避免未使用的临时变量\"的例子\n<例子1>\npublic String initCreationForm(Map model) {\n\t\tOwner owner = new Owner();\n\t\tmodel.put(\"owner\", owner);\n\t\tint unusedVar = 10;\n\t\treturn VIEWS_OWNER_CREATE_OR_UPDATE_FORM;\n\t}\n上述代码中unusedVar变量未被使用,所以这个被判定为\"避免未使用的临时变量\"\n\n<例子2>\nint unusedVariable = 10;\nSystem.out.println(\"Hello, World!\");\n这段代码的变量\"unusedVariable\"未被使用或者引用,所以这个不能判定为\"避免未使用的临时变量\"\n", + "no_example": "### 不能被判定为\"避免未使用的临时变量\"的例子\n<例子1>\npublic void setTransientVariablesLocal(Map transientVariables) {\nthrow new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}\n这段代码的\"transientVariables\"是函数参数而不是临时变量,虽然transientVariables没有被使用或者引用,但是这个也不能判定为\"避免未使用的临时变量\"\n\n\n<例子2>\npublic class TriggerCmd extends NeedsActiveExecutionCmd {\n protected Map transientVariables;\n public TriggerCmd(Map transientVariables) {\n this.transientVariables = transientVariables;\n }\n}\n上述代码中transientVariables不属于临时变量,它是类属性,且它在构造函数中被使用,所以这个不能被判定为\"避免未使用的临时变量\"\n" + }, + { + "id": 2, + "text": "不要使用 System.out.println 去打印", + "language": "Java", + "detail": "缺陷类型:不要使用 System.out.println 去打印;对应Fixer:SystemPrintlnFixer;修复方案:注释System.out.println代码", + "yes_example": "### 被判定为\"不要使用 System.out.println 去打印\"的例子\n<例子1>\nSystem.out.println(\"Initializing new owner form.\");\n上述代码使用了\"System.out.println\"进行打印,所以这个被判定为\"不要使用 System.out.println 去打印\"\n", + "no_example": "### 不能被判定为\"不要使用 System.out.println 去打印\"的例子\n<例子1>\nthrow new IllegalStateException(\"There is no authenticated user, we need a user authenticated to find tasks\");\n上述代码是抛出异常的代码,没有使用\"System.out.print\",所以这个不能被判定为\"不要使用 System.out.println 去打印\"\n" + }, + { + "id": 3, + "text": "避免函数中未使用的形参", + "language": "Java", + "detail": "缺陷类型:避免函数中未使用的形参;修复方案:忽略", + "yes_example": "### 被判定为\"避免函数中未使用的形参\"的例子\n<例子1>\npublic void setTransientVariablesLocal(Map transientVariables) {\n throw new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}这段代码中的形参\"transientVariables\"未在函数体内出现,所以这个被判定为\"避免函数中未使用的形参\"\n\n\n<例子2>\nprotected void modifyFetchPersistencePackageRequest(PersistencePackageRequest ppr, Map pathVars) {}\n这段代码中的形参\"ppr\"和\"pathVars\"未在函数体内出现,所以这个被判定为\"避免函数中未使用的形参\"\n", + "no_example": "### 不能被判定为\"避免函数中未使用的形参\"的例子\n<例子1>\npublic String processFindForm(@RequestParam(value = \"pageNo\", defaultValue = \"1\") int pageNo) {\n\tlastName = owner.getLastName();\n\treturn addPaginationModel(pageNo, paginationModel, lastName, ownersResults);\n}这段代码中的形参\"pageNo\"在当前函数'processFindForm'内被'return addPaginationModel(pageNo, paginationModel, lastName, ownersResults);'这一句被使用,虽然pageNo没有被用于逻辑计算,但作为了函数调用其他函数的参数使用了,所以这个不能被判定为\"避免函数中未使用的形参\"\n\n<例子2>\npublic void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}这段代码中的形参date在System.out.println(\"Formatted date: \" + sdf.format(date))这一句中被引用到,所以这个不能被判定为\"避免函数中未使用的形参\"\n" + }, + { + "id": 4, + "text": "if语句块不能为空", + "language": "Java", + "detail": "缺陷类型:if 语句块不能为空;对应Fixer:EmptyIfStmtFixer;修复方案:删除if语句块 或 适当的逻辑处理 或 注释说明为何为空", + "yes_example": "### 被判定为\"if语句块不能为空\"的例子\n<例子1>\npublic void emptyIfStatement() {\n\tif (getSpecialties().isEmpty()) {\n\t}\n}这段代码中的if语句块内容是空的,所以这个被判定为\"if语句块不能为空\"\n\n\n<例子2>\npublic void judgePersion() {\n\tif (persion != null) {\n\t\t// judge persion if not null\n\t}\n}\n这段代码中的if语句块虽然有内容,但是\"// judge persion if not null\"只是代码注释,if语句块内并没有实际的逻辑代码,所以这个被判定为\"if语句块不能为空\"\n", + "no_example": "### 不能被判定为\"if语句块不能为空\"的例子\n<例子1>\npublic void judgePersion() {\n\tif (persion != null) {\n\t\treturn 0;\n\t}\n}这段代码中的if语句块里有内容,且里面有非注释代码的逻辑代码\"return 0;\",所以这个不能被判定为\"if语句块不能为空\"\n" + }, + { + "id": 5, + "text": "循环体不能为空", + "language": "Java", + "detail": "缺陷类型:循环体不能为空;对应Fixer:EmptyStatementNotInLoopFixer;修复方案:删除对应while、for、foreach 循环体 或 添加适当的逻辑处理或者注释说明为何为空", + "yes_example": "### 被判定为\"循环体不能为空\"的例子\n<例子1>\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t}\n}这段代码中的for循环体的内容是空的,所以这个被判定为\"循环体不能为空\"\n\n\n<例子2>\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t// this is a code example\n\t}\n}这段代码中的while循环体的内容虽然不是空的,但内容只是代码注释,无逻辑内容,所以这个被判定为\"循环体不能为空\"\n\n\n<例子3>\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t\n\t}\n}这段代码中的while循环体内容是空的,所以这个被判定为\"循环体不能为空\"\n", + "no_example": "### 不能被判定为\"循环体不能为空\"的例子\n<例子1>\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t\ta = 1;\n\t\tif (a == 1) {\n\t\t\tretrun a;\n\t\t}\n\t}\n}上述代码的for循环体的内容不为空,且内容不全是代码注释,所以这个不能被判定为\"循环体不能为空\"\n" + }, + { + "id": 6, + "text": "避免使用 printStackTrace(),应该使用日志的方式去记录", + "language": "Java", + "detail": "缺陷类型:避免使用 printStackTrace(),应该使 用日志的方式去记录;修复方案:用日志的方式去记录", + "yes_example": "### 被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"的例子\n<例子1>\npublic void usePrintStackTrace() {\n\ttry {\n\t\tthrow new Exception(\"Fake exception\");\n\t} catch (Exception e) {\n\t\te.printStackTrace();\n\t}\n}这段代码中的catch语句中使用了printStackTrace(),所以这个被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"\n", + "no_example": "### 不能被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"的例子\n<例子1>\npublic void usePrintStackTrace() {\n\ttry {\n\t\tthrow new Exception(\"Fake exception\");\n\t} catch (Exception e) {\n\t\tlogging.info(\"info\");\n\t}\n}这段代码的catch语句中使用的是日志记录的方式,所以这个不能被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"\n" + }, + { + "id": 7, + "text": "catch 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:catch 语句块不能为空;对应Fixer:EmptyCatchBlockFixer;修复方案:在catch里面添加注释", + "yes_example": "### 被判定为\"catch语句块不能为空\"的例子\n<例子1>\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n \n}\n这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n\n<例子2>\ntry {\n String str = null;\n str.length();\n} catch (NullPointerException e) {\n \n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n\n<例子3>\npublic class EmptyCatchExample {\n public static void main(String[] args) {\n try {\n // 尝试除以零引发异常\n int result = 10 / 0;\n } catch (ArithmeticException e) {\n \n }\n }\n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n<例子4>\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n \n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n<例子5>\ntry {\n Object obj = \"string\";\n Integer num = (Integer) obj;\n} catch (ClassCastException e) {\n\t\n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n", + "no_example": "### 不能被判定为\"catch语句块不能为空\"的例子\n<例子1>\npersionNum = 1\ntry {\n\treturn True;\n} catch (Exception e) {\n\t// 如果人数为1则返回false\n\tif (persionNum == 1){\n\t\treturn False;\n\t}\n}这段代码的catch语句中不为空,所以不能把这个被判定为\"catch语句块不能为空\"\n\n\n<例子2>\ntry {\n\tthrow new Exception(\"Fake exception\");\n} catch (Exception e) {\n\te.printStackTrace();\n}这段代码的catch语句中虽然只有\"e.printStackTrace();\"但确实不为空,所以不能把这个被判定为\"catch语句块不能为空\"\n" + }, + { + "id": 8, + "text": "避免不必要的永真/永假判断", + "language": "Java", + "detail": "缺陷类型:避免不必要的永真/永假判断;对应Fixer:UnconditionalIfStatement Fixer;修复方案:删除永真/永假判断逻辑", + "yes_example": "### 被判定为\"避免不必要的永真/永假判断\"的例子\n<例子1>\npublic void someMethod() {\n\twhile (true) {\n\t}\n}这段代码中的\"while (true)\"是一个使用true做判断条件,但是没有循环结束标记,所以这个被判定为\"避免不必要的永真/永假判断\"\n\n\n<例子2>\nif (true) {\n\tSystem.out.println(\"This is always true\");\n}这段代码中的\"if (true)\"是一个使用true条件做条件,但是没有循环结束标记,所以这个被判定为\"避免不必要的永真/永假判断\"\n\n\n<例子3>\na = 1;\nwhile(a > 0){\n\ta = a + 1\n}这段代码初始化a=1,是大于0的,while循环体的逻辑是每次加1,那么判断条件a > 0会永远是真的,不会退出循环,所以这个被判定为\"避免不必要的永真/永假判断\"\n<例子3>", + "no_example": "### 不能被判定为\"避免不必要的永真/永假判断\"的例子\n<例子1>\na = 0;\nwhile (a < 5) {\n\ta = a + 1;\n}这段代码中的a<5是一个判断,当执行了5次while语句中的逻辑a=a+1之后,a会满足a < 5,就会退出循环,所以这个能被判定为\"避免不必要的永真/永假判断\"\n" + }, + { + "id": 9, + "text": "switch 中 default 必须放在最后", + "language": "Java", + "detail": "缺陷类型:switch 中 default 必须放在最后;对应Fixer:DefaultLabelNotLastInSwitchStmtFixer;修复方案:switch 中 default 放在最后", + "yes_example": "### 被判定为\"switch 中 default 必须放在最后\"的例子\n<例子1>\nswitch (number) {\n\tdefault:\n\t\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\t\tbreak;\n\tcase 1:\n\t\tSystem.out.println(\"Number one\");\n\t\tbreak;\n\tcase 2:\n\t\tSystem.out.println(\"Number two\");\n\t\tbreak;\n}这段代码是一个switch语句,但是里面的default没有放在最后,所以这个被判定为\"switch 中 default 必须放在最后\"\n", + "no_example": "### 不能被判定为\"switch 中 default 必须放在最后\"的例子\n<例子1>\nswitch (number) {\ncase 3:\n\tSystem.out.println(\"Number one\");\n\tbreak;\ncase 4:\n\tSystem.out.println(\"Number two\");\n\tbreak;\ndefault:\n\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\tbreak;\n}这段代码是一个switch语句且里面的default放在了最后,所以这个不能被判定为\"switch 中 default 必须放在最后\"\n" + }, + { + "id": 10, + "text": "未使用equals()函数对 String 作比较", + "language": "Java", + "detail": "缺陷类型:未使用equals()函数对 String 作比较;对应Fixer:UnSynStaticDateFormatter Fixer;修复方案:使用equals()函数对 String 作比较", + "yes_example": "### 被判定为\"未使用equals()函数对 String 作比较\"的例子\n<例子1>\nif (existingPet != null && existingPet.getName() == petName) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}这段代码中所涉及的existingPet.getName()和petName均是字符串,但是在if语句里做比较的时候使用了==而没有使用equals()对string做比较,所以这个被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子2>\nString isOk = \"ok\";\nif (\"ok\" == isOk) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}这段代码中的isOk是个字符串,但在if判断中与\"ok\"比较的时候使用的是==,未使用equals()对string做比较,应该使用\"ok\".equals(isOk),所以这个被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子3>\nString str1 = \"Hello\";\nString str2 = \"Hello\";\nif (str1 == str2) {\n\tSystem.out.println(\"str1 和 str2 引用相同\");\n} else {\n\tSystem.out.println(\"str1 和 str2 引用不同\");\n}\n这段代码中的if (str1 == str2) 使用了==进行str1和str2的比较,未使用equals()对string做比较,应该使用str1.equals(str2),所以这个被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子4>\nString str = \"This is string\";\nif (str == \"This is not str\") {\n\treturn str;\n}这段代码中的if (str == \"This is not str\")使用了==进行字符串比较,未使用equals()对string做比较,\"This is not str\".equals(str),所以这个被判定为\"未使用equals()函数对 String 作比较\"\n", + "no_example": "### 不能被判定为\"未使用equals()函数对 String 作比较\"的例子\n<例子1>\nif (PROPERTY_VALUE_YES.equalsIgnoreCase(readWriteReqNode))\n formProperty.setRequired(true);\n这段代码中的PROPERTY_VALUE_YES和readWriteReqNode均是字符串,在if语句里比较PROPERTY_VALUE_YES和readWriteReqNode的使用的是equalsIgnoreCase(字符串比较忽略大小写),所以equalsIgnoreCase也是符合使用equals()函数对 String 作比较的,所以这个不能被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子2>\nString isOk = \"ok\";\nif (\"ok\".equals(isOk)) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}这段代码中的isOk是个字符串,在if判断中与\"ok\"比较的时候使用的是equals()对string做比较,所以这个不能被判定为\"未使用equals()函数对 String 作比较\"\n" + }, + { + "id": 11, + "text": "禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象", + "language": "Java", + "detail": "缺陷类型:禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象 输出异常;对应Fixer:ConcatExceptionFixer;修复方案:使用占位符传递异常对象", + "yes_example": "### 被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"的例子\n<例子1>\ntry {\n listenersNode = objectMapper.readTree(listenersNode.asText());\n} catch (Exception e) {\n LOGGER.info(\"Listeners node can not be read\", e);\n}这段代码中日志输出内容内容是直接使用字符串\"Listeners node can not be read\"拼接,日志输出异常时,应使用占位符输出异常信息,而不是直接使用字符串拼接,所以这个被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"\n", + "no_example": "### 不能被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"的例子\n<例子1>\nPersion persion = persionService.getPersion(1);\nif (persion == null){\n\tLOGGER.error(PERSION_NOT_EXIT);\n}这段代码中的PERSION_NOT_EXIT是一个用户自定义的异常常量,代表persion不存在,没有直接使用字符串\"persion not exit\"拼接,所以这个不能被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"\n<例子1>\n\n<例子2>\ntry {\n a = a + 1;\n} catch (Exception e) {\n Persion persion = persionService.getPersion(1);\n LOGGER.info(persion);\n}这段代码中输出日志没有直接使用字符串拼接,而是使用的Persion对象输出,所以这个不能被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"\n" + }, + { + "id": 12, + "text": "finally 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:finally 语句块不能为空;对应Fixer:EmptyFinallyBlockFixer;修复方案:删除空 finally 语句块", + "yes_example": "### 被判定为\"finally 语句块不能为空\"的例子\n<例子1>\ntry {\n\tPersion persion = persionService.getPersion(1);\n\treturn persion;\n} finally {\n\t\n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子2>\ntry {\n\tSystem.out.println(\"Inside try block\");\n} finally {\n\t// 空的finally块,没有任何语句,这是一个缺陷\n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子3>\ntry {\n int result = 10 / 0;\n} catch (ArithmeticException e) {\n e.printStackTrace();\n} finally {\n \n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子4>\ntry {\n String str = null;\n System.out.println(str.length());\n} catch (NullPointerException e) {\n e.printStackTrace();\n} finally {\n \n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子5>\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n e.printStackTrace();\n} finally {\n // 只有注释的 finally 语句块\n // 这是一个空的 finally 块\n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子6>\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n e.printStackTrace();\n} finally {\n // 只有空行的 finally 语句块\n \n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n", + "no_example": "### 不能被判定为\"finally 语句块不能为空\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\treturn null;\n\t}\n}这段代码中的finally语句块中有非注释意外的内容\"return null;\",所以这个不能被判定为\"finally 语句块不能为空\"\n" + }, + { + "id": 13, + "text": "try 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:try 语句块不能为空;对应Fixer:EmptyTryBlockFixer;修复方案:删除整个 try 语句", + "yes_example": "### 被判定为\"try 语句块不能为空\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\n\t}\n\treturn null;\n}这段代码中的try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n\n\n<例子2>\npublic void demoFinallyBlock() {\n\ttry {\n\n\t} finally {\n\t\treturn null;\n\t}\n}这段代码中的try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n\n\n<例子3>\ntry {\n \n} catch (Exception e) {\n e.printStackTrace();\n}这段代码中的try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n\n\n<例子4>\ntry {\n // 只有注释的 try 语句块\n\t\n} catch (Exception e) {\n e.printStackTrace();\n}这段代码中的try语句块内只有注释和空行,也可以认定为这种情况是try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n", + "no_example": "### 不能被判定为\"try 语句块不能为空\"的例子\n<例子1>\ntry {\n\ta = a + 1;\n} catch (Exception e) {\n\te.printStackTrace();\n}\n这段代码中的try语句块中有非注释意外的内容\"return null;\",所以这个不能被判定为\"try 语句块不能为空\"\n" + }, + { + "id": 14, + "text": "避免对象进行不必要的 NULL或者null 检查", + "language": "Java", + "detail": "缺陷类型:避免对象进行不必要的 NULL或者null 检查;对应Fixer:LogicalOpNpeFixer;修复方案:删除对对象不必要的 NULL 检查的逻辑", + "yes_example": "### 被判定为\"避免对象进行不必要的 NULL或者null 检查\"的例子\n<例子1>\na = \"dog\";\nif (a != null){\n\treturn a;\n}这段代码中的对象a已经是确定的值\"dog\",所以if条件句的判断\"a != null\"是不必要的,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子2>\nif (authenticatedUserId != null && !authenticatedUserId.isEmpty() && userGroupManager!=null){\n\treturn authenticatedUserId;\n}这段代码中的\"authenticatedUserId != null\"和\"!authenticatedUserId.isEmpty()\"都是对\"authenticatedUserId\"的空判断,重复了,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子3>\nList list = new ArrayList<>();\nif (list != null) {\n list.add(1);\n}这段代码中的list已经被初始化,不需要进行 null 检查,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子4>\nif (this.type != null && this.type.getName() != null) {\n\tSystem.out.println(\"Type name is not null\");\n}这段代码中的对象type已经检查过非null,再次检查getName()是否为null是不必要的,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n\n<例子5>\nif (\"dog\".equals(null)){\n\treturn a;\n}这段代码中的\"dog\"是个确定的字符串,不需要进行null 检查,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子6>\nInteger num = 10;\nif (num != null) {\n System.out.println(num);\n}这段代码中的num 已经被初始化,不需要进行 null 检查,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n", + "no_example": "### 不能被判定为\"避免对象进行不必要的 NULL或者null 检查\"的例子\n<例子1>\nCat cat = catService.get(1);\nif (cat != null){\n\tretrun cat;\n}这段代码中的对象\"cat\"是通过service获取到的,不确定是否为空,所以if条件句的判断的\"cat != null\"是必要的,所以这个不能被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n" + }, + { + "id": 15, + "text": "避免 finally 块中出现 return", + "language": "Java", + "detail": "缺陷类型:避免 finally 块中出现 return;修复方案:无需修复", + "yes_example": "### 被判定为\"避免 finally 块中出现 return\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\treturn null;\n\t}\n}这段代码中的finally语句块内容包含\"return\",所以这个被判定为\"避免 finally 块中出现 return\"\n", + "no_example": "### 不能被判定为\"避免 finally 块中出现 return\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\tLOGGER.info(PERSION_NOT_EXIT);\n\t}\n}这段代码中的finally语句块中内容不包含\"return\",所以这个不能被判定为\"避免 finally 块中出现 return\"\n" + }, + { + "id": 16, + "text": "避免空的 static 初始化", + "language": "Java", + "detail": "缺陷类型:避免空的 static 初始化;对应Fixer:EmptyInitializerFixer;修复方案:删除整个空初始化块", + "yes_example": "### 被判定为\"避免空的 static 初始化\"的例子\n<例子1>\npublic class PetValidator implements Validator {\n\tstatic {\n\n\t}\n}这段代码中的static语句块没有内容,是空的,所以这个被判定为\"避免空的 static 初始化\"\n\n\n<例子2>\npublic class Persion {\n\tstatic {\n\t\t// 初始化的静态块\n\t}\n}这段代码中的static语句块是有内容的,不是空的,但是static初始化语句块中只有注释代码,没有实际的逻辑,所以这个被判定为\"避免空的 static 初始化\"\n", + "no_example": "### 不能被判定为\"避免空的 static 初始化\"的例子\n<例子1>\npublic class Cat {\n\tstatic {\n\t\t// 初始化的静态块\n\t\tcat = null;\n\t}\n}这段代码中的static语句块是有内容的,不是空的,且static初始化语句块中有非注释代码,有实际的逻辑,所以这个不能被判定为\"避免空的 static 初始化\"\n" + }, + { + "id": 17, + "text": "避免日历类用法不当风险", + "language": "Java", + "detail": "缺陷类型:避免日历类用法不当风险;修复方案:使用Java 8 及以上版本中的 java.time 包的LocalDate", + "yes_example": "### 被判定为\"避免日历类用法不当风险\"的例子\n<例子1>\nprivate static final Calendar calendar = new GregorianCalendar(2020, Calendar.JANUARY, 1);\n这段代码中的Calendar和GregorianCalendar是线程不安全的,所以这个被判定为\"避免日历类用法不当风险\"\n", + "no_example": "### 不能被判定为\"避免日历类用法不当风险\"的例子\n<例子1>\nprivate static final LocalDate calendar = LocalDate.of(2020, 1, 1);\n这段代码中的LocalDate使用的是Java 8 及以上版本中的 java.time 包,LocalDate 是不可变的并且是线程安全的,不会有线程安全和性能方面的问题,所以这个不能被判定为\"避免日历类用法不当风险\"\n" + }, + { + "id": 18, + "text": "使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()", + "language": "Java", + "detail": "缺陷类型:使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size();对应Fixer:ClassCastExpWithToArrayF ixer;修复方案:使用集合的toArray(T[]array),且传入的是类型完全一样的数组", + "yes_example": "### 被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"的例子\n<例子1>\nList stringList = new ArrayList<>();\nstringList.add(\"Apple\");\nstringList.add(\"Banana\");\nObject[] objectArray = stringList.toArray(new Object[5]);\n这段代码使用集合转数组的方法的时候使用了toArray(new Object[5]),但是传入的数组类型不一致,所以这个被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"\n", + "no_example": "### 不能被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"的例子\n<例子1>\nList stringList = new ArrayList<>();\nstringList.add(\"Apple\");\nstringList.add(\"Banana\");\nString[] stringArray = stringList.toArray(new String[stringList.size()]);\n这段代码使用集合转数组的方法的时候使用了toArray(new String[stringList.size()]),传入的是类型完全一样的数组,所以这个不能被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"\n" + }, + { + "id": 19, + "text": "禁止在 equals()中使用 NULL或者null 做比较", + "language": "Java", + "detail": "缺陷类型:禁止在 equals()中使用 NULL或者null 做比较;对应Fixer:EqualsNullFixer;修复方案:使用Object的判空函数 做比较", + "yes_example": "### 被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"的例子\n<例子1>\nif (\"test\".equals(null)) {\n\tSystem.out.println(\"test\");\n}这段代码中if条件中的代码\"test\".equals(null)使用equals()函数与null进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子2>\nif (!rangeValues[1].equals(\"null\")) {\n\tmaxValue = new BigDecimal(rangeValues[1]);\n}这段代码中if条件中的代码!rangeValues[1].equals(\"null\")使用equals()函数与Nnull进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子3>\nString str1 = \"example\";\nif (str1.equals(\"null\")) {\n System.out.println(\"str1 is null\");\n}这段代码中if条件中的代码str1.equals(null)使用equals()函数与null进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子4>\nString str3 = \"example\";\nif (str3 != null && str3.equals(\"null\")) {\n System.out.println(\"str3 is null\");\n}这段代码中if条件中的代码str3.equals(\"null\")使用equals()函数与\"null\"进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子5>\nInteger num1 = 10;\nif (num1.equals(null)) {\n System.out.println(\"num1 is null\");\n}这段代码中if条件中的代码num1.equals(null)使用equals()函数与\"null\"进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子6>\nObject obj = new Object();\nif (obj.equals(null)) {\n System.out.println(\"obj is null\");\n}这段代码中if条件中的代码obj.equals(null)使用equals()函数与\"null\"进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n", + "no_example": "### 不能被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"的例子\n<例子1>\na = \"test\";\nif (a.equals(\"test\")) {\n\tSystem.out.println(\"test\");\n}这段代码中if条件中的代码a.equals(\"test\")使用equals()函数与\"test\"进行了比较,所以这个不能被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n" + }, + { + "id": 20, + "text": "switch 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:switch 语句块不能为空;对应Fixer:EmptySwitchStatementsFix;修复方案:删除整个空 switch 语句块", + "yes_example": "### 被判定为\"switch 语句块不能为空\"的例子\n<例子1>\nswitch (number) {\n\t\n}这段代码是一个switch语句块,但是里面没有内容,所以这个被判定为\"switch 语句块不能为空\"\n\n\n<例子2>\nswitch (number) {\n\t// 这是一个switch语句块\n}这段代码是一个switch语句块,里面虽然有内容,但是内容仅仅是注释内容,没有实际的逻辑,所以这个被判定为\"switch 语句块不能为空\"\n", + "no_example": "### 不能被判定为\"switch 语句块不能为空\"的例子\n<例子1>\nswitch (number) {\n\tcase 1:\n\t\tSystem.out.println(\"Number one\");\n\t\tbreak;\n\tdefault:\n\t\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\t\tbreak;\n}这段代码是一个switch语句块,里面有内容,而且内容里有非注释的代码,有实际的逻辑,所以这个不能被判定为\"switch 语句块不能为空\"\n" + }, + { + "id": 21, + "text": "在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开", + "detail": "缺陷类型:在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开;修复方案:在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开。", + "language": "Java", + "yes_example": "### 被判定为\"在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开\"的例子\n<例子1>\nint a = (int) 3.0;\n\n<例子2>\nint b = (int) 4.0;\n\n<例子3>\nlong a = (long) 5;\n\n<例子4>\nstring a = (string) 3.5;\n\n<例子5>\nPersion a = (Persion) \"zhangsan\";\n", + "no_example": "### 不能被判定为\"在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开\"的例子\n<例子1>\nint a = (int)3.0;\n" + }, + { + "id": 22, + "text": "方法参数在定义和传入时,多个参数逗号后面必须加空格", + "detail": "缺陷类型:方法参数在定义和传入时,多个参数逗号后面必须加空格;修复方案:方法参数在定义和传入时,多个参数逗号后面必须加空格。", + "language": "Java", + "yes_example": "### 被判定为\"方法参数在定义和传入时,多个参数逗号后面必须加空格\"的例子\n<例子1>\npublic void exampleMethod(int a,int b,int c) {}\n", + "no_example": "### 不能被判定为\"方法参数在定义和传入时,多个参数逗号后面必须加空格\"的例子\n<例子1>\npublic void exampleMethod(int a, int b, int c) {}\n" + }, + { + "id": 23, + "text": "禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象", + "detail": "缺陷类型:禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象;修复方案:推荐使用 BigDecimal 的 valueOf 方法。", + "language": "Java", + "yes_example": "### 被判定为\"禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象\"的例子\n<例子1>\nBigDecimal bd = new BigDecimal(0.1);\n", + "no_example": "### 不能被判定为\"禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象\"的例子\n<例子1>\nBigDecimal bd = BigDecimal.valueOf(0.1);\n" + }, + { + "id": 24, + "text": "不能有多余的分号", + "detail": "缺陷类型:多余的分号;修复方案:删除多余的分号", + "yes_example": "### 被判定为\"不能有多余的分号\"的例子\n<例子1>\npublic void trigger(String executionId, Map processVariables) {\n commandExecutor.execute(new TriggerCmd(executionId, processVariables));\n}\n;\na = 1;\nb = 2;\nsum = a + b;\n这段代码中包含一个多余的分号\";\",所以这个被判定为\"不能有多余的分号\"\n", + "no_example": "### 不能被判定为\"不能有多余的分号\"的例子\n<例子1>\nwhile (True) {\n\ta = a + 1;\n\tbreak;\n}这段代码每个分号都是必须要的,所以这个能被判定为\"不能有多余的分号\"\n" + }, + { + "id": 25, + "text": "非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized", + "detail": "缺陷类型:非线程安全的 SimpleDateFormat 使用;修复方案:在函数或代码块级别加上synchronized修饰 或 使用其他线程安全的方式", + "yes_example": "### 被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"的例子\n<例子1>\npublic void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}这段代码中的函数formatDate在未使用synchronized同步修饰的情况下使用了SimpleDateFormat,这是线程不安全的,所以这个被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"\n", + "no_example": "### 不能被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"的例子\n<例子1>\npublic synchronized void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}这段代码是在synchronized同步块对函数'formatDate'进行保护,保证了线程安全,所以这个不能被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"\n" + }, + { + "id": 26, + "text": "未按驼峰命名规范进行命名,类名使用驼峰式UpperCamelCase风格, 方法名、参数名、成员变量、局部变量都统一使用lowerCamelCase风格", + "detail": "缺陷类型:未按驼峰命名规范进行命名;修复方案:类名使用UpperCamelCase风格,方法名、参数名、成员变量、局部变量使用lowerCamelCase风格。", + "language": "Java", + "yes_example": "### 被判定为\"未按驼峰命名规范进行命名\"的例子\n<例子1>\npublic class myClass {\n private int MyVariable;\n public void MyMethod() {}\n}\n这段代码中的类名、成员变量和方法名没有遵循驼峰命名法,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"未按驼峰命名规范进行命名\"的例子\n<例子1>\npublic class MyClass {\n private int myVariable;\n public void myMethod() {}\n}\n这段代码中的类名、成员变量和方法名都遵循了驼峰命名法,所以不能被判定为命名规范问题。\n" + }, + { + "id": 27, + "text": "抽象类命名使用 Abstract 或 Base 开头;异常类命名使用 Exception 结尾,测试类命名以它要测试的类的名称开始,以 Test 结尾", + "detail": "缺陷类型:命名规范;修复方案:抽象类命名使用 Abstract 或 Base 开头,异常类命名使用 Exception 结尾,测试类命名以它要测试的类的名称开始,以 Test 结尾。", + "language": "Java", + "yes_example": "### 被判定为\"命名规范\"的例子\n<例子1>\npublic class MyAbstractClass {}\npublic class MyExceptionClass {}\npublic class TestMyClass {}\n这段代码中的抽象类、异常类和测试类的命名不符合规范,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"命名规范\"的例子\n<例子1>\npublic abstract class AbstractMyClass {}\npublic class MyCustomException extends Exception {}\npublic class MyClassTest {}\n这段代码中的抽象类、异常类和测试类的命名都符合规范,所以不能被判定为命名规范问题。\n" + }, + { + "id": 28, + "text": "POJO 类中的任何布尔类型的变量,避免加\"is\" 前缀", + "detail": "缺陷类型:命名规范;修复方案:POJO 类中的布尔类型变量不要加 is 前缀。", + "language": "Java", + "yes_example": "### 被判定为\"命名规范\"的例子\n<例子1>\npublic class User {\n private boolean isActive;\n}\n这段代码中的布尔类型变量加了 is 前缀,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"命名规范\"的例子\n<例子1>\npublic class User {\n private boolean active;\n}\n这段代码中的布尔类型变量没有加 is 前缀,所以不能被判定为命名规范问题。\n" + }, + { + "id": 29, + "text": "杜绝完全不规范的英文缩写,避免望文不知义。", + "detail": "缺陷类型:命名规范;修复方案:避免使用不规范的英文缩写,确保代码可读性。", + "language": "Java", + "yes_example": "### 被判定为\"命名规范\"的例子\n<例子1>\npublic class CfgMgr {\n private int cnt;\n}\n这段代码中的类名和变量名使用了不规范的英文缩写,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"命名规范\"的例子\n<例子1>\npublic class ConfigManager {\n private int count;\n}\n这段代码中的类名和变量名没有使用不规范的英文缩写,所以不能被判定为命名规范问题。\n" + }, + { + "id": 30, + "text": "避免出现魔法字符和数字,应声明为常量", + "detail": "缺陷类型:避免出现魔法字符和数字,应声明为常量;修复方案:将魔法值定义为常量。", + "language": "Java", + "yes_example": "### 被判定为\"避免出现魔法字符和数字,应声明为常量\"的例子\n<例子1>\npublic class MagicNumberExample {\n public void calculate() {\n int result = 42 * 2;\n }\n}\n这段代码中直接使用了魔法值 42,所以被判定为代码规范问题。\n\n<例子2>\npublic class MagicNumberExample {\n public void calculate() {\n String result = \"This is a result\";\n }\n}\n这段代码中直接使用了魔法值 \"This is a result\",所以被判定为代码规范问题。\n", + "no_example": "### 不能被判定为\"避免出现魔法字符和数字,应声明为常量\"的例子\n<例子1>\npublic class MagicNumberExample {\n private static final int MULTIPLIER = 42;\n public void calculate() {\n int result = MULTIPLIER * 2;\n }\n}\n这段代码中将魔法值定义为了常量,所以不能被判定为代码规范问题。\n" + }, + { + "id": 31, + "text": "long 或 Long 赋值时,数值后使用大写 L,不能是小写 l,浮点数类型的数值后缀统一为大写的 D 或 F", + "detail": "缺陷类型:代码规范;修复方案:long 或 Long 赋值时使用大写 L,浮点数类型的数值后缀使用大写的 D 或 F。", + "language": "Java", + "yes_example": "### 被判定为\"代码规范\"的例子\n<例子1>\npublic class NumberExample {\n private long value = 1000l;\n private double pi = 3.14d;\n}\n这段代码中使用了小写的 l 和 d,所以被判定为代码规范问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class NumberExample {\n private long value = 1000L;\n private double pi = 3.14D;\n}\n这段代码中使用了大写的 L 和 D,所以不能被判定为代码规范问题。\n" + }, + { + "id": 32, + "text": "如果大括号内为空,简洁地写成{}即可,大括号中间无需换行和空格;如果是非空代码块,则:1)左大括号前不换行。2)左大括号后换行。3)右大括号前换行。4)右大括号后还有 else 等代码则不换行;表示终止的右大括号后必须换行。", + "detail": "缺陷类型:代码格式;修复方案:遵循大括号的使用规范。", + "language": "Java", + "yes_example": "### 被判定为\"代码格式\"的例子\n<例子1>\npublic class BracketExample{public void method(){\n if (true) {\n }}\n}\n这段代码中的大括号使用不符合规范,所以被判定为代码格式问题。\n", + "no_example": "### 不能被判定为\"代码格式\"的例子\n<例子1>\npublic class BracketExample {\n public void method() {\n if (true) {\n // do something\n }\n }\n}\n这段代码中的大括号使用符合规范,所以不能被判定为代码格式问题。\n" + }, + { + "id": 33, + "text": "左小括号和右边相邻字符之间不需要空格;右小括号和左边相邻字符之间也不需要空格;而左大括号前需要加空格。", + "detail": "缺陷类型:代码格式;修复方案:遵循括号和空格的使用规范。", + "language": "Java", + "yes_example": "### 被判定为\"代码格式\"的例子\n<例子1>\npublic class SpaceExample {\n public void method (){\n }\n}\n这段代码中的括号和空格使用不符合规范,所以被判定为代码格式问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class SpaceExample {\n public void method() {}\n}\n这段代码中的括号和空格使用符合规范,所以不能被判定为代码格式问题。\n" + }, + { + "id": 34, + "text": "if / for / while / switch / do 等保留字与左右括号之间都必须加空格。", + "detail": "缺陷类型:代码格式;修复方案:保留字与左右括号之间加空格。", + "language": "Java", + "yes_example": "### 被判定为\"代码规范\"的例子\n<例子1>\npublic class KeywordExample {\n public void method() {\n if(true) {\n }\n }\n}\n这段代码中的 if 关键字与括号之间没有空格,所以被判定为代码格式问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class KeywordExample {\n public void method() {\n if (true) {\n }\n }\n}\n这段代码中的 if 关键字与括号之间有空格,所以不能被判定为代码格式问题。\n" + }, + { + "id": 35, + "text": "所有整型包装类对象之间值的比较,全部使用 equals 方法比较", + "detail": "缺陷类型:代码规范;修复方案:整型包装类对象之间的值比较使用 equals 方法。", + "language": "Java", + "yes_example": "### 被判定为\"代码规范\"的例子\n<例子1>\npublic class IntegerComparison {\n public void compare() {\n Integer a = 100;\n Integer b = 100;\n if (a == b) {\n }\n }\n}\n这段代码中使用了 == 比较整型包装类对象,所以被判定为代码规范问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class IntegerComparison {\n public void compare() {\n Integer a = 100;\n Integer b = 100;\n if (a.equals(b)) {\n }\n }\n}\n这段代码中使用了 equals 方法比较整型包装类对象,所以不能被判定为代码规范问题。\n" + }, + { + "id": 36, + "text": "BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法。", + "detail": "缺陷类型:BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法;修复方案:使用 compareTo() 方法进行比较。", + "language": "Java", + "yes_example": "### 被判定为\"BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法\"的例子\n<例子1>\nBigDecimal a = new BigDecimal(\"1.0\");\nBigDecimal b = new BigDecimal(\"1.00\");\nif (a.equals(b)) {\n // 这段代码会返回 false,因为 equals() 方法会比较精度\n}\n", + "no_example": "### 不能被判定为\"BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法\"的例子\n<例子1>\nBigDecimal a = new BigDecimal(\"1.0\");\nBigDecimal b = new BigDecimal(\"1.00\");\nif (a.compareTo(b) == 0) {\n // 这段代码会返回 true,因为 compareTo() 方法只比较数值\n}\n" + }, + { + "id": 37, + "text": "禁止在 POJO 类中,同时存在对应属性 xxx 的 isXxx() 和 getXxx() 方法。", + "detail": "缺陷类型:POJO 类中存在重复的 getter 方法;修复方案:确保只存在一个 getter 方法。", + "language": "Java", + "yes_example": "### 被判定为\"禁止在 POJO 类中,同时存在对应属性 xxx 的 isXxx() 和 getXxx() 方法\"的例子\n<例子1>\npublic class User {\n private boolean active;\n public boolean isActive() {\n return active;\n }\n public boolean getActive() {\n return active;\n }\n}\n", + "no_example": "### 不能被判定为\"禁止在 POJO 类中,同时存在对应属性 xxx 的 isXxx() 和 getXxx() 方法\"的例子\n<例子1>\npublic class User {\n private int age;\n public int getAge() {\n return age;\n }\n}\n" + }, + { + "id": 38, + "text": "日期格式化时,传入 pattern 中表示年份统一使用小写的 y。", + "detail": "缺陷类型:日期格式化错误;修复方案:使用小写的 y 表示年份。", + "language": "Java", + "yes_example": "### 被判定为\"日期格式化时,传入 pattern 中表示年份统一使用小写的 y\"的例子\n<例子1>\nSimpleDateFormat sdf = new SimpleDateFormat(\"YYYY-MM-dd\");\n", + "no_example": "### 不能被判定为\"日期格式化时,传入 pattern 中表示年份统一使用小写的 y\"的例子\n<例子1>\nSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n" + }, + { + "id": 39, + "text": "禁止在程序任何地方中使用:1)java.sql.Date 2)java.sql.Time 3)java.sql.Timestamp。", + "detail": "缺陷类型:使用了 java.sql 包中的日期类;修复方案:使用 java.time 包中的日期类。", + "language": "Java", + "yes_example": "### 被判定为\"禁止在程序任何地方中使用:1)java.sql.Date 2)java.sql.Time 3)java.sql.Timestamp\"的例子\n<例子1>\njava.sql.Date sqlDate = new java.sql.Date(System.currentTimeMillis());\n", + "no_example": "### 不能被判定为\"禁止在程序任何地方中使用:1)java.sql.Date 2)java.sql.Time 3)java.sql.Timestamp\"的例子\n<例子1>\njava.time.LocalDate localDate = java.time.LocalDate.now();\n" + }, + { + "id": 40, + "text": "判断所有集合内部的元素是否为空,使用 isEmpty() 方法,而不是 size() == 0 的方式。", + "detail": "缺陷类型:集合判空方式错误;修复方案:使用 isEmpty() 方法。", + "language": "Java", + "yes_example": "### 被判定为\"判断所有集合内部的元素是否为空,使用 isEmpty() 方法,而不是 size() == 0 的方式\"的例子\n<例子1>\nList list = new ArrayList<>();\nif (list.size() == 0) {\n // 判空逻辑\n}\n", + "no_example": "### 不能被判定为\"判断所有集合内部的元素是否为空,使用 isEmpty() 方法,而不是 size() == 0 的方式\"的例子\n<例子1>\nList list = new ArrayList<>();\nif (list.isEmpty()) {\n // 判空逻辑\n}\n" + }, + { + "id": 41, + "text": "只要重写 equals,就必须重写 hashCode。", + "detail": "缺陷类型:未重写 hashCode 方法;修复方案:同时重写 equals 和 hashCode 方法。", + "language": "Java", + "yes_example": "### 被判定为\"只要重写 equals,就必须重写 hashCode\"的例子\n<例子1>\npublic class User {\n private String name;\n @Override\n public boolean equals(Object o) {\n if (this == o) return true;\n if (o == null || getClass() != o.getClass()) return false;\n User user = (User) o;\n return Objects.equals(name, user.name);\n }\n}\n", + "no_example": "### 不能被判定为\"只要重写 equals,就必须重写 hashCode\"的例子\n<例子1>\npublic class User {\n private String name;\n @Override\n public boolean equals(Object o) {\n if (this == o) return true;\n if (o == null || getClass() != o.getClass()) return false;\n User user = (User) o;\n return Objects.equals(name, user.name);\n }\n @Override\n public int hashCode() {\n return Objects.hash(name);\n }\n}\n" + }, + { + "id": 42, + "text": "使用 Map 的方法 keySet() / values() / entrySet() 返回集合对象时,不可以对其进行添加元素操作,否则会抛出 UnsupportedOperationException 异常。", + "detail": "缺陷类型:对 Map 的 keySet() / values() / entrySet() 返回的集合进行添加操作;修复方案:避免对这些集合进行添加操作。", + "language": "Java", + "yes_example": "### 被判定为\"使用 Map 的方法 keySet() / values() / entrySet() 返回集合对象时,不可以对其进行添加元素操作,否则会抛出 UnsupportedOperationException 异常\"的例子\n<例子1>\nMap map = new HashMap<>();\nmap.put(\"key1\", \"value1\");\nSet keys = map.keySet();\nkeys.add(\"key2\");\n", + "no_example": "### 不能被判定为\"使用 Map 的方法 keySet() / values() / entrySet() 返回集合对象时,不可以对其进行添加元素操作,否则会抛出 UnsupportedOperationException 异常\"的例子\n<例子1>\nMap map = new HashMap<>();\nmap.put(\"key1\", \"value1\");\nSet keys = map.keySet();\n// 不进行添加操作\n" + }, + { + "id": 43, + "text": "不要在 foreach 循环里进行元素的 remove / add 操作。remove 元素请使用 iterator 方式,如果并发操作,需要对 iterator", + "detail": "缺陷类型:在 foreach 循环中进行元素的 remove / add 操作;修复方案:使用 iterator 进行元素的 remove 操作。", + "language": "Java", + "yes_example": "### 被判定为\"不要在 foreach 循环里进行元素的 remove / add 操作。remove 元素请使用 iterator 方式,如果并发操作,需要对 iterator\"的例子\n<例子1>\nList list = new ArrayList<>(Arrays.asList(\"a\", \"b\", \"c\"));\nfor (String s : list) {\n if (s.equals(\"a\")) {\n list.remove(s);\n }\n}\n", + "no_example": "### 不能被判定为\"不要在 foreach 循环里进行元素的 remove / add 操作。remove 元素请使用 iterator 方式,如果并发操作,需要对 iterator\"的例子\n<例子1>\nList list = new ArrayList<>(Arrays.asList(\"a\", \"b\", \"c\"));\nIterator iterator = list.iterator();\nwhile (iterator.hasNext()) {\n String s = iterator.next();\n if (s.equals(\"a\")) {\n iterator.remove();\n }\n}\n" + }, + { + "id": 44, + "text": "类、类属性、类方法的注释必须使用 Javadoc 规范,使用 /** 内容 */ 格式,不得使用 // xxx方式。", + "detail": "缺陷类型:注释不符合 Javadoc 规范;修复方案:使用 Javadoc 规范的注释格式。", + "language": "Java", + "yes_example": "### 被判定为\"类、类属性、类方法的注释必须使用 Javadoc 规范,使用 /** 内容 */ 格式,不得使用 // xxx方式\"的例子\n<例子1>\npublic class Example {\n // 这是一个类注释\n private String name;\n // 这是一个属性注释\n public String getName() {\n return name;\n }\n // 这是一个方法注释\n}\n", + "no_example": "### 不能被判定为\"类、类属性、类方法的注释必须使用 Javadoc 规范,使用 /** 内容 */ 格式,不得使用 // xxx方式\"的例子\n<例子1>\n/**\n * 这是一个类注释\n */\npublic class Example {\n /**\n * 这是一个属性注释\n */\n private String name;\n /**\n * 这是一个方法注释\n */\n public String getName() {\n return name;\n }\n}\n" + }, + { + "id": 45, + "text": "所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释", + "detail": "缺陷类型:所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释;修复方案:为所有的抽象方法(包括接口中的方法)添加 Javadoc 注释,除了返回值、参数异常说明外,还必须指出该方法做什么事情,实现什么功能。", + "language": "Java", + "yes_example": "### 被判定为\"所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释\"的例子\n<例子1>\npublic interface MyInterface {\n void doSomething();\n}\n这段代码中的接口方法 doSomething() 没有 Javadoc 注释,所以被判定为缺少 Javadoc 注释。\n", + "no_example": "### 不能被判定为\"所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释\"的例子\n<例子1>\n/**\n * 执行某个操作\n * @param param 参数说明\n * @return 返回值说明\n * @throws Exception 异常说明\n */\npublic interface MyInterface {\n void doSomething(String param) throws Exception;\n}\n这段代码中的接口方法 doSomething() 有完整的 Javadoc 注释,所以不能被判定为缺少 Javadoc 注释。\n" + }, + { + "id": 46, + "text": "方法内部单行注释和多行注释的使用规范", + "detail": "缺陷类型:注释使用不规范;修复方案:方法内部单行注释,在被注释语句上方另起一行,使用 // 注释。方法内部多行注释使用 /* */注释,注意与代码对齐。", + "language": "Java", + "yes_example": "### 被判定为\"注释使用不规范\"的例子\n<例子1>\npublic void exampleMethod() {\n int a = 1; // 初始化变量a\n int b = 2; /* 初始化变量b */\n}\n这段代码中的单行注释和多行注释没有按照规范使用,所以被判定为注释使用不规范。\n", + "no_example": "### 不能被判定为\"注释使用不规范\"的例子\n<例子1>\npublic void exampleMethod() {\n // 初始化变量a\n int a = 1;\n /*\n * 初始化变量b\n */\n int b = 2;\n}\n这段代码中的单行注释和多行注释按照规范使用,所以不能被判定为注释使用不规范。\n" + }, + { + "id": 47, + "text": "所有的枚举类型字段必须要有注释", + "detail": "缺陷类型:枚举类型字段缺少注释;修复方案:为所有的枚举类型字段添加注释,说明每个数据项的用途。", + "language": "Java", + "yes_example": "### 被判定为\"枚举类型字段缺少注释\"的例子\n<例子1>\npublic enum Status {\n ACTIVE,\n INACTIVE\n}\n这段代码中的枚举类型字段没有注释,所以被判定为枚举类型字段缺少注释。\n", + "no_example": "### 不能被判定为\"枚举类型字段缺少注释\"的例子\n<例子1>\npublic enum Status {\n /**\n * 活跃状态\n */\n ACTIVE,\n /**\n * 非活跃状态\n */\n INACTIVE\n}\n这段代码中的枚举类型字段有注释,所以不能被判定为枚举类型字段缺少注释。\n" + }, + { + "id": 48, + "text": "finally 块必须对资源对象、流对象进行关闭", + "detail": "缺陷类型:资源对象、流对象未在 finally 块中关闭;修复方案:在 finally 块中对资源对象、流对象进行关闭,有异常也要做 try-catch。", + "language": "Java", + "yes_example": "### 被判定为\"资源对象、流对象未在 finally 块中关闭\"的例子\n<例子1>\npublic void readFile() {\n FileInputStream fis = null;\n try {\n fis = new FileInputStream(\"file.txt\");\n // 读取文件内容\n } catch (IOException e) {\n e.printStackTrace();\n }\n}\n这段代码中的 FileInputStream 对象没有在 finally 块中关闭,所以被判定为资源对象、流对象未在 finally 块中关闭。\n", + "no_example": "### 不能被判定为\"资源对象、流对象未在 finally 块中关闭\"的例子\n<例子1>\npublic void readFile() {\n FileInputStream fis = null;\n try {\n fis = new FileInputStream(\"file.txt\");\n // 读取文件内容\n } catch (IOException e) {\n e.printStackTrace();\n } finally {\n if (fis != null) {\n try {\n fis.close();\n } catch (IOException e) {\n e.printStackTrace();\n }\n }\n }\n}\n这段代码中的 FileInputStream 对象在 finally 块中关闭,所以不能被判定为资源对象、流对象未在 finally 块中关闭。\n" + }, + { + "id": 49, + "text": "常量命名应该全部大写,单词间用下划线隔开", + "detail": "缺陷类型:常量命名不规范;修复方案:常量命名应该全部大写,单词间用下划线隔开,力求语义表达完整清楚,不要嫌名字长。", + "language": "Java", + "yes_example": "### 被判定为\"常量命名应该全部大写,单词间用下划线隔开\"的例子\n<例子1>\npublic static final int maxCount = 100;\n", + "no_example": "### 不能被判定为\"常量命名应该全部大写,单词间用下划线隔开\"的例子\n<例子1>\npublic static final int MAX_COUNT = 100;\n" + }, + { + "id": 50, + "text": "任何二目、三目运算符的左右两边都需要加一个空格", + "detail": "缺陷类型:运算符两边缺少空格;修复方案:任何二目、三目运算符的左右两边都需要加一个空格。", + "language": "Java", + "yes_example": "### 被判定为\"任何二目、三目运算符的左右两边都需要加一个空格\"的例子\n<例子1>\nint a=b+c;\n", + "no_example": "### 不能被判定为\"任何二目、三目运算符的左右两边都需要加一个空格\"的例子\n<例子1>\nint a = b + c;\n" + }, + { + "id": 51, + "text": "避免使用from import *", + "detail": "缺陷类型:避免使用from import *,导入所有内容会造成命名冲突;修复方案:每个使用到的子依赖需分别导入。", + "language": "Python", + "yes_example": "### 被判定为\"避免使用from import *\"的例子\n<例子1>from math import * \n", + "no_example": "### 不能被判定为\"避免使用from import *\"的例子\n<例子1>from math import sqrt, pi \n" + }, + { + "id": 52, + "text": "避免使用__import__()函数动态导入模块", + "detail": "缺陷类型:避免使用__import__()函数动态导入模块;修复方案:使用标准的import语句。", + "language": "Python", + "yes_example": "### 被判定为\"使用__import__()函数动态导入模块\"的例子\n<例子1>module = __import__('math') \n", + "no_example": "### 不能被判定为\"使用__import__()函数动态导入模块\"的例子\n<例子1>import math \n" + }, + { + "id": 53, + "text": "导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组", + "detail": "缺陷类型:导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组;修复方案:按顺序分组导入语句。", + "language": "Python", + "yes_example": "### 被判定为'导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组'的例子\n<例子1>\nimport numpy as np\nimport os\nimport sys\nfrom my_local_module import my_function\n在这个样例中,先导入了第三方库,然后导入了标准库。\n\n<例子2>\nfrom my_project import my_local_function\nimport datetime\nimport requests\n在这个样例中,先导入了本地模块,然后导入了标准库。\n\n<例子3>\nimport os\nfrom my_project.local_module import some_function\nimport pandas as pd\nimport sys\nfrom another_local_module import another_function\nimport math\n在这个样例中,导入语句完全混乱,没有遵循任何顺序。\n\n<例子4>\nimport os\nimport requests\nimport sys\nimport numpy as np\nfrom local_package import local_module\n在这个样例中,导入标准库和第三方库交替进行。\n", + "no_example": "### 不能被判定为'导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组'的例子\n<例子1>import os \n\n import requests \n\n import mymodule \n" + }, + { + "id": 54, + "text": "避免未使用的函数形参", + "detail": "缺陷类型:避免未使用的函数形参;修复方案:移除未使用的函数形参。", + "language": "Python", + "yes_example": "### 被判定为'避免未使用的函数形参'的例子\n<例子1>def func(a, b): \n return a\n<例子2>def start_game(unused_param): \npuzzle = Puzzle() \npuzzle.solve()\n<例子3>def make_move(self, board):\npass \n\n<例子4>def move(self, direction):\npass \n", + "no_example": "### 不能被判定为'避免未使用的函数形参'的例子\n<例子1>def func(a): \n return a" + }, + { + "id": 55, + "text": "使用is not None来检查一个变量是否不是None", + "detail": "缺陷类型:未使用is not None来检查一个变量是否不是None;修复方案:使用is not None来检查。", + "language": "Python", + "yes_example": "### 被判定为'未使用is not None来检查一个变量是否不是None'的例子\n<例子1>if variable != None:\n pass", + "no_example": "### 不能被判定为'未使用is not None来检查一个变量是否不是None'的例子\n<例子1>if variable is not None:\n pass" + }, + { + "id": 56, + "text": "避免使用==或!=来比较对象实例的等价性", + "detail": "缺陷类型:使用==或!=来比较对象实例的等价性;修复方案:应使用equals比较。", + "language": "Python", + "yes_example": "### 被判定为'使用==或!=来比较对象实例的等价性'的例子\n<例子1>obj1 = MyClass() \n obj2 = MyClass() if obj1 == obj2: \n pass\n", + "no_example": "### 不能被判定为'使用==或!=来比较对象实例的等价性'的例子\n<例子1>obj1 = MyClass() \n obj2 = MyClass() if obj1.equals(obj2): \n pass\n\n<例子2>obj1 = 21 \n obj2 = 22 \n if obj1.equals(obj2):\n pass" + }, + { + "id": 57, + "text": "避免使用单字母变量名,使用描述性变量名", + "detail": "缺陷类型:避免使用单字母变量名,使用描述性变量名;修复方案:使用描述性变量名。", + "language": "Python", + "yes_example": "### 被判定为'避免使用单字母变量名,使用描述性变量名'的例子\n<例子1>x = 10 \n\n<例子2>y = 10 \n", + "no_example": "### 不能被判定为'避免使用单字母变量名,使用描述性变量名'的例子\n<例子1>count = 10 \n" + }, + { + "id": 58, + "text": "常量命名使用全大写字母,并用下划线分隔", + "detail": "缺陷类型:常量命名未使用全大写字母或未用下划线分隔;修复方案:常量命名使用全大写字母,并用下划线分隔。", + "language": "Python", + "yes_example": "### 被判定为'常量命名未使用全大写字母,并用下划线分隔'的例子\n<例子1>pi = 3.14159", + "no_example": "### 不能被判定为'常量命名未使用全大写字母,并用下划线分隔'的例子\n<例子1>PI = 3.14159\n<例子2>max_size = 1 \n max_size += 1" + }, + { + "id": 59, + "text": "类名应使用驼峰式命名(CamelCase)", + "detail": "缺陷类型:类名未使用驼峰式命名;修复方案:类名使用驼峰式命名。", + "language": "Python", + "yes_example": "### 被判定为'类名未使用驼峰式命名(CamelCase)'的例子\n<例子1>class my_class: \n pass\n<例子2>class my_class: \n def solve(self):\n pass", + "no_example": "### 不能被判定为'类名未使用驼峰式命名(CamelCase)'的例子\n<例子1>class MyClass: \n pass" + }, + { + "id": 60, + "text": "尽量使用with语句来管理资源", + "detail": "缺陷类型:未使用with语句来管理资源;修复方案:使用with语句来管理资源。", + "language": "Python", + "yes_example": "### 被判定为'未使用with语句来管理资源'的例子\n<例子1>file = open('file.txt', 'r') \n content = file.read() \n file.close()", + "no_example": "### 不能被判定为'未使用with语句来管理资源'的例子\n<例子1>with open('file.txt', 'r') as file: \n content = file.read()" + }, + { + "id": 61, + "text": "避免使用except 或 通用的Exception来捕获所有异常,应该指定异常类型", + "detail": "缺陷类型:捕获所有异常;修复方案:指定具体的异常类型。", + "language": "Python", + "yes_example": "### 被判定为'使用except:来捕获所有异常'的例子\n<例子1>try: \n # some code \n except: \n handle_error()\n### 被判定为'抛出通用的Exception异常'的例子\n<例子2>\n try:\n process_data(data) \n except: \n raise Exception('An error occurred') \n ", + "no_example": "### 不能被判定为'使用except:来捕获所有异常'的例子\n<例子1>try: \n # some code \n except ValueError: \n handle_value_error()" + }, + { + "id": 62, + "text": "尽量避免手动拼接字符串", + "detail": "缺陷类型:手动拼接字符串;修复方案:使用格式化字符串或join方法。", + "language": "Python", + "yes_example": "### 被判定为'手动拼接字符串'的例子\n<例子1>\n name = 'John' \n greeting = 'Hello, ' + name + '!' \n \n <例子2>greeting = '2048' + 'game' \n \n <例子3>pygame.display.set_caption('贪吃蛇' + '游戏')", + "no_example": "### 不能被判定为'手动拼接字符串'的例子\n<例子1>\n name = 'John' \n greeting = f'Hello, {name}!' \n" + }, + { + "id": 63, + "text": "避免出现魔法字符和数字,应声明为常量", + "detail": "缺陷类型:使用魔法字符和数字;修复方案:将其声明为常量。", + "language": "Python", + "yes_example": "### 被判定为'出现魔法字符和数字'的例子\n<例子1>\n if status == 1: \n print('Active')' \n\n<例子2>\n self.board = [[0] * 4 for _ in range(4)] \n self.score = 0\n<例子3>\ndef __init__(self, width=10, height=10, mines=15):\n\n<例子4>\nx, y = event.x // 20, event.y // 20\n\n<例子5>\nraise ValueError(\"余额不足\")\n\n<例子6>\ntransfer(bank, \"123\", \"456\", 200)\n\n<例子7>\nbank.add_account(Account(\"123\", 1000))\n", + "no_example": "### 不能被判定为'出现魔法字符和数字'的例子\n<例子1>\n ACTIVE_STATUS = 1 \n if status == ACTIVE_STATUS:\n print(ACTIVE_STATUS)' \n" + }, + { + "id": 64, + "text": "boolean变量判断无需显式比较", + "detail": "缺陷类型:显式比较boolean变量;修复方案:直接使用boolean变量进行判断。", + "language": "Python", + "yes_example": "### 被判定为'显式比较boolean变量'的例子\n<例子1>flag = True \n if flag == True: \n print('Flag is true')\n<例子2>if self.game.is_game_over() == True: \n return<例子3>if self.canvas.drawings ==True:", + "no_example": "### 不能被判定为'显式比较boolean变量'的例子\n<例子1>flag = True \n if flag: \n print('Flag is true') \n" + }, + { + "id": 65, + "text": "避免使用type()检查对象类型", + "detail": "缺陷类型:避免使用type()检查对象类型;修复方案:使用isinstance()函数。", + "language": "Python", + "yes_example": "### 被判定为'避免使用type()检查对象类型'的例子\n<例子1>\n if type(obj) == list: \n print('obj is a list')", + "no_example": "### 不能被判定为'避免使用type()检查对象类型'的例子\n<例子1>\n if isinstance(obj, list): \n print('obj is a list') \n" + }, + { + "id": 66, + "text": "避免使用os.system()来调用外部命令", + "detail": "缺陷类型:使用os.system()调用外部命令;修复方案:使用subprocess模块。", + "language": "Python", + "yes_example": "### 被判定为'使用os.system()来调用外部命令'的例子\n<例子1>os.system('ls -l')\n<例子2>os.system('ls -l')", + "no_example": "### 不能被判定为'使用os.system()来调用外部命令'的例子\n<例子1>import subprocess \n subprocess.run(['ls', '-l'])" + }, + { + "id": 67, + "text": "只使用@property装饰器创建只读属性,而非修改属性", + "detail": "缺陷类型:使用@property装饰器创建可修改属性;修复方案:只使用@property装饰器创建只读属性。", + "language": "Python", + "yes_example": "### 被判定为'使用@property装饰器来创建可修改属性'的例子\n<例子1>@property \n def value(self, new_value): \n self._value = new_value\n<例子2>@property \n def game_over(self): \n return self._is_game_over() \n def _is_game_over(self): \n pass", + "no_example": "### 不能被判定为'使用@property装饰器来创建可修改属性'的例子\n<例子1>@property \n def value(self): \n return self._value\n<例子2>@property \n def __str__(self): \n return 'Maze Game State'" + }, + { + "id": 68, + "text": "在使用索引或切片时,不要在方括号或冒号内加空格", + "detail": "缺陷类型:在索引或切片的方括号或冒号内加空格;修复方案:去掉方括号或冒号内的空格。", + "language": "Python", + "yes_example": "### 被判定为'在使用索引或切片时,在方括号或冒号内加空格'的例子\n<例子1>list = [1, 2, 3, 4] \n sublist = list[ 1 : 3 ]\n<例子2>start_point = self.canvas.drawings[ -1] \n<例子3>if head[ 0] < 0 or head[ 0] >= GRID_WIDTH or head[ 1] < 0 or head[ 1] >= GRID_HEIGHT:\n<例子4>for segment in self.snake[ 1:]:", + "no_example": "### 不能被判定为'在使用索引或切片时,在方括号或冒号内加空格'的例子\n<例子1>list = [1, 2, 3, 4] \n sublist = list[1:3]" + }, + { + "id": 69, + "text": "在逗号、分号或冒号前不要加空格,但在它们之后要加空格", + "detail": "缺陷类型:在逗号、分号或冒号前加空格或在它们之后不加空格;修复方案:在逗号、分号或冒号前不要加空格,但在它们之后要加空格。", + "language": "Python", + "yes_example": "### 被判定为'在逗号、分号或冒号前加空格,或没在它们之后加空格'的例子\n<例子1>if x == 4 : \n print(x , y)\n<例子2>if event.keysym == 'Up' or event.keysym == 'Down' or event.keysym == 'Left' or event.keysym == 'Right' :\n<例子3>x ,y = 1 ,2\n<例子4>def on_key_press(self , event) :\n<例子5>elif event.keysym == 'Down' ; \n<例子6>def update_status(self ,message: str) : \n pass ", + "no_example": "### 不能被判定为'在逗号、分号或冒号前加空格,或没在它们之后加空格'的例子\n<例子1>if x == 4: \n print(x, y)" + }, + { + "id": 70, + "text": "对于二元操作符,两边都应有空格", + "detail": "缺陷类型:二元操作符两边没有空格;修复方案:在二元操作符两边加空格", + "language": "Python", + "yes_example": "### 被判定为'二元操作符两边没有空格'的例子\n<例子1>a=b+1", + "no_example": "### 不能被判定为'二元操作符两边没有空格'的例子\n<例子1>a = b + 1\n<例子2>label = tk.Label(self.root, text=str(cell), bg='white')\n<例子3>label.grid(row=i, column=j)" + }, + { + "id": 71, + "text": "避免使用Python关键字作为变量名或函数名", + "detail": "缺陷类型:使用Python关键字作为变量名或函数名;修复方案:使用非关键字的名称。", + "language": "Python", + "yes_example": "### 被判定为'使用Python关键字作为变量名或函数名'的例子\n<例子1>def class(): \n pass\n<例子2>for = 5\n<例子3>def if(self): ", + "no_example": "### 不能被判定为'使用Python关键字作为变量名或函数名'的例子\n<例子1>def my_function(): \n pass\n<例子2>number = 5" + }, + { + "id": 72, + "text": "避免使用特殊字符作为变量名/方法名/类名,例如$或@", + "detail": "缺陷类型:使用特殊字符作为变量名/方法名/类名;修复方案:使用合法的变量名。", + "language": "Python", + "yes_example": "### 被判定为'使用特殊字符作为变量名/方法名/类名,例如$或@'的例子\n<例子1>my$var = 10\n<例子2>@var = 20\n<例子3>def add_score@(self, points): \n self.score += points\n<例子4>class @MyClass: \n pass\n<例子5>def mine@(self):", + "no_example": "### 不能被判定为'使用特殊字符作为变量名/方法名/类名,例如$或@'的例子\n<例子1>my_var = 10\n<例子2>var_20 = 20" + }, + { + "id": 73, + "text": "避免使用raise来重新抛出当前的异常,这会丢失原始的栈跟踪", + "detail": "缺陷类型:使用raise重新抛出当前异常;修复方案:使用raise ... from ...语法。", + "language": "Python", + "yes_example": "### 被判定为'避免使用raise来重新抛出当前的异常,这会丢失原始的栈跟踪'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError: \n raise SomeException('新的异常信息')\n\n<例子2>\ntry:\n db.get_data()\nexcept ValueError as e:\n raise ValueError(\"Something went wrong!\")\n\n<例子3>\ntry:\n\traise Exception(\"形状添加失败\")\nexcept Exception as e:\n\tpass\n", + "no_example": "### 不能被判定为'避免使用raise来重新抛出当前的异常,这会丢失原始的栈跟踪'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError as e: \n raise RuntimeError('Error occurred') from e \n\n<例子2>\n try: \n 1 / 0 \n except ZeroDivisionError as e: \n\tlogger.error(e)\n raise \n" + }, + { + "id": 74, + "text": "避免在except块中使用pass,这会捕获并忽略异常", + "detail": "缺陷类型:在except块中使用pass;修复方案:处理异常或记录日志。", + "language": "Python", + "yes_example": "### 被判定为'在except块中使用pass'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError: \n pass \n \n<例子2>\n try: \n 1 / 0 \n except ZeroDivisionError: \n pass \n", + "no_example": "### 不能被判定为'在except块中使用pass'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError as e: \n logging.error('Error occurred: %s', e) \n" + }, + { + "id": 75, + "text": "避免使用assert语句来执行重要的运行时检查", + "detail": "缺陷类型:使用assert语句执行重要的运行时检查;修复方案:使用显式的条件检查和异常处理。", + "language": "Python", + "yes_example": "### 被判定为'使用assert语句来执行重要的运行时检查'的例子\n<例子1>\n def divide(a, b): \n assert b != 0 \n return a / b \n", + "no_example": "### 不能被判定为'使用assert语句来执行重要的运行时检查'的例子\n<例子1>\n def divide(a, b): \n if b == 0: \n raise ValueError('b cannot be zero') \n return a / b \n" + }, + { + "id": 76, + "text": "避免使用eval()和exec(),这些函数可能会带来安全风险", + "detail": "缺陷类型:使用eval()和exec()函数;修复方案:使用安全的替代方案。", + "language": "Python", + "yes_example": "### 被判定为'使用eval()和exec()'的例子\n<例子1>\n eval('print(1)') \n\n<例子2> \n exec('a = 1') \n", + "no_example": "### 不能被判定为'使用eval()和exec()'的例子\n<例子1>\n compiled_code = compile('print(1)', '', 'exec') \n exec(compiled_code) \n" + }, + { + "id": 77, + "text": "避免使用sys.exit(),应使用异常来控制程序的退出", + "detail": "缺陷类型:避免使用sys.exit(),应使用异常来控制程序的退出;修复方案:使用异常来控制程序的退出。", + "language": "Python", + "yes_example": "### 被判定为'避免使用sys.exit(),应使用异常来控制程序的退出'的例子\n<例子1>\n import sys\nsys.exit(1)\n\n<例子2>\n import sys \n sys.exit()\n\n<例子3>\nif event.type == pygame.QUIT:\n\tpygame.quit()\n\texit()\n\n<例子4>\n import sys \n sys.exit('退出程序'))\n", + "no_example": "### 不能被判定为'避免使用sys.exit(),应使用异常来控制程序的退出'的例子\n<例子1>\n raise SystemExit(1)\n" + }, + { + "id": 78, + "text": "避免使用time.sleep()进行线程同步,应使用同步原语,如锁或事件", + "detail": "缺陷类型:使用time.sleep()进行线程同步;修复方案:使用同步原语。", + "language": "Python", + "yes_example": "### 被判定为'使用time.sleep()进行线程同步'的例子\n<例子1>\n import time \n\n def worker(): \n time.sleep(1) \n\n<例子2>\n import time \n\n time.sleep(1) \n", + "no_example": "### 不能被判定为'使用time.sleep()进行线程同步'的例子\n<例子1>\n import threading \n\n event = threading.Event() \n\n def worker(): \n event.wait()\n" + }, + { + "id": 79, + "text": "每行代码避免超过79个字符", + "detail": "缺陷类型:每行代码避免超过79个字符;修复方案:将长行代码格式化为多行。", + "language": "Python", + "yes_example": "### 被判定为'每行代码避免超过79个字符'的例子\n<例子1>\n print('This is a very long line of code that exceeds the 79 characters limit........') \n", + "no_example": "### 不能被判定为'每行代码避免超过79个字符'的例子\n<例子1>\n print('This is a very long line of code that exceeds the 79 characters limit' + \n ' but it is split into two lines')\n" + }, + { + "id": 80, + "text": "模块级别的函数和类定义之间用两个空行分隔,类内部的方法定义之间用一个空行分隔", + "detail": "缺陷类型:模块级别的函数和类定义之间没有用两个空行分隔,类内部的方法定义之间没有用一个空行分隔;修复方案:按照规范添加空行。", + "language": "Python", + "yes_example": "### 被判定为'模块级别的函数和类定义之间没用两个空行分隔,类内部的方法定义之间没用一个空行分隔'的例子\n<例子1>\n def func1(): \n pass \n def func2(): \n pass \n\n<例子2>\n class MyClass: \n def method1(self): \n pass \n def method2(self): \n pass \n", + "no_example": "### 不能被判定为'模块级别的函数和类定义之间没用两个空行分隔,类内部的方法定义之间没用一个空行分隔'的例子\n<例子1>\n def func1(): \n pass \n\n\n def func2(): \n pass \n\n<例子2>\n class MyClass: \n def method1(self): \n pass \n\n def method2(self): \n pass \n" + }, + { + "id": 81, + "text": "使用小写字母和下划线分隔的方式命名变量和函数名", + "detail": "缺陷类型:变量和函数命名不符合小写字母和下划线分隔的方式;修复方案:使用小写字母和下划线分隔的方式命名。", + "language": "Python", + "yes_example": "### 被判定为'未使用小写字母和下划线分隔的方式命名变量和函数'的例子\n<例子1>\n def myFunction(): \n pass \n\n<例子2>\n myVariable = 10 \n\n<例子3>\n def Calculatesquareroot(self, x): \n return 1 \n", + "no_example": "### 不能被判定为'未使用小写字母和下划线分隔的方式命名变量和函数'的例子\n<例子1>\n def my_function(): \n pass \n\n<例子2>\n my_variable = 10 \n" + }, + { + "id": 82, + "text": "不允许使用print()函数来记录日志,使用logging模块等来记录日志", + "detail": "缺陷类型:使用print()函数记录日志;修复方案:使用logging模块记录日志。", + "language": "Python", + "yes_example": "### 被判定为'使用print()函数来记录日志'的例子\n<例子1>\n print('Error occurred') \n\n<例子2>\n print('打印的日志字符串内容') \n\n<例子3>\n task = 'xxx' \n print(task) \n\n<例子4>\n print(1)\n", + "no_example": "### 不能被判定为'使用print()函数来记录日志'的例子\n<例子1>\n import logging \n logging.error('Error occurred') \n" + } +] diff --git a/metagpt/ext/cr/utils/__init__.py b/metagpt/ext/cr/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metagpt/ext/cr/utils/cleaner.py b/metagpt/ext/cr/utils/cleaner.py new file mode 100644 index 0000000000..8fc0b798ca --- /dev/null +++ b/metagpt/ext/cr/utils/cleaner.py @@ -0,0 +1,68 @@ +"""Cleaner.""" + +from unidiff import Hunk, PatchedFile, PatchSet + +from metagpt.logs import logger + + +def rm_patch_useless_part(patch: PatchSet, used_suffix: list[str] = ["java", "py"]) -> PatchSet: + new_patch = PatchSet("") + useless_files = [] + for pfile in patch: + suffix = str(pfile.target_file).split(".")[-1] + if suffix not in used_suffix or pfile.is_removed_file: + useless_files.append(pfile.path) + continue + new_patch.append(pfile) + logger.info(f"total file num: {len(patch)}, used file num: {len(new_patch)}, useless_files: {useless_files}") + return new_patch + + +def add_line_num_on_patch(patch: PatchSet, start_line_num: int = 1) -> PatchSet: + new_patch = PatchSet("") + lineno = start_line_num + for pfile in patch: + new_pfile = PatchedFile( + source=pfile.source_file, + target=pfile.target_file, + source_timestamp=pfile.source_timestamp, + target_timestamp=pfile.target_timestamp, + ) + for hunk in pfile: + arr = [str(line) for line in hunk] + new_hunk = Hunk( + src_start=hunk.source_start, + src_len=hunk.source_length, + tgt_start=hunk.target_start, + tgt_len=hunk.target_length, + section_header=hunk.section_header, + ) + + for line in arr: + # if len(line) > 0 and line[0] in ["+", "-"]: + # line = f"{lineno} {line}" + # lineno += 1 + line = f"{lineno} {line}" + lineno += 1 + new_hunk.append(line) + new_pfile.append(new_hunk) + new_patch.append(new_pfile) + return new_patch + + +def get_code_block_from_patch(patch: PatchSet, code_start_line: str, code_end_line: str) -> str: + line_arr = str(patch).split("\n") + code_arr = [] + add_line_tag = False + for line in line_arr: + if line.startswith(f"{code_start_line} "): + add_line_tag = True + + if add_line_tag: + new_line = " ".join(line.split(" ")[1:]) # rm line-no tag + code_arr.append(new_line) + + if line.startswith(f"{code_end_line} "): + add_line_tag = False + + return "\n".join(code_arr) diff --git a/metagpt/ext/cr/utils/schema.py b/metagpt/ext/cr/utils/schema.py new file mode 100644 index 0000000000..beb27a07f9 --- /dev/null +++ b/metagpt/ext/cr/utils/schema.py @@ -0,0 +1,20 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class Point(BaseModel): + id: int = Field(default=0, description="ID of the point.") + text: str = Field(default="", description="Content of the point.") + language: Literal["Python", "Java"] = Field( + default="Python", description="The programming language that the point corresponds to." + ) + file_path: str = Field(default="", description="The file that the points come from.") + start_line: int = Field(default=0, description="The starting line number that the point refers to.") + end_line: int = Field(default=0, description="The ending line number that the point refers to.") + detail: str = Field(default="", description="File content from start_line to end_line.") + yes_example: str = Field(default="", description="yes of point examples") + no_example: str = Field(default="", description="no of point examples") + + def rag_key(self) -> str: + return self.text diff --git a/metagpt/ext/stanford_town/actions/st_action.py b/metagpt/ext/stanford_town/actions/st_action.py index 321676374d..48cda353cc 100644 --- a/metagpt/ext/stanford_town/actions/st_action.py +++ b/metagpt/ext/stanford_town/actions/st_action.py @@ -8,7 +8,6 @@ from typing import Any, Optional, Union from metagpt.actions.action import Action -from metagpt.config2 import config from metagpt.ext.stanford_town.utils.const import PROMPTS_DIR from metagpt.logs import logger @@ -62,13 +61,13 @@ async def _aask(self, prompt: str) -> str: async def _run_gpt35_max_tokens(self, prompt: str, max_tokens: int = 50, retry: int = 3): for idx in range(retry): try: - tmp_max_tokens_rsp = getattr(config.llm, "max_token", 1500) - setattr(config.llm, "max_token", max_tokens) + tmp_max_tokens_rsp = getattr(self.config.llm, "max_token", 1500) + setattr(self.config.llm, "max_token", max_tokens) self.llm.use_system_prompt = False # to make it behave like a non-chat completions llm_resp = await self._aask(prompt) - setattr(config.llm, "max_token", tmp_max_tokens_rsp) + setattr(self.config.llm, "max_token", tmp_max_tokens_rsp) logger.info(f"Action: {self.cls_name} llm _run_gpt35_max_tokens raw resp: {llm_resp}") if self._func_validate(llm_resp, prompt): return self._func_cleanup(llm_resp, prompt) diff --git a/metagpt/ext/stanford_town/roles/st_role.py b/metagpt/ext/stanford_town/roles/st_role.py index 79f58b07d2..592b78a8f4 100644 --- a/metagpt/ext/stanford_town/roles/st_role.py +++ b/metagpt/ext/stanford_town/roles/st_role.py @@ -16,7 +16,7 @@ from datetime import datetime, timedelta from operator import itemgetter from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import Optional from pydantic import ConfigDict, Field, field_validator, model_validator @@ -27,6 +27,7 @@ EnvObsParams, EnvObsType, ) +from metagpt.environment.stanford_town.stanford_town_env import StanfordTownEnv from metagpt.ext.stanford_town.actions.dummy_action import DummyAction, DummyMessage from metagpt.ext.stanford_town.actions.inner_voice_action import ( AgentWhisperThoughtAction, @@ -49,28 +50,15 @@ from metagpt.schema import Message from metagpt.utils.common import any_to_str -if TYPE_CHECKING: - from metagpt.environment.stanford_town.stanford_town_env import ( # noqa: F401 - StanfordTownEnv, - ) - class STRoleContext(RoleContext): model_config = ConfigDict(arbitrary_types_allowed=True) - env: "StanfordTownEnv" = Field(default=None, exclude=True) + env: StanfordTownEnv = Field(default=None, exclude=True) memory: AgentMemory = Field(default_factory=AgentMemory) scratch: Scratch = Field(default_factory=Scratch) spatial_memory: MemoryTree = Field(default_factory=MemoryTree) - @classmethod - def model_rebuild(cls, **kwargs): - from metagpt.environment.stanford_town.stanford_town_env import ( # noqa: F401 - StanfordTownEnv, - ) - - super(RoleContext, cls).model_rebuild(**kwargs) - class STRole(Role): # add a role's property structure to store role's age and so on like GA's Scratch. @@ -181,13 +169,13 @@ def save_into(self): logger.info(f"Role: {self.name} saved role's memory into {str(self.role_storage_path)}") - async def _observe(self, ignore_memory=False) -> int: + async def _observe(self) -> int: if not self.rc.env: return 0 news = [] if not news: news = self.rc.msg_buffer.pop_all() - old_messages = [] if ignore_memory else self.rc.memory.get() + old_messages = [] if not self.enable_memory else self.rc.memory.get() # Filter out messages of interest. self.rc.news = [ n for n in news if (n.cause_by in self.rc.watch or self.name in n.send_to) and n not in old_messages @@ -635,6 +623,3 @@ async def _react(self) -> Message: time.sleep(0.5) return DummyMessage() - - -STRoleContext.model_rebuild() diff --git a/metagpt/ext/stanford_town/utils/utils.py b/metagpt/ext/stanford_town/utils/utils.py index 3aa0e80e8d..e09cce8fe3 100644 --- a/metagpt/ext/stanford_town/utils/utils.py +++ b/metagpt/ext/stanford_town/utils/utils.py @@ -13,7 +13,7 @@ from openai import OpenAI -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger @@ -48,6 +48,7 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True): def get_embedding(text, model: str = "text-embedding-ada-002"): + config = Config.default() text = text.replace("\n", " ") if not text: text = "this is blank" diff --git a/metagpt/learn/text_to_embedding.py b/metagpt/learn/text_to_embedding.py index f859ab6381..2b4adda801 100644 --- a/metagpt/learn/text_to_embedding.py +++ b/metagpt/learn/text_to_embedding.py @@ -6,12 +6,13 @@ @File : text_to_embedding.py @Desc : Text-to-Embedding skill, which provides text-to-embedding functionality. """ -import metagpt.config2 +from typing import Optional + from metagpt.config2 import Config from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding -async def text_to_embedding(text, model="text-embedding-ada-002", config: Config = metagpt.config2.config): +async def text_to_embedding(text, model="text-embedding-ada-002", config: Optional[Config] = None): """Text to embedding :param text: The text used for embedding. @@ -19,6 +20,7 @@ async def text_to_embedding(text, model="text-embedding-ada-002", config: Config :param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys` :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. """ + config = config if config else Config.default() openai_api_key = config.get_openai_llm().api_key proxy = config.get_openai_llm().proxy return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key, proxy=proxy) diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index 163859fc0d..9bfed532bf 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -7,8 +7,8 @@ @Desc : Text-to-Image skill, which provides text-to-image functionality. """ import base64 +from typing import Optional -import metagpt.config2 from metagpt.config2 import Config from metagpt.const import BASE64_FORMAT from metagpt.llm import LLM @@ -17,7 +17,7 @@ from metagpt.utils.s3 import S3 -async def text_to_image(text, size_type: str = "512x512", config: Config = metagpt.config2.config): +async def text_to_image(text, size_type: str = "512x512", config: Optional[Config] = None): """Text to image :param text: The text used for image conversion. @@ -25,6 +25,7 @@ async def text_to_image(text, size_type: str = "512x512", config: Config = metag :param config: Config :return: The image data is returned in Base64 encoding. """ + config = config if config else Config.default() image_declaration = "data:image/png;base64," model_url = config.metagpt_tti_url diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index 8dbd6d2436..9d3dba6853 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -6,7 +6,8 @@ @File : text_to_speech.py @Desc : Text-to-Speech skill, which provides text-to-speech functionality """ -import metagpt.config2 +from typing import Optional + from metagpt.config2 import Config from metagpt.const import BASE64_FORMAT from metagpt.tools.azure_tts import oas3_azsure_tts @@ -20,7 +21,7 @@ async def text_to_speech( voice="zh-CN-XiaomoNeural", style="affectionate", role="Girl", - config: Config = metagpt.config2.config, + config: Optional[Config] = None, ): """Text to speech For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` @@ -38,7 +39,7 @@ async def text_to_speech( :return: Returns the Base64-encoded .wav/.mp3 file data if successful, otherwise an empty string. """ - + config = config if config else Config.default() subscription_key = config.azure_tts_subscription_key region = config.azure_tts_region if subscription_key and region: diff --git a/metagpt/logs.py b/metagpt/logs.py index 90bac21aaf..63c10fa2fc 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -6,14 +6,34 @@ @File : logs.py """ +from __future__ import annotations + +import asyncio +import inspect import sys +from contextvars import ContextVar from datetime import datetime from functools import partial +from typing import Any from loguru import logger as _logger +from pydantic import BaseModel, Field from metagpt.const import METAGPT_ROOT +LLM_STREAM_QUEUE: ContextVar[asyncio.Queue] = ContextVar("llm-stream") + + +class ToolLogItem(BaseModel): + type_: str = Field(alias="type", default="str", description="Data type of `value` field.") + name: str + value: Any + + +TOOL_LOG_END_MARKER = ToolLogItem( + type="str", name="end_marker", value="\x18\x19\x1B\x18" +) # A special log item to suggest the end of a stream log + def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None): """Adjust the log level to above level""" @@ -31,12 +51,93 @@ def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None def log_llm_stream(msg): + """ + Logs a message to the LLM stream. + + Args: + msg: The message to be logged. + + Notes: + If the LLM_STREAM_QUEUE has not been set (e.g., if `create_llm_stream_queue` has not been called), + the message will not be added to the LLM stream queue. + """ + + queue = get_llm_stream_queue() + if queue: + queue.put_nowait(msg) _llm_stream_log(msg) +def log_tool_output(output: ToolLogItem | list[ToolLogItem], tool_name: str = ""): + """interface for logging tool output, can be set to log tool output in different ways to different places with set_tool_output_logfunc""" + _tool_output_log(output=output, tool_name=tool_name) + + +async def log_tool_output_async(output: ToolLogItem | list[ToolLogItem], tool_name: str = ""): + """async interface for logging tool output, used when output contains async object""" + await _tool_output_log_async(output=output, tool_name=tool_name) + + +async def get_human_input(prompt: str = ""): + """interface for getting human input, can be set to get input from different sources with set_human_input_func""" + if inspect.iscoroutinefunction(_get_human_input): + return await _get_human_input(prompt) + else: + return _get_human_input(prompt) + + def set_llm_stream_logfunc(func): global _llm_stream_log _llm_stream_log = func +def set_tool_output_logfunc(func): + global _tool_output_log + _tool_output_log = func + + +async def set_tool_output_logfunc_async(func): + # async version + global _tool_output_log_async + _tool_output_log_async = func + + +def set_human_input_func(func): + global _get_human_input + _get_human_input = func + + _llm_stream_log = partial(print, end="") + + +_tool_output_log = ( + lambda *args, **kwargs: None +) # a dummy function to avoid errors if set_tool_output_logfunc is not called + + +async def _tool_output_log_async(*args, **kwargs): + # async version + pass + + +def create_llm_stream_queue(): + """Creates a new LLM stream queue and sets it in the context variable. + + Returns: + The newly created asyncio.Queue instance. + """ + queue = asyncio.Queue() + LLM_STREAM_QUEUE.set(queue) + return queue + + +def get_llm_stream_queue(): + """Retrieves the current LLM stream queue from the context variable. + + Returns: + The asyncio.Queue instance if set, otherwise None. + """ + return LLM_STREAM_QUEUE.get(None) + + +_get_human_input = input # get human input from console by default diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index c58148ead1..8c2846d1d2 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -12,9 +12,9 @@ import re from typing import Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator -from metagpt.config2 import config +from metagpt.config2 import Config as _Config from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger from metagpt.provider import MetaGPTLLM @@ -32,6 +32,12 @@ class BrainMemory(BaseModel): last_talk: Optional[str] = None cacheable: bool = True llm: Optional[BaseLLM] = Field(default=None, exclude=True) + config: Optional[_Config] = None + + @field_validator("config") + @classmethod + def set_default_config(cls, config): + return config if config else _Config.default() class Config: arbitrary_types_allowed = True @@ -54,9 +60,8 @@ def get_knowledge(self) -> str: texts = [m.content for m in self.knowledge] return "\n".join(texts) - @staticmethod - async def loads(redis_key: str) -> "BrainMemory": - redis = Redis(config.redis) + async def loads(self, redis_key: str) -> "BrainMemory": + redis = Redis(self.config.redis) if not redis_key: return BrainMemory() v = await redis.get(key=redis_key) @@ -70,7 +75,7 @@ async def loads(redis_key: str) -> "BrainMemory": async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60): if not self.is_dirty: return - redis = Redis(config.redis) + redis = Redis(self.config.redis) if not redis_key: return False v = self.model_dump_json() @@ -140,7 +145,7 @@ async def _openai_summarize(self, llm, max_words=200, keep_language: bool = Fals return text summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit) if summary: - await self.set_history_summary(history_summary=summary, redis_key=config.redis_key) + await self.set_history_summary(history_summary=summary, redis_key=self.config.redis_key) return summary raise ValueError(f"text too long:{text_length}") @@ -164,7 +169,7 @@ async def _metagpt_summarize(self, max_words=200): msgs.reverse() self.history = msgs self.is_dirty = True - await self.dumps(redis_key=config.redis.key) + await self.dumps(redis_key=self.config.redis.key) self.is_dirty = False return BrainMemory.to_metagpt_history_format(self.history) @@ -181,7 +186,7 @@ async def get_title(self, llm, max_words=5, **kwargs) -> str: summary = await self.summarize(llm=llm, max_words=500) - language = config.language + language = self.config.language command = f"Translate the above summary into a {language} title of less than {max_words} words." summaries = [summary, command] msg = "\n".join(summaries) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 580361d336..0707a36ea4 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -7,13 +7,14 @@ @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ from collections import defaultdict -from typing import DefaultDict, Iterable, Set +from typing import DefaultDict, Iterable, Optional, Set from pydantic import BaseModel, Field, SerializeAsAny from metagpt.const import IGNORED_MESSAGE_ID from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.exceptions import handle_exception class Memory(BaseModel): @@ -104,3 +105,8 @@ def get_by_actions(self, actions: Set) -> list[Message]: continue rsp += self.index[action] return rsp + + @handle_exception + def get_by_position(self, position: int) -> Optional[Message]: + """Returns the message at the given position if valid, otherwise returns None""" + return self.storage[position] diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py new file mode 100644 index 0000000000..a323208812 --- /dev/null +++ b/metagpt/memory/role_zero_memory.py @@ -0,0 +1,201 @@ +""" +This module implements a memory system combining short-term and long-term storage for AI role memory management. +It utilizes a RAG (Retrieval-Augmented Generation) engine for long-term memory storage and retrieval. +""" + +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import Field + +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME +from metagpt.logs import logger +from metagpt.memory import Memory +from metagpt.schema import LongTermMemoryItem, Message +from metagpt.utils.common import any_to_str +from metagpt.utils.exceptions import handle_exception + +if TYPE_CHECKING: + from llama_index.core.schema import NodeWithScore + + from metagpt.rag.engines import SimpleEngine + + +class RoleZeroLongTermMemory(Memory): + """ + Implements a memory system combining short-term and long-term storage using a RAG engine. + Transfers old memories to long-term storage when short-term capacity is reached. + Retrieves combined short-term and long-term memories as needed. + """ + + persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") + collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.") + memory_k: int = Field(default=200, description="The capacity of short-term memory.") + similarity_top_k: int = Field(default=5, description="The number of long-term memories to retrieve.") + use_llm_ranker: bool = Field(default=False, description="Whether to use LLM Reranker to get better result.") + + _rag_engine: Any = None + + @property + def rag_engine(self) -> "SimpleEngine": + if self._rag_engine is None: + self._rag_engine = self._resolve_rag_engine() + + return self._rag_engine + + def _resolve_rag_engine(self) -> "SimpleEngine": + """Lazy loading of the RAG engine components, ensuring they are only loaded when needed. + + It uses `Chroma` for retrieval and `LLMRanker` for ranking. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig + except ImportError: + raise ImportError("To use the RoleZeroMemory, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig( + persist_path=self.persist_path, + collection_name=self.collection_name, + similarity_top_k=self.similarity_top_k, + ) + ] + ranker_configs = [LLMRankerConfig()] if self.use_llm_ranker else [] + + rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return rag_engine + + def add(self, message: Message): + """Add a new message and potentially transfer it to long-term memory.""" + + super().add(message) + + if not self._should_use_longterm_memory_for_add(): + return + + self._transfer_to_longterm_memory() + + def get(self, k=0) -> list[Message]: + """Return recent memories and optionally combines them with related long-term memories.""" + + memories = super().get(k) + + if not self._should_use_longterm_memory_for_get(k=k): + return memories + + query = self._build_longterm_memory_query() + related_memories = self._fetch_longterm_memories(query) + logger.info(f"Fetched {len(related_memories)} long-term memories.") + + final_memories = related_memories + memories + + return final_memories + + def _should_use_longterm_memory_for_add(self) -> bool: + """Determines if long-term memory should be used for add.""" + + return self.count() > self.memory_k + + def _should_use_longterm_memory_for_get(self, k: int) -> bool: + """Determines if long-term memory should be used for get. + + Long-term memory is used if: + - k is not 0. + - The last message is from user requirement. + - The count of recent memories is greater than self.memory_k. + """ + + conds = [ + k != 0, + self._is_last_message_from_user_requirement(), + self.count() > self.memory_k, + ] + + return all(conds) + + def _transfer_to_longterm_memory(self): + item = self._get_longterm_memory_item() + self._add_to_longterm_memory(item) + + def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: + """Retrieves the most recent message before the last k messages.""" + + index = -(self.memory_k + 1) + message = self.get_by_position(index) + + return LongTermMemoryItem(message=message) if message else None + + @handle_exception + def _add_to_longterm_memory(self, item: LongTermMemoryItem): + """Adds a long-term memory item to the RAG engine. + + If adding long-term memory fails, it will only log the error without interrupting program execution. + """ + + if not item or not item.message.content: + return + + self.rag_engine.add_objs([item]) + + @handle_exception(default_return=[]) + def _fetch_longterm_memories(self, query: str) -> list[Message]: + """Fetches long-term memories based on a query. + + If fetching long-term memories fails, it will return the default value (an empty list) without interrupting program execution. + + Args: + query (str): The query string to search for relevant memories. + + Returns: + list[Message]: A list of user and AI messages related to the query. + """ + + if not query: + return [] + + nodes = self.rag_engine.retrieve(query) + items = self._get_items_from_nodes(nodes) + memories = [item.message for item in items] + + return memories + + def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]: + """Get items from nodes and arrange them in order of their `created_at`.""" + + items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes] + items.sort(key=lambda item: item.created_at) + + return items + + def _build_longterm_memory_query(self) -> str: + """Build the content used to query related long-term memory. + + Default is to get the most recent user message, or an empty string if none is found. + """ + + message = self._get_the_last_message() + + return message.content if message else "" + + def _get_the_last_message(self) -> Optional[Message]: + if not self.count(): + return None + + return self.get_by_position(-1) + + def _is_last_message_from_user_requirement(self) -> bool: + """Checks if the last message is from a user requirement or sent by the team leader.""" + + message = self._get_the_last_message() + + if not message: + return False + + is_user_message = message.is_user_message() + cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement) + sent_from_team_leader = message.sent_from == TEAMLEADER_NAME + + return is_user_message and (cause_by_user_requirement or sent_from_team_leader) diff --git a/metagpt/prompts/di/architect.py b/metagpt/prompts/di/architect.py new file mode 100644 index 0000000000..6d114abaa3 --- /dev/null +++ b/metagpt/prompts/di/architect.py @@ -0,0 +1,114 @@ +from metagpt.const import REACT_TEMPLATE_PATH, VUE_TEMPLATE_PATH + +SYSTEM_DESIGN_EXAMPLE = """ +```markdown +## Implementation approach": + +We will ... + +## File list + +- a.jsx +- b.jx +- c.py +- d.css +- e.html + +## Data structures and interfaces: + + +classDiagram + class Main { + <> + +main() str + } + class SearchEngine { + +search(query: str) str + } + class Index { + +create_index(data: dict) + +query_index(query: str) list + } + class Ranking { + +rank_results(results: list) list +} + +## Program call flow: + + +sequenceDiagram + participant M as Main + participant SE as SearchEngine + participant I as Index + participant R as Ranking + participant S as Summary + participant KB as KnowledgeBase + M->>SE: search(query) + SE->>I: query_index(query) + I->>KB: fetch_data(query) + KB-->>I: return data + + +## Anything UNCLEAR + +Clarification needed on third-party API integration, ... + +``` +""" + +ARCHITECT_INSTRUCTION = """ +You are an architect. Your task is to design a software system that meets the requirements. + +Note: +1. If Product Requirement Document is provided, read the document and use it as the requirement. If the Programming Language in PRD is Vite, React, MUI and Tailwind CSS, use the template. +2. Default programming language is Vite, React, MUI and Tailwind CSS. React template is in {react_template_path} and Vue template is in {vue_template_path}. +3. Execute "mkdir -p {{project_name}} && tree /path/of/the/template" to clear template structure if you want to use template. This must be a single response WITHOUT other commands. +4. The system design must adhere to the following rules: +4.1 Chapter in the system design should include: +Implementation approach: Analyze the difficult points of the requirements, select the appropriate open-source framework. +File list: Only need relative paths. If using template, index.html and the file in src folder must be included. +Data structures and interfaces: Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design. +Program call flow: Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT. +Anything UNCLEAR: Mention unclear project aspects, then try to clarify it. +4.2 System Design Format example: +{system_design_example} +5. Use Editor.write to write the system design in markdown format. The file path must be "{{project}}/docs/system_design.md". Use command_name "end" when the system design is finished. +6. If not memtioned, always use Editor.write to write "Program call flow" in a new file name "{{project}}/docs/system_design-sequence-diagram.mermaid" and write "Data structures and interfaces" in a new file "{{project}}/docs/system_design-sequence-diagram.mermaid-class-diagram". Mermaid code only. Do not add "```mermaid". +""".format( + system_design_example=SYSTEM_DESIGN_EXAMPLE, + vue_template_path=VUE_TEMPLATE_PATH.resolve().absolute(), + react_template_path=REACT_TEMPLATE_PATH.resolve().absolute(), +) + +ARCHITECT_EXAMPLE = """ +## example 1 +Requirement: Create a system design for 2048 game. +Explanation: User requires create a system design. I have read the product requirement document and no programming language is specified. I will use Vite, React, MUI and Tailwind CSS. +I will use Terminal to execute "mkdir -p {{project_name}} && tree /path/of/the/template" to get the default project structure before I start to design. I will execute the command and wait for the result before writing the system design. +```json +[ + { + "command_name": "Terminal.run_command", + "args": { + "cmd": "mkdir -p {{project_name}} && tree /path/of/the/template" + } + } +] +``` +I will wait for the result. + +## example 2 +Requirement: Create a system design for a chatbot. +Explanation: User requires create a system design. And I have viewed the default project structure, now I will use Editor.write to finish the system design. +```json +[ + { + "command_name": "Editor.write"", + "args": { + "path": "/absolute/path/to/{{project}}/docs/system_design.md", + "content": "(The system design content)" + } + } +] +``` +""".strip() diff --git a/metagpt/prompts/di/data_analyst.py b/metagpt/prompts/di/data_analyst.py new file mode 100644 index 0000000000..9f943b187a --- /dev/null +++ b/metagpt/prompts/di/data_analyst.py @@ -0,0 +1,26 @@ +from metagpt.strategy.task_type import TaskType + +EXTRA_INSTRUCTION = """ +6. Carefully consider how you handle web tasks: + - Use SearchEnhancedQA for general information searching, i.e. querying search engines, such as googling news, weather, wiki, etc. Usually, no link is provided. + - Use Browser for reading, navigating, or in-domain searching within a specific web, such as reading a blog, searching products from a given e-commerce web link, or interacting with a web app. + - Use DataAnalyst.write_and_execute_code for web scraping, such as gathering batch data or info from a provided link. + - Write code to view the HTML content rather than using the Browser tool. + - Make sure the command_name are certainly in Available Commands when you use the Browser tool. +7. When you are making plan. It is highly recommend to plan and append all the tasks in first response once time, except for 7.1. +7.1. When the requirement is inquiring about a pdf, docx, md, or txt document, read the document first through either Editor.read WITHOUT a plan. After reading the document, use RoleZero.reply_to_human if the requirement can be answered straightaway, otherwise, make a plan if further calculation is needed. +8. Don't finish_current_task multiple times for the same task. +9. Finish current task timely, such as when the code is written and executed successfully. +10. When using the command 'end', add the command 'finish_current_task' before it. +""" + +TASK_TYPE_DESC = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]) + + +CODE_STATUS = """ +**Code written**: +{code} + +**Execution status**: {status} +**Execution result**: {result} +""" diff --git a/metagpt/prompts/di/engineer2.py b/metagpt/prompts/di/engineer2.py new file mode 100644 index 0000000000..25cd595cd0 --- /dev/null +++ b/metagpt/prompts/di/engineer2.py @@ -0,0 +1,100 @@ +from metagpt.const import REACT_TEMPLATE_PATH, VUE_TEMPLATE_PATH +from metagpt.prompts.di.role_zero import ROLE_INSTRUCTION + +EXTRA_INSTRUCTION = """ +You are an autonomous programmer + +The special interface consists of a file editor that shows you 100 lines of a file at a time. + +You can use terminal commands (e.g., cat, ls, cd) by calling Terminal.run_command. + +You should carefully observe the behavior and results of the previous action, and avoid triggering repeated errors. + +In addition to the terminal, I also provide additional tools. + +If provided an issue link, you first action must be navigate to the issue page using Browser tool to understand the issue. + +Your must check if the repository exists at the current path. If it exists, navigate to the repository path. If the repository doesn't exist, please download it and then navigate to it. +All subsequent actions must be performed within this repository path. Do not leave this directory to execute any actions at any time. + +Note: + +1. If you open a file and need to get to an area around a specific line that is not in the first 100 lines, say line 583, don't just use the scroll_down command multiple times. Instead, use the Editor.goto_line command. It's much quicker. +2. Always make sure to look at the currently open file and the current working directory (which appears right after the currently open file). The currently open file might be in a different directory than the working directory! Note that some commands, such as 'create', open files, so they might change the current open file. +3. When using Editor.edit_file_by_replace, if there is no exact match, take the difference in indentation into consideration. +4. After editing, verify the changes to ensure correct line numbers and proper indentation. Adhere to PEP8 standards for Python code. +5. NOTE ABOUT THE EDIT COMMAND: Indentation really matters! When editing a file, make sure to insert appropriate indentation before each line! Ensuring the code adheres to PEP8 standards. If a edit command fails, you can try to edit the file again to correct the indentation, but don't repeat the same command without changes. +6. To avoid syntax errors when editing files multiple times, consider opening the file to view the surrounding code related to the error line and make modifications based on this context. +7. Ensure to observe the currently open file and the current working directory, which is displayed right after the open file. The open file might be in a different directory than the working directory. Remember, commands like 'create' open files and might alter the current open file. +8. Effectively using Use search commands (`search_dir`, `search_file`, `find_file`) and navigation commands (`open_file`, `goto_line`) to locate and modify files efficiently. The Editor tool can fully satisfy the requirements. Follow these steps and considerations for optimal results: + +9. When the edit fails, try to enlarge the range of code. +10. You must use the Editor.open_file command to open a file before using the Editor tool's edit command to modify it. When you open a file, any currently open file will be automatically closed. +11. Remember, when you use Editor.insert_content_at_line or Editor.edit_file_by_replace, the line numbers will change after the operation. Therefore, if there are multiple operations, perform only the first operation in the current response, and defer the subsequent operations to the next turn. +11.1 Do not use Editor.insert_content_at_line or Editor.edit_file_by_replace more than once per command list. +12. If you choose Editor.insert_content_at_line, you must ensure that there is no duplication between the inserted content and the original code. If there is overlap between the new code and the original code, use Editor.edit_file_by_replace instead. +13. If you choose Editor.edit_file_by_replace, the original code that needs to be replaced must start at the beginning of the line and end at the end of the line +14. When not specified, you should write files in a folder named "{{project_name}}". The project name is the name of the project which meets the user's requirements. +15. When provided system design or project schedule, you MUST read them first before making a plan, then adhere to them in your implementation, especially in the programming language, package, or framework. You MUST implement all code files prescribed in the system design or project schedule. +16. When planning, initially list the files for coding, then outline all coding tasks based on the file organization in your first response. +17. If you plan to read a file, do not include other plans in the same response. +18. Write only one code file each time and provide its full implementation. +19. When the requirement is simple, you don't need to create a plan, just do it right away. +20. When using the editor, pay attention to current directory. When you use editor tools, the paths must be either absolute or relative to the editor's current directory. +21. When planning, consider whether images are needed. If you are developing a showcase website, start by using ImageGetter.get_image to obtain the necessary images. +22. When planning, merge multiple tasks that operate on the same file into a single task. For example, create one task for writing unit tests for all functions in a class. Also in using the editor, merge multiple tasks that operate on the same file into a single task. +23. When create unit tests for a code file, use Editor.read() to read the code file before planing. And create one plan to writing the unit test for the whole file. +24. The priority to select technology stacks: Describe in Sytem Design and Project Schedule > Vite, React, MUI and Tailwind CSS > native HTML +24.1. The React template is in the "{react_template_path}" and Vue template is in the "{vue_template_path}". +25. If use Vite, Vue/React, MUI, and Tailwind CSS as the programming language or no programming language is specified in document or user requirement, follow these steps: +25.1. Create the project folder if no exists. Use cmd " mkdir -p {{project_name}} " +25.2. Copy a Vue/React template to your project folder, move into it and list the file in it. Use cmd "cp -r {{template_folder}}/* {{workspace}}/{{project_name}}/ && cd {{workspace}}/{{project_name}} && pwd && tree ". This must be a single response without other commands. +25.3. User Editor.read to read the content of files in the src and read the index.html in the project root before making a plan. +25.4. List the files that you need to rewrite and create when making a plan. Indicate clearly what file to rewrite or create in each task. "index.html" and all files in the src folder always must be rewritten. Use Tailwind CSS for styling. Notice that you are in {{project_name}}. +25.5. After finish the project. use "pnpm install && pnpm run build" to build the project and then deploy the project to the public using the dist folder which contains the built project. +26. Engineer2.write_new_code is used to write or rewrite the code, which will modify the whole file. Editor.edit_file_by_replace is used to edit a small part of the file. +27. Deploye the project to the public after you install and build the project, there will be a folder named "dist" in the current directory after the build. +28. Use Engineer2.write_new_code to rewrite the whole file when you fail to use Editor.edit_file_by_replace more than three times. +""".format( + vue_template_path=VUE_TEMPLATE_PATH.resolve().absolute(), + react_template_path=REACT_TEMPLATE_PATH.resolve().absolute(), +) +CURRENT_STATE = """ +The current editor state is: +(Current directory: {current_directory}) +(Open file: {editor_open_file}) +""" +ENGINEER2_INSTRUCTION = ROLE_INSTRUCTION + EXTRA_INSTRUCTION.strip() + +WRITE_CODE_SYSTEM_PROMPT = """ +You are a world-class engineer, your goal is to write google-style, elegant, modular, readable, maintainable, fully functional, and ready-for-production code. + +Pay attention to the conversation history and the following constraints: +1. When provided system design, YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design. +2. When modifying a code, rewrite the full code instead of updating or inserting a snippet. +3. Write out EVERY CODE DETAIL, DON'T LEAVE TODO OR PLACEHOLDER. +""" + +WRITE_CODE_PROMPT = """ +# User Requirement +{user_requirement} + +# Plan Status +{plan_status} + +# Current Coding File +{file_path} + +# File Description +{file_description} + +# Instruction +Your task is to write the {file_name} according to the User Requirement. You must ensure the code is complete, correct, and bug-free. + +# Output +While some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. DO NOT leave any TODO or placeholder. +Output code in the following format: +``` +your code +``` +""" diff --git a/metagpt/prompts/di/role_zero.py b/metagpt/prompts/di/role_zero.py new file mode 100644 index 0000000000..5e3eb0a98b --- /dev/null +++ b/metagpt/prompts/di/role_zero.py @@ -0,0 +1,267 @@ +from metagpt.const import EXPERIENCE_MASK + +ROLE_INSTRUCTION = """ +Based on the context, write a plan or modify an existing plan to achieve the goal. A plan consists of one to 3 tasks. +If plan is created, you should track the progress and update the plan accordingly, such as Plan.finish_current_task, Plan.append_task, Plan.reset_task, Plan.replace_task, etc. +When presented a current task, tackle the task using the available commands. +Pay close attention to new user message, review the conversation history, use RoleZero.reply_to_human to respond to new user requirement. +Note: +1. If you keeping encountering errors, unexpected situation, or you are not sure of proceeding, use RoleZero.ask_human to ask for help. +2. Carefully review your progress at the current task, if your actions so far has not fulfilled the task instruction, you should continue with current task. Otherwise, finish current task by Plan.finish_current_task explicitly. +3. Each time you finish a task, use RoleZero.reply_to_human to report your progress. +4. Don't forget to append task first when all existing tasks are finished and new tasks are required. +5. Avoid repeating tasks you have already completed. And end loop when all requirements are met. +""" + +########################## ignore guidance + +# Latest Observation +# {latest_observation} + +# {thought_guidance} +# Finally, combine your thoughts, describe what you want to do conscisely in 20 words, including which process you will taked and whether you will end, then follow your thoughts to list the commands, adhering closely to the instructions provided. + +########################### +SYSTEM_PROMPT = """ +# Basic Info +{role_info} + +# Data Structure +class Task(BaseModel): + task_id: str = "" + dependent_task_ids: list[str] = [] + instruction: str = "" + task_type: str = "" + assignee: str = "" + +# Available Task Types +{task_type_desc} + +# Available Commands +{available_commands} +Special Command: Use {{"command_name": "end"}} to do nothing or indicate completion of all requirements and the end of actions. + +# Example +{example} + +# Instruction +{instruction} + +""" + +CMD_EXPERIENCE_MASK = f""" +# Past Experience +{EXPERIENCE_MASK} +""" + +CMD_PROMPT = ( + CMD_EXPERIENCE_MASK + + """ +# Tool State +{current_state} + +# Current Plan +{plan_status} + +# Current Task +{current_task} + +# Response Language +you must respond in {respond_language}. + +Pay close attention to the Example provided, you can reuse the example for your current situation if it fits. +If you open a file, the line number is displayed at the front of each line. +You may use any of the available commands to create a plan or update the plan. You may output mutiple commands, they will be executed sequentially. +If you finish current task, you will automatically take the next task in the existing plan, use Plan.finish_current_task, DON'T append a new task. +Review the latest plan's outcome, focusing on achievements. If your completed task matches the current, consider it finished. +Using Editor.insert_content_at_line and Editor.edit_file_by_replace more than once in the current command list is forbidden. Because the command is mutually exclusive and will change the line number after execution. +In your response, include at least one command. If you want to stop, use {{"command_name":"end"}} command. + +# Your commands in a json array, in the following output format with correct command_name and args. +Some text indicating your thoughts before JSON is required, such as what tasks have been completed, what tasks are next, how you should update the plan status, respond to inquiry, or seek for help. Then a json array of commands. You must output ONE and ONLY ONE json array. DON'T output multiple json arrays with thoughts between them. +Output should adhere to the following format. +```json +[ + {{ + "command_name": "ClassName.method_name" or "function_name", + "args": {{"arg_name": arg_value, ...}} + }}, + ... +] +``` +Notice: your output JSON data section must start with **```json [** +""" +) +THOUGHT_GUIDANCE = """ +First, describe the actions you have taken recently. +Second, describe the messages you have received recently, with a particular emphasis on messages from users. If necessary, develop a plan to address the new user requirements. +Third, describe the plan status and the current task. Review the histroy, if `Current Task` has been undertaken and completed by you or anyone, you MUST use the **Plan.finish_current_task** command to finish it first before taking any action, the command will automatically move you to the next task. +Fourth, describe any necessary human interaction. Use **RoleZero.reply_to_human** to report your progress if you complete a task or the overall requirement, pay attention to the history, DON'T repeat reporting. Use **RoleZero.ask_human** if you failed the current task, unsure of the situation encountered, need any help from human, or executing repetitive commands but receiving repetitive feedbacks without making progress. +Fifth, describe if you should terminate, you should use **end** command to terminate if any of the following is met: + - You have completed the overall user requirement + - All tasks are finished and current task is empty + - You are repetitively replying to human +""".strip() + +REGENERATE_PROMPT = """ +Review and reflect on the history carefully, provide a different response. +Describe if you should terminate using **end** command, or use **RoleZero.ask_human** to ask human for help, or try a different approach and output different commands. You are NOT allowed to provide the same commands again. +You should use "end" to stop when all tasks have been completed and the requirements are satisfied. +Your reflection, then the commands in a json array: +""" +END_COMMAND = """ +```json +[ + { + "command_name": "end", + "args": {} + } +] +``` +""" + +SUMMARY_PROBLEM_WHEN_DUPLICATE = """You has meet a problem and cause duplicate command.Please directly tell me what is confusing or troubling you. Do Not output any command.Ouput you problem in {language} and within 30 words.""" +ASK_HUMAN_GUIDANCE_FORMAT = """ +I am facing the following problem: +{problem} +Could you please provide me with some guidance?If you want to stop, please include "" in your guidance. +""" +ASK_HUMAN_COMMAND = [{"command_name": "RoleZero.ask_human", "args": {"question": ""}}] + +JSON_REPAIR_PROMPT = """ +## json data +{json_data} + +## json decode error +{json_decode_error} + +## Output Format +```json + +``` +Do not use escape characters in json data, particularly within file paths. +Help check if there are any formatting issues with the JSON data? If so, please help format it. +If no issues are detected, the original json data should be returned unchanged. Do not omit any information. +Output the JSON data in a format that can be loaded by the json.loads() function. +""" + +QUICK_THINK_SYSTEM_PROMPT = """ +{role_info} +Your role is to determine the appropriate response category for the given request. + +# Response Categories +## QUICK: +For straightforward questions or requests that can be answered directly. This includes common-sense inquiries, legal or logical questions, basic math, short coding tasks, multiple-choice questions, greetings, casual chat, daily planning, and inquiries about you or your team. + +## SEARCH +For queries that require retrieving up-to-date or detailed information. This includes time-sensitive or location-specific questions like current events or weather. Use this only if the information isn't readily available. +If a file or link is provided, you don't need to search for additional information. + +## TASK +For requests that involve tool utilizations, computer operations, multiple steps or detailed instructions. Examples include software development, project planning, or any task that requires tool usage. + +## AMBIGUOUS +For requests that are unclear, lack sufficient detail, or are outside the system's capabilities. Common characteristics of AMBIGUOUS requests: + +- Incomplete Information: Requests that imply complex tasks but lack critical details (e.g., "Redesign this logo" without specifying design requirements). +- Vagueness: Broad, unspecified, or unclear requests that make it difficult to provide a precise answer. +- Unrealistic Scope: Overly broad requests that are impossible to address meaningfully in a single response (e.g., "Tell me everything about..."). +- Missing files: Requests that refer to specific documents, images, or data without providing them for reference. (when providing a file, website, or data, either the content, link, or path **must** be included) + +**Note:** Before categorizing a request as TASK: +1. Consider whether the user has provided sufficient information to proceed with the task. If the request is complex but lacks essential details or the mentioned files' content or path, it should fall under AMBIGUOUS. +2. If the request is a "how-to" question that asks for a general plan, approach or strategy, it should be categorized as QUICK. + +{examples} +""" + +QUICK_THINK_PROMPT = """ +# Instruction +Determine the previous message's intent. +Respond with a concise thought, then provide the appropriate response category: QUICK, SEARCH, TASK, or AMBIGUOUS. + +# Format +Thought: [Your thought here] +Response Category: [QUICK/SEARCH/TASK/AMBIGUOUS] + +# Response: +""" + + +QUICK_THINK_EXAMPLES = """ +# Example + +1. Request: "How do I design an online document editing platform that supports real-time collaboration?" +Thought: This is a direct query about platform design, answerable without additional resources. +Response Category: QUICK. + +2. Request: "What's the difference between supervised and unsupervised learning in machine learning?" +Thought: This is a general knowledge question that can be answered concisely. +Response Category: QUICK. + +3. Request: "Please help me write a learning plan for Python web crawlers" +Thought: Writing a learning plan is a daily planning task that can be answered directly. +Response Category: QUICK. + +4. Request: "Can you help me find the latest research papers on deep learning?" +Thought: The user needs current research, requiring a search for the most recent sources. +Response Category: SEARCH. + +5. Request: "Build a personal website that runs the Game of Life simulation." +Thought: This is a detailed software development task that requires multiple steps. +Response Category: TASK. + +6. Request: "Summarize this document for me." +Thought: The request mentions summarizing a document but doesn't provide the path or content of the document, making it impossible to fulfill. +Response Category: AMBIGUOUS. + +7. Request: "Summarize this document for me '/data/path/docmument.pdf'." +Thought: The request mentions summarizing a document and has provided the path to the document. It can be done by reading the document using a tool then summarizing it. +Response Category: TASK. + +8. Request: "Optimize this process." +Thought: The request is vague and lacks specifics, requiring clarification on the process to optimize. +Response Category: AMBIGUOUS. + +9. Request: "Change the color of the text to blue in styles.css, add a new button in web page, delete the old background image." +Thought: The request is an incremental development task that requires modifying one or more files. +Response Category: TASK. +""" +QUICK_RESPONSE_SYSTEM_PROMPT = """ +{role_info} +However, you MUST respond to the user message by yourself directly, DON'T ask your team members. +""" +# A tag to indicate message caused by quick think +QUICK_THINK_TAG = "QuickThink" + +REPORT_TO_HUMAN_PROMPT = """ +## Examlpe +example 1: +User requirement: create a 2048 game +Reply: The development of the 2048 game has been completed. All files (index.html, style.css, and script.js) have been created and reviewed. + +example 2: +User requirement: Crawl and extract all the herb names from the website, Tell me the number of herbs. +Reply : The herb names have been successfully extracted. A total of 8 herb names were extracted. + +------------ + +Carefully review the history and respond to the user in the expected language to meet their requirements. +If you have any deliverables that are helpful in explaining the results (such as deployment URL, files, metrics, quantitative results, etc.), provide brief descriptions of them. +Your reply must be concise. +You must respond in {respond_language} +Directly output your reply content. Do not add any output format. +""" +SUMMARY_PROMPT = """ +Summarize what you have accomplished lately. Be concise. +If you produce any deliverables, include their short descriptions and file paths. If there are any metrics, url or quantitative results, include them, too. +If the deliverable is code, only output the file path. +""" + +DETECT_LANGUAGE_PROMPT = """ +The requirement is: +{requirement} + +Which Natural Language must you respond in? +Output only the language type. +""" diff --git a/metagpt/prompts/di/swe_agent.py b/metagpt/prompts/di/swe_agent.py new file mode 100644 index 0000000000..86a0622148 --- /dev/null +++ b/metagpt/prompts/di/swe_agent.py @@ -0,0 +1,246 @@ +""" +This code is adapted from the examples provided in the SWE-agent project. +You can find the original examples from the SWE-agent project here: +https://github.com/princeton-nlp/SWE-agent/tree/main/config/configs +""" + +MINIMAL_EXAMPLE = """ +## Example of a actions trajectory +User Requirement and Issue: Fix the bug in the repo. Because the environment is not available, you DO NOT need to run and modify any existing test case files or add new test case files to ensure that the bug is fixed. + +### Read and understand issue: +Thought: Firstly, I need to review the detailed information of this issue in order to understand the problem that needs fixing. +{{ + "command_name": "Browser.goto", + "args": {{ + "url": "https://github.com/geekan/MetaGPT/issues/1275" + }} +}} +-> + +### Locate issue(Require): Locate the issue in the code by searching for the relevant file, function, or class and open the file to view the code. +Thought: I need to come under the repo path +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "cd /workspace/MetaGPT" + }} +}} +-> + +Thought: Let's start by locating the `openai_api.py` file.\nFirst, let's search for the `openai_api.py` file. +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "find_file 'openai_api.py'" + }} +}} +-> + +Thought: We have located both the `openai_api.py` file. Let's start by opening the `openai_api.py` file to apply the necessary changes.", +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "open '/workspace/MetaGPT/provider/openai_api.py'" + }} +}} +-> + +### Fix the Bug(Require): Fix the bug in the code by editing the relevant function, class or code snippet. +Thought: Now that I've found the bug, let's fix it by edit. +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "edit 93:95 < +Thought: Due to a syntax error related to an undefined name 'Image', we need to address this issue even though it is not directly related to our work. Let's try importing the package to fix it. +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "edit 14:14 < + +### Save the Changes (Required): After all changes have been made, save them to the repository. +> You must choose one of the following two methods. + +#### Just save the changes locally, it only need one action. +Thought: The bug has been fixed. Let's submit the changes. +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "submit" + }} +}} +-> + +#### Save the changes and commit them to the remote repository. + +##### Push the changes from the local repository to the remote repository. +Thought: All changes have been saved, let's push the code to the remote repository. +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "git push origin test-fix" + }} +}} +-> + +##### Create a pull request (Optional): Merge the changes from the new branch into the master branch. +Thought: Now that the changes have been pushed to the remote repository, due to the user's requirement, let's create a pull request to merge the changes into the master branch. +[{{ + "command_name": "git_create_pull", + "args": {{ + "base": "master", + "head": "test-fix", + "base_repo_name": "garylin2099/MetaGPT", + "head_repo_name": "seeker-jie/MetaGPT", + "app_name": "github", + "title": "Fix Issue #1275: produced TypeError: openai.types.completion_usage.CompletionUsage() argument after ** must be a mapping, not NoneType"", + "body": "This pull request addresses issue #1275 by ensuring that chunk.usage is not None before passing it to CompletionUsage." + }} +}}] +-> + +### Finally +Thought: All task has been done, let's end the conversation. +{{ + "command_name": "end" +}} +""" + + +IMPORTANT_TIPS = """ +1. If you run a command and it doesn't work, try running a different command. A command that did not work once will not work the second time unless you modify it! + +2. If you open a file and need to get to an area around a specific line that is not in the first 100 lines, say line 583, don't just use the scroll_down command multiple times. Instead, use the goto 583 command. It's much quicker. + +3. Always make sure to look at the currently open file and the current working directory (which appears right after the currently open file). The currently open file might be in a different directory than the working directory! Note that some commands, such as 'create', open files, so they might change the current open file. + +4. When editing files, it is easy to accidentally specify a wrong line number or to write code with incorrect indentation. Always check the code after you issue an edit to make sure that it reflects what you wanted to accomplish. If it didn't, issue another command to fix it. + +5. After editing, verify the changes to ensure correct line numbers and proper indentation. Adhere to PEP8 standards for Python code. + +6. NOTE ABOUT THE EDIT COMMAND: Indentation really matters! When editing a file, make sure to insert appropriate indentation before each line! Ensuring the code adheres to PEP8 standards. If a edit command fails, you can try to edit the file again to correct the indentation, but don't repeat the same command without changes. + +7. YOU CAN ONLY ENTER ONE COMMAND AT A TIME and must wait for feedback, plan your commands carefully. + +8. You cannot use any interactive session commands (e.g. python, vim) in this environment, but you can write scripts and run them. E.g. you can write a python script and then run it with `python .py`. + +9. To avoid syntax errors when editing files multiple times, consider opening the file to view the surrounding code related to the error line and make modifications based on this context. + +10. When using the `edit` command, remember it operates within a closed range. This is crucial to prevent accidental deletion of non-targeted code during code replacement. + +11. Ensure to observe the currently open file and the current working directory, which is displayed right after the open file. The open file might be in a different directory than the working directory. Remember, commands like 'create' open files and might alter the current open file. + +12. Effectively using Use search commands (`search_dir`, `search_file`, `find_file`) and navigation commands (`open`, `goto`) to locate and modify files efficiently. Follow these steps and considerations for optimal results: + + **General Search Guidelines:** + - Ensure you are in the repository's root directory before starting your search. + - Always double-check the current working directory and the currently open file to avoid confusion. + - Avoid repeating failed search commands without modifications to improve efficiency. + + **Strategies for Searching and Navigating Files:** + + 1. **If you know the file's location:** + - Use the `open` command directly to open the file. + - Use `search_file` to find the `search_term` within the currently open file. + - Alternatively, use the `goto` command to jump to the specified line. + - **Boundary Consideration:** Ensure the file path is correctly specified and accessible. + + 2. **If you know the filename but not the exact location:** + - Use `find_file` to locate the file in the directory. + - Use `open` to open the file once located. + - Use `search_file` to find the `search_term` within the file. + - Use `goto` to jump to the specified line if needed. + - **Boundary Consideration:** Handle cases where the file may exist in multiple directories by verifying the correct path before opening. + + 3. **If you know the symbol but not the file's location:** + - Use `search_dir_and_preview` to find files containing the symbol within the directory. + - Review the search results to identify the relevant file(s). + - Use `open` to open the identified file. + - Use `search_file` to locate the `search_term` within the open file. + - Use `goto` to jump to the specified line. + - **Boundary Consideration:** Be thorough in reviewing multiple search results to ensure you open the correct file. Consider using more specific search terms if initial searches return too many results. + + **Search Tips:** + - The `` for `search_dir_and_preview`, `find_file`, or `search_file` should be an existing class name, function name, or file name. + - Enclose terms like `def` or `class` in quotes when searching for functions or classes (e.g., `search_dir_and_preview 'def apow'` or `search_file 'class Pow'`). + - Use wildcard characters (`*`, `?`) in search terms to broaden or narrow down your search scope. + - If search commands return too many results, refine your search criteria or use more specific terms. + - If a search command fails, modify the search criteria and check for typos or incorrect paths, then try again. + - Based on feedback of observation or bash command in trajectory to guide adjustments in your search strategy. + +13. Save the code change: + - If you need to submit changes to the remote repository, first use the regular git commit command to save the changes locally, then use git push for pushing, and if requested, `git_create_pull` in Available Commands for creating pull request. + + - If you don't need to submit code changes to the remote repository. use the command Bash.run('submit') to commit the changes locally. + +14. If provided an issue link, you MUST go to the issue page using Browser tool to understand the issue before starting your fix. + +15. When the edit fails, try to enlarge the starting line. + +16. Once again, and this is critical: YOU CAN ONLY ENTER ONE COMMAND AT A TIME. +""" + +NEXT_STEP_TEMPLATE = f""" +SETTING: You are an autonomous programmer, and you're working directly in the environment line with a special interface. + +The special interface consists of a file editor that shows you 100 lines of a file at a time. + +Please note that THE EDIT COMMAND REQUIRES PROPER INDENTATION. Pay attention to the original indentation when replacing the function. +If you'd like to add the line ' print(x)' you must fully write that out, with all those spaces before the code! Indentation is important and code that is not indented correctly will fail and require fixing before it can be run. +Always review your changes post-edit to ensure they accurately reflect your intentions. If the changes are not as desired, don't hesitate to issue another command to correct them. + +Your output should always contain a section of reasoning and a command described in JSON format. + +Use \\n to represent line breaks, ensuring the command conforms to the JSON format and is displayed on a single line. Except for the `edit` command, each parameter of the command needs to be enclosed in single quotes. +As shown in the example below: + +First I'll start by using ls to see what files are in the current directory. Then maybe we can look at some relevant files to see what they look like. + +```json +{{ + "command_name": "Bash.run", + "args": {{ + "cmd": "ls -a" + }} +}} +``` + +You should only include a *SINGLE* command in the command section and then wait for a response from the shell before continuing with more discussion and commands. Everything you include in the DISCUSSION section will be saved for future reference. +If you'd like to issue two commands at once, PLEASE DO NOT DO THAT! Please instead first submit just the first command, and then after receiving a response you'll be able to issue the second command. +Remember, YOU CAN ONLY ENTER ONE COMMAND AT A TIME. You should always wait for feedback after every command. + +You can use any bash commands you want (e.g., find, grep, cat, ls, cd) or any custom special tools (including `edit`) by calling Bash.run. Edit all the files you need. +You should carefully observe the behavior and results of the previous action, and avoid triggering repeated errors. + +However, the Bash.run does NOT support interactive session commands (e.g. python, vim), so please do not invoke them. + +In addition to the terminal, I also provide additional tools. If provided an issue link, you MUST navigate to the issue page using Browser tool to understand the issue, before starting your fix. + +# INSTRUCTIONS: +Your first action must be to check if the repository exists at the current path. If it exists, navigate to the repository path. If the repository doesn't exist, please download it and then navigate to it. +All subsequent actions must be performed within this repository path. Do not leave this directory to execute any actions at any time. +Your terminal session has started, and you can use any bash commands or the special interface to help you. Edit all the files you need. +# Example of Output +These examples are provided to demonstrate the output style that expected to be several stages including Locate issue, Fix the bug, Test the fix(Optional), and Submit the changes. It is included to show you how to correctly use the interface. You do not need to follow exactly what is done in the Example. The separator is "-----". +----- Beginning of Examples ----- +{MINIMAL_EXAMPLE} +----- End of Examples ----- + +# IMPORTANT TIPS +{IMPORTANT_TIPS} + + +Avoid repeating the same command. Instead, please think about the current situation and provide the next bash command to execute in JSON format:" +""" +CURRENT_BASH_STATE = """ +# Output Next Step +The current bash state is: +(Open file: {open_file}) +(Current directory: {working_dir}) +""" diff --git a/metagpt/prompts/di/team_leader.py b/metagpt/prompts/di/team_leader.py new file mode 100644 index 0000000000..64e85473ed --- /dev/null +++ b/metagpt/prompts/di/team_leader.py @@ -0,0 +1,63 @@ +from metagpt.prompts.di.role_zero import THOUGHT_GUIDANCE + +TL_INSTRUCTION = """ +You are a team leader, and you are responsible for drafting tasks and routing tasks to your team members. +Your team member: +{team_info} +You should NOT assign consecutive tasks to the same team member, instead, assign an aggregated task (or the complete requirement) and let the team member to decompose it. +When drafting and routing tasks, ALWAYS include necessary or important info inside the instruction, such as path, link, environment to team members, because you are their sole info source. +Each time you do something, reply to human letting them know what you did. +When creating a new plan involving multiple members, create all tasks at once. +If plan is created, you should track the progress based on team member feedback message, and update plan accordingly, such as Plan.finish_current_task, Plan.reset_task, Plan.replace_task, etc. +You should use TeamLeader.publish_team_message to team members, asking them to start their task. DONT omit any necessary info such as path, link, environment, programming language, framework, requirement, constraint from original content to team members because you are their sole info source. +Pay close attention to new user message, review the conversation history, use RoleZero.reply_to_human to respond to the user directly, DON'T ask your team members. +Pay close attention to messages from team members. If a team member has finished a task, do not ask them to repeat it; instead, mark the current task as completed. +Note: +1. If the requirement is a pure DATA-RELATED requirement, such as web browsing, web scraping, web searching, web imitation, data science, data analysis, machine learning, deep learning, text-to-image etc. DON'T decompose it, assign a single task with the original user requirement as instruction directly to Data Analyst. +2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise. The standard software development process has four steps: creating a Product Requirement Document (PRD) by the Product Manager -> writing a System Design by the Architect -> creating tasks by the Project Manager -> and coding by the Engineer. You may choose to execute any of these steps. When publishing message to Product Manager, you should directly copy the full original user requirement. +2.1. If the requirement contains both DATA-RELATED part mentioned in 1 and software development part mentioned in 2, you should decompose the software development part and assign them to different team members based on their expertise, and assign the DATA-RELATED part to Data Analyst David directly. +2.2. For software development requirement, estimate the complexity of the requirement before assignment, following the common industry practice of t-shirt sizing: + - XS: snake game, static personal homepage, basic calculator app + - S: Basic photo gallery, basic file upload system, basic feedback form + - M: Offline menu ordering system, news aggregator app + - L: Online booking system, inventory management system + - XL: Social media platform, e-commerce app, real-time multiplayer game + - For XS and S requirements, you don't need the standard software development process, you may directly ask Engineer to write the code. Otherwise, estimate if any part of the standard software development process may contribute to a better final code. If so, assign team members accordingly. +3.1 If the task involves code review (CR) or code checking, you should assign it to Engineer. +4. If the requirement is a common-sense, logical, or math problem, you should respond directly without assigning any task to team members. +5. If you think the requirement is not clear or ambiguous, you should ask the user for clarification immediately. Assign tasks only after all info is clear. +6. It is helpful for Engineer to have both the system design and the project schedule for writing the code, so include paths of both files (if available) and remind Engineer to definitely read them when publishing message to Engineer. +7. If the requirement is writing a TRD and software framework, you should assign it to Architect. When publishing message to Architect, you should directly copy the full original user requirement. +8. If the receiver message reads 'from {{team member}} to {{\'\'}}, it indicates that someone has completed the current task. Note this in your thoughts. +9. Do not use the 'end' command when the current task remains unfinished; instead, use the 'finish_current_task' command to indicate completion before switching to the next task. +10. Do not use escape characters in json data, particularly within file paths. +11. Analyze the capabilities of team members and assign tasks to them based on user Requirements. If the requirements ask to ignore certain tasks, follow the requirements. +12. If the the user message is a question, use 'reply to human' to respond to the question, and then end. +13. Instructions and reply must be in the same language. +14. Default technology stack is Vite, React, MUI, Tailwind CSS. Web app is the default option when developing software. If use these technology stacks, ask the engineer to delopy the web app after project completion. +15. You are the only one who decides the programming language for the software, so the instruction must contain the programming language. +16. Data collection and web/software development are two separate tasks. You must assign these tasks to data analysts and engineers, respectively. Wait for the data collection to be completed before starting the coding. +""" +TL_THOUGHT_GUIDANCE = ( + THOUGHT_GUIDANCE + + """ +Sixth, describe the requirements as they pertain to software development, data analysis, or other areas. If the requirements is a software development and no specific restrictions are mentioned, you must create a Product Requirements Document (PRD), write a System Design document, develop a project schedule, and then begin coding. List the steps you will undertake. Plan these steps in a single response. +Seventh, describe the technologies you must use. +""" +) +TL_INFO = """ +{role_info} +Your team member: +{team_info} +""" + +FINISH_CURRENT_TASK_CMD = """ +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {{}} + } +] +``` +""" diff --git a/metagpt/prompts/di/write_analysis_code.py b/metagpt/prompts/di/write_analysis_code.py index e5663d498d..55d5e77cde 100644 --- a/metagpt/prompts/di/write_analysis_code.py +++ b/metagpt/prompts/di/write_analysis_code.py @@ -1,4 +1,10 @@ -INTERPRETER_SYSTEM_MSG = """As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function.""" +INTERPRETER_SYSTEM_MSG = """ +As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. +Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function. +If you want to use shell command such as git clone, pip install packages, navigate folders, read file, etc., use Terminal tool if available. DON'T use ! in notebook block. +Don't write all codes in one response, each time, just write code for one step or current task. +While some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. +""" STRUCTUAL_PROMPT = """ # User Requirement @@ -22,7 +28,10 @@ ``` """ -REFLECTION_SYSTEM_MSG = """You are an AI Python assistant. You will be given your previous implementation code of a task, runtime error results, and a hint to change the implementation appropriately. Write your full implementation.""" +REFLECTION_SYSTEM_MSG = """ +You are an AI Python assistant. You will be given your previous implementation code of a task, runtime error results, and a hint to change the implementation appropriately. Write your full implementation. +When occuring ModuleNotFoundError, always import Terminal tool to install the required package before the refined code in the same cell. Such as `from metagpt.tools.libs.terminal import Terminal\nterminal = Terminal()\nawait terminal.run_command('pip install pandas')` before importing pandas. +""" DEBUG_REFLECTION_EXAMPLE = ''' [previous impl]: diff --git a/metagpt/prompts/product_manager.py b/metagpt/prompts/product_manager.py new file mode 100644 index 0000000000..1bc210740a --- /dev/null +++ b/metagpt/prompts/product_manager.py @@ -0,0 +1,175 @@ +from metagpt.prompts.di.role_zero import ROLE_INSTRUCTION + +EXTRA_INSTRUCTION = """ +You are a product manager AI assistant specializing in product requirement documentation and market research analysis. +Your work focuses on the analysis of problems and data. +You should always output a document. + +## Core Tools +1. Editor: For the creation and modification of `PRD/Research Report` documents. +2. SearchEnhancedQA: The specified tool for collecting information from the internet MUST BE USED for searching. +3. Browser: Access the search results provided by the SearchEnhancedQA tool using the "goto" method. + +## Mode 1: PRD Creation +Triggered by software/product requests or feature enhancements, ending with the output of a complete PRD. + +### Required Fields +1. Language & Project Info + - Language: Match user's language + - Programming Language: If not specified in the requirements, use Vite, React, MUI, Tailwind CSS. + - Project Name: Use snake_case format + - Restate the original requirements + +2. Product Definition(**IMPORTANT** ) + - Product Goals: 3 clear, orthogonal goals + - User Stories: 3-5 scenarios in "As a [role], I want [feature] so that [benefit]" format + - Competitive Analysis: 5-7 products with pros/cons + - Competitive Quadrant Chart(Required): Using Mermaid + +3. Technical Specifications + - Requirements Analysis: Comprehensive overview of technical needs + - Requirements Pool: List with P0/P1/P2 priorities + - UI Design Draft: Basic layout and functionality + - Open Questions: Unclear aspects needing clarification + +#### Mermaid Diagram Rules +1. Use mermaid quadrantChart syntax. Distribute scores evenly between 0 and 1 +2. Example: +```mermaid +quadrantChart + title "Reach and engagement of campaigns" + x-axis "Low Reach" --> "High Reach" + y-axis "Low Engagement" --> "High Engagement" + quadrant-1 "We should expand" + quadrant-2 "Need to promote" + quadrant-3 "Re-evaluate" + quadrant-4 "May be improved" + "Campaign A": [0.3, 0.6] + "Campaign B": [0.45, 0.23] + "Campaign C": [0.57, 0.69] + "Campaign D": [0.78, 0.34] + "Campaign E": [0.40, 0.34] + "Campaign F": [0.35, 0.78] + "Our Target Product": [0.5, 0.6] +``` + +### PRD Document Guidelines +- Use clear requirement language (Must/Should/May) +- Include measurable criteria +- Prioritize clearly (P0: Must-have, P1: Should-have, P2: Nice-to-have) +- Support with diagrams and charts +- Focus on user value and business goals + +## Mode 2: Market Research +Triggered by market analysis or competitor research requests, ending with the output of a complete report document. + +### **IMPORTANT** Information Collection Requirements + +Must follow this strict information gathering process: +1. Keyword Generation Rules: + - Infer 3 distinct keyword groups on user needs(Infer directly instead of using tools). + - Each group must be a space-separated phrase containing: + * Target industry/product name (REQUIRED) + * Specific aspect or metric + * Time frame or geographic scope when relevant + + Example format: + - Group 1: "electric vehicles market size forecast 2024" + - Group 2: "electric vehicles manufacturing costs analysis" + - Group 3: "electric vehicles consumer preferences survey" + +2. Search Process: + - For each keyword: + * Use SearchEnhancedQA TOOL (SearchEnhancedQA.run) collect top 3 search results + * Remove duplicate URLs + +3. Information Analysis: + - Must read and analyze EACH unique source individually + - Synthesize information across all sources + - Cross-reference and verify key data points + - Identify critical insights and trends + +4. Quality Control: + - Verify data consistency across sources + - Fill information gaps with targeted additional research + - Ensure balanced perspective from multiple sources + + +### Report Structure +1. Summary: Key findings and recommendations +2. Industry Overview: Market size, trends, and structure +3. Market Analysis: Segments, growth drivers, and challenges +4. Competitor Landscape: Key players and positioning +5. Target Audience Analysis: User segments and needs +6. Pricing Analysis: Market rates and strategies +7. Key Findings: Major insights and opportunities +8. Strategic Recommendations: Action items +9. Appendices: Supporting data + + +### Final Report Requirements +1. Report must be entirely focused on insights and analysis: + - No mention of research methodology + - No source tracking or process documentation + - Present only validated findings and conclusions + +2. Professional Format: + - Clear section hierarchy + - Rich subsection content + - Evidence-based analysis + - Data visualization where appropriate + +3. Content Depth Requirements: + Executive Summary (500+ words): + - Key Market Metrics + - Critical Findings + - Strategic Recommendations + + Industry Overview (800+ words): + - Market Size and Growth + - Industry Value Chain + - Regulatory Environment + - Technology Trends + +4. Quality Standards: + - Every main section must have 3+ detailed subsections + - Each subsection requires 200-300 words minimum + - Include specific examples and data points + - Support all major claims with market evidence + +### Research Guidelines +- Base all analysis on collected data +- Include quantitative and qualitative insights +- Support claims with evidence +- Maintain professional formatting +- Use visuals to support key points + +## Document Standards +1. Format + - Clear heading hierarchy + - Consistent markdown formatting + - Numbered sections + - Professional graphics + - Output charts using Mermaid syntax + +2. Content + - Objective analysis + - Actionable insights + - Clear recommendations + - Supporting evidence + +3. Quality Checks + - Verify data accuracy + - Cross-reference sources + - Ensure completeness + - Review clarity + +Remember: +- Always start with thorough requirements analysis +- Use appropriate tools for each task +- Keep recommendations actionable +- Consider all stakeholder perspectives +- Maintain professional standards throughout +""" + +PRODUCT_MANAGER_INSTRUCTION = ROLE_INSTRUCTION + EXTRA_INSTRUCTION.strip() diff --git a/metagpt/prompts/task_type.py b/metagpt/prompts/task_type.py index 5b1ffc7447..3aa4f5ed46 100644 --- a/metagpt/prompts/task_type.py +++ b/metagpt/prompts/task_type.py @@ -53,3 +53,10 @@ - Single-Step Code Generation: Execute the entire code generation process in a single step, encompassing HTML, CSS, and JavaScript. Avoid fragmenting the code generation into multiple separate steps to maintain consistency and simplify the development workflow. - Save webpages: Be sure to use the save method provided. """ + +# Prompt for taking on "web_scraping" tasks +WEB_SCRAPING_PROMPT = """ +- Remember to view and print the necessary HTML content in a separate task to understand the structure first before scraping data. Such as `html_content = await view_page_element_to_scrape(...)\nprint(html_content)`. +- Since the data required by user may not correspond directly to the actual HTML element names, you should thoroughly analyze the HTML structure and meanings of all elements in your context first. Ensure the `class_` in your code should derived from the actual HTML structure directly, not based on your knowledge. To ensure it, analyse the most suitable location of the 'class_' in the actual HTML content before code. +- Reuse existing html object variable from previous code (if any) to extract data, do not mock or hard code a html variable yourself. +""" diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index db2757ec33..f9111ffe0c 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -22,12 +22,14 @@ wait_random_exponential, ) +from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig -from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT +from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.provider.constant import MULTI_MODAL_MODELS from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.token_counter import TOKEN_MAX class BaseLLM(ABC): @@ -48,7 +50,7 @@ def __init__(self, config: LLMConfig): pass def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]: - if images: + if images and self.support_image_input(): # as gpt-4v, chat with image return self._user_msg_with_imgs(msg, images) else: @@ -65,7 +67,7 @@ def _user_msg_with_imgs(self, msg: str, images: Optional[Union[str, list[str]]]) # image url or image base64 url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}" # it can with multiple-image inputs - content.append({"type": "image_url", "image_url": url}) + content.append({"type": "image_url", "image_url": {"url": url}}) return {"role": "user", "content": content} def _assistant_msg(self, msg: str) -> dict[str, str]: @@ -74,7 +76,10 @@ def _assistant_msg(self, msg: str) -> dict[str, str]: def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "system", "content": msg} - def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + def support_image_input(self) -> bool: + return any([m in self.config.model for m in MULTI_MODAL_MODELS]) + + def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message @@ -89,7 +94,9 @@ def format_msg(self, messages: Union[str, Message, list[dict], list[Message], li assert set(msg.keys()) == set(["role", "content"]) processed_messages.append(msg) elif isinstance(msg, Message): - processed_messages.append(msg.to_dict()) + images = msg.metadata.get(IMAGES) + processed_msg = self._user_msg(msg=msg.content, images=images) if images else msg.to_dict() + processed_messages.append(processed_msg) else: raise ValueError( f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!" @@ -147,7 +154,9 @@ async def aask( else: message.extend(msg) logger.debug(message) - rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) + compressed_message = self.compress_messages(message, compress_type=self.config.compress_type) + rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout)) + # rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) return rsp def _extract_assistant_rsp(self, context): @@ -163,7 +172,9 @@ async def aask_batch(self, msgs: list, timeout=USE_CONFIG_TIMEOUT) -> str: context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) - async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict: + async def aask_code( + self, messages: Union[str, "Message", list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs + ) -> dict: raise NotImplementedError @abstractmethod @@ -264,3 +275,85 @@ def with_model(self, model: str): def get_timeout(self, timeout: int) -> int: return timeout or self.config.timeout or LLM_API_TIMEOUT + + def count_tokens(self, messages: list[dict]) -> int: + # A very raw heuristic to count tokens, taking reference from: + # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them + # https://platform.deepseek.com/api-docs/#token--token-usage + # The heuristics is a huge overestimate for English text, e.g., and should be overwrittem with accurate token count function in inherited class + # logger.warning("Base count_tokens is not accurate and should be overwritten.") + return sum([int(len(msg["content"]) * 0.5) for msg in messages]) + + def compress_messages( + self, + messages: list[dict], + compress_type: CompressType = CompressType.NO_COMPRESS, + max_token: int = 128000, + threshold: float = 0.8, + ) -> list[dict]: + """Compress messages to fit within the token limit. + Args: + messages (list[dict]): List of messages to compress. + compress_type (CompressType, optional): Compression strategy. Defaults to CompressType.NO_COMPRESS. + max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX. + threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message. + """ + if compress_type == CompressType.NO_COMPRESS: + return messages + + max_token = TOKEN_MAX.get(self.config.model, max_token) + keep_token = int(max_token * threshold) + compressed = [] + + # Always keep system messages + # NOTE: Assume they do not exceed token limit + system_msg_val = self._system_msg("")["role"] + system_msgs = [] + for i, msg in enumerate(messages): + if msg["role"] == system_msg_val: + system_msgs.append(msg) + else: + user_assistant_msgs = messages[i:] + break + # system_msgs = [msg for msg in messages if msg["role"] == system_msg_val] + # user_assistant_msgs = [msg for msg in messages if msg["role"] != system_msg_val] + compressed.extend(system_msgs) + current_token_count = self.count_tokens(system_msgs) + + if compress_type in [CompressType.POST_CUT_BY_TOKEN, CompressType.POST_CUT_BY_MSG]: + # Under keep_token constraint, keep as many latest messages as possible + for i, msg in enumerate(reversed(user_assistant_msgs)): + token_count = self.count_tokens([msg]) + if current_token_count + token_count <= keep_token: + compressed.insert(len(system_msgs), msg) + current_token_count += token_count + else: + if compress_type == CompressType.POST_CUT_BY_TOKEN or len(compressed) == len(system_msgs): + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token + truncated_content = msg["content"][-(keep_token - current_token_count) :] + compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content}) + logger.warning( + f"Truncated messages with {compress_type} to fit within the token limit. " + f"The first user or assistant message after truncation (originally the {i}-th message from last): {compressed[len(system_msgs)]}." + ) + break + + elif compress_type in [CompressType.PRE_CUT_BY_TOKEN, CompressType.PRE_CUT_BY_MSG]: + # Under keep_token constraint, keep as many earliest messages as possible + for i, msg in enumerate(user_assistant_msgs): + token_count = self.count_tokens([msg]) + if current_token_count + token_count <= keep_token: + compressed.append(msg) + current_token_count += token_count + else: + if compress_type == CompressType.PRE_CUT_BY_TOKEN or len(compressed) == len(system_msgs): + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token + truncated_content = msg["content"][: keep_token - current_token_count] + compressed.append({"role": msg["role"], "content": truncated_content}) + logger.warning( + f"Truncated messages with {compress_type} to fit within the token limit. " + f"The last user or assistant message after truncation (originally the {i}-th message): {compressed[-1]}." + ) + break + + return compressed diff --git a/metagpt/provider/constant.py b/metagpt/provider/constant.py index dee78dc3bc..1e372b07f2 100644 --- a/metagpt/provider/constant.py +++ b/metagpt/provider/constant.py @@ -29,3 +29,9 @@ # tool_choice value for general_function_schema # https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice GENERAL_TOOL_CHOICE = {"type": "function", "function": {"name": "execute"}} + + +MULTI_MODAL_MODELS = [ + "gpt-4o", + "gpt-4o-mini", +] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e4b3a3f177..5c1b925033 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -22,7 +22,6 @@ from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -from metagpt.schema import Message class GeminiGenerativeModel(GenerativeModel): @@ -73,7 +72,7 @@ def _assistant_msg(self, msg: str) -> dict[str, str]: def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "user", "parts": [msg]} - def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index dbfed72df2..8cb503572d 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -40,7 +40,19 @@ ) -@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL, LLMType.YI]) +@register_provider( + [ + LLMType.OPENAI, + LLMType.FIREWORKS, + LLMType.OPEN_LLM, + LLMType.MOONSHOT, + LLMType.MISTRAL, + LLMType.YI, + LLMType.OPEN_ROUTER, + LLMType.DEEPSEEK, + LLMType.SILICONFLOW, + ] +) class OpenAILLM(BaseLLM): """Check https://platform.openai.com/examples for examples""" @@ -91,10 +103,13 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI if finish_reason: if hasattr(chunk, "usage"): # Some services have usage as an attribute of the chunk, such as Fireworks - usage = CompletionUsage(**chunk.usage) + usage = CompletionUsage(**chunk.usage) if isinstance(chunk.usage, dict) else chunk.usage elif hasattr(chunk.choices[0], "usage"): # The usage of some services is an attribute of chunk.choices[0], such as Moonshot usage = CompletionUsage(**chunk.choices[0].usage) + if "openrouter.ai" in self.config.base_url and hasattr(chunk, "usage") and chunk.usage is not None: + # due to it get token cost from api + usage = chunk.usage log_llm_stream("\n") full_reply_content = "".join(collected_messages) @@ -220,7 +235,7 @@ def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict: # The response content is `code``, but it appears in the content instead of the arguments. code_formats = "```" if message.content.startswith(code_formats) and message.content.endswith(code_formats): - code = CodeParser.parse_code(None, message.content) + code = CodeParser.parse_code(text=message.content) return {"language": "python", "code": code} # reponse is message return {"language": "markdown", "code": self.get_choice_text(rsp)} @@ -285,3 +300,9 @@ async def gen_image( img_url_or_b64 = item.url if resp_format == "url" else item.b64_json imgs.append(decode_image(img_url_or_b64)) return imgs + + def count_tokens(self, messages: list[dict]) -> int: + try: + return count_message_tokens(messages, self.config.model) + except: + return super().count_tokens(messages) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 5c58103089..8d78fcad79 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -2,9 +2,11 @@ import json import os -from typing import Any, Optional, Union +from pathlib import Path +from typing import Any, List, Optional, Set, Union -from llama_index.core import SimpleDirectoryReader, VectorStoreIndex +import fsspec +from llama_index.core import SimpleDirectoryReader from llama_index.core.callbacks.base import CallbackManager from llama_index.core.embeddings import BaseEmbedding from llama_index.core.embeddings.mock_embed_model import MockEmbedding @@ -36,7 +38,12 @@ get_retriever, ) from metagpt.rag.interface import NoEmbedding, RAGObject -from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever +from metagpt.rag.retrievers.base import ( + DeletableRAGRetriever, + ModifiableRAGRetriever, + PersistableRAGRetriever, + QueryableRAGRetriever, +) from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseIndexConfig, @@ -63,7 +70,7 @@ def __init__( response_synthesizer: Optional[BaseSynthesizer] = None, node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, callback_manager: Optional[CallbackManager] = None, - index: Optional[BaseIndex] = None, + transformations: Optional[list[TransformComponent]] = None, ) -> None: super().__init__( retriever=retriever, @@ -71,7 +78,8 @@ def __init__( node_postprocessors=node_postprocessors, callback_manager=callback_manager, ) - self.index = index + self._transformations = transformations or self._default_transformations() + self._filenames = set() @classmethod def from_docs( @@ -83,6 +91,7 @@ def from_docs( llm: LLM = None, retriever_configs: list[BaseRetrieverConfig] = None, ranker_configs: list[BaseRankerConfig] = None, + fs: Optional[fsspec.AbstractFileSystem] = None, ) -> "SimpleEngine": """From docs. @@ -96,19 +105,25 @@ def from_docs( llm: Must supported by llama index. Default OpenAI. retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. ranker_configs: Configuration for rankers. + fs: File system to use. """ if not input_dir and not input_files: raise ValueError("Must provide either `input_dir` or `input_files`.") - documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() + documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files, fs=fs).load_data() cls._fix_document_metadata(documents) - index = VectorStoreIndex.from_documents( - documents=documents, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations = transformations or cls._default_transformations() + nodes = run_transformations(documents, transformations=transformations) + + return cls._from_nodes( + nodes=nodes, + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_objs( @@ -136,13 +151,16 @@ def from_objs( if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") - nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] - index = VectorStoreIndex( + nodes = cls.get_obj_nodes(objs) + + return cls._from_nodes( nodes=nodes, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_index( @@ -161,6 +179,13 @@ async def asearch(self, content: str, **kwargs) -> str: """Inplement tools.SearchInterface""" return await self.aquery(content) + def retrieve(self, query: QueryType) -> list[NodeWithScore]: + query_bundle = QueryBundle(query) if isinstance(query, str) else query + + nodes = super().retrieve(query_bundle) + self._try_reconstruct_obj(nodes) + return nodes + async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: """Allow query to be str.""" query_bundle = QueryBundle(query) if isinstance(query, str) else query @@ -169,21 +194,21 @@ async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: self._try_reconstruct_obj(nodes) return nodes - def add_docs(self, input_files: list[str]): + def add_docs(self, input_files: List[Union[str, Path]]): """Add docs to retriever. retriever must has add_nodes func.""" self._ensure_retriever_modifiable() - documents = SimpleDirectoryReader(input_files=input_files).load_data() + documents = SimpleDirectoryReader(input_files=[str(i) for i in input_files]).load_data() self._fix_document_metadata(documents) - nodes = run_transformations(documents, transformations=self.index._transformations) + nodes = run_transformations(documents, transformations=self._transformations) self._save_nodes(nodes) def add_objs(self, objs: list[RAGObject]): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" self._ensure_retriever_modifiable() - nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + nodes = self.get_obj_nodes(objs) self._save_nodes(nodes) def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): @@ -192,6 +217,65 @@ def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): self._persist(str(persist_dir), **kwargs) + def count(self) -> int: + """Count.""" + self._ensure_retriever_queryable() + + return self.retriever.query_total_count() + + def clear(self, **kwargs): + """Clear.""" + self._ensure_retriever_deletable() + + return self.retriever.clear(**kwargs) + + def delete_docs(self, input_files: List[Union[str, Path]]): + """Delete documents from the index and document store. + + Args: + input_files (List[Union[str, Path]]): A list of file paths or file names to be deleted. + + Raises: + NotImplementedError: If the method is not implemented. + """ + exists_filenames = set() + filenames = {str(i) for i in input_files} + for doc_id, info in self.retriever._index.ref_doc_info.items(): + if info.metadata.get("file_path") in filenames: + exists_filenames.add(doc_id) + + for doc_id in exists_filenames: + self.retriever._index.delete_ref_doc(doc_id, delete_from_docstore=True) + + @staticmethod + def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: + """Converts a list of RAGObjects to a list of ObjectNodes.""" + + return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + + @classmethod + def _from_nodes( + cls, + nodes: list[BaseNode], + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ) -> "SimpleEngine": + embed_model = cls._resolve_embed_model(embed_model, retriever_configs) + llm = llm or get_rag_llm() + + retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model) + rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] + + return cls( + retriever=retriever, + node_postprocessors=rankers, + response_synthesizer=get_response_synthesizer(llm=llm), + transformations=transformations, + ) + @classmethod def _from_index( cls, @@ -201,6 +285,7 @@ def _from_index( ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": llm = llm or get_rag_llm() + retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] @@ -208,7 +293,6 @@ def _from_index( retriever=retriever, node_postprocessors=rankers, response_synthesizer=get_response_synthesizer(llm=llm), - index=index, ) def _ensure_retriever_modifiable(self): @@ -217,6 +301,12 @@ def _ensure_retriever_modifiable(self): def _ensure_retriever_persistable(self): self._ensure_retriever_of_type(PersistableRAGRetriever) + def _ensure_retriever_queryable(self): + self._ensure_retriever_of_type(QueryableRAGRetriever) + + def _ensure_retriever_deletable(self): + self._ensure_retriever_of_type(DeletableRAGRetriever) + def _ensure_retriever_of_type(self, required_type: BaseRetriever): """Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever. @@ -259,3 +349,11 @@ def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = return MockEmbedding(embed_dim=1) return embed_model or get_rag_embedding() + + @staticmethod + def _default_transformations(): + return [SentenceSplitter()] + + @property + def filenames(self) -> Set[str]: + return self._filenames diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index fbdfbf1a81..e58643efe5 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -26,6 +26,9 @@ def get_instance(self, key: Any, **kwargs) -> Any: if creator: return creator(**kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Creator not registered for key: {key}") @@ -33,19 +36,26 @@ class ConfigBasedFactory(GenericFactory): """Designed to get objects based on object type.""" def get_instance(self, key: Any, **kwargs) -> Any: - """Key is config, such as a pydantic model. + """Get instance by the type of key. - Call func by the type of key, and the key will be passed to func. + Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func. + Raise Exception if key not found. """ creator = self._creators.get(type(key)) if creator: return creator(key, **kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Unknown config: `{type(key)}`, {key}") @staticmethod def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any: - """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.""" + """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs. + + Return None if not found. + """ if config is not None and hasattr(config, key): val = getattr(config, key) if val is not None: @@ -54,6 +64,4 @@ def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any if key in kwargs: return kwargs[key] - raise KeyError( - f"The key '{key}' is required but not provided in either configuration object or keyword arguments." - ) + return None diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 4247db256f..19b8b36f67 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -1,37 +1,108 @@ """RAG Embedding Factory.""" +from __future__ import annotations + +from typing import Any, Optional from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding -from metagpt.config2 import config +from metagpt.config2 import Config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.base import GenericFactory class RAGEmbeddingFactory(GenericFactory): - """Create LlamaIndex Embedding with MetaGPT's config.""" + """Create LlamaIndex Embedding with MetaGPT's embedding config.""" - def __init__(self): + def __init__(self, config: Optional[Config] = None): creators = { + EmbeddingType.OPENAI: self._create_openai, + EmbeddingType.AZURE: self._create_azure, + EmbeddingType.GEMINI: self._create_gemini, + EmbeddingType.OLLAMA: self._create_ollama, + # For backward compatibility LLMType.OPENAI: self._create_openai, LLMType.AZURE: self._create_azure, } super().__init__(creators) + self.config = config if config else Config.default() + + def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding: + """Key is EmbeddingType.""" + return super().get_instance(key or self._resolve_embedding_type()) + + def _resolve_embedding_type(self) -> EmbeddingType | LLMType: + """Resolves the embedding type. + + If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE. + Raise TypeError if embedding type not found. + """ + if self.config.embedding.api_type: + return self.config.embedding.api_type + + if self.config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: + return self.config.llm.api_type + + raise TypeError("To use RAG, please set your embedding in config2.yaml.") + + def _create_openai(self) -> "OpenAIEmbedding": + from llama_index.embeddings.openai import OpenAIEmbedding + + params = dict( + api_key=self.config.embedding.api_key or self.config.llm.api_key, + api_base=self.config.embedding.base_url or self.config.llm.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return OpenAIEmbedding(**params) + + def _create_azure(self) -> AzureOpenAIEmbedding: + params = dict( + api_key=self.config.embedding.api_key or self.config.llm.api_key, + azure_endpoint=self.config.embedding.base_url or self.config.llm.base_url, + api_version=self.config.embedding.api_version or self.config.llm.api_version, + ) + + self._try_set_model_and_batch_size(params) + + return AzureOpenAIEmbedding(**params) + + def _create_gemini(self) -> "GeminiEmbedding": + from llama_index.embeddings.gemini import GeminiEmbedding + + params = dict( + api_key=self.config.embedding.api_key, + api_base=self.config.embedding.base_url, + ) - def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding: - """Key is LLMType, default use config.llm.api_type.""" - return super().get_instance(key or config.llm.api_type) + self._try_set_model_and_batch_size(params) - def _create_openai(self): - return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url) + return GeminiEmbedding(**params) - def _create_azure(self): - return AzureOpenAIEmbedding( - azure_endpoint=config.llm.base_url, - api_key=config.llm.api_key, - api_version=config.llm.api_version, + def _create_ollama(self) -> "OllamaEmbedding": + from llama_index.embeddings.ollama import OllamaEmbedding + + params = dict( + base_url=self.config.embedding.base_url, ) + self._try_set_model_and_batch_size(params) + + return OllamaEmbedding(**params) + + def _try_set_model_and_batch_size(self, params: dict): + """Set the model_name and embed_batch_size only when they are specified.""" + if self.config.embedding.model: + params["model_name"] = self.config.embedding.model + + if self.config.embedding.embed_batch_size: + params["embed_batch_size"] = self.config.embedding.embed_batch_size + + def _raise_for_key(self, key: Any): + raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}") + -get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding +def get_rag_embedding(key: EmbeddingType = None, config: Optional[Config] = None): + return RAGEmbeddingFactory(config=config).get_rag_embedding(key) diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index a56471359e..f897af3ad0 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -48,7 +48,7 @@ def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: db = chromadb.PersistentClient(str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 17c499b766..bd252771ac 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -1,5 +1,5 @@ """RAG LLM.""" - +import asyncio from typing import Any from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW @@ -12,10 +12,9 @@ from llama_index.core.llms.callbacks import llm_completion_callback from pydantic import Field -from metagpt.config2 import config -from metagpt.llm import LLM +from metagpt.config2 import Config from metagpt.provider.base_llm import BaseLLM -from metagpt.utils.async_helper import run_coroutine_in_new_loop +from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -26,9 +25,34 @@ class RAGLLM(CustomLLM): """ model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.") - context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) - num_output: int = config.llm.max_token - model_name: str = config.llm.model + context_window: int = -1 + num_output: int = -1 + model_name: str = "" + + def __init__( + self, + model_infer: BaseLLM, + context_window: int = -1, + num_output: int = -1, + model_name: str = "", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + config = Config.default() + if context_window < 0: + context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) + + if num_output < 0: + num_output = config.llm.max_token + + if not model_name: + model_name = config.llm.model + + self.model_infer = model_infer + self.context_window = context_window + self.num_output = num_output + self.model_name = model_name @property def metadata(self) -> LLMMetadata: @@ -39,7 +63,8 @@ def metadata(self) -> LLMMetadata: @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs)) + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs)) @llm_completion_callback() async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: @@ -53,4 +78,6 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM: """Get llm that can be used by LlamaIndex.""" + from metagpt.llm import LLM + return RAGLLM(model_infer=model_infer or LLM()) diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 476fe8c1a6..c825c228ce 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -8,6 +8,8 @@ from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor from metagpt.rag.schema import ( BaseRankerConfig, + BGERerankConfig, + CohereRerankConfig, ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig, @@ -22,6 +24,8 @@ def __init__(self): LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker, ObjectRankerConfig: self._create_object_ranker, + CohereRerankConfig: self._create_cohere_rerank, + BGERerankConfig: self._create_bge_rerank, } super().__init__(creators) @@ -34,6 +38,7 @@ def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[ def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank: config.llm = self._extract_llm(config, **kwargs) + return LLMRerank(**config.model_dump()) def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: @@ -45,6 +50,26 @@ def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRe ) return ColbertRerank(**config.model_dump()) + def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.cohere_rerank import CohereRerank + except ImportError: + raise ImportError( + "`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`" + ) + return CohereRerank(**config.model_dump()) + + def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.flag_embedding_reranker import ( + FlagEmbeddingReranker, + ) + except ImportError: + raise ImportError( + "`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`" + ) + return FlagEmbeddingReranker(**config.model_dump()) + def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: return ObjectSortPostprocessor(**config.model_dump()) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 65729002ea..6bc8e4ad5d 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -1,10 +1,14 @@ """RAG Retriever Factory.""" -import copy + +from functools import wraps import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.embeddings.mock_embed_model import MockEmbedding +from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -24,10 +28,25 @@ ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, FAISSRetrieverConfig, - IndexRetrieverConfig, ) +def get_or_build_index(build_index_func): + """Decorator to get or build an index. + + Get index using `_extract_index` method, if not found, using build_index_func. + """ + + @wraps(build_index_func) + def wrapper(self, config, **kwargs): + index = self._extract_index(config, **kwargs) + if index is not None: + return index + return build_index_func(self, config, **kwargs) + + return wrapper + + class RetrieverFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" @@ -54,48 +73,85 @@ def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0] def _create_default(self, **kwargs) -> RAGRetriever: - return self._extract_index(**kwargs).as_retriever() + index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs) + + return index.as_retriever() def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_faiss_index(config, **kwargs) return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: - config.index = copy.deepcopy(self._extract_index(config, **kwargs)) + index = self._extract_index(config, **kwargs) + nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs) - return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump()) + if index and not config.index: + config.index = index - def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: - db = chromadb.PersistentClient(path=str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) + if not config.index and config.create_index: + config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1)) - vector_store = ChromaVectorStore(chroma_collection=chroma_collection) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) + + def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: + config.index = self._build_chroma_index(config, **kwargs) return ChromaRetriever(**config.model_dump()) def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever: - vector_store = ElasticsearchStore(**config.store_config.model_dump()) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_es_index(config, **kwargs) return ElasticsearchRetriever(**config.model_dump()) def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) + def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]: + return self._val_from_config_or_kwargs("nodes", config, **kwargs) + + def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding: + return self._val_from_config_or_kwargs("embed_model", config, **kwargs) + + def _build_default_index(self, **kwargs) -> VectorStoreIndex: + index = VectorStoreIndex( + nodes=self._extract_nodes(**kwargs), + embed_model=self._extract_embed_model(**kwargs), + ) + + return index + + @get_or_build_index + def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + @get_or_build_index + def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex: + db = chromadb.PersistentClient(path=str(config.persist_path)) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + @get_or_build_index + def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + def _build_index_from_vector_store( - self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs + self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs ) -> VectorStoreIndex: storage_context = StorageContext.from_defaults(vector_store=vector_store) - old_index = self._extract_index(config, **kwargs) - new_index = VectorStoreIndex( - nodes=list(old_index.docstore.docs.values()), + index = VectorStoreIndex( + nodes=self._extract_nodes(config, **kwargs), storage_context=storage_context, - embed_model=old_index._embed_model, + embed_model=self._extract_embed_model(config, **kwargs), ) - return new_index + + return index get_retriever = RetrieverFactory().get_retriever diff --git a/metagpt/rag/prompts/__init__.py b/metagpt/rag/prompts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metagpt/rag/prompts/default_prompts.py b/metagpt/rag/prompts/default_prompts.py new file mode 100644 index 0000000000..eadcaa7702 --- /dev/null +++ b/metagpt/rag/prompts/default_prompts.py @@ -0,0 +1,35 @@ +"""Set of default prompts.""" + +from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.prompts.prompt_type import PromptType + +DEFAULT_CHOICE_SELECT_PROMPT_TMPL = """ +You are a highly efficient assistant, tasked with evaluating a list of documents to a given question. + +I will provide you with a question with a list of documents. Your task is to respond with the numbers of the documents you should consult to answer the question, in order of relevance, as well as the relevance score. + + +## Question +{query_str} + +## Documents +{context_str} + +## Format Example +Doc: 9, Relevance: 7 + +## Instructions +- Understand the question. +- Evaluate the relevance between the question and the documents. +- The relevance score is a number from 1-10 based on how relevant you think the document is to the question. +- Do not include any documents that are not relevant to the question. +- If none of the documents provided contain information that directly answers the question, simply respond with "no relevant documents". + +## Constraint +Format: Just print the result in format like **Format Example**. + +## Action +Follow instructions, generate output and make sure it follows the **Constraint**. +""" + +DEFAULT_CHOICE_SELECT_PROMPT = PromptTemplate(DEFAULT_CHOICE_SELECT_PROMPT_TMPL, prompt_type=PromptType.CHOICE_SELECT) diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index a7b8368336..69475d6ea4 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -45,3 +45,31 @@ def __subclasshook__(cls, C): @abstractmethod def persist(self, persist_dir: str, **kwargs) -> None: """To support persist, must inplement this func""" + + +class QueryableRAGRetriever(RAGRetriever): + """Support querying total count.""" + + @classmethod + def __subclasshook__(cls, C): + if cls is QueryableRAGRetriever: + return check_methods(C, "query_total_count") + return NotImplemented + + @abstractmethod + def query_total_count(self) -> int: + """To support querying total count, must implement this func.""" + + +class DeletableRAGRetriever(RAGRetriever): + """Support deleting all nodes.""" + + @classmethod + def __subclasshook__(cls, C): + if cls is DeletableRAGRetriever: + return check_methods(C, "clear") + return NotImplemented + + @abstractmethod + def clear(self, **kwargs) -> int: + """To support deleting all nodes, must implement this func.""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 241820cf4a..4891fad504 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -1,4 +1,5 @@ """BM25 retriever.""" +from pathlib import Path from typing import Callable, Optional from llama_index.core import VectorStoreIndex @@ -36,12 +37,37 @@ def __init__( def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """Support add nodes.""" + self._nodes.extend(nodes) self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) - self._index.insert_nodes(nodes, **kwargs) + if self._index: + self._index.insert_nodes(nodes, **kwargs) def persist(self, persist_dir: str, **kwargs) -> None: """Support persist.""" - self._index.storage_context.persist(persist_dir) + + if self._index: + self._index.storage_context.persist(persist_dir) + + def query_total_count(self) -> int: + """Support query total count.""" + + return len(self._nodes) + + def clear(self, **kwargs) -> None: + """Support deleting all nodes.""" + + self._delete_json_files(kwargs.get("persist_dir")) + self._nodes = [] + + @staticmethod + def _delete_json_files(directory: str): + """Delete all JSON files in the specified directory.""" + + if not directory: + return + + for file in Path(directory).glob("*.json"): + file.unlink() diff --git a/metagpt/rag/retrievers/chroma_retriever.py b/metagpt/rag/retrievers/chroma_retriever.py index d41f375e4c..4d3d4469e5 100644 --- a/metagpt/rag/retrievers/chroma_retriever.py +++ b/metagpt/rag/retrievers/chroma_retriever.py @@ -2,11 +2,16 @@ from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.schema import BaseNode +from llama_index.vector_stores.chroma import ChromaVectorStore class ChromaRetriever(VectorIndexRetriever): """Chroma retriever.""" + @property + def vector_store(self) -> ChromaVectorStore: + return self._vector_store + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: """Support add nodes.""" self._index.insert_nodes(nodes, **kwargs) @@ -15,3 +20,15 @@ def persist(self, persist_dir: str, **kwargs) -> None: """Support persist. Chromadb automatically saves, so there is no need to implement.""" + + def query_total_count(self) -> int: + """Support query total count.""" + + return self.vector_store._collection.count() + + def clear(self, **kwargs) -> None: + """Support deleting all nodes.""" + + ids = self.vector_store._collection.get()["ids"] + if ids: + self.vector_store._collection.delete(ids=ids) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 183f6e0c76..4180536a35 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,15 +1,21 @@ """RAG schemas.""" - +from enum import Enum from pathlib import Path -from typing import Any, Literal, Union +from typing import Any, ClassVar, List, Literal, Optional, Union +from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex +from llama_index.core.prompts import BasePromptTemplate from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from metagpt.config2 import Config +from metagpt.configs.embedding_config import EmbeddingType +from metagpt.logs import logger from metagpt.rag.interface import RAGObject +from metagpt.rag.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT class BaseRetrieverConfig(BaseModel): @@ -31,12 +37,36 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.") + dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") + + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + config = Config.default() + self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( + config.embedding.api_type, 1536 + ) + if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: + logger.warning( + f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" + ) + + return self class BM25RetrieverConfig(IndexRetrieverConfig): """Config for BM25-based retrievers.""" + create_index: bool = Field( + default=False, + description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.", + exclude=True, + ) _no_embedding: bool = PrivateAttr(default=True) @@ -45,6 +75,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class ElasticsearchStoreConfig(BaseModel): @@ -93,6 +126,9 @@ class LLMRankerConfig(BaseRankerConfig): default=None, description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.", ) + choice_select_prompt: Optional[BasePromptTemplate] = Field( + default=DEFAULT_CHOICE_SELECT_PROMPT, description="Choice select prompt." + ) class ColbertRerankConfig(BaseRankerConfig): @@ -101,6 +137,16 @@ class ColbertRerankConfig(BaseRankerConfig): keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") +class CohereRerankConfig(BaseRankerConfig): + model: str = Field(default="rerank-english-v3.0") + api_key: str = Field(default="YOUR_COHERE_API") + + +class BGERerankConfig(BaseRankerConfig): + model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.") + use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.") + + class ObjectRankerConfig(BaseRankerConfig): field_name: str = Field(..., description="field name of the object, field's value must can be compared.") order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") @@ -130,6 +176,9 @@ class ChromaIndexConfig(VectorIndexConfig): """Config for chroma-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class BM25IndexConfig(BaseIndexConfig): @@ -176,3 +225,51 @@ def get_obj_metadata(obj: RAGObject) -> dict: ) return metadata.model_dump() + + +class OmniParseType(str, Enum): + """OmniParseType""" + + PDF = "PDF" + DOCUMENT = "DOCUMENT" + + +class ParseResultType(str, Enum): + """The result type for the parser.""" + + TXT = "text" + MD = "markdown" + JSON = "json" + + +class OmniParseOptions(BaseModel): + """OmniParse Options config""" + + result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type") + parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type") + max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests") + num_workers: int = Field( + default=5, + gt=0, + lt=10, + description="Number of concurrent requests for multiple files", + ) + + +class OminParseImage(BaseModel): + image: str = Field(default="", description="image str bytes") + image_name: str = Field(default="", description="image name") + image_info: Optional[dict] = Field(default={}, description="image info") + + +class OmniParsedResult(BaseModel): + markdown: str = Field(default="", description="markdown text") + text: str = Field(default="", description="plain text") + images: Optional[List[OminParseImage]] = Field(default=[], description="images") + metadata: Optional[dict] = Field(default={}, description="metadata") + + @model_validator(mode="before") + def set_markdown(cls, values): + if not values.get("markdown"): + values["markdown"] = values.get("text") + return values diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 166f8cfd07..7ec937675a 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -5,13 +5,16 @@ @Author : alexanderwu @File : architect.py """ +from pydantic import Field -from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.roles.role import Role +from metagpt.actions.write_prd import WritePRD +from metagpt.prompts.di.architect import ARCHITECT_EXAMPLE, ARCHITECT_INSTRUCTION +from metagpt.roles.di.role_zero import RoleZero +from metagpt.tools.libs.terminal import Terminal -class Architect(Role): +class Architect(RoleZero): """ Represents an Architect role in a software development process. @@ -24,16 +27,32 @@ class Architect(Role): name: str = "Bob" profile: str = "Architect" - goal: str = "design a concise, usable, complete software system" + goal: str = "design a concise, usable, complete software system. output the system design." constraints: str = ( "make sure the architecture is simple enough and use appropriate open source " "libraries. Use same language as user requirement" ) + terminal: Terminal = Field(default_factory=Terminal, exclude=True) + instruction: str = ARCHITECT_INSTRUCTION + tools: list[str] = [ + "Editor:write,read,similarity_search", + "RoleZero", + "Terminal:run_command", + ] def __init__(self, **kwargs) -> None: super().__init__(**kwargs) + + # NOTE: The following init setting will only be effective when self.use_fixed_sop is changed to True + self.enable_memory = False # Initialize actions specific to the Architect role self.set_actions([WriteDesign]) # Set events or actions the Architect should watch or be aware of self._watch({WritePRD}) + + def _retrieve_experience(self) -> str: + return ARCHITECT_EXAMPLE + + def _update_tool_execution(self): + self.tool_execution_map.update({"Terminal.run_command": self.terminal.run_command}) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py new file mode 100644 index 0000000000..54ee3864b6 --- /dev/null +++ b/metagpt/roles/di/data_analyst.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from typing import Annotated + +from pydantic import Field, model_validator + +from metagpt.actions.di.execute_nb_code import ExecuteNbCode +from metagpt.actions.di.write_analysis_code import CheckData, WriteAnalysisCode +from metagpt.logs import logger +from metagpt.prompts.di.data_analyst import ( + CODE_STATUS, + EXTRA_INSTRUCTION, + TASK_TYPE_DESC, +) +from metagpt.prompts.di.role_zero import ROLE_INSTRUCTION +from metagpt.prompts.di.write_analysis_code import DATA_INFO +from metagpt.roles.di.role_zero import RoleZero +from metagpt.schema import Message, TaskResult +from metagpt.strategy.experience_retriever import ExpRetriever, KeywordExpRetriever +from metagpt.strategy.task_type import TaskType +from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender +from metagpt.tools.tool_registry import register_tool + + +@register_tool(include_functions=["write_and_exec_code"]) +class DataAnalyst(RoleZero): + name: str = "David" + profile: str = "DataAnalyst" + goal: str = "Take on any data-related tasks, such as data analysis, machine learning, deep learning, web browsing, web scraping, web searching, terminal operation, document QA & analysis, etc." + instruction: str = ROLE_INSTRUCTION + EXTRA_INSTRUCTION + task_type_desc: str = TASK_TYPE_DESC + + tools: list[str] = [ + "Plan", + "DataAnalyst", + "RoleZero", + "Browser", + "Editor:write,read,similarity_search", + "SearchEnhancedQA", + ] + custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read,similarity_search"] + custom_tool_recommender: ToolRecommender = None + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever() + + use_reflection: bool = True + write_code: WriteAnalysisCode = Field(default_factory=WriteAnalysisCode, exclude=True) + execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True) + + @model_validator(mode="after") + def set_custom_tool(self): + if self.custom_tools and not self.custom_tool_recommender: + self.custom_tool_recommender = BM25ToolRecommender(tools=self.custom_tools, force=True) + + def _update_tool_execution(self): + self.tool_execution_map.update( + { + "DataAnalyst.write_and_exec_code": self.write_and_exec_code, + } + ) + + async def write_and_exec_code(self, instruction: str = ""): + """Write a code block for current task and execute it in an interactive notebook environment. + + Args: + instruction (optional, str): Further hints or notice other than the current task instruction, must be very concise and can be empty. Defaults to "". + """ + if self.planner.plan: + logger.info(f"Current task {self.planner.plan.current_task}") + + counter = 0 + success = False + await self.execute_code.init_code() + + # plan info + if self.planner.current_task: + # clear task result from plan to save token, since it has been in memory + plan_status = self.planner.get_plan_status(exclude=["task_result"]) + plan_status += f"\nFurther Task Instruction: {instruction}" + else: + return "No current_task found now. Please use command Plan.append_task to add a task first." + + # tool info + if self.custom_tool_recommender: + plan = self.planner.plan + fixed = ["Terminal"] if "Terminal" in self.custom_tools else None + tool_info = await self.custom_tool_recommender.get_recommended_tool_info(fixed=fixed, plan=plan) + else: + tool_info = "" + + # data info + await self._check_data() + + while not success and counter < 3: + ### write code ### + logger.info("ready to WriteAnalysisCode") + use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial + + code = await self.write_code.run( + user_requirement=self.planner.plan.goal, + plan_status=plan_status, + tool_info=tool_info, + working_memory=self.rc.working_memory.get(), + use_reflection=use_reflection, + memory=self.rc.memory.get(self.memory_k), + ) + self.rc.working_memory.add(Message(content=code, role="assistant", cause_by=WriteAnalysisCode)) + + ### execute code ### + result, success = await self.execute_code.run(code) + print(result) + + self.rc.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode)) + + ### process execution result ### + counter += 1 + if success: + task_result = TaskResult(code=code, result=result, is_success=success) + self.planner.current_task.update_task_result(task_result) + + status = "Success" if success else "Failed" + output = CODE_STATUS.format(code=code, status=status, result=result) + if success: + output += "The code written has been executed successfully." + self.rc.working_memory.clear() + return output + + async def _check_data(self): + if not self.planner.plan.get_finished_tasks() or self.planner.plan.current_task.task_type not in [ + TaskType.DATA_PREPROCESS.type_name, + TaskType.FEATURE_ENGINEERING.type_name, + TaskType.MODEL_TRAIN.type_name, + ]: + return + logger.info("Check updated data") + code = await CheckData().run(self.planner.plan) + if not code.strip(): + return + result, success = await self.execute_code.run(code) + if success: + print(result) + data_info = DATA_INFO.format(info=result) + self.rc.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData)) + + async def _run_special_command(self, cmd) -> str: + """command requiring special check or parsing.""" + # TODO: duplicate with Engineer2._run_special_command, consider dedup + + # finish current task before end. + command_output = "" + if cmd["command_name"] == "end" and not self.planner.plan.is_plan_finished(): + self.planner.plan.finish_all_tasks() + command_output += "All tasks are finished.\n" + command_output += await super()._run_special_command(cmd) + return command_output diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index 44a85fe065..f90a928dfc 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -1,11 +1,11 @@ from __future__ import annotations import json -from typing import Literal, Union +from typing import Literal from pydantic import Field, model_validator -from metagpt.actions.di.ask_review import ReviewConst +# from metagpt.actions.di.ask_review import ReviewConst from metagpt.actions.di.execute_nb_code import ExecuteNbCode from metagpt.actions.di.write_analysis_code import CheckData, WriteAnalysisCode from metagpt.logs import logger @@ -15,6 +15,7 @@ from metagpt.strategy.task_type import TaskType from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.utils.common import CodeParser +from metagpt.utils.report import ThoughtReporter REACT_THINK_PROMPT = """ # User Requirement @@ -39,10 +40,11 @@ class DataInterpreter(Role): use_plan: bool = True use_reflection: bool = False execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True) - tools: Union[str, list[str]] = [] # Use special symbol [""] to indicate use of all registered tools + tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools tool_recommender: ToolRecommender = None react_mode: Literal["plan_and_act", "react"] = "plan_and_act" max_react_loop: int = 10 # used for react mode + user_requirement: str = "" @model_validator(mode="after") def set_plan_and_tool(self) -> "Interpreter": @@ -50,7 +52,7 @@ def set_plan_and_tool(self) -> "Interpreter": self.use_plan = ( self.react_mode == "plan_and_act" ) # create a flag for convenience, overwrite any passed-in value - if self.tools: + if self.tools and not self.tool_recommender: self.tool_recommender = BM25ToolRecommender(tools=self.tools) self.set_actions([WriteAnalysisCode]) self._set_state(0) @@ -62,7 +64,7 @@ def working_memory(self): async def _think(self) -> bool: """Useful in 'react' mode. Use LLM to decide whether and what to do next.""" - user_requirement = self.get_memories()[0].content + self.user_requirement = self.get_memories()[-1].content context = self.working_memory.get() if not context: @@ -71,9 +73,10 @@ async def _think(self) -> bool: self._set_state(0) return True - prompt = REACT_THINK_PROMPT.format(user_requirement=user_requirement, context=context) - rsp = await self.llm.aask(prompt) - rsp_dict = json.loads(CodeParser.parse_code(block=None, text=rsp)) + prompt = REACT_THINK_PROMPT.format(user_requirement=self.user_requirement, context=context) + async with ThoughtReporter(enable_llm_stream=True): + rsp = await self.llm.aask(prompt) + rsp_dict = json.loads(CodeParser.parse_code(text=rsp)) self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant")) need_action = rsp_dict["state"] self._set_state(0) if need_action else self._set_state(-1) @@ -83,9 +86,10 @@ async def _think(self) -> bool: async def _act(self) -> Message: """Useful in 'react' mode. Return a Message conforming to Role._act interface.""" code, _, _ = await self._write_and_exec_code() - return Message(content=code, role="assistant", cause_by=WriteAnalysisCode) + return Message(content=code, role="assistant", sent_from=self._setting, cause_by=WriteAnalysisCode) async def _plan_and_act(self) -> Message: + self._set_state(0) try: rsp = await super()._plan_and_act() await self.execute_code.terminate() @@ -108,7 +112,7 @@ async def _write_and_exec_code(self, max_retry: int = 3): plan_status = self.planner.get_plan_status() if self.use_plan else "" # tool info - if self.tools: + if self.tool_recommender: context = ( self.working_memory.get()[-1].content if self.working_memory.get() else "" ) # thoughts from _think stage in 'react' mode @@ -135,11 +139,11 @@ async def _write_and_exec_code(self, max_retry: int = 3): ### process execution result ### counter += 1 - if not success and counter >= max_retry: - logger.info("coding failed!") - review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER) - if ReviewConst.CHANGE_WORDS[0] in review: - counter = 0 # redo the task again with help of human suggestions + # if not success and counter >= max_retry: + # logger.info("coding failed!") + # review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER) + # if ReviewConst.CHANGE_WORDS[0] in review: + # counter = 0 # redo the task again with help of human suggestions return code, result, success @@ -153,10 +157,8 @@ async def _write_code( logger.info(f"ready to {todo.name}") use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial - user_requirement = self.get_memories()[0].content - code = await todo.run( - user_requirement=user_requirement, + user_requirement=self.user_requirement, plan_status=plan_status, tool_info=tool_info, working_memory=self.working_memory.get(), diff --git a/metagpt/roles/di/engineer2.py b/metagpt/roles/di/engineer2.py new file mode 100644 index 0000000000..683b5f4a93 --- /dev/null +++ b/metagpt/roles/di/engineer2.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import os +from pathlib import Path + +from pydantic import Field + +from metagpt.logs import logger + +# from metagpt.actions.write_code_review import ValidateAndRewriteCode +from metagpt.prompts.di.engineer2 import ( + CURRENT_STATE, + ENGINEER2_INSTRUCTION, + WRITE_CODE_PROMPT, + WRITE_CODE_SYSTEM_PROMPT, +) +from metagpt.roles.di.role_zero import RoleZero +from metagpt.schema import UserMessage +from metagpt.strategy.experience_retriever import ENGINEER_EXAMPLE +from metagpt.tools.libs.cr import CodeReview +from metagpt.tools.libs.deployer import Deployer +from metagpt.tools.libs.git import git_create_pull +from metagpt.tools.libs.image_getter import ImageGetter +from metagpt.tools.libs.terminal import Terminal +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import CodeParser, awrite +from metagpt.utils.report import EditorReporter + + +@register_tool(include_functions=["write_new_code"]) +class Engineer2(RoleZero): + name: str = "Alex" + profile: str = "Engineer" + goal: str = "Take on game, app, web development and deployment." + instruction: str = ENGINEER2_INSTRUCTION + terminal: Terminal = Field(default_factory=Terminal, exclude=True) + deployer: Deployer = Field(default_factory=Deployer, exclude=True) + tools: list[str] = [ + "Plan", + "Editor", + "RoleZero", + "Terminal:run_command", + "Browser:goto,scroll", + "git_create_pull", + "SearchEnhancedQA", + "Engineer2", + "CodeReview", + "ImageGetter", + "Deployer", + ] + # SWE Agent parameter + run_eval: bool = False + output_diff: str = "" + max_react_loop: int = 40 + + async def _think(self) -> bool: + await self._format_instruction() + res = await super()._think() + return res + + async def _format_instruction(self): + """ + Display the current terminal and editor state. + This information will be dynamically added to the command prompt. + """ + current_directory = (await self.terminal.run_command("pwd")).strip() + self.editor._set_workdir(current_directory) + state = { + "editor_open_file": self.editor.current_file, + "current_directory": current_directory, + } + self.cmd_prompt_current_state = CURRENT_STATE.format(**state).strip() + + def _update_tool_execution(self): + # validate = ValidateAndRewriteCode() + cr = CodeReview() + image_getter = ImageGetter() + self.exclusive_tool_commands.append("Engineer2.write_new_code") + if self.run_eval is True: + # Evalute tool map + self.tool_execution_map.update( + { + "git_create_pull": git_create_pull, + "Engineer2.write_new_code": self.write_new_code, + "ImageGetter.get_image": image_getter.get_image, + "CodeReview.review": cr.review, + "CodeReview.fix": cr.fix, + "Terminal.run_command": self._eval_terminal_run, + "RoleZero.ask_human": self._end, + "RoleZero.reply_to_human": self._end, + "Deployer.deploy_to_public": self._deploy_to_public, + } + ) + else: + # Default tool map + self.tool_execution_map.update( + { + "git_create_pull": git_create_pull, + "Engineer2.write_new_code": self.write_new_code, + "ImageGetter.get_image": image_getter.get_image, + "CodeReview.review": cr.review, + "CodeReview.fix": cr.fix, + "Terminal.run_command": self.terminal.run_command, + "Deployer.deploy_to_public": self._deploy_to_public, + } + ) + + def _retrieve_experience(self) -> str: + return ENGINEER_EXAMPLE + + async def write_new_code(self, path: str, file_description: str = "") -> str: + """Write a new code file. + + Args: + path (str): The absolute path of the file to be created. + file_description (optional, str): "Brief description and important notes of the file content, must be very concise and can be empty. Defaults to "". + """ + # If the path is not absolute, try to fix it with the editor's working directory. + path = self.editor._try_fix_path(path) + plan_status, _ = self._get_plan_status() + prompt = WRITE_CODE_PROMPT.format( + user_requirement=self.planner.plan.goal, + plan_status=plan_status, + file_path=path, + file_description=file_description, + file_name=os.path.basename(path), + ) + # Sometimes the Engineer repeats the last command to respond. + # Replace the last command with a manual prompt to guide the Engineer to write new code. + memory = self.rc.memory.get(self.memory_k)[:-1] + context = self.llm.format_msg(memory + [UserMessage(content=prompt)]) + + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "code", "filename": Path(path).name, "src_path": path}, "meta") + rsp = await self.llm.aask(context, system_msgs=[WRITE_CODE_SYSTEM_PROMPT]) + code = CodeParser.parse_code(text=rsp) + await awrite(path, code) + await reporter.async_report(path, "path") + + # TODO: Consider adding line no to be ready for editing. + return f"The file {path} has been successfully created, with content:\n{code}" + + async def _deploy_to_public(self, dist_dir): + """fix the dist_dir path to absolute path before deploying + Args: + dist_dir (str): The dist directory of the web project after run build. This must be an absolute path. + """ + # Try to fix the path with the editor's working directory. + if not Path(dist_dir).is_absolute(): + default_dir = self.editor._try_fix_path(dist_dir) + if not default_dir.exists(): + raise ValueError("dist_dir must be an absolute path.") + dist_dir = default_dir + return await self.deployer.deploy_to_public(dist_dir) + + async def _eval_terminal_run(self, cmd): + """change command pull/push/commit to end.""" + if any([cmd_key_word in cmd for cmd_key_word in ["pull", "push", "commit"]]): + # The Engineer2 attempts to submit the repository after fixing the bug, thereby reaching the end of the fixing process. + logger.info("Engineer2 use cmd:{cmd}\nCurrent test case is finished.") + # Set self.rc.todo to None to stop the engineer. + self._set_state(-1) + else: + command_output = await self.terminal.run_command(cmd) + return command_output + + async def _end(self): + if not self.planner.plan.is_plan_finished(): + self.planner.plan.finish_all_tasks() + return await super()._end() diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py new file mode 100644 index 0000000000..5799d5e166 --- /dev/null +++ b/metagpt/roles/di/role_zero.py @@ -0,0 +1,626 @@ +from __future__ import annotations + +import inspect +import json +import re +import traceback +from datetime import datetime +from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple + +from pydantic import Field, model_validator + +from metagpt.actions import Action, UserRequirement +from metagpt.actions.di.run_command import RunCommand +from metagpt.actions.search_enhanced_qa import SearchEnhancedQA +from metagpt.const import IMAGES +from metagpt.exp_pool import exp_cache +from metagpt.exp_pool.context_builders import RoleZeroContextBuilder +from metagpt.exp_pool.serializers import RoleZeroSerializer +from metagpt.logs import logger +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory +from metagpt.prompts.di.role_zero import ( + ASK_HUMAN_COMMAND, + ASK_HUMAN_GUIDANCE_FORMAT, + CMD_PROMPT, + DETECT_LANGUAGE_PROMPT, + END_COMMAND, + JSON_REPAIR_PROMPT, + QUICK_RESPONSE_SYSTEM_PROMPT, + QUICK_THINK_EXAMPLES, + QUICK_THINK_PROMPT, + QUICK_THINK_SYSTEM_PROMPT, + QUICK_THINK_TAG, + REGENERATE_PROMPT, + REPORT_TO_HUMAN_PROMPT, + ROLE_INSTRUCTION, + SUMMARY_PROBLEM_WHEN_DUPLICATE, + SUMMARY_PROMPT, + SYSTEM_PROMPT, +) +from metagpt.roles import Role +from metagpt.schema import AIMessage, Message, UserMessage +from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever +from metagpt.strategy.planner import Planner +from metagpt.tools.libs.browser import Browser +from metagpt.tools.libs.editor import Editor +from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images +from metagpt.utils.repair_llm_raw_output import ( + RepairType, + repair_escape_error, + repair_llm_raw_output, +) +from metagpt.utils.report import ThoughtReporter + + +@register_tool(include_functions=["ask_human", "reply_to_human"]) +class RoleZero(Role): + """A role who can think and act dynamically""" + + # Basic Info + name: str = "Zero" + profile: str = "RoleZero" + goal: str = "" + system_msg: Optional[list[str]] = None # Use None to conform to the default value at llm.aask + system_prompt: str = SYSTEM_PROMPT # Use None to conform to the default value at llm.aask + cmd_prompt: str = CMD_PROMPT + cmd_prompt_current_state: str = "" + instruction: str = ROLE_INSTRUCTION + task_type_desc: Optional[str] = None + + # React Mode + react_mode: Literal["react"] = "react" + max_react_loop: int = 50 # used for react mode + + # Tools + tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools + tool_recommender: Optional[ToolRecommender] = None + tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {} + special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Terminal.run_command", "RoleZero.ask_human"] + # List of exclusive tool commands. + # If multiple instances of these commands appear, only the first occurrence will be retained. + exclusive_tool_commands: list[str] = [ + "Editor.edit_file_by_replace", + "Editor.insert_content_at_line", + "Editor.append_file", + "Editor.open_file", + ] + # Equipped with three basic tools by default for optional use + editor: Editor = Editor(enable_auto_lint=True) + browser: Browser = Browser() + + # Experience + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever() + + # Others + observe_all_msg_from_buffer: bool = True + command_rsp: str = "" # the raw string containing the commands + commands: list[dict] = [] # commands to be executed + memory_k: int = 200 # number of memories (messages) to use as historical context + use_fixed_sop: bool = False + respond_language: str = "" # Language for responding humans and publishing messages. + use_summary: bool = True # whether to summarize at the end + + @model_validator(mode="after") + def set_plan_and_tool(self) -> "RoleZero": + # We force using this parameter for DataAnalyst + assert self.react_mode == "react" + + # Roughly the same part as DataInterpreter.set_plan_and_tool + self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop) + if self.tools and not self.tool_recommender: + self.tool_recommender = BM25ToolRecommender(tools=self.tools, force=True) + self.set_actions([RunCommand]) + + # HACK: Init Planner, control it through dynamic thinking; Consider formalizing as a react mode + self.planner = Planner(goal="", working_memory=self.rc.working_memory, auto_run=True) + + return self + + @model_validator(mode="after") + def set_tool_execution(self) -> "RoleZero": + # default map + self.tool_execution_map = { + "Plan.append_task": self.planner.plan.append_task, + "Plan.reset_task": self.planner.plan.reset_task, + "Plan.replace_task": self.planner.plan.replace_task, + "RoleZero.ask_human": self.ask_human, + "RoleZero.reply_to_human": self.reply_to_human, + "SearchEnhancedQA.run": SearchEnhancedQA().run, + } + self.tool_execution_map.update( + { + f"Browser.{i}": getattr(self.browser, i) + for i in [ + "click", + "close_tab", + "go_back", + "go_forward", + "goto", + "hover", + "press", + "scroll", + "tab_focus", + "type", + ] + } + ) + self.tool_execution_map.update( + { + f"Editor.{i}": getattr(self.editor, i) + for i in [ + "append_file", + "create_file", + "edit_file_by_replace", + "find_file", + "goto_line", + "insert_content_at_line", + "open_file", + "read", + "scroll_down", + "scroll_up", + "search_dir", + "search_file", + "similarity_search", + # "set_workdir", + "write", + ] + } + ) + # can be updated by subclass + self._update_tool_execution() + return self + + @model_validator(mode="after") + def set_longterm_memory(self) -> "RoleZero": + """Set up long-term memory for the role if enabled in the configuration. + + If `enable_longterm_memory` is True, set up long-term memory. + The role name will be used as the collection name. + """ + + if self.config.role_zero.enable_longterm_memory: + # Use config.role_zero to initialize long-term memory + self.rc.memory = RoleZeroLongTermMemory( + **self.rc.memory.model_dump(), + persist_path=self.config.role_zero.longterm_memory_persist_path, + collection_name=self.name.replace(" ", ""), + memory_k=self.config.role_zero.memory_k, + similarity_top_k=self.config.role_zero.similarity_top_k, + use_llm_ranker=self.config.role_zero.use_llm_ranker, + ) + logger.info(f"Long-term memory set for role '{self.name}'") + + return self + + def _update_tool_execution(self): + pass + + async def _think(self) -> bool: + """Useful in 'react' mode. Use LLM to decide whether and what to do next.""" + # Compatibility + if self.use_fixed_sop: + return await super()._think() + + ### 0. Preparation ### + if not self.rc.todo: + return False + + if not self.planner.plan.goal: + self.planner.plan.goal = self.get_memories()[-1].content + detect_language_prompt = DETECT_LANGUAGE_PROMPT.format(requirement=self.planner.plan.goal) + self.respond_language = await self.llm.aask(detect_language_prompt) + ### 1. Experience ### + example = self._retrieve_experience() + + ### 2. Plan Status ### + plan_status, current_task = self._get_plan_status() + + ### 3. Tool/Command Info ### + tools = await self.tool_recommender.recommend_tools() + tool_info = json.dumps({tool.name: tool.schemas for tool in tools}) + + ### Role Instruction ### + instruction = self.instruction.strip() + system_prompt = self.system_prompt.format( + role_info=self._get_prefix(), + task_type_desc=self.task_type_desc, + available_commands=tool_info, + example=example, + instruction=instruction, + ) + + ### Make Decision Dynamically ### + prompt = self.cmd_prompt.format( + current_state=self.cmd_prompt_current_state, + plan_status=plan_status, + current_task=current_task, + respond_language=self.respond_language, + ) + + ### Recent Observation ### + memory = self.rc.memory.get(self.memory_k) + memory = await self.parse_browser_actions(memory) + memory = await self.parse_editor_result(memory) + memory = self.parse_images(memory) + + req = self.llm.format_msg(memory + [UserMessage(content=prompt)]) + state_data = dict( + plan_status=plan_status, + current_task=current_task, + instruction=instruction, + ) + async with ThoughtReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "react"}) + self.command_rsp = await self.llm_cached_aask(req=req, system_msgs=[system_prompt], state_data=state_data) + self.command_rsp = await self._check_duplicates(req, self.command_rsp) + return True + + @exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer()) + async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str], **kwargs) -> str: + """Use `exp_cache` to automatically manage experiences. + + The `RoleZeroContextBuilder` attempts to add experiences to `req`. + The `RoleZeroSerializer` extracts essential parts of `req` for the experience pool, trimming lengthy entries to retain only necessary parts. + """ + return await self.llm.aask(req, system_msgs=system_msgs) + + async def parse_browser_actions(self, memory: list[Message]) -> list[Message]: + if not self.browser.is_empty_page: + pattern = re.compile(r"Command Browser\.(\w+) executed") + for index, msg in zip(range(len(memory), 0, -1), memory[::-1]): + if pattern.search(msg.content): + memory.insert(index, UserMessage(cause_by="browser", content=await self.browser.view())) + break + return memory + + async def parse_editor_result(self, memory: list[Message], keep_latest_count=5) -> list[Message]: + """Retain the latest result and remove outdated editor results.""" + pattern = re.compile(r"Command Editor\.(\w+?) executed") + new_memory = [] + i = 0 + for msg in reversed(memory): + matches = pattern.findall(msg.content) + if matches: + i += 1 + if i > keep_latest_count: + new_content = msg.content[: msg.content.find("Command Editor")] + new_content += "\n".join([f"Command Editor.{match} executed." for match in matches]) + msg = UserMessage(content=new_content) + new_memory.append(msg) + # Reverse the new memory list so the latest message is at the end + new_memory.reverse() + return new_memory + + def parse_images(self, memory: list[Message]) -> list[Message]: + if not self.llm.support_image_input(): + return memory + for msg in memory: + if IMAGES in msg.metadata or msg.role != "user": + continue + images = extract_and_encode_images(msg.content) + if images: + msg.add_metadata(IMAGES, images) + return memory + + def _get_prefix(self) -> str: + time_info = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + return super()._get_prefix() + f" The current time is {time_info}." + + async def _act(self) -> Message: + if self.use_fixed_sop: + return await super()._act() + + commands, ok, self.command_rsp = await self._parse_commands(self.command_rsp) + self.rc.memory.add(AIMessage(content=self.command_rsp)) + if not ok: + error_msg = commands + self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand)) + return error_msg + logger.info(f"Commands: \n{commands}") + outputs = await self._run_commands(commands) + logger.info(f"Commands outputs: \n{outputs}") + self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand)) + + return AIMessage( + content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", + sent_from=self.name, + cause_by=RunCommand, + ) + + async def _react(self) -> Message: + # NOTE: Diff 1: Each time landing here means news is observed, set todo to allow news processing in _think + self._set_state(0) + + # problems solvable by quick thinking doesn't need to a formal think-act cycle + quick_rsp, _ = await self._quick_think() + if quick_rsp: + return quick_rsp + + actions_taken = 0 + rsp = AIMessage(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act + while actions_taken < self.rc.max_react_loop: + # NOTE: Diff 2: Keep observing within _react, news will go into memory, allowing adapting to new info + await self._observe() + + # think + has_todo = await self._think() + if not has_todo: + break + # act + logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}") + rsp = await self._act() + actions_taken += 1 + + # post-check + if self.rc.max_react_loop >= 10 and actions_taken >= self.rc.max_react_loop: + # If max_react_loop is a small value (e.g. < 10), it is intended to be reached and make the agent stop + logger.warning(f"reached max_react_loop: {actions_taken}") + human_rsp = await self.ask_human( + "I have reached my max action rounds, do you want me to continue? Yes or no" + ) + if "yes" in human_rsp.lower(): + actions_taken = 0 + return rsp # return output from the last action + + def format_quick_system_prompt(self) -> str: + """Format the system prompt for quick thinking.""" + return QUICK_THINK_SYSTEM_PROMPT.format(examples=QUICK_THINK_EXAMPLES, role_info=self._get_prefix()) + + async def _quick_think(self) -> Tuple[Message, str]: + answer = "" + rsp_msg = None + if self.rc.news[-1].cause_by != any_to_str(UserRequirement): + # Agents themselves won't generate quick questions, use this rule to reduce extra llm calls + return rsp_msg, "" + + # routing + memory = self.get_memories(k=self.memory_k) + context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) + async with ThoughtReporter() as reporter: + await reporter.async_report({"type": "classify"}) + intent_result = await self.llm.aask(context, system_msgs=[self.format_quick_system_prompt()]) + + if "QUICK" in intent_result or "AMBIGUOUS" in intent_result: # llm call with the original context + async with ThoughtReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "quick"}) + answer = await self.llm.aask( + self.llm.format_msg(memory), + system_msgs=[QUICK_RESPONSE_SYSTEM_PROMPT.format(role_info=self._get_prefix())], + ) + # If the answer contains the substring '[Message] from A to B:', remove it. + pattern = r"\[Message\] from .+? to .+?:\s*" + answer = re.sub(pattern, "", answer, count=1) + if "command_name" in answer: + # an actual TASK intent misclassified as QUICK, correct it here, FIXME: a better way is to classify it correctly in the first place + answer = "" + intent_result = "TASK" + elif "SEARCH" in intent_result: + query = "\n".join(str(msg) for msg in memory) + answer = await SearchEnhancedQA().run(query) + + if answer: + self.rc.memory.add(AIMessage(content=answer, cause_by=QUICK_THINK_TAG)) + await self.reply_to_human(content=answer) + rsp_msg = AIMessage( + content=answer, + sent_from=self.name, + cause_by=QUICK_THINK_TAG, + ) + + return rsp_msg, intent_result + + async def _check_duplicates(self, req: list[dict], command_rsp: str, check_window: int = 10): + past_rsp = [mem.content for mem in self.rc.memory.get(check_window)] + if command_rsp in past_rsp and '"command_name": "end"' not in command_rsp: + # Normal response with thought contents are highly unlikely to reproduce + # If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content + # In this case, ask human for help and regenerate + # TODO: switch to llm_cached_aask + + # Hard rule to ask human for help + if past_rsp.count(command_rsp) >= 3: + if '"command_name": "Plan.finish_current_task",' in command_rsp: + # Detect the duplicate of the 'Plan.finish_current_task' command, and use the 'end' command to finish the task. + logger.warning(f"Duplicate response detected: {command_rsp}") + return END_COMMAND + problem = await self.llm.aask( + req + [UserMessage(content=SUMMARY_PROBLEM_WHEN_DUPLICATE.format(language=self.respond_language))] + ) + ASK_HUMAN_COMMAND[0]["args"]["question"] = ASK_HUMAN_GUIDANCE_FORMAT.format(problem=problem).strip() + ask_human_command = "```json\n" + json.dumps(ASK_HUMAN_COMMAND, indent=4, ensure_ascii=False) + "\n```" + return ask_human_command + # Try correction by self + logger.warning(f"Duplicate response detected: {command_rsp}") + regenerate_req = req + [UserMessage(content=REGENERATE_PROMPT)] + regenerate_req = self.llm.format_msg(regenerate_req) + command_rsp = await self.llm.aask(regenerate_req) + return command_rsp + + async def _parse_commands(self, command_rsp) -> Tuple[List[Dict], bool]: + """Retrieves commands from the Large Language Model (LLM). + + This function attempts to retrieve a list of commands from the LLM by + processing the response (`self.command_rsp`). It handles potential errors + during parsing and LLM response formats. + + Returns: + A tuple containing: + - A boolean flag indicating success (True) or failure (False). + """ + try: + commands = CodeParser.parse_code(block=None, lang="json", text=command_rsp) + if commands.endswith("]") and not commands.startswith("["): + commands = "[" + commands + commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON)) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse JSON for: {command_rsp}. Trying to repair...") + commands = await self.llm.aask( + msg=JSON_REPAIR_PROMPT.format(json_data=command_rsp, json_decode_error=str(e)) + ) + try: + commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands)) + except json.JSONDecodeError: + # repair escape error of code and math + commands = CodeParser.parse_code(block=None, lang="json", text=command_rsp) + new_command = repair_escape_error(commands) + commands = json.loads( + repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON) + ) + except Exception as e: + tb = traceback.format_exc() + print(tb) + error_msg = str(e) + return error_msg, False, command_rsp + + # 为了对LLM不按格式生成进行容错 + if isinstance(commands, dict): + commands = commands["commands"] if "commands" in commands else [commands] + + # Set the exclusive command flag to False. + command_flag = [command["command_name"] not in self.exclusive_tool_commands for command in commands] + if command_flag.count(False) > 1: + # Keep only the first exclusive command + index_of_first_exclusive = command_flag.index(False) + commands = commands[: index_of_first_exclusive + 1] + command_rsp = "```json\n" + json.dumps(commands, indent=4, ensure_ascii=False) + "\n```" + logger.info( + "exclusive command more than one in current command list. change the command list.\n" + command_rsp + ) + return commands, True, command_rsp + + async def _run_commands(self, commands) -> str: + outputs = [] + for cmd in commands: + output = f"Command {cmd['command_name']} executed" + # handle special command first + if self._is_special_command(cmd): + special_command_output = await self._run_special_command(cmd) + outputs.append(output + ":" + special_command_output) + continue + # run command as specified by tool_execute_map + if cmd["command_name"] in self.tool_execution_map: + tool_obj = self.tool_execution_map[cmd["command_name"]] + try: + if inspect.iscoroutinefunction(tool_obj): + tool_output = await tool_obj(**cmd["args"]) + else: + tool_output = tool_obj(**cmd["args"]) + if tool_output: + output += f": {str(tool_output)}" + outputs.append(output) + except Exception as e: + tb = traceback.format_exc() + logger.exception(str(e) + tb) + outputs.append(output + f": {tb}") + break # Stop executing if any command fails + else: + outputs.append(f"Command {cmd['command_name']} not found.") + break + outputs = "\n\n".join(outputs) + + return outputs + + def _is_special_command(self, cmd) -> bool: + return cmd["command_name"] in self.special_tool_commands + + async def _run_special_command(self, cmd) -> str: + """command requiring special check or parsing""" + command_output = "" + + if cmd["command_name"] == "Plan.finish_current_task": + if not self.planner.plan.is_plan_finished(): + self.planner.plan.finish_current_task() + command_output = ( + "Current task is finished. If you no longer need to take action, use the command ‘end’ to stop." + ) + + elif cmd["command_name"] == "end": + command_output = await self._end() + elif cmd["command_name"] == "RoleZero.ask_human": + human_response = await self.ask_human(**cmd["args"]) + if human_response.strip().lower().endswith(("stop", "")): + human_response += "The user has asked me to stop because I have encountered a problem." + self.rc.memory.add(UserMessage(content=human_response, cause_by=RunCommand)) + end_output = "\nCommand end executed:" + end_output += await self._end() + return end_output + return human_response + # output from bash.run may be empty, add decorations to the output to ensure visibility. + elif cmd["command_name"] == "Terminal.run_command": + tool_obj = self.tool_execution_map[cmd["command_name"]] + tool_output = await tool_obj(**cmd["args"]) + if len(tool_output) <= 10: + command_output += ( + f"\n[command]: {cmd['args']['cmd']} \n[command output] : {tool_output} (pay attention to this.)" + ) + else: + command_output += f"\n[command]: {cmd['args']['cmd']} \n[command output] : {tool_output}" + + return command_output + + def _get_plan_status(self) -> Tuple[str, str]: + plan_status = self.planner.plan.model_dump(include=["goal", "tasks"]) + current_task = ( + self.planner.plan.current_task.model_dump(exclude=["code", "result", "is_success"]) + if self.planner.plan.current_task + else "" + ) + # format plan status + # Example: + # [GOAL] create a 2048 game + # [TASK_ID 1] (finished) Create a Product Requirement Document (PRD) for the 2048 game. This task depends on tasks[]. [Assign to Alice] + # [TASK_ID 2] ( ) Design the system architecture for the 2048 game. This task depends on tasks[1]. [Assign to Bob] + formatted_plan_status = f"[GOAL] {plan_status['goal']}\n" + if len(plan_status["tasks"]) > 0: + formatted_plan_status += "[Plan]\n" + for task in plan_status["tasks"]: + formatted_plan_status += f"[TASK_ID {task['task_id']}] ({'finished' if task['is_finished'] else ' '}){task['instruction']} This task depends on tasks{task['dependent_task_ids']}. [Assign to {task['assignee']}]\n" + else: + formatted_plan_status += "No Plan \n" + return formatted_plan_status, current_task + + def _retrieve_experience(self) -> str: + """Default implementation of experience retrieval. Can be overwritten in subclasses.""" + context = [str(msg) for msg in self.rc.memory.get(self.memory_k)] + context = "\n\n".join(context) + example = self.experience_retriever.retrieve(context=context) + return example + + async def ask_human(self, question: str) -> str: + """Use this when you fail the current task or if you are unsure of the situation encountered. Your response should contain a brief summary of your situation, ended with a clear and concise question.""" + # NOTE: Can be overwritten in remote setting + from metagpt.environment.mgx.mgx_env import MGXEnv # avoid circular import + + if not isinstance(self.rc.env, MGXEnv): + return "Not in MGXEnv, command will not be executed." + return await self.rc.env.ask_human(question, sent_from=self) + + async def reply_to_human(self, content: str) -> str: + """Reply to human user with the content provided. Use this when you have a clear answer or solution to the user's question.""" + # NOTE: Can be overwritten in remote setting + from metagpt.environment.mgx.mgx_env import MGXEnv # avoid circular import + + if not isinstance(self.rc.env, MGXEnv): + return "Not in MGXEnv, command will not be executed." + return await self.rc.env.reply_to_human(content, sent_from=self) + + async def _end(self, **kwarg): + self._set_state(-1) + memory = self.rc.memory.get(self.memory_k) + # Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking. + if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]): + logger.info("manually reply to human") + reply_to_human_prompt = REPORT_TO_HUMAN_PROMPT.format(respond_language=self.respond_language) + async with ThoughtReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "quick"}) + reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)])) + await self.reply_to_human(content=reply_content) + self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand)) + outputs = "" + # Summary of the Completed Task and Deliverables + if self.use_summary: + logger.info("end current run and summarize") + outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)])) + return outputs diff --git a/metagpt/roles/di/swe_agent.py b/metagpt/roles/di/swe_agent.py new file mode 100644 index 0000000000..731b00b0b4 --- /dev/null +++ b/metagpt/roles/di/swe_agent.py @@ -0,0 +1,83 @@ +import json + +from pydantic import Field + +from metagpt.logs import logger +from metagpt.prompts.di.swe_agent import ( + CURRENT_BASH_STATE, + MINIMAL_EXAMPLE, + NEXT_STEP_TEMPLATE, +) +from metagpt.roles.di.role_zero import RoleZero +from metagpt.schema import Message +from metagpt.tools.libs.git import git_create_pull +from metagpt.tools.libs.terminal import Bash + + +class SWEAgent(RoleZero): + name: str = "Swen" + profile: str = "Issue Solver" + goal: str = "Resolve GitHub issue or bug in any existing codebase" + _instruction: str = NEXT_STEP_TEMPLATE + tools: list[str] = [ + "Bash", + "Browser:goto,scroll", + "RoleZero", + "git_create_pull", + ] + terminal: Bash = Field(default_factory=Bash, exclude=True) + output_diff: str = "" + max_react_loop: int = 40 + run_eval: bool = False + + async def _think(self) -> bool: + await self._format_instruction() + res = await super()._think() + return res + + def _update_tool_execution(self): + self.tool_execution_map.update( + { + "Bash.run": self.terminal.run, + "git_create_pull": git_create_pull, + } + ) + + async def _format_instruction(self): + """ + Formats the instruction message for the SWE agent. + Runs the "state" command in the terminal, parses its output as JSON, + and uses it to format the `_instruction` template. + """ + state_output = await self.terminal.run("state") + bash_state = json.loads(state_output) + self.cmd_prompt_current_state = CURRENT_BASH_STATE.format(**bash_state).strip() + + async def _act(self) -> Message: + message = await super()._act() + if self.run_eval: + self._parse_commands_for_eval() + return message + + async def _parse_commands_for_eval(self): + """ + Handles actions based on parsed commands. + Parses commands, checks for a "submit" action, and generates a patch using `git diff`. + Stores the cleaned patch in `output_diff`. Logs any exceptions. + This function is specifically added for SWE bench evaluation. + """ + # If todo switches to None, it indicates that this is the final round of reactions, and the Swe-Agent will stop. Use git diff to store any changes made. + if not self.rc.todo: + from metagpt.tools.swe_agent_commands.swe_agent_utils import extract_patch + + try: + diff_output = await self.terminal.run("git diff --cached") + clear_diff = extract_patch(diff_output) + logger.info(f"Diff output: \n{clear_diff}") + if clear_diff: + self.output_diff = clear_diff + except Exception as e: + logger.error(f"Error during submission: {e}") + + def _retrieve_experience(self) -> str: + return MINIMAL_EXAMPLE diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py new file mode 100644 index 0000000000..7a8b8b5bec --- /dev/null +++ b/metagpt/roles/di/team_leader.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import Annotated + +from pydantic import Field + +from metagpt.actions.di.run_command import RunCommand +from metagpt.const import TEAMLEADER_NAME +from metagpt.prompts.di.role_zero import QUICK_THINK_TAG +from metagpt.prompts.di.team_leader import ( + FINISH_CURRENT_TASK_CMD, + TL_INFO, + TL_INSTRUCTION, + TL_THOUGHT_GUIDANCE, +) +from metagpt.roles.di.role_zero import RoleZero +from metagpt.schema import AIMessage, Message, UserMessage +from metagpt.strategy.experience_retriever import ExpRetriever, SimpleExpRetriever +from metagpt.tools.tool_registry import register_tool + + +@register_tool(include_functions=["publish_team_message"]) +class TeamLeader(RoleZero): + name: str = TEAMLEADER_NAME + profile: str = "Team Leader" + goal: str = "Manage a team to assist users" + thought_guidance: str = TL_THOUGHT_GUIDANCE + # TeamLeader only reacts once each time, but may encounter errors or need to ask human, thus allowing 2 more turns + max_react_loop: int = 3 + + tools: list[str] = ["Plan", "RoleZero", "TeamLeader"] + + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = SimpleExpRetriever() + + use_summary: bool = False + + def _update_tool_execution(self): + self.tool_execution_map.update( + { + "TeamLeader.publish_team_message": self.publish_team_message, + "TeamLeader.publish_message": self.publish_team_message, # alias + } + ) + + def _get_team_info(self) -> str: + if not self.rc.env: + return "" + team_info = "" + for role in self.rc.env.roles.values(): + # if role.profile == "Team Leader": + # continue + team_info += f"{role.name}: {role.profile}, {role.goal}\n" + return team_info + + def _get_prefix(self) -> str: + role_info = super()._get_prefix() + team_info = self._get_team_info() + return TL_INFO.format(role_info=role_info, team_info=team_info) + + async def _think(self) -> bool: + self.instruction = TL_INSTRUCTION.format(team_info=self._get_team_info()) + return await super()._think() + + def publish_message(self, msg: Message, send_to="no one"): + """Overwrite Role.publish_message, send to no one if called within Role.run (except for quick think), send to the specified role if called dynamically.""" + if not msg: + return + if not self.rc.env: + # If env does not exist, do not publish the message + return + if msg.cause_by != QUICK_THINK_TAG: + msg.send_to = send_to + self.rc.env.publish_message(msg, publicer=self.profile) + + def publish_team_message(self, content: str, send_to: str): + """ + Publish a message to a team member, use member name to fill send_to args. You may copy the full original content or add additional information from upstream. This will make team members start their work. + DONT omit any necessary info such as path, link, environment, programming language, framework, requirement, constraint from original content to team members because you are their sole info source. + """ + self._set_state(-1) # each time publishing a message, pause to wait for the response + if send_to == self.name: + return # Avoid sending message to self + # Specify the outer send_to to overwrite the default "no one" value. Use UserMessage because message from self is like a user request for others. + self.publish_message( + UserMessage(content=content, sent_from=self.name, send_to=send_to, cause_by=RunCommand), send_to=send_to + ) + + def finish_current_task(self): + self.planner.plan.finish_current_task() + self.rc.memory.add(AIMessage(content=FINISH_CURRENT_TASK_CMD)) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 6962b1bb5b..d9d0e8b70b 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -22,16 +22,19 @@ import json from collections import defaultdict from pathlib import Path -from typing import Optional, Set +from typing import List, Optional, Set -from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks +from pydantic import BaseModel, Field + +from metagpt.actions import WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug +from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST from metagpt.actions.summarize_code import SummarizeCode from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange from metagpt.const import ( - BUGFIX_FILENAME, CODE_PLAN_AND_CHANGE_FILE_REPO, + MESSAGE_ROUTE_TO_SELF, REQUIREMENT_FILENAME, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, @@ -39,6 +42,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import ( + AIMessage, CodePlanAndChangeContext, CodeSummarizeContext, CodingContext, @@ -46,7 +50,15 @@ Documents, Message, ) -from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set +from metagpt.utils.common import ( + any_to_name, + any_to_str, + any_to_str_set, + get_project_srcs_path, + init_python_folder, +) +from metagpt.utils.git_repository import ChangeType +from metagpt.utils.project_repo import ProjectRepo IS_PASS_PROMPT = """ {context} @@ -84,10 +96,12 @@ class Engineer(Role): summarize_todos: list = [] next_todo_action: str = "" n_summarize: int = 0 + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - + self.enable_memory = False self.set_actions([WriteCode]) self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug, WriteCodePlanAndChange]) self.code_todos = [] @@ -112,26 +126,24 @@ async def _act_sp_with_cr(self, review=False) -> Set[str]: coding_context = await todo.run() # Code review if review: - action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm) + action = WriteCodeReview( + i_context=coding_context, + repo=self.repo, + input_args=self.input_args, + context=self.context, + llm=self.llm, + ) self._init_action(action) coding_context = await action.run() dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path} if self.config.inc: dependencies.add(coding_context.code_plan_and_change_doc.root_relative_path) - await self.project_repo.srcs.save( + await self.repo.srcs.save( filename=coding_context.filename, dependencies=list(dependencies), content=coding_context.code_doc.content, ) - msg = Message( - content=coding_context.model_dump_json(), - instruct_content=coding_context, - role=self.profile, - cause_by=WriteCode, - ) - self.rc.memory.add(msg) - changed_files.add(coding_context.code_doc.filename) if not changed_files: logger.info("Nothing has changed.") @@ -150,28 +162,26 @@ async def _act(self) -> Message | None: if isinstance(self.rc.todo, SummarizeCode): self.next_todo_action = any_to_name(WriteCode) return await self._act_summarize() - return None + return await self.rc.todo.run(self.rc.history) async def _act_write_code(self): - changed_files = await self._act_sp_with_cr(review=self.use_code_review) - return Message( - content="\n".join(changed_files), - role=self.profile, - cause_by=WriteCodeReview if self.use_code_review else WriteCode, - send_to=self, - sent_from=self, + await self._act_sp_with_cr(review=self.use_code_review) + return AIMessage( + content="", cause_by=WriteCodeReview if self.use_code_review else WriteCode, send_to=MESSAGE_ROUTE_TO_SELF ) async def _act_summarize(self): tasks = [] for todo in self.summarize_todos: + if self.n_summarize >= self.config.max_auto_summarize_code: + break summary = await todo.run() summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name dependencies = {todo.i_context.design_filename, todo.i_context.task_filename} for filename in todo.i_context.codes_filenames: - rpath = self.project_repo.src_relative_path / filename + rpath = self.repo.src_relative_path / filename dependencies.add(str(rpath)) - await self.project_repo.resources.code_summary.save( + await self.repo.resources.code_summary.save( filename=summary_filename, content=summary, dependencies=dependencies ) is_pass, reason = await self._is_pass(summary) @@ -179,29 +189,46 @@ async def _act_summarize(self): todo.i_context.reason = reason tasks.append(todo.i_context.model_dump()) - await self.project_repo.docs.code_summary.save( + await self.repo.docs.code_summary.save( filename=Path(todo.i_context.design_filename).name, content=todo.i_context.model_dump_json(), dependencies=dependencies, ) else: - await self.project_repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name) - + await self.repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name) + self.summarize_todos = [] logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}") if not tasks or self.config.max_auto_summarize_code == 0: - return Message( - content="", - role=self.profile, + self.n_summarize = 0 + kvs = self.input_args.model_dump() + kvs["changed_src_filenames"] = [ + str(self.repo.srcs.workdir / i) for i in list(self.repo.srcs.changed_files.keys()) + ] + if self.repo.docs.code_plan_and_change.changed_files: + kvs["changed_code_plan_and_change_filenames"] = [ + str(self.repo.docs.code_plan_and_change.workdir / i) + for i in list(self.repo.docs.code_plan_and_change.changed_files.keys()) + ] + if self.repo.docs.code_summary.changed_files: + kvs["changed_code_summary_filenames"] = [ + str(self.repo.docs.code_summary.workdir / i) + for i in list(self.repo.docs.code_summary.changed_files.keys()) + ] + return AIMessage( + content=f"Coding is complete. The source code is at {self.repo.workdir.name}/{self.repo.srcs.root_path}, containing: " + + "\n".join( + list(self.repo.resources.code_summary.changed_files.keys()) + + list(self.repo.srcs.changed_files.keys()) + + list(self.repo.resources.code_plan_and_change.changed_files.keys()) + ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="SummarizeCodeOutput"), cause_by=SummarizeCode, - sent_from=self, send_to="Edward", # The name of QaEngineer ) # The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited. # This parameter is used for debugging the workflow. self.n_summarize += 1 if self.config.max_auto_summarize_code > self.n_summarize else 0 - return Message( - content=json.dumps(tasks), role=self.profile, cause_by=SummarizeCode, send_to=self, sent_from=self - ) + return AIMessage(content="", cause_by=SummarizeCode, send_to=MESSAGE_ROUTE_TO_SELF) async def _act_code_plan_and_change(self): """Write code plan and change that guides subsequent WriteCode and WriteCodeReview""" @@ -209,27 +236,21 @@ async def _act_code_plan_and_change(self): code_plan_and_change = node.instruct_content.model_dump_json() dependencies = { REQUIREMENT_FILENAME, - str(self.project_repo.docs.prd.root_path / self.rc.todo.i_context.prd_filename), - str(self.project_repo.docs.system_design.root_path / self.rc.todo.i_context.design_filename), - str(self.project_repo.docs.task.root_path / self.rc.todo.i_context.task_filename), + str(Path(self.rc.todo.i_context.prd_filename).relative_to(self.repo.workdir)), + str(Path(self.rc.todo.i_context.design_filename).relative_to(self.repo.workdir)), + str(Path(self.rc.todo.i_context.task_filename).relative_to(self.repo.workdir)), } code_plan_and_change_filepath = Path(self.rc.todo.i_context.design_filename) - await self.project_repo.docs.code_plan_and_change.save( + await self.repo.docs.code_plan_and_change.save( filename=code_plan_and_change_filepath.name, content=code_plan_and_change, dependencies=dependencies ) - await self.project_repo.resources.code_plan_and_change.save( + await self.repo.resources.code_plan_and_change.save( filename=code_plan_and_change_filepath.with_suffix(".md").name, content=node.content, dependencies=dependencies, ) - return Message( - content=code_plan_and_change, - role=self.profile, - cause_by=WriteCodePlanAndChange, - send_to=self, - sent_from=self, - ) + return AIMessage(content="", cause_by=WriteCodePlanAndChange, send_to=MESSAGE_ROUTE_TO_SELF) async def _is_pass(self, summary) -> (str, str): rsp = await self.llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False) @@ -238,45 +259,52 @@ async def _is_pass(self, summary) -> (str, str): return True, rsp return False, rsp - async def _think(self) -> Action | None: - if not self.src_workspace: - self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name - write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug]) - write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode]) - summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview]) + async def _think(self) -> bool: if not self.rc.news: - return None + return False msg = self.rc.news[0] + input_args = msg.instruct_content + if msg.cause_by in {any_to_str(WriteTasks), any_to_str(FixBug)}: + self.input_args = input_args + self.repo = ProjectRepo(input_args.project_path) + if self.repo.src_relative_path is None: + path = get_project_srcs_path(self.repo.workdir) + self.repo.with_src_path(path) + write_plan_and_change_filters = any_to_str_set([PrepareDocuments, WriteTasks, FixBug]) + write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode]) + summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview]) if self.config.inc and msg.cause_by in write_plan_and_change_filters: logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}") await self._new_code_plan_and_change_action(cause_by=msg.cause_by) - return self.rc.todo + return bool(self.rc.todo) if msg.cause_by in write_code_filters: logger.debug(f"TODO WriteCode:{msg.model_dump_json()}") await self._new_code_actions() - return self.rc.todo + return bool(self.rc.todo) if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self): logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}") await self._new_summarize_actions() - return self.rc.todo - return None + return bool(self.rc.todo) + return False - async def _new_coding_context(self, filename, dependency) -> CodingContext: - old_code_doc = await self.project_repo.srcs.get(filename) + async def _new_coding_context(self, filename, dependency) -> Optional[CodingContext]: + old_code_doc = await self.repo.srcs.get(filename) if not old_code_doc: - old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="") + old_code_doc = Document(root_path=str(self.repo.src_relative_path), filename=filename, content="") dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)} task_doc = None design_doc = None code_plan_and_change_doc = await self._get_any_code_plan_and_change() if await self._is_fixbug() else None for i in dependencies: if str(i.parent) == TASK_FILE_REPO: - task_doc = await self.project_repo.docs.task.get(i.name) + task_doc = await self.repo.docs.task.get(i.name) elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO: - design_doc = await self.project_repo.docs.system_design.get(i.name) + design_doc = await self.repo.docs.system_design.get(i.name) elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO: - code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name) + code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(i.name) if not task_doc or not design_doc: + if filename == "__init__.py": # `__init__.py` created by `init_python_folder` + return None logger.error(f'Detected source code "{filename}" from an unknown origin.') raise ValueError(f'Detected source code "{filename}" from an unknown origin.') context = CodingContext( @@ -288,30 +316,71 @@ async def _new_coding_context(self, filename, dependency) -> CodingContext: ) return context - async def _new_coding_doc(self, filename, dependency): + async def _new_coding_doc(self, filename, dependency) -> Optional[Document]: context = await self._new_coding_context(filename, dependency) + if not context: + return None # `__init__.py` created by `init_python_folder` coding_doc = Document( - root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json() + root_path=str(self.repo.src_relative_path), filename=filename, content=context.model_dump_json() ) return coding_doc async def _new_code_actions(self): bug_fix = await self._is_fixbug() # Prepare file repos - changed_src_files = self.project_repo.srcs.all_files if bug_fix else self.project_repo.srcs.changed_files - changed_task_files = self.project_repo.docs.task.changed_files + changed_src_files = self.repo.srcs.changed_files + if self.context.kwargs.src_filename: + changed_src_files = {self.context.kwargs.src_filename: ChangeType.UNTRACTED} + if bug_fix: + changed_src_files = self.repo.srcs.all_files changed_files = Documents() # Recode caused by upstream changes. - for filename in changed_task_files: - design_doc = await self.project_repo.docs.system_design.get(filename) - task_doc = await self.project_repo.docs.task.get(filename) - code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename) + if hasattr(self.input_args, "changed_task_filenames"): + changed_task_filenames = self.input_args.changed_task_filenames + else: + changed_task_filenames = [ + str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys()) + ] + for filename in changed_task_filenames: + task_filename = Path(filename) + design_filename = None + if hasattr(self.input_args, "changed_system_design_filenames"): + changed_system_design_filenames = self.input_args.changed_system_design_filenames + else: + changed_system_design_filenames = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] + for i in changed_system_design_filenames: + if task_filename.name == Path(i).name: + design_filename = Path(i) + break + code_plan_and_change_filename = None + if hasattr(self.input_args, "changed_code_plan_and_change_filenames"): + changed_code_plan_and_change_filenames = self.input_args.changed_code_plan_and_change_filenames + else: + changed_code_plan_and_change_filenames = [ + str(self.repo.docs.code_plan_and_change.workdir / i) + for i in list(self.repo.docs.code_plan_and_change.changed_files.keys()) + ] + for i in changed_code_plan_and_change_filenames: + if task_filename.name == Path(i).name: + code_plan_and_change_filename = Path(i) + break + design_doc = await Document.load(filename=design_filename, project_path=self.repo.workdir) + task_doc = await Document.load(filename=task_filename, project_path=self.repo.workdir) + code_plan_and_change_doc = await Document.load( + filename=code_plan_and_change_filename, project_path=self.repo.workdir + ) task_list = self._parse_tasks(task_doc) + await self._init_python_folder(task_list) for task_filename in task_list: - old_code_doc = await self.project_repo.srcs.get(task_filename) + if self.context.kwargs.src_filename and task_filename != self.context.kwargs.src_filename: + continue + old_code_doc = await self.repo.srcs.get(task_filename) if not old_code_doc: old_code_doc = Document( - root_path=str(self.project_repo.src_relative_path), filename=task_filename, content="" + root_path=str(self.repo.src_relative_path), filename=task_filename, content="" ) if not code_plan_and_change_doc: context = CodingContext( @@ -326,7 +395,7 @@ async def _new_code_actions(self): code_plan_and_change_doc=code_plan_and_change_doc, ) coding_doc = Document( - root_path=str(self.project_repo.src_relative_path), + root_path=str(self.repo.src_relative_path), filename=task_filename, content=context.model_dump_json(), ) @@ -337,31 +406,42 @@ async def _new_code_actions(self): ) changed_files.docs[task_filename] = coding_doc self.code_todos = [ - WriteCode(i_context=i, context=self.context, llm=self.llm) for i in changed_files.docs.values() + WriteCode(i_context=i, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm) + for i in changed_files.docs.values() ] # Code directly modified by the user. - dependency = await self.git_repo.get_dependency() + dependency = await self.repo.git_repo.get_dependency() for filename in changed_src_files: if filename in changed_files.docs: continue coding_doc = await self._new_coding_doc(filename=filename, dependency=dependency) + if not coding_doc: + continue # `__init__.py` created by `init_python_folder` changed_files.docs[filename] = coding_doc - self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm)) + self.code_todos.append( + WriteCode( + i_context=coding_doc, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ) + ) if self.code_todos: self.set_todo(self.code_todos[0]) async def _new_summarize_actions(self): - src_files = self.project_repo.srcs.all_files + src_files = self.repo.srcs.all_files # Generate a SummarizeCode action for each pair of (system_design_doc, task_doc). summarizations = defaultdict(list) for filename in src_files: - dependencies = await self.project_repo.srcs.get_dependency(filename=filename) + dependencies = await self.repo.srcs.get_dependency(filename=filename) ctx = CodeSummarizeContext.loads(filenames=list(dependencies)) summarizations[ctx].append(filename) for ctx, filenames in summarizations.items(): + if not ctx.design_filename or not ctx.task_filename: + continue # cause by `__init__.py` which is created by `init_python_folder` ctx.codes_filenames = filenames - new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm) + new_summarize = SummarizeCode( + i_context=ctx, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ) for i, act in enumerate(self.summarize_todos): if act.i_context.task_filename == new_summarize.i_context.task_filename: self.summarize_todos[i] = new_summarize @@ -371,34 +451,63 @@ async def _new_summarize_actions(self): self.summarize_todos.append(new_summarize) if self.summarize_todos: self.set_todo(self.summarize_todos[0]) - self.summarize_todos.pop(0) async def _new_code_plan_and_change_action(self, cause_by: str): """Create a WriteCodePlanAndChange action for subsequent to-do actions.""" - files = self.project_repo.all_files options = {} if cause_by != any_to_str(FixBug): - requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME) + requirement_doc = await Document.load(filename=self.input_args.requirements_filename) options["requirement"] = requirement_doc.content else: - fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME) + fixbug_doc = await Document.load(filename=self.input_args.issue_filename) options["issue"] = fixbug_doc.content - code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, **options) - self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm) + # The code here is flawed: if there are multiple unrelated requirements, this piece of logic will break + if hasattr(self.input_args, "changed_prd_filenames"): + code_plan_and_change_ctx = CodePlanAndChangeContext( + requirement=options.get("requirement", ""), + issue=options.get("issue", ""), + prd_filename=self.input_args.changed_prd_filenames[0], + design_filename=self.input_args.changed_system_design_filenames[0], + task_filename=self.input_args.changed_task_filenames[0], + ) + else: + code_plan_and_change_ctx = CodePlanAndChangeContext( + requirement=options.get("requirement", ""), + issue=options.get("issue", ""), + prd_filename=str(self.repo.docs.prd.workdir / self.repo.docs.prd.all_files[0]), + design_filename=str(self.repo.docs.system_design.workdir / self.repo.docs.system_design.all_files[0]), + task_filename=str(self.repo.docs.task.workdir / self.repo.docs.task.all_files[0]), + ) + self.rc.todo = WriteCodePlanAndChange( + i_context=code_plan_and_change_ctx, + repo=self.repo, + input_args=self.input_args, + context=self.context, + llm=self.llm, + ) @property def action_description(self) -> str: """AgentStore uses this attribute to display to the user what actions the current role should take.""" return self.next_todo_action + async def _init_python_folder(self, task_list: List[str]): + for i in task_list: + filename = Path(i) + if filename.suffix != ".py": + continue + workdir = self.repo.srcs.workdir / filename.parent + if not workdir.exists(): + workdir = self.repo.workdir / filename.parent + await init_python_folder(workdir) + async def _is_fixbug(self) -> bool: - fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME) - return bool(fixbug_doc and fixbug_doc.content) + return bool(self.input_args and hasattr(self.input_args, "issue_filename")) async def _get_any_code_plan_and_change(self) -> Optional[Document]: - changed_files = self.project_repo.docs.code_plan_and_change.changed_files + changed_files = self.repo.docs.code_plan_and_change.changed_files for filename in changed_files.keys(): - doc = await self.project_repo.docs.code_plan_and_change.get(filename) + doc = await self.repo.docs.code_plan_and_change.get(filename) if doc and doc.content: return doc return None diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index fbe139a991..0f3613866f 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -4,16 +4,21 @@ @Time : 2023/5/11 14:43 @Author : alexanderwu @File : product_manager.py -@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. +@Modified By: liushaojie, 2024/10/17. """ - from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.roles.role import Role -from metagpt.utils.common import any_to_name +from metagpt.actions.search_enhanced_qa import SearchEnhancedQA +from metagpt.prompts.product_manager import PRODUCT_MANAGER_INSTRUCTION +from metagpt.roles.di.role_zero import RoleZero +from metagpt.roles.role import RoleReactMode +from metagpt.tools.libs.browser import Browser +from metagpt.tools.libs.editor import Editor +from metagpt.utils.common import any_to_name, any_to_str, tool2name +from metagpt.utils.git_repository import GitRepository -class ProductManager(Role): +class ProductManager(RoleZero): """ Represents a Product Manager role responsible for product development and management. @@ -26,26 +31,34 @@ class ProductManager(Role): name: str = "Alice" profile: str = "Product Manager" - goal: str = "efficiently create a successful product that meets market demands and user expectations" + goal: str = "Create a Product Requirement Document or market research/competitive product research." constraints: str = "utilize the same language as the user requirements for seamless communication" - todo_action: str = "" + instruction: str = PRODUCT_MANAGER_INSTRUCTION + tools: list[str] = ["RoleZero", Browser.__name__, Editor.__name__, SearchEnhancedQA.__name__] + + todo_action: str = any_to_name(WritePRD) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) + if self.use_fixed_sop: + self.enable_memory = False + self.set_actions([PrepareDocuments(send_to=any_to_str(self)), WritePRD]) + self._watch([UserRequirement, PrepareDocuments]) + self.rc.react_mode = RoleReactMode.BY_ORDER - self.set_actions([PrepareDocuments, WritePRD]) - self._watch([UserRequirement, PrepareDocuments]) - self.todo_action = any_to_name(PrepareDocuments) + def _update_tool_execution(self): + wp = WritePRD() + self.tool_execution_map.update(tool2name(WritePRD, ["run"], wp.run)) async def _think(self) -> bool: """Decide what to do""" - if self.git_repo and not self.config.git_reinit: + if not self.use_fixed_sop: + return await super()._think() + + if GitRepository.is_git_dir(self.config.project_path) and not self.config.git_reinit: self._set_state(1) else: self._set_state(0) self.config.git_reinit = False self.todo_action = any_to_name(WritePRD) return bool(self.rc.todo) - - async def _observe(self, ignore_memory=False) -> int: - return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 422d2889b3..cb0ead9dec 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -5,13 +5,12 @@ @Author : alexanderwu @File : project_manager.py """ - from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.roles.role import Role +from metagpt.roles.di.role_zero import RoleZero -class ProjectManager(Role): +class ProjectManager(RoleZero): """ Represents a Project Manager role responsible for overseeing project execution and team efficiency. @@ -30,8 +29,22 @@ class ProjectManager(Role): ) constraints: str = "use same language as user requirement" + instruction: str = """Use WriteTasks tool to write a project task list""" + max_react_loop: int = 1 # FIXME: Read and edit files requires more steps, consider later + tools: list[str] = ["Editor:write,read,similarity_search", "RoleZero", "WriteTasks"] + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - + # NOTE: The following init setting will only be effective when self.use_fixed_sop is changed to True + self.enable_memory = False self.set_actions([WriteTasks]) self._watch([WriteDesign]) + + def _update_tool_execution(self): + wt = WriteTasks() + self.tool_execution_map.update( + { + "WriteTasks.run": wt.run, + "WriteTasks": wt.run, # alias + } + ) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index c73c10ef35..fc8fa53534 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -14,14 +14,26 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ +from typing import Optional -from metagpt.actions import DebugError, RunCode, WriteTest +from pydantic import BaseModel, Field + +from metagpt.actions import DebugError, RunCode, UserRequirement, WriteTest +from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import MESSAGE_ROUTE_TO_NONE +from metagpt.const import MESSAGE_ROUTE_TO_NONE, MESSAGE_ROUTE_TO_SELF from metagpt.logs import logger from metagpt.roles import Role -from metagpt.schema import Document, Message, RunCodeContext, TestingContext -from metagpt.utils.common import any_to_str_set, parse_recipient +from metagpt.schema import AIMessage, Document, Message, RunCodeContext, TestingContext +from metagpt.utils.common import ( + any_to_str, + any_to_str_set, + get_project_srcs_path, + init_python_folder, + parse_recipient, +) +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import EditorReporter class QaEngineer(Role): @@ -34,9 +46,12 @@ class QaEngineer(Role): ) test_round_allowed: int = 5 test_round: int = 0 + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) def __init__(self, **kwargs): super().__init__(**kwargs) + self.enable_memory = False # FIXME: a bit hack here, only init one action to circumvent _think() logic, # will overwrite _think() in future updates @@ -45,67 +60,61 @@ def __init__(self, **kwargs): self.test_round = 0 async def _write_test(self, message: Message) -> None: - src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs - changed_files = set(src_file_repo.changed_files.keys()) - # Unit tests only. - if self.config.reqa_file and self.config.reqa_file not in changed_files: - changed_files.add(self.config.reqa_file) + reqa_file = self.context.kwargs.reqa_file or self.config.reqa_file + changed_files = {reqa_file} if reqa_file else set(self.repo.srcs.changed_files.keys()) for filename in changed_files: # write tests if not filename or "test" in filename: continue - code_doc = await src_file_repo.get(filename) - if not code_doc: + code_doc = await self.repo.srcs.get(filename) + if not code_doc or not code_doc.content: continue if not code_doc.filename.endswith(".py"): continue - test_doc = await self.project_repo.tests.get("test_" + code_doc.filename) + test_doc = await self.repo.tests.get("test_" + code_doc.filename) if not test_doc: test_doc = Document( - root_path=str(self.project_repo.tests.root_path), filename="test_" + code_doc.filename, content="" + root_path=str(self.repo.tests.root_path), filename="test_" + code_doc.filename, content="" ) logger.info(f"Writing {test_doc.filename}..") context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc) + context = await WriteTest(i_context=context, context=self.context, llm=self.llm).run() - await self.project_repo.tests.save_doc( - doc=context.test_doc, dependencies={context.code_doc.root_relative_path} - ) + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "test", "filename": test_doc.filename}, "meta") + + doc = await self.repo.tests.save_doc( + doc=context.test_doc, dependencies={context.code_doc.root_relative_path} + ) + await reporter.async_report(self.repo.workdir / doc.root_relative_path, "path") # prepare context for run tests in next round run_code_context = RunCodeContext( command=["python", context.test_doc.root_relative_path], code_filename=context.code_doc.filename, test_filename=context.test_doc.filename, - working_directory=str(self.project_repo.workdir), - additional_python_paths=[str(self.context.src_workspace)], + working_directory=str(self.repo.workdir), + additional_python_paths=[str(self.repo.srcs.workdir)], ) self.publish_message( - Message( - content=run_code_context.model_dump_json(), - role=self.profile, - cause_by=WriteTest, - sent_from=self, - send_to=self, - ) + AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF) ) - logger.info(f"Done {str(self.project_repo.tests.workdir)} generating.") + logger.info(f"Done {str(self.repo.tests.workdir)} generating.") async def _run_code(self, msg): run_code_context = RunCodeContext.loads(msg.content) - src_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get( - run_code_context.code_filename - ) + src_doc = await self.repo.srcs.get(run_code_context.code_filename) if not src_doc: return - test_doc = await self.project_repo.tests.get(run_code_context.test_filename) + test_doc = await self.repo.tests.get(run_code_context.test_filename) if not test_doc: return run_code_context.code = src_doc.content run_code_context.test_code = test_doc.content result = await RunCode(i_context=run_code_context, context=self.context, llm=self.llm).run() run_code_context.output_filename = run_code_context.test_filename + ".json" - await self.project_repo.test_outputs.save( + await self.repo.test_outputs.save( filename=run_code_context.output_filename, content=result.model_dump_json(), dependencies={src_doc.root_relative_path, test_doc.root_relative_path}, @@ -115,43 +124,58 @@ async def _run_code(self, msg): # the recipient might be Engineer or myself recipient = parse_recipient(result.summary) mappings = {"Engineer": "Alex", "QaEngineer": "Edward"} - self.publish_message( - Message( - content=run_code_context.model_dump_json(), - role=self.profile, - cause_by=RunCode, - sent_from=self, - send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE), + if recipient != "Engineer": + self.publish_message( + AIMessage( + content=run_code_context.model_dump_json(), + cause_by=RunCode, + instruct_content=self.input_args, + send_to=MESSAGE_ROUTE_TO_SELF, + ) + ) + else: + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] + self.publish_message( + AIMessage( + content=run_code_context.model_dump_json(), + cause_by=RunCode, + instruct_content=self.input_args, + send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE), + ) ) - ) async def _debug_error(self, msg): run_code_context = RunCodeContext.loads(msg.content) - code = await DebugError(i_context=run_code_context, context=self.context, llm=self.llm).run() - await self.project_repo.tests.save(filename=run_code_context.test_filename, content=code) + code = await DebugError( + i_context=run_code_context, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ).run() + await self.repo.tests.save(filename=run_code_context.test_filename, content=code) run_code_context.output = None self.publish_message( - Message( - content=run_code_context.model_dump_json(), - role=self.profile, - cause_by=DebugError, - sent_from=self, - send_to=self, - ) + AIMessage(content=run_code_context.model_dump_json(), cause_by=DebugError, send_to=MESSAGE_ROUTE_TO_SELF) ) async def _act(self) -> Message: + if self.input_args.project_path: + await init_python_folder(self.repo.tests.workdir) if self.test_round > self.test_round_allowed: - result_msg = Message( - content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)", - role=self.profile, + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] + result_msg = AIMessage( + content=f"Exceeding {self.test_round_allowed} rounds of tests, stop. " + + "\n".join(list(self.repo.tests.changed_files.keys())), cause_by=WriteTest, - sent_from=self.profile, + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"), send_to=MESSAGE_ROUTE_TO_NONE, ) return result_msg - code_filters = any_to_str_set({SummarizeCode}) + code_filters = any_to_str_set({PrepareDocuments, SummarizeCode}) test_filters = any_to_str_set({WriteTest, DebugError}) run_filters = any_to_str_set({RunCode}) for msg in self.rc.news: @@ -166,16 +190,42 @@ async def _act(self) -> Message: elif msg.cause_by in run_filters: # I ran my test code, time to fix bugs, if any await self._debug_error(msg) + elif msg.cause_by == any_to_str(UserRequirement): + return await self._parse_user_requirement(msg) self.test_round += 1 - return Message( + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] + return AIMessage( content=f"Round {self.test_round} of tests done", - role=self.profile, + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"), cause_by=WriteTest, - sent_from=self.profile, send_to=MESSAGE_ROUTE_TO_NONE, ) - async def _observe(self, ignore_memory=False) -> int: - # This role has events that trigger and execute themselves based on conditions, and cannot rely on the - # content of memory to activate. - return await super()._observe(ignore_memory=True) + async def _parse_user_requirement(self, msg: Message) -> AIMessage: + action = PrepareDocuments( + send_to=any_to_str(self), + key_descriptions={ + "project_path": 'the project path if exists in "Original Requirement"', + "reqa_file": 'the file name to rewrite unit test if exists in "Original Requirement"', + }, + context=self.context, + ) + rsp = await action.run([msg]) + if not self.src_workspace: + self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name + return rsp + + async def _think(self) -> bool: + if not self.rc.news: + return False + msg = self.rc.news[0] + if msg.cause_by == any_to_str(SummarizeCode): + self.input_args = msg.instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + if self.repo.src_relative_path is None: + path = get_project_srcs_path(self.repo.workdir) + self.repo.with_src_path(path) + return True diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e0f8a7ea69..1851dd20f7 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -23,27 +23,31 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Iterable, Optional, Set, Type, Union +from typing import Iterable, Optional, Set, Type, Union from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement +from metagpt.base import BaseEnvironment, BaseRole +from metagpt.const import MESSAGE_ROUTE_TO_SELF from metagpt.context_mixin import ContextMixin from metagpt.logs import logger from metagpt.memory import Memory from metagpt.provider import HumanProvider -from metagpt.schema import Message, MessageQueue, SerializationMixin +from metagpt.schema import ( + AIMessage, + Message, + MessageQueue, + SerializationMixin, + Task, + TaskResult, +) from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator -from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output -if TYPE_CHECKING: - from metagpt.environment import Environment # noqa: F401 - - PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """ CONSTRAINT_TEMPLATE = "the constraint is {constraints}. " @@ -91,7 +95,7 @@ class RoleContext(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` - env: "Environment" = Field(default=None, exclude=True) # # avoid circular import + env: BaseEnvironment = Field(default=None, exclude=True) # # avoid circular import # TODO judge if ser&deser msg_buffer: MessageQueue = Field( default_factory=MessageQueue, exclude=True @@ -117,14 +121,8 @@ def important_memory(self) -> list[Message]: def history(self) -> list[Message]: return self.memory.get() - @classmethod - def model_rebuild(cls, **kwargs): - from metagpt.environment.base_env import Environment # noqa: F401 - - super().model_rebuild(**kwargs) - -class Role(SerializationMixin, ContextMixin, BaseModel): +class Role(BaseRole, SerializationMixin, ContextMixin, BaseModel): """Role/Agent""" model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @@ -135,6 +133,9 @@ class Role(SerializationMixin, ContextMixin, BaseModel): constraints: str = "" desc: str = "" is_human: bool = False + enable_memory: bool = ( + True # Stateless, atomic roles, or roles that use external storage can disable this to save memory. + ) role_id: str = "" states: list[str] = [] @@ -153,6 +154,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): # builtin variables recovered: bool = False # to tag if a recovered role latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted + observe_all_msg_from_buffer: bool = False # whether to save all msgs from buffer to memory for role's awareness __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` @@ -170,7 +172,9 @@ def _process_role_extra(self): self._check_actions() self.llm.system_prompt = self._get_prefix() self.llm.cost_manager = self.context.cost_manager - self._watch(kwargs.pop("watch", [UserRequirement])) + # if observe_all_msg_from_buffer, we should not use cause_by to select messages but observe all + if not self.observe_all_msg_from_buffer: + self._watch(kwargs.pop("watch", [UserRequirement])) if self.latest_observed_msg: self.recovered = True @@ -186,29 +190,6 @@ def set_todo(self, value: Optional[Action]): value.context = self.context self.rc.todo = value - @property - def git_repo(self): - """Git repo""" - return self.context.git_repo - - @git_repo.setter - def git_repo(self, value): - self.context.git_repo = value - - @property - def src_workspace(self): - """Source workspace under git repo""" - return self.context.src_workspace - - @src_workspace.setter - def src_workspace(self, value): - self.context.src_workspace = value - - @property - def project_repo(self) -> ProjectRepo: - project_repo = ProjectRepo(self.context.git_repo) - return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo - @property def prompt_schema(self): """Prompt schema: json/markdown""" @@ -246,10 +227,9 @@ def _check_actions(self): return self def _init_action(self, action: Action): - if not action.private_config: - action.set_llm(self.llm, override=True) - else: - action.set_llm(self.llm, override=False) + action.set_context(self.context) + override = not action.private_config + action.set_llm(self.llm, override=override) action.set_prefix(self._get_prefix()) def set_action(self, action: Action): @@ -325,7 +305,7 @@ def _set_state(self, state: int): logger.debug(f"actions={self.actions}, state={state}") self.set_todo(self.actions[self.rc.state] if state >= 0 else None) - def set_env(self, env: "Environment"): + def set_env(self, env: BaseEnvironment): """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self.rc.env = env @@ -391,22 +371,21 @@ async def _act(self) -> Message: logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") response = await self.rc.todo.run(self.rc.history) if isinstance(response, (ActionOutput, ActionNode)): - msg = Message( + msg = AIMessage( content=response.content, instruct_content=response.instruct_content, - role=self._setting, cause_by=self.rc.todo, sent_from=self, ) elif isinstance(response, Message): msg = response else: - msg = Message(content=response, role=self.profile, cause_by=self.rc.todo, sent_from=self) + msg = AIMessage(content=response or "", cause_by=self.rc.todo, sent_from=self) self.rc.memory.add(msg) return msg - async def _observe(self, ignore_memory=False) -> int: + async def _observe(self) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. news = [] @@ -415,12 +394,17 @@ async def _observe(self, ignore_memory=False) -> int: if not news: news = self.rc.msg_buffer.pop_all() # Store the read messages in your own memory to prevent duplicate processing. - old_messages = [] if ignore_memory else self.rc.memory.get() - self.rc.memory.add_batch(news) - # Filter out messages of interest. + old_messages = [] if not self.enable_memory else self.rc.memory.get() + # Filter in messages of interest. self.rc.news = [ n for n in news if (n.cause_by in self.rc.watch or self.name in n.send_to) and n not in old_messages ] + if self.observe_all_msg_from_buffer: + # save all new messages from the buffer into memory, the role may not react to them but can be aware of them + self.rc.memory.add_batch(news) + else: + # only save messages of interest into memory + self.rc.memory.add_batch(self.rc.news) self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None # record the latest observed msg # Design Rules: @@ -435,9 +419,19 @@ def publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not msg: return + if MESSAGE_ROUTE_TO_SELF in msg.send_to: + msg.send_to.add(any_to_str(self)) + msg.send_to.remove(MESSAGE_ROUTE_TO_SELF) + if not msg.sent_from or msg.sent_from == MESSAGE_ROUTE_TO_SELF: + msg.sent_from = any_to_str(self) + if all(to in {any_to_str(self), self.name} for to in msg.send_to): # Message to myself + self.put_message(msg) + return if not self.rc.env: # If env does not exist, do not publish the message return + if isinstance(msg, AIMessage) and not msg.agent: + msg.with_agent(self._setting) self.rc.env.publish_message(msg) def put_message(self, message): @@ -452,11 +446,11 @@ async def _react(self) -> Message: Use llm to select actions in _think dynamically """ actions_taken = 0 - rsp = Message(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act + rsp = AIMessage(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act while actions_taken < self.rc.max_react_loop: # think - await self._think() - if self.rc.todo is None: + has_todo = await self._think() + if not has_todo: break # act logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}") @@ -467,7 +461,7 @@ async def _react(self) -> Message: async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state - rsp = Message(content="No actions taken yet") # return default message if actions=[] + rsp = AIMessage(content="No actions taken yet") # return default message if actions=[] for i in range(start_idx, len(self.states)): self._set_state(i) rsp = await self._act() @@ -492,6 +486,8 @@ async def _plan_and_act(self) -> Message: await self.planner.process_task_result(task_result) rsp = self.planner.get_useful_memories()[0] # return the completed plan as a response + rsp.role = "assistant" + rsp.sent_from = self._setting self.rc.memory.add(rsp) # add to persistent memory @@ -522,6 +518,8 @@ async def react(self) -> Message: else: raise ValueError(f"Unsupported react mode: {self.rc.react_mode}") self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None + if isinstance(rsp, AIMessage): + rsp.with_agent(self._setting) return rsp def get_memories(self, k=0) -> list[Message]: @@ -592,6 +590,3 @@ def action_description(self) -> str: if self.actions: return any_to_name(self.actions[0]) return "" - - -RoleContext.model_rebuild() diff --git a/metagpt/schema.py b/metagpt/schema.py index 071518d62e..52badcc21a 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -18,9 +18,11 @@ import asyncio import json import os.path +import time import uuid from abc import ABC from asyncio import Queue, QueueEmpty, wait_for +from enum import Enum from json import JSONDecodeError from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union @@ -30,25 +32,36 @@ ConfigDict, Field, PrivateAttr, + create_model, field_serializer, field_validator, - model_serializer, - model_validator, ) +from metagpt.base.base_serialization import BaseSerialization from metagpt.const import ( + AGENT, MESSAGE_ROUTE_CAUSE_BY, MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, - PRDS_FILE_REPO, + SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) from metagpt.logs import logger from metagpt.repo_parser import DotClassInfo -from metagpt.utils.common import any_to_str, any_to_str_set, import_class +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import ( + CodeParser, + any_to_str, + any_to_str_set, + aread, + import_class, + read_json_file, + write_json_file, +) from metagpt.utils.exceptions import handle_exception +from metagpt.utils.report import TaskReporter from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, actionoutput_mapping_to_str, @@ -56,66 +69,65 @@ ) -class SerializationMixin(BaseModel, extra="forbid"): - """ - PolyMorphic subclasses Serialization / Deserialization Mixin - - First of all, we need to know that pydantic is not designed for polymorphism. - - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need - to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. - - More details: - - https://docs.pydantic.dev/latest/concepts/serialization/ - - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` - """ +class SerializationMixin(BaseSerialization): + @handle_exception + def serialize(self, file_path: str = None) -> str: + """Serializes the current instance to a JSON file. - __is_polymorphic_base = False - __subclasses_map__ = {} + If an exception occurs, `handle_exception` will catch it and return `None`. - @model_serializer(mode="wrap") - def __serialize_with_class_type__(self, default_serializer) -> Any: - # default serializer, then append the `__module_class_name` field and return - ret = default_serializer(self) - ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" - return ret + Args: + file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. + + Returns: + str: The path to the JSON file where the instance was saved. + """ + + file_path = file_path or self.get_serialization_path() + + serialized_data = self.model_dump() + + write_json_file(file_path, serialized_data, use_fallback=True) + logger.debug(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") + + return file_path - @model_validator(mode="wrap") @classmethod - def __convert_to_real_type__(cls, value: Any, handler): - if isinstance(value, dict) is False: - return handler(value) - - # it is a dict so make sure to remove the __module_class_name - # because we don't allow extra keywords but want to ensure - # e.g Cat.model_validate(cat.model_dump()) works - class_full_name = value.pop("__module_class_name", None) - - # if it's not the polymorphic base we construct via default handler - if not cls.__is_polymorphic_base: - if class_full_name is None: - return handler(value) - elif str(cls) == f"": - return handler(value) - else: - # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") - pass + @handle_exception + def deserialize(cls, file_path: str = None) -> BaseModel: + """Deserializes a JSON file to an instance of cls. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file to read from. Defaults to None. + + Returns: + An instance of the cls. + """ + + file_path = file_path or cls.get_serialization_path() - # otherwise we lookup the correct polymorphic type and construct that - # instead - if class_full_name is None: - raise ValueError("Missing __module_class_name field") + data: dict = read_json_file(file_path) - class_type = cls.__subclasses_map__.get(class_full_name, None) + model = cls(**data) + logger.debug(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}") + + return model + + @classmethod + def get_serialization_path(cls) -> str: + """Get the serialization path for the class. - if class_type is None: - # TODO could try dynamic import - raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!") + This method constructs a file path for serialization based on the class name. + The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName' + is the name of the class. - return class_type(**value) + Returns: + str: The path to the serialization file. + """ - def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): - cls.__is_polymorphic_base = is_polymorphic_base - cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls - super().__init_subclass__(**kwargs) + return str(SERDESER_PATH / f"{cls.__qualname__}.json") class SimpleMessage(BaseModel): @@ -154,6 +166,30 @@ def __str__(self): def __repr__(self): return self.content + @classmethod + async def load( + cls, filename: Union[str, Path], project_path: Optional[Union[str, Path]] = None + ) -> Optional["Document"]: + """ + Load a document from a file. + + Args: + filename (Union[str, Path]): The path to the file to load. + project_path (Optional[Union[str, Path]], optional): The path to the project. Defaults to None. + + Returns: + Optional[Document]: The loaded document, or None if the file does not exist. + + """ + if not filename or not Path(filename).exists(): + return None + content = await aread(filename=filename) + doc = cls(content=content, filename=str(filename)) + if project_path and Path(filename).is_relative_to(project_path): + doc.root_path = Path(filename).relative_to(project_path).parent + doc.filename = Path(filename).name + return doc + class Documents(BaseModel): """A class representing a collection of documents. @@ -185,16 +221,25 @@ def to_action_output(self) -> "ActionOutput": return ActionOutput(content=self.model_dump_json(), instruct_content=self) +class Resource(BaseModel): + """Used by `Message`.`parse_resources`""" + + resource_type: str # the type of resource + value: str # a string type of resource content + description: str # explanation + + class Message(BaseModel): """list[: ]""" id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135 - content: str + content: str # natural language for user or agent instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True) role: str = "user" # system / user / assistant cause_by: str = Field(default="", validate_default=True) sent_from: str = Field(default="", validate_default=True) send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) + metadata: Dict[str, Any] = Field(default_factory=dict) # metadata for `content` and `instruct_content` @field_validator("id", mode="before") @classmethod @@ -310,14 +355,75 @@ def load(val): logger.error(f"parse json failed: {val}, error:{err}") return None + async def parse_resources(self, llm: "BaseLLM", key_descriptions: Dict[str, str] = None) -> Dict: + """ + `parse_resources` corresponds to the in-context adaptation capability of the input of the atomic action, + which will be migrated to the context builder later. + + Args: + llm (BaseLLM): The instance of the BaseLLM class. + key_descriptions (Dict[str, str], optional): A dictionary containing descriptions for each key, + if provided. Defaults to None. + + Returns: + Dict: A dictionary containing parsed resources. + + """ + if not self.content: + return {} + content = f"## Original Requirement\n```text\n{self.content}\n```\n" + return_format = ( + "Return a markdown JSON object with:\n" + '- a "resources" key contain a list of objects. Each object with:\n' + ' - a "resource_type" key explain the type of resource;\n' + ' - a "value" key containing a string type of resource content;\n' + ' - a "description" key explaining why;\n' + ) + key_descriptions = key_descriptions or {} + for k, v in key_descriptions.items(): + return_format += f'- a "{k}" key containing {v};\n' + return_format += '- a "reason" key explaining why;\n' + instructions = ['Lists all the resources contained in the "Original Requirement".', return_format] + rsp = await llm.aask(msg=content, system_msgs=instructions) + json_data = CodeParser.parse_code(text=rsp, lang="json") + m = json.loads(json_data) + m["resources"] = [Resource(**i) for i in m.get("resources", [])] + return m + + def add_metadata(self, key: str, value: str): + self.metadata[key] = value + + @staticmethod + def create_instruct_value(kvs: Dict[str, Any], class_name: str = "") -> BaseModel: + """ + Dynamically creates a Pydantic BaseModel subclass based on a given dictionary. + + Parameters: + - data: A dictionary from which to create the BaseModel subclass. + + Returns: + - A Pydantic BaseModel subclass instance populated with the given data. + """ + if not class_name: + class_name = "DM" + uuid.uuid4().hex[0:8] + dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) + return dynamic_class.model_validate(kvs) + + def is_user_message(self) -> bool: + return self.role == "user" + + def is_ai_message(self) -> bool: + return self.role == "assistant" + class UserMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ - def __init__(self, content: str): - super().__init__(content=content, role="user") + def __init__(self, content: str, **kwargs): + kwargs.pop("role", None) + super().__init__(content=content, role="user", **kwargs) class SystemMessage(Message): @@ -325,8 +431,9 @@ class SystemMessage(Message): Facilitate support for OpenAI messages """ - def __init__(self, content: str): - super().__init__(content=content, role="system") + def __init__(self, content: str, **kwargs): + kwargs.pop("role", None) + super().__init__(content=content, role="system", **kwargs) class AIMessage(Message): @@ -334,8 +441,17 @@ class AIMessage(Message): Facilitate support for OpenAI messages """ - def __init__(self, content: str): - super().__init__(content=content, role="assistant") + def __init__(self, content: str, **kwargs): + kwargs.pop("role", None) + super().__init__(content=content, role="assistant", **kwargs) + + def with_agent(self, name: str): + self.add_metadata(key=AGENT, value=name) + return self + + @property + def agent(self) -> str: + return self.metadata.get(AGENT, "") class Task(BaseModel): @@ -347,6 +463,7 @@ class Task(BaseModel): result: str = "" is_success: bool = False is_finished: bool = False + assignee: str = "" def reset(self): self.code = "" @@ -355,8 +472,8 @@ def reset(self): self.is_finished = False def update_task_result(self, task_result: TaskResult): - self.code = task_result.code - self.result = task_result.result + self.code = self.code + "\n" + task_result.code + self.result = self.result + "\n" + task_result.result self.is_success = task_result.is_success @@ -368,7 +485,17 @@ class TaskResult(BaseModel): is_success: bool +@register_tool( + include_functions=[ + "append_task", + "reset_task", + "replace_task", + "finish_current_task", + ] +) class Plan(BaseModel): + """Plan is a sequence of tasks towards a goal.""" + goal: str context: str = "" tasks: list[Task] = [] @@ -441,19 +568,23 @@ def add_tasks(self, tasks: list[Task]): def reset_task(self, task_id: str): """ - Clear code and result of the task based on task_id, and set the task as unfinished. + Reset a task based on task_id, i.e. set Task.is_finished=False and request redo. This also resets all tasks depending on it. Args: task_id (str): The ID of the task to be reset. - - Returns: - None """ if task_id in self.task_map: task = self.task_map[task_id] task.reset() + # reset all downstream tasks that are dependent on the reset task + for dep_task in self.tasks: + if task_id in dep_task.dependent_task_ids: + # FIXME: if LLM generates cyclic tasks, this will result in infinite recursion + self.reset_task(dep_task.task_id) - def replace_task(self, new_task: Task): + self._update_current_task() + + def _replace_task(self, new_task: Task): """ Replace an existing task with the new input task based on task_id, and reset all tasks depending on it. @@ -476,7 +607,9 @@ def replace_task(self, new_task: Task): if new_task.task_id in task.dependent_task_ids: self.reset_task(task.task_id) - def append_task(self, new_task: Task): + self._update_current_task() + + def _append_task(self, new_task: Task): """ Append a new task to the end of existing task sequences @@ -486,7 +619,11 @@ def append_task(self, new_task: Task): Returns: None """ - assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead" + # assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead" + if self.has_task_id(new_task.task_id): + logger.warning( + "Task already in current plan, should use replace_task instead. Overwriting the existing task." + ) assert all( [self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids] @@ -501,12 +638,17 @@ def has_task_id(self, task_id: str) -> bool: return task_id in self.task_map def _update_current_task(self): + self.tasks = self._topological_sort(self.tasks) + # Update the task map for quick access to tasks by ID + self.task_map = {task.task_id: task for task in self.tasks} + current_task_id = "" for task in self.tasks: if not task.is_finished: current_task_id = task.task_id break - self.current_task_id = current_task_id # all tasks finished + self.current_task_id = current_task_id + TaskReporter().report({"tasks": [i.model_dump() for i in self.tasks], "current_task_id": current_task_id}) @property def current_task(self) -> Task: @@ -523,6 +665,15 @@ def finish_current_task(self): self.current_task.is_finished = True self._update_current_task() # set to next task + def finish_all_tasks(self): + "Finish all tasks." + while self.current_task: + self.finish_current_task() + + def is_plan_finished(self) -> bool: + """Check if all tasks are finished""" + return all(task.is_finished for task in self.tasks) + def get_finished_tasks(self) -> list[Task]: """return all finished tasks in correct linearized order @@ -531,6 +682,33 @@ def get_finished_tasks(self) -> list[Task]: """ return [task for task in self.tasks if task.is_finished] + def append_task( + self, task_id: str, dependent_task_ids: list[str], instruction: str, assignee: str, task_type: str = "" + ): + """ + Append a new task with task_id (number) to the end of existing task sequences. + If dependent_task_ids is not empty, the task will depend on the tasks with the ids in the list. + Note that the assignee should be the 'name' of the role. + """ + new_task = Task( + task_id=task_id, + dependent_task_ids=dependent_task_ids, + instruction=instruction, + assignee=assignee, + task_type=task_type, + ) + return self._append_task(new_task) + + def replace_task(self, task_id: str, new_dependent_task_ids: list[str], new_instruction: str, new_assignee: str): + """Replace an existing task (can be current task) based on task_id, and reset all tasks depending on it.""" + new_task = Task( + task_id=task_id, + dependent_task_ids=new_dependent_task_ids, + instruction=new_instruction, + assignee=new_assignee, + ) + return self._replace_task(new_task) + class MessageQueue(BaseModel): """Message queue which supports asynchronous updates.""" @@ -671,10 +849,6 @@ def __hash__(self): return hash((self.design_filename, self.task_filename)) -class BugFixContext(BaseContext): - filename: str = "" - - class CodePlanAndChangeContext(BaseModel): requirement: str = "" issue: str = "" @@ -682,22 +856,6 @@ class CodePlanAndChangeContext(BaseModel): design_filename: str = "" task_filename: str = "" - @staticmethod - def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext: - ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", "")) - for filename in filenames: - filename = Path(filename) - if filename.is_relative_to(PRDS_FILE_REPO): - ctx.prd_filename = filename.name - continue - if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO): - ctx.design_filename = filename.name - continue - if filename.is_relative_to(TASK_FILE_REPO): - ctx.task_filename = filename.name - continue - return ctx - # mermaid class view class UMLClassMeta(BaseModel): @@ -785,3 +943,34 @@ def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView: method.return_type = i.return_args.type_ class_view.methods.append(method) return class_view + + +class BaseEnum(Enum): + """Base class for enums.""" + + def __new__(cls, value, desc=None): + """ + Construct an instance of the enum member. + + Args: + cls: The class. + value: The value of the enum member. + desc: The description of the enum member. Defaults to None. + """ + if issubclass(cls, str): + obj = str.__new__(cls, value) + elif issubclass(cls, int): + obj = int.__new__(cls, value) + else: + obj = object.__new__(cls) + obj._value_ = value + obj.desc = desc + return obj + + +class LongTermMemoryItem(BaseModel): + message: Message + created_at: Optional[float] = Field(default_factory=time.time) + + def rag_key(self) -> str: + return self.message.content diff --git a/metagpt/software_company.py b/metagpt/software_company.py index f290d497a7..f74b61191c 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -7,7 +7,7 @@ import typer from metagpt.const import CONFIG_ROOT -from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.common import any_to_str app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False) @@ -25,9 +25,9 @@ def generate_repo( reqa_file="", max_auto_summarize_code=0, recover_path=None, -) -> ProjectRepo: +): """Run the startup logic. Can be called from CLI or other Python scripts.""" - from metagpt.config2 import config + from metagpt.config2 import Config from metagpt.context import Context from metagpt.roles import ( Architect, @@ -38,6 +38,8 @@ def generate_repo( ) from metagpt.team import Team + config = Config.default() + config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) ctx = Context(config=config) @@ -65,10 +67,10 @@ def generate_repo( idea = company.idea company.invest(investment) - company.run_project(idea) + company.run_project(idea, send_to=any_to_str(ProductManager)) asyncio.run(company.run(n_round=n_round)) - return ctx.repo + return ctx.kwargs.get("project_path") @app.command("", help="Start a new project.") diff --git a/metagpt/strategy/experience_retriever.py b/metagpt/strategy/experience_retriever.py new file mode 100644 index 0000000000..0d5bcad52f --- /dev/null +++ b/metagpt/strategy/experience_retriever.py @@ -0,0 +1,1243 @@ +from typing import Literal + +from pydantic import BaseModel + + +class ExpRetriever(BaseModel): + """interface for experience retriever""" + + def retrieve(self, context: str = "") -> str: + raise NotImplementedError + + +class DummyExpRetriever(ExpRetriever): + """A dummy experience retriever that returns empty string.""" + + def retrieve(self, context: str = "") -> str: + return self.EXAMPLE + + EXAMPLE: str = "" + + +class TRDAllExpRetriever(ExpRetriever): + def retrieve(self, context: str = "") -> str: + return self.EXAMPLE + + EXAMPLE: str = """ +## example 1 +User Requirement: Given some user requirements, write a software framework. +Explanation: Given a complete user requirement, to write a TRD and software framework, you must follow all of the following steps to complete the TRD output required by the user: 1. Call 'write_trd' to generate TRD; 2. Call 'write_framework' to implement TRD into the software framework. +```json +[ + { + "command_name": "write_trd_and_framework", + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Execute `write_trd_and_framework` to write a TRD and software framework based on user requirements", + "args": { + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "additional_technical_requirements": "These are additional technical requirements, balabala..." + } + } +] +``` +## example 2 +User Requirement: Given some user requirements, write a software framework. +Explanation: Given a complete user requirement, to write a software framework, you must follow all of the following steps to complete the TRD output required by the user: 1. Call 'write_trd' to generate TRD; 2. Call 'write_framework' to implement TRD into the software framework. +```json +[ + { + "command_name": "write_trd", + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Execute `write_trd` to write the TRD based on user requirements", + "args": { + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + } + }, + { + "command_name": "write_framework", + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Execute `write_framework` to write the framework based on the TRD", + "args": { + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `write_trd`", + "additional_technical_requirements": "These are additional technical requirements, balabala..." + } + } +] +``` +## example 3 +User Requirement: Given some user requirements, write a TRD, and implement the TRD within a software framework. +Explanation: + Given a complete requirement, 要写TRD需要follow如下步骤: + 1. 调用`CompressExternalInterfaces.run`,从acknowledgement中抽取external interfaces的信息; + 2. 按顺序执行如下步骤: + 2.1. 执行`DetectInteraction.run`; + 2.2. 执行`WriteTRD.run`; + 2.3. 执行`EvaluateTRD.run`; + 2.4. 检查`EvaluateTRD.run`的结果: + 2.4.1. 如果`EvaluateTRD.run`的结果被判定为pass,则执行步骤3; + 2.4.2. 如果`EvaluateTRD.run`的结果被判定为deny,则继续执行步骤2; + 3. 按顺序执行如下步骤: + 3.1. 执行`WriteFramework.run`; + 3.2. 执行`EvaluateFramework.run`; + 3.3. 检查`EvaluateFramework.run`的结果: + 3.3.1. 如果`EvaluateFramework.run`的结果被判定为pass,则执行步骤4; + 3.3.2. 如果`EvaluateFramework.run`的结果被判定为deny,则继续执行步骤3; + 3.3.3. 如果已经重复执行步骤3超过9次,则执行步骤4; + 4. 执行`save_framework`,将`WriteFramework.run`的结果保存下来; +```json +[ + { + "command_name": "CompressExternalInterfaces.run", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Execute `DetectInteraction.run` to extract external interfaces information from acknowledgement.", + "acknowledge": "## Interfaces\n balabala..." + } + }, + { + "command_name": "DetectInteraction.run", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Execute `DetectInteraction.run` to extract external interfaces information from acknowledgement.", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + } + }, + { + "command_name": "WriteTRD.run", + "args": { + "task_id": "3", + "dependent_task_ids": ["2"], + "instruction": "Execute `WriteTRD.run` to write TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": " returned by `DetectInteraction.run`" + } + }, + { + "command_name": "EvaluateTRD.run", + "args": { + "task_id": "4", + "dependent_task_ids": ["3"], + "instruction": "Execute `EvaluateTRD.run` to evaluate the TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": "", + "trd": " returned by `EvaluateTRD.run`" + } + }, + { + "command_name": "DetectInteraction.run", + "args": { + "task_id": "5", + "dependent_task_ids": ["4"], + "instruction": "Execute `DetectInteraction.run` to extract external interfaces information from acknowledgement.", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "evaluation_conclusion": " returned by `EvaluateTRD.run`" + } + }, + { + "command_name": "WriteTRD.run", + "args": { + "task_id": "6", + "dependent_task_ids": ["5"], + "instruction": "Execute `WriteTRD.run` to write TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": " returned by `DetectInteraction.run`", + "previous_version_trd": " returned by `WriteTRD.run`" + } + }, + { + "command_name": "EvaluateTRD.run", + "args": { + "task_id": "7", + "dependent_task_ids": ["6"], + "instruction": "Execute `EvaluateTRD.run` to evaluate the TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": " returned by `DetectInteraction.run`", + "trd": " returned by `WriteTRD.run`", + } + }, + { + "command_name": "WriteFramework.run", + "args": { + "task_id": "8", + "dependent_task_ids": ["7"], + "instruction": "Execute `WriteFramework.run` to write a software framework according to the TRD", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "EvaluateFramework.run", + "args": { + "task_id": "9", + "dependent_task_ids": ["8"], + "instruction": "Execute `EvaluateFramework.run` to evaluate the software framework returned by `WriteFramework.run`", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "legacy_output": " returned by `WriteFramework.run`", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "WriteFramework.run", + "args": { + "task_id": "10", + "dependent_task_ids": ["9"], + "instruction": "Execute `WriteFramework.run` to write a software framework according to the TRD", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "EvaluateFramework.run", + "args": { + "task_id": "11", + "dependent_task_ids": ["10"], + "instruction": "Execute `EvaluateFramework.run` to evaluate the software framework returned by `WriteFramework.run`", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "legacy_output": " returned by `WriteFramework.run`", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "save_framework", + "args": { + "task_id": "12", + "dependent_task_ids": ["11"], + "instruction": "Execute `save_framework` to save the software framework returned by `WriteFramework.run`", + "dir_data": " returned by `WriteFramework.run`", + } + } +] +``` + """ + + +class TRDToolExpRetriever(ExpRetriever): + """A TRD-related experience retriever that returns empty string.""" + + def retrieve(self, context: str = "") -> str: + return self.EXAMPLE + + EXAMPLE: str = """ +## example 1 +User Requirement: Given some user requirements, write a software framework. +Explanation: Given a complete user requirement, to write a TRD and software framework, you must follow all of the following steps to complete the TRD output required by the user: 1. Call 'write_trd' to generate TRD; 2. Call 'write_framework' to implement TRD into the software framework. +```json +[ + { + "command_name": "write_trd_and_framework", + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Execute `write_trd_and_framework` to write a TRD and software framework based on user requirements", + "args": { + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "additional_technical_requirements": "These are additional technical requirements, balabala..." + } + } +] + """ + # EXAMPLE: str = """ + # ## example 1 + # User Requirement: Given some user requirements, write a software framework. + # Explanation: Given a complete user requirement, to write a software framework, you must follow all of the following steps to complete the TRD output required by the user: 1. Call 'write_trd' to generate TRD; 2. Call 'write_framework' to implement TRD into the software framework. + # ```json + # [ + # { + # "command_name": "write_trd", + # "task_id": "1", + # "dependent_task_ids": [], + # "instruction": "Execute `write_trd` to write the TRD based on user requirements", + # "args": { + # "user_requirements": "This is user requirement balabala...", + # "use_case_actors": "These are actors involved in the use case, balabala...", + # } + # }, + # { + # "command_name": "write_framework", + # "task_id": "2", + # "dependent_task_ids": ["1"], + # "instruction": "Execute `write_framework` to write the framework based on the TRD", + # "args": { + # "use_case_actors": "These are actors involved in the use case, balabala...", + # "trd": " returned by `write_trd`", + # "additional_technical_requirements": "These are additional technical requirements, balabala..." + # } + # } + # ] + # ``` + # """ + + +class TRDExpRetriever(ExpRetriever): + """A TRD-related experience retriever that returns empty string.""" + + def retrieve(self, context: str = "") -> str: + return self.EXAMPLE + + EXAMPLE: str = """ + ## example 1 + User Requirement: Given some user requirements, write a TRD, and implement the TRD within a software framework. + Explanation: + Given a complete requirement, 要写TRD需要follow如下步骤: + 1. 调用`CompressExternalInterfaces.run`,从acknowledgement中抽取external interfaces的信息; + 2. 按顺序执行如下步骤: + 2.1. 执行`DetectInteraction.run`; + 2.2. 执行`WriteTRD.run`; + 2.3. 执行`EvaluateTRD.run`; + 2.4. 检查`EvaluateTRD.run`的结果: + 2.4.1. 如果`EvaluateTRD.run`的结果被判定为pass,则执行步骤3; + 2.4.2. 如果`EvaluateTRD.run`的结果被判定为deny,则继续执行步骤2; + 3. 按顺序执行如下步骤: + 3.1. 执行`WriteFramework.run`; + 3.2. 执行`EvaluateFramework.run`; + 3.3. 检查`EvaluateFramework.run`的结果: + 3.3.1. 如果`EvaluateFramework.run`的结果被判定为pass,则执行步骤4; + 3.3.2. 如果`EvaluateFramework.run`的结果被判定为deny,则继续执行步骤3; + 3.3.3. 如果已经重复执行步骤3超过9次,则执行步骤4; + 4. 执行`save_framework`,将`WriteFramework.run`的结果保存下来; + ```json + [ + { + "command_name": "CompressExternalInterfaces.run", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Execute `DetectInteraction.run` to extract external interfaces information from acknowledgement.", + "acknowledge": "## Interfaces\n balabala..." + } + }, + { + "command_name": "DetectInteraction.run", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Execute `DetectInteraction.run` to extract external interfaces information from acknowledgement.", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + } + }, + { + "command_name": "WriteTRD.run", + "args": { + "task_id": "3", + "dependent_task_ids": ["2"], + "instruction": "Execute `WriteTRD.run` to write TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": " returned by `DetectInteraction.run`" + } + }, + { + "command_name": "EvaluateTRD.run", + "args": { + "task_id": "4", + "dependent_task_ids": ["3"], + "instruction": "Execute `EvaluateTRD.run` to evaluate the TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": "", + "trd": " returned by `EvaluateTRD.run`" + } + }, + { + "command_name": "DetectInteraction.run", + "args": { + "task_id": "5", + "dependent_task_ids": ["4"], + "instruction": "Execute `DetectInteraction.run` to extract external interfaces information from acknowledgement.", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "evaluation_conclusion": " returned by `EvaluateTRD.run`" + } + }, + { + "command_name": "WriteTRD.run", + "args": { + "task_id": "6", + "dependent_task_ids": ["5"], + "instruction": "Execute `WriteTRD.run` to write TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": " returned by `DetectInteraction.run`", + "previous_version_trd": " returned by `WriteTRD.run`" + } + }, + { + "command_name": "EvaluateTRD.run", + "args": { + "task_id": "7", + "dependent_task_ids": ["6"], + "instruction": "Execute `EvaluateTRD.run` to evaluate the TRD", + "user_requirements": "This is user requirement balabala...", + "use_case_actors": "These are actors involved in the use case, balabala...", + "available_external_interfaces": " returned by `CompressExternalInterfaces.run`", + "interaction_events": " returned by `DetectInteraction.run`", + "trd": " returned by `WriteTRD.run`", + } + }, + { + "command_name": "WriteFramework.run", + "args": { + "task_id": "8", + "dependent_task_ids": ["7"], + "instruction": "Execute `WriteFramework.run` to write a software framework according to the TRD", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "EvaluateFramework.run", + "args": { + "task_id": "9", + "dependent_task_ids": ["8"], + "instruction": "Execute `EvaluateFramework.run` to evaluate the software framework returned by `WriteFramework.run`", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "legacy_output": " returned by `WriteFramework.run`", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "WriteFramework.run", + "args": { + "task_id": "10", + "dependent_task_ids": ["9"], + "instruction": "Execute `WriteFramework.run` to write a software framework according to the TRD", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "EvaluateFramework.run", + "args": { + "task_id": "11", + "dependent_task_ids": ["10"], + "instruction": "Execute `EvaluateFramework.run` to evaluate the software framework returned by `WriteFramework.run`", + "use_case_actors": "These are actors involved in the use case, balabala...", + "trd": " returned by `WriteTRD.run`", + "acknowledge": "## Interfaces\n balabala...", + "legacy_output": " returned by `WriteFramework.run`", + "additional_technical_requirements": "These are additional technical requirements, balabala...", + } + }, + { + "command_name": "save_framework", + "args": { + "task_id": "12", + "dependent_task_ids": ["11"], + "instruction": "Execute `save_framework` to save the software framework returned by `WriteFramework.run`", + "dir_data": " returned by `WriteFramework.run`", + } + } + ] + ``` + """ + + +TL_EXAMPLE = """ +## example 1 +User Requirement: Create a cli snake game. +Explanation: The requirement is about software development. Assign each tasks to a different team member based on their expertise. When publishing message to Product Manager, we copy original user requirement directly to ensure no information loss. +```json +[ + { + "command_name": "Plan.append_task", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Use Vite, React, MUI, Tailwind CSS for the program. And create a product requirement document (PRD). ", + "assignee": "Alice" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Use Vite, React, MUI, Tailwind CSS for the program. Design the software architecture for the CLI snake game.", + "assignee": "Bob" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "3", + "dependent_task_ids": ["2"], + "instruction": "Break down the architecture into manageable tasks, identify task dependencies, and prepare a detailed task list for implementation.", + "assignee": "Eve" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "4", + "dependent_task_ids": ["3"], + "instruction": "Use Vite, React, MUI, Tailwind CSS for the program. Implement the core game logic for the CLI snake game, including snake movement, food generation, and score tracking.", + "assignee": "Alex" + } + }, + { + "command_name": "TeamLeader.publish_message", + "args": { + "content": "Use Vite, React, MUI, Tailwind CSS for the program. Create a cli snake game.", + "send_to": "Alice" + } + }, + { + "command_name": "RoleZero.reply_to_human", + "args": { + "content": "I have assigned the tasks to the team members. Alice will create the PRD, Bob will design the software architecture, Eve will break down the architecture into tasks, Alex will implement the core game logic, and Edward will write comprehensive tests. The team will work on the project accordingly" + } + }, + { + "command_name": "end" + } +] +``` + + +## example 2 +User Requirement: Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy. +Explanation: DON'T decompose requirement if it is a DATA-RELATED task, assign a single task directly to Data Analyst David. He will manage the decomposition and implementation. +```json +[ + { + "command_name": "Plan.append_task", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.", + "assignee": "David" + } + }, + { + "command_name": "TeamLeader.publish_message", + "args": { + "content": "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.", + "send_to": "David" + } + }, + { + "command_name": "RoleZero.reply_to_human", + "args": { + "content": "I have assigned the task to David. He will break down the task further by himself and starts solving it.", + } + }, + { + "command_name": "end" + } +] +``` + +## example 3 +Conversation History: +[ + ..., + {'role': 'assistant', 'content': 'from Alice(Product Manager) to {''}: Request is completed, with outputs: Command WritePRD executed: PRD filename: "/tmp/workspace/snake_game/docs/prd.json"'}, +] +Explanation: You received a message from Alice, the Product Manager, that she has completed the PRD, use Plan.finish_current_task to mark her task as finished and moves the plan to the next task. Based on plan status, next task is for Bob (Architect), publish a message asking him to start. The message content should contain important path info. +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {} + }, + { + "command_name": "TeamLeader.publish_message", + "args": { + "content": "Please design the software architecture for the snake game based on the PRD created by Alice. The PRD is at '/tmp/workspace/snake_game/docs/prd.json'.", + "send_to": "Bob" + } + }, + { + "command_name": "RoleZero.reply_to_human", + "args": { + "content": "Alice has completed the PRD. I have marked her task as finished and sent the PRD to Bob. Bob will work on the software architecture." + } + }, + { + "command_name": "end" + } +] +``` + +## example 4 +User Question: how does the project go? +Explanation: The user is asking for a general update on the project status. Give a straight answer about the current task the team is working on and provide a summary of the completed tasks. +```json +[ + { + "command_name": "RoleZero.reply_to_human", + "args": { + "content": "The team is currently working on ... We have completed ..." + } + }, + { + "command_name": "end" + } +] +``` + +## example 5 +OBSERVATION : current task is none and all task is finished. +Explanation: Last task is "Plan.finish_current_task" or 'RoleZero.reply_to_human' and now the current task is none, it means everything is done.Just coutput command "end". +```json +[ + { + "command_name": "end" + } +] + +## example 6 +OBSERVATION : The previously completed task is identical to the current task. +Explanation: The current task has been accomplished previously. +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {} + }, +] +``` + +## example 7 +OBSERVATION : the task assigned to Alice is still ongoing as it has not been marked as finished. The current task in the plan is for Alice to create the PRD. +Explanation: "I attempted to locate historical records containing 'send to []', and discovered an entry stating 'PRD is finished and masked.' This indicates that Alice's task has been completed. +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {} + }, +] +``` +""" + + +class SimpleExpRetriever(ExpRetriever): + """A simple experience retriever that returns manually crafted examples.""" + + def retrieve(self, context: str = "") -> str: + return TL_EXAMPLE + + +class KeywordExpRetriever(ExpRetriever): + """An experience retriever that returns examples based on keywords in the context.""" + + def retrieve(self, context: str, exp_type: Literal["plan", "task"] = "plan") -> str: + if exp_type == "plan": + if "deploy" in context.lower(): + return DEPLOY_EXAMPLE + elif "issue" in context.lower(): + return FIX_ISSUE_EXAMPLE + elif "https:" in context.lower() or "http:" in context.lower() or "search" in context.lower(): + if "search" in context.lower() or "click" in context.lower(): + return WEB_SCRAPING_EXAMPLE + return WEB_SCRAPING_EXAMPLE_SIMPLE + # elif exp_type == "task": + # if "diagnose" in context.lower(): + # return SEARCH_SYMBOL_EXAMPLE + return "" + + +DEPLOY_EXAMPLE = """ +## example 1 +User Requirement: launch a service from workspace/web_snake_game/web_snake_game, and deploy it to public +Explanation: Launching a service requires Terminal tool with daemon mode, write this into task instruction. +```json +[ + { + "command_name": "Plan.append_task", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Use the Terminal tool to launch the service in daemon mode", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Test the service with a simple request", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "3", + "dependent_task_ids": ["2"], + "instruction": "Deploy the service to public", + "assignee": "David" + } + }, +] +""" + + +FIX_ISSUE_EXAMPLE = """ +## example 1 +User Requirement: Write a fix for this issue: https://github.com/xxx/xxx/issues/xxx, and commit, push your changes, and create a PR to the target repo. +Explanation: The requirement is to fix an issue in an existing repository. The process is broken down into several steps, each demanding specific actions and tools. +```json +[ + { + "command_name": "Plan.append_task", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Read the issue description to understand the problem using the Browser tool.", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Clone the repository using the Terminal tool.", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "3", + "dependent_task_ids": ["2"], + "instruction": "Use Editor to search relevant function(s) or open relevant files, then diagnose and identify the source of the problem.", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "4", + "dependent_task_ids": ["3"], + "instruction": "Use Editor tool to fix the problem in the corresponding file(s).", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "5", + "dependent_task_ids": ["4"], + "instruction": "Commit, push the changes to the repository, and create a pull request to the target repository.", + "assignee": "David" + } + }, +] +``` +""" + +ENGINEER_EXAMPLE = """ +## example 1 +User Requirement: Please implement the core game logic for the 2048 game, including tile movements, merging logic, score tracking, and keyboard interaction. Refer to the project schedule located at '/tmp/project_schedule.json' and the system design document at '/tmp/system_design.json' for detailed information. +Explanation: I will first need to read the system design document and the project schedule to understand the specific requirements and architecture outlined for the game development. I should NOT create tasks at this stage. + +```json +[ + { + "command_name": "Editor.read", + "args": { + "path": "/tmp/project_schedule.json" + } + }, + { + "command_name": "Editor.read", + "args": { + "path": "/tmp/system_design.json" + } + } +] +``` +## example 2 +User Requirement: Implement the core game project in Vue/React framework. Document has been read. +Explanation: This is a project that needs to be implemented using Vue.js according to the system design document and user requirements. Therefore, I need to copy the Vue/React template to the project folder first. +```json +[ + { + "command_name": "Terminal.run_command", + "args": { + "cmd": "cp -r {{template_folder}}/* {{workspace}}/{{project_name}}/ && cd {{workspace}}/{{project_name}} && pwd && tree " + } + } +] +``` + +## example 3 +User Requirement: Writing code. + +Here's the Plan +1. Rewrite the code index.html and the code in src folder. Specifically, this includes the index.html, src/main.jsx, src/index.css, and src/App.jsx. which is the main structure file, entry point of the project, the global style file, and the main component. All these files must Use Tailwind CSS for styling +2. Create new files when needed. In the current ecommerce website project, I need to create homepage.jsx, product.jsx. +3. Install, build and deploy after the project is finished. +If the project is a Vue or React Project, install the dependencies after finishing project. And then deploy the project to the public. +```json +[ + { + "command_name": "Plan.append_task", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Rewrite the index.html file with the project title and main entry point.", + "assignee": "Alex" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Rewrite the src/App.jsx file, which is the main component. Use Tailwind CSS for styling", + "assignee": "Alex" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "3", + "dependent_task_ids": ["2"], + "instruction": "Rewrite the src/style.css file with Tailwind CSS.", + "assignee": "Alex" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "4", + "dependent_task_ids": ["2","3"], + "instruction": "Rewrite the src/main.js, which will include the main Vue instance, global styles, and the router.", + "assignee": "Alex" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "5", + "dependent_task_ids": ["2","3","4"], + "instruction": "Create the src/homepage.jsx, which will include the homepage content. Use Tailwind CSS for styling", + "assignee": "Alex" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "6", + "dependent_task_ids": ["2","3","4","5"], + "instruction": "Create the src/product.js, which will include the product detail page. Use Tailwind CSS for styling", + "assignee": "Alex" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "7", + "dependent_task_ids": [], + "instruction": "Install the necessary dependencies, configure the project structure and deploy it to the public", + "assignee": "Alex" + } + } +] +``` + +## example 4 +Explanation: Take on one task, such as writing or rewriting a file. Upon completion, finish current task. + +```json +[ + { + "command_name": "Engineer2.write_new_code", + "args": { + "path": "/absolute/path/to/src/index.html" + } + }, + { + "command_name": "Plan.finish_current_task", + "args": {{}} + } +] +``` +## example 5 +Explanation: The project have been completed. This project is Vue/React Project, I will install and build the project to create static dist folder in the current project folder. + +```json +[ + { + "command_name": "Terminal.run_command", + "args": { + "cmd": "pnpm install && pnpm run build" + } + } +] +``` + +## example 6 +Explanation: After install and build the project, static dist is created in the current project folder. I will deploy the project to the public. +```json +[ + { + "command_name": "Deployer.deploy_to_public, + "args": { + "dist_dir": "/example/dist" + } + } +] + +## example 7 +I have received a GitHub issue URL. +I will use browser to review the detailed information of this issue in order to understand the problem. +```json +[ + { + "command_name": "Browser.goto", + "args": { + "url": "https://github.com/geekan/MetaGPT/issues/1275" + } + } +] +``` + +## example 8 +I need to locating the `openai_api.py` file, so I will search for the `openai_api.py` file. +```json +[ + { + "command_name": "Editor.find_file", + "args": { + "file_name": "openai_api.py" + } + } +] +``` + + + +## example 9 +I have located the openai_api.py file. I want to edit this file, so I will open it first. +```json +[ + { + "command_name": "Editor.open_file", + "args": { + "path": "/workspace/MetaGPT/provider/openai_api.py" + } + } +] +``` + +## example 10 +I have opened the openai_api.py file. However, the range of lines shown is from 001 to 100, and I want to see more. Therefore, I want to use the scroll_down command to view additional lines. +```json +[ + { + "command_name": "Editor.scroll_down", + "args": {{}} + } +] +``` + +## example 11 +I want to change the key bindings from (w/s) to the arrow keys (up, down). And add the space bar to pause. +the previous file look like: +142| while not self.is_game_over(): +143| if event.key == pygame.K_w: +144| self.move_up() +145| elif event.key == pygame.K_s: +146| self.move_down() +147| self.add_random_tile() +Since I only need to modify the lines 143 to 146, I will use Editor.edit_file_by_replace. The original content will be replaced by the new code. +Editor tool is exclusive. If I use this tool, I cannot use any other commands in the current response. +```json +[ + { + "command_name": "Editor.edit_file_by_replace", + "args": { + "file_name":"/workspace/MetaGPT/provider/openai_api.py", + "first_replaced_line_number": 143, + "first_replaced_line_content":" if event.key == pygame.K_w:", + "new_content": " if event.key == pygame.K_UP:\\n self.move_up()\\n elif event.key == pygame.K_DOWN:\\n self.move_down()\\n elif event.key == pygame.K_SPACE:\\n self.stop()" + "last_replaced_line_number": 146, + "last_replaced_line_content": " self.move_down()", + } + } +] +``` + +## example 12 +I want to add a score variable in the initialization of the game. +the previous file look like: +028| if restart: +029| self.snake = Snake() +030| self.food = Food(self.board_size) +031| self.start_game() +032| self.location = (0,0) +I only need to add a few lines to the file, so I will use Editor.insert_content_at_line. The new code will not cover the original code. +Note that the Editor command must be executed in a single response, so this step will only involve using the Editor command. +```json +[ + { + "command_name": "Editor.insert_content_at_line", + "args": { + "file_name":"/workspace/MetaGPT/provider/openai_api.py" + "line_number":31, + "insert_content": " self.score = Score()" + + } + } +] +``` +After executing the command, the file will be: +028| if restart: +029| self.snake = Snake() +030| self.food = Food(self.board_size) +031| self.score = Score() +032| self.start_game() +033| self.location = (0,0) +In the next turn, I will try to add another code snippet + +## example 13 + +Create a pull request (Optional): Merge the changes from the new branch into the master branch. +Thought: Now that the changes have been pushed to the remote repository, due to the user's requirement, let's create a pull request to merge the changes into the master branch. +```json +[ + { + "command_name": "git_create_pull", + "args": { + "base": "master", + "head": "test-fix", + "base_repo_name": "garylin2099/MetaGPT", + "head_repo_name": "seeker-jie/MetaGPT", + "app_name": "github", + "title": "Fix Issue #1275: produced TypeError: openai.types.completion_usage.CompletionUsage() argument after ** must be a mapping, not NoneType"", + "body": "This pull request addresses issue #1275 by ensuring that chunk.usage is not None before passing it to CompletionUsage." + } + } +] +``` + +## example 14 +The requirement is to create a product website featuring goods such as caps, dresses, and T-shirts. +I believe pictures would improve the site, so I will get the images first. +```json +[ + { + "command_name": "ImageGetter.get_image", + "args": { + "search_term": "cap", + "save_file_path": "/tmp/workspace/images/cap.png", + } + } +] +``` +""" + +WEB_SCRAPING_EXAMPLE = """ +## action 1 +User Requirement: Scrap and list the restaurant names of first page by searching for the keyword `beef` on the website https://www.yelp.com/. +Explanation: The requirement is to scrape data from a website and extract information about restaurants. The process involves searching for restaurants with a specific keyword, retrieving and presenting the data in a structured format. + +```json +[ + { + "command_name": "Plan.append_task", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Navigate to the yelp website.", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Search for restaurants with the keyword 'beef'.", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "3", + "dependent_task_ids": ["2"], + "instruction": "View and print the html content of the search result page before scrap data to understand the structure.", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "4", + "dependent_task_ids": ["3"], + "instruction": "Parse the html content to scrape the restaurant names and print it.", + "assignee": "David" + } + } +] +``` + +## action 2 +Explanation: To search for restaurants, I will now go to the website https://www.yelp.com/ first. + +```json +[ + { + "command_name": "Browser.goto", + "args": { + "url": "https://www.yelp.com/" + } + } +] +``` + +## action 3 +Explanation: Since the Browser has successfully navigated to the website, and I find that the element id of the search box is 53. I will finish the current task and then use the Browser tool to type the keyword `beef` in the search box and press enter. + +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {} + }, + { + "command_name": "Browser.type", + "args": { + "element_id": 53, + "content": "beef", + "press_enter_after": true + } + } +] +``` + +## action 4 +Explanation: Since the Browser has successfully search the keyword `beef`, I will finish the current task and then write code to view and print the html content of the page. + +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {} + }, + { + "command_name": "DataAnalyst.write_and_exec_code", + "args": {} + } +] +``` + +## action 5 +Explanation: Since I has successfully viewed the html content in the context, I will first finish the current task and then write code to parse the html content and extract the restaurant names. + +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {} + }, + { + "command_name": "DataAnalyst.write_and_exec_code", + "args": {} + } +] + +... +""" + + +WEB_SCRAPING_EXAMPLE_SIMPLE = """ +## action 1 +User Requirement: List the restaurant names on the website https://www.yelp.com/search?find_desc=beef&find_loc=New+York%2C+NY. +Explanation: The requirement is to scrape data from a website and extract information about restaurants. The process involves retrieving and presenting the data in a structured format. + +```json +[ + { + "command_name": "Plan.append_task", + "args": { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "View and print the html content of the page before scrap data to understand the structure.", + "assignee": "David" + } + }, + { + "command_name": "Plan.append_task", + "args": { + "task_id": "2", + "dependent_task_ids": ["1"], + "instruction": "Parse the html content to scrape the restaurant names and print it.", + "assignee": "David" + } + } +] +``` + +## action 2 +Explanation: To scrap data from the website, I will first view and print the html content of the page. + +```json +[ + { + "command_name": "DataAnalyst.write_and_exec_code", + "args": {} + } +] +``` + +## action 3 +Explanation: Since I has successfully viewed the html content in the context, I will first finish the current task and then write code to parse the html content and extract the restaurant names. + +```json +[ + { + "command_name": "Plan.finish_current_task", + "args": {} + }, + { + "command_name": "DataAnalyst.write_and_exec_code", + "args": {} + } +] +``` +... +""" diff --git a/metagpt/strategy/planner.py b/metagpt/strategy/planner.py index fbf7848372..1b4b1cca9e 100644 --- a/metagpt/strategy/planner.py +++ b/metagpt/strategy/planner.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from typing import List from pydantic import BaseModel, Field @@ -40,8 +41,16 @@ ## Current Task {current_task} +## Finished Section of Current Task +### code +```python +{current_task_code} +``` +### execution result +{current_task_result} + ## Task Guidance -Write complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc. +Write code for the incomplete sections of 'Current Task'. And avoid duplicating code from 'Finished Tasks' and 'Finished Section of Current Task', such as repeated import of packages, reading data, etc. Specifically, {guidance} """ @@ -119,7 +128,7 @@ async def ask_review( If human confirms the task result, then we deem the task completed, regardless of whether the code run succeeds; if auto mode, then the code run has to succeed for the task to be considered completed. """ - auto_run = auto_run or self.auto_run + auto_run = auto_run if auto_run is not None else self.auto_run if not auto_run: context = self.get_useful_memories() review, confirmed = await AskReview().run( @@ -157,8 +166,10 @@ def get_useful_memories(self, task_exclude_field=None) -> list[Message]: return context_msg + self.working_memory.get() - def get_plan_status(self) -> str: + def get_plan_status(self, exclude: List[str] = None) -> str: # prepare components of a plan status + exclude = exclude or [] + exclude_prompt = "omit here" finished_tasks = self.plan.get_finished_tasks() code_written = [remove_comments(task.code) for task in finished_tasks] code_written = "\n\n".join(code_written) @@ -170,9 +181,11 @@ def get_plan_status(self) -> str: # combine components in a prompt prompt = PLAN_STATUS.format( - code_written=code_written, - task_results=task_results, + code_written=code_written if "code" not in exclude else exclude_prompt, + task_results=task_results if "task_result" not in exclude else exclude_prompt, current_task=self.current_task.instruction, + current_task_code=self.current_task.code if "code" not in exclude else exclude_prompt, + current_task_result=self.current_task.result if "task_result" not in exclude else exclude_prompt, guidance=guidance, ) diff --git a/metagpt/strategy/solver.py b/metagpt/strategy/solver.py index e532f736bd..4aedb42aa8 100644 --- a/metagpt/strategy/solver.py +++ b/metagpt/strategy/solver.py @@ -39,7 +39,7 @@ async def solve(self): self.graph.topological_sort() for key in self.graph.execution_order: op = self.graph.nodes[key] - await op.fill(self.context, self.llm, mode="root") + await op.fill(req=self.context, llm=self.llm, mode="root") class TOTSolver(BaseSolver): diff --git a/metagpt/strategy/task_type.py b/metagpt/strategy/task_type.py index d21705c162..f4c2a09c81 100644 --- a/metagpt/strategy/task_type.py +++ b/metagpt/strategy/task_type.py @@ -9,6 +9,7 @@ IMAGE2WEBPAGE_PROMPT, MODEL_EVALUATE_PROMPT, MODEL_TRAIN_PROMPT, + WEB_SCRAPING_PROMPT, ) @@ -62,11 +63,16 @@ class TaskType(Enum): WEBSCRAPING = TaskTypeDef( name="web scraping", desc="For scraping data from web pages.", + guidance=WEB_SCRAPING_PROMPT, ) EMAIL_LOGIN = TaskTypeDef( name="email login", desc="For logging to an email.", ) + DEVELOP_SOFTWARE = TaskTypeDef( + name="develop software", + desc="SOP related to develop software such as Writes a PRD, Writes a design, Writes a project plan and Writes code to implement designed features according to the project plan", + ) @property def type_name(self): diff --git a/metagpt/strategy/thinking_command.py b/metagpt/strategy/thinking_command.py new file mode 100644 index 0000000000..f08afa448f --- /dev/null +++ b/metagpt/strategy/thinking_command.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from enum import Enum + +from pydantic import BaseModel + +from metagpt.environment.mgx.mgx_env import MGXEnv +from metagpt.memory import Memory +from metagpt.roles import Role +from metagpt.schema import Message + + +class CommandDef(BaseModel): + name: str + signature: str = "" + desc: str = "" + + +class Command(Enum): + # commands for planning + APPEND_TASK = CommandDef( + name="append_task", + signature="append_task(task_id: str, dependent_task_ids: list[str], instruction: str, assignee: str)", + desc="Append a new task with task_id (number) to the end of existing task sequences. If dependent_task_ids is not empty, the task will depend on the tasks with the ids in the list.", + ) + RESET_TASK = CommandDef( + name="reset_task", + signature="reset_task(task_id: str)", + desc="Reset a task based on task_id, i.e. set Task.is_finished=False and request redo. This also resets all tasks depending on it.", + ) + REPLACE_TASK = CommandDef( + name="replace_task", + signature="replace_task(task_id: str, new_dependent_task_ids: list[str], new_instruction: str, new_assignee: str)", + desc="Replace an existing task (can be current task) based on task_id, and reset all tasks depending on it.", + ) + FINISH_CURRENT_TASK = CommandDef( + name="finish_current_task", + signature="finish_current_task()", + desc="Finishes current task, set Task.is_finished=True, set current task to next task", + ) + + # commands for env interaction + PUBLISH_MESSAGE = CommandDef( + name="publish_message", + signature="publish_message(content: str, send_to: str)", + desc="Publish a message to a team member, use member name to fill send_to args. You may copy the full original content or add additional information from upstream. This will make team members start their work. DONT omit any necessary info such as path, link, environment, programming language, framework, requirement, constraint from original content to team members because you are their sole info source.", + ) + REPLY_TO_HUMAN = CommandDef( + name="reply_to_human", + signature="reply_to_human(content: str)", + desc="Reply to human user with the content provided. Use this when you have a clear answer or solution to the user's question.", + ) + ASK_HUMAN = CommandDef( + name="ask_human", + signature="ask_human(question: str)", + desc="Use this when you fail the current task or if you are unsure of the situation encountered. Your response should contain a brief summary of your situation, ended with a clear and concise question.", + ) + + # common commands + PASS = CommandDef( + name="pass", + signature="pass", + desc="Pass and do nothing, if you don't think the plan needs to be updated nor a message to be published or forwarded. The reasons can be the latest message is unnecessary or obsolete, or you want to wait for more information before making a move.", + ) + + @property + def cmd_name(self): + return self.value.name + + +def prepare_command_prompt(commands: list[Command]) -> str: + command_prompt = "" + for i, command in enumerate(commands): + command_prompt += f"{i+1}. {command.value.signature}:\n{command.value.desc}\n\n" + return command_prompt + + +async def run_env_command(role: Role, cmd: list[dict], role_memory: Memory = None): + if not isinstance(role.rc.env, MGXEnv): + return + if cmd["command_name"] == Command.PUBLISH_MESSAGE.cmd_name: + role.publish_message(Message(**cmd["args"])) + if cmd["command_name"] == Command.ASK_HUMAN.cmd_name: + # TODO: Operation on role memory should not appear here, consider moving it into role + role.rc.working_memory.add(Message(content=cmd["args"]["question"], role="assistant")) + human_rsp = await role.rc.env.ask_human(sent_from=role, **cmd["args"]) + role.rc.working_memory.add(Message(content=human_rsp, role="user")) + elif cmd["command_name"] == Command.REPLY_TO_HUMAN.cmd_name: + # TODO: consider if the message should go into memory + await role.rc.env.reply_to_human(sent_from=role, **cmd["args"]) + + +def run_plan_command(role: Role, cmd: list[dict]): + if cmd["command_name"] == Command.APPEND_TASK.cmd_name: + role.planner.plan.append_task(**cmd["args"]) + elif cmd["command_name"] == Command.RESET_TASK.cmd_name: + role.planner.plan.reset_task(**cmd["args"]) + elif cmd["command_name"] == Command.REPLACE_TASK.cmd_name: + role.planner.plan.replace_task(**cmd["args"]) + elif cmd["command_name"] == Command.FINISH_CURRENT_TASK.cmd_name: + if role.planner.plan.is_plan_finished(): + return + if role.task_result: + role.planner.plan.current_task.update_task_result(task_result=role.task_result) + role.planner.plan.finish_current_task() + role.rc.working_memory.clear() + + +async def run_commands(role: Role, cmds: list[dict], role_memory: Memory = None): + print(*cmds, sep="\n") + for cmd in cmds: + await run_env_command(role, cmd, role_memory) + run_plan_command(role, cmd) + + if role.planner.plan.is_plan_finished(): + role._set_state(-1) diff --git a/metagpt/strategy/tot.py b/metagpt/strategy/tot.py index 88c2ac9ff7..17ce632118 100644 --- a/metagpt/strategy/tot.py +++ b/metagpt/strategy/tot.py @@ -62,7 +62,7 @@ async def generate_thoughts(self, current_state="", current_node=None) -> List[T current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample} ) rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT) - thoughts = CodeParser.parse_code(block="", text=rsp) + thoughts = CodeParser.parse_code(text=rsp) thoughts = eval(thoughts) # fixme 避免不跟随,生成过多nodes # valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample] diff --git a/metagpt/team.py b/metagpt/team.py index 79c4c36aa1..c3498b96b6 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -20,7 +20,7 @@ from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Role -from metagpt.schema import Message +from metagpt.schema import UserMessage from metagpt.utils.common import ( NoMoneyException, read_json_file, @@ -102,7 +102,7 @@ def run_project(self, idea, send_to: str = ""): # Human requirement. self.env.publish_message( - Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL), + UserMessage(content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL), peekable=False, ) diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index 35fa046589..2027dbb1d1 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -6,33 +6,18 @@ @File : __init__.py """ -from enum import Enum from metagpt.tools import libs # this registers all tools from metagpt.tools.tool_registry import TOOL_REGISTRY - -_ = libs, TOOL_REGISTRY # Avoid pre-commit error - - -class SearchEngineType(Enum): - SERPAPI_GOOGLE = "serpapi" - SERPER_GOOGLE = "serper" - DIRECT_GOOGLE = "google" - DUCK_DUCK_GO = "ddg" - CUSTOM_ENGINE = "custom" - BING = "bing" +from metagpt.configs.search_config import SearchEngineType +from metagpt.configs.browser_config import WebBrowserEngineType -class WebBrowserEngineType(Enum): - PLAYWRIGHT = "playwright" - SELENIUM = "selenium" - CUSTOM = "custom" - - @classmethod - def __missing__(cls, key): - """Default type conversion""" - return cls.CUSTOM +_ = libs, TOOL_REGISTRY # Avoid pre-commit error class SearchInterface: async def asearch(self, *args, **kwargs): ... + + +__all__ = ["SearchEngineType", "WebBrowserEngineType", "TOOL_REGISTRY"] diff --git a/metagpt/tools/libs/__init__.py b/metagpt/tools/libs/__init__.py index 91596fd3d8..6f8f754e77 100644 --- a/metagpt/tools/libs/__init__.py +++ b/metagpt/tools/libs/__init__.py @@ -10,8 +10,14 @@ sd_engine, gpt_v_generator, web_scraping, - email_login, + # email_login, + terminal, + editor, + browser, + deployer, + git, ) +from metagpt.tools.libs.env import get_env, set_get_env_entry, default_get_env, get_env_description, get_env_default _ = ( data_preprocess, @@ -19,5 +25,15 @@ sd_engine, gpt_v_generator, web_scraping, - email_login, + # email_login, + terminal, + editor, + browser, + deployer, + git, + get_env, + get_env_default, + get_env_description, + set_get_env_entry, + default_get_env, ) # Avoid pre-commit error diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py new file mode 100644 index 0000000000..f5aff553eb --- /dev/null +++ b/metagpt/tools/libs/browser.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import time +from typing import Literal, Optional + +from playwright.async_api import Browser as Browser_ +from playwright.async_api import ( + BrowserContext, + Frame, + Page, + Playwright, + Request, + async_playwright, +) +from pydantic import BaseModel, ConfigDict, Field + +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.a11y_tree import ( + click_element, + get_accessibility_tree, + get_backend_node_id, + hover_element, + key_press, + parse_accessibility_tree, + scroll_page, + type_text, +) +from metagpt.utils.proxy_env import get_proxy_from_env +from metagpt.utils.report import BrowserReporter + + +@register_tool( + tags=["web", "browse"], + include_functions=[ + "click", + "close_tab", + "go_back", + "go_forward", + "goto", + "hover", + "press", + "scroll", + "tab_focus", + "type", + ], +) +class Browser(BaseModel): + """A tool for browsing the web. Don't initialize a new instance of this class if one already exists. + + Note: If you plan to use the browser to assist you in completing tasks, then using the browser should be a standalone + task, executing actions each time based on the content seen on the webpage before proceeding to the next step. + + ## Example + Issue: The details of the latest issue in the geekan/MetaGPT repository. + Plan: Use a browser to view the details of the latest issue in the geekan/MetaGPT repository. + Solution: + Let's first open the issue page of the MetaGPT repository with the `Browser.goto` command + + >>> await browser.goto("https://github.com/geekan/MetaGPT/issues") + + From the output webpage, we've identified that the latest issue can be accessed by clicking on the element with ID "1141". + + >>> await browser.click(1141) + + Finally, we have found the webpage for the latest issue, we can close the tab and finish current task. + + >>> await browser.close_tab() + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + playwright: Optional[Playwright] = Field(default=None, exclude=True) + browser_instance: Optional[Browser_] = Field(default=None, exclude=True) + browser_ctx: Optional[BrowserContext] = Field(default=None, exclude=True) + page: Optional[Page] = Field(default=None, exclude=True) + accessibility_tree: list = Field(default_factory=list) + headless: bool = Field(default=True) + proxy: Optional[dict] = Field(default_factory=get_proxy_from_env) + is_empty_page: bool = Field(default=True) + reporter: BrowserReporter = Field(default_factory=BrowserReporter) + + async def start(self) -> None: + """Starts Playwright and launches a browser""" + if self.playwright is None: + self.playwright = playwright = await async_playwright().start() + browser = self.browser_instance = await playwright.chromium.launch(headless=self.headless, proxy=self.proxy) + browser_ctx = self.browser_ctx = await browser.new_context() + self.page = await browser_ctx.new_page() + + async def stop(self): + if self.playwright: + playwright = self.playwright + self.playwright = None + self.browser_instance = None + self.browser_ctx = None + await playwright.stop() + + async def click(self, element_id: int): + """clicks on an element with a specific id on the webpage.""" + await click_element(self.page, get_backend_node_id(element_id, self.accessibility_tree)) + return await self._wait_page() + + async def type(self, element_id: int, content: str, press_enter_after: bool = False): + """Use this to type the content into the field with id.""" + if press_enter_after: + content += "\n" + await click_element(self.page, get_backend_node_id(element_id, self.accessibility_tree)) + await type_text(self.page, content) + return await self._wait_page() + + async def hover(self, element_id: int): + """Hover over an element with id.""" + await hover_element(self.page, get_backend_node_id(element_id, self.accessibility_tree)) + return await self._wait_page() + + async def press(self, key_comb: str): + """Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).""" + await key_press(self.page, key_comb) + return await self._wait_page() + + async def scroll(self, direction: Literal["down", "up"]): + """Scroll the page up or down.""" + await scroll_page(self.page, direction) + return await self._wait_page() + + async def goto(self, url: str, timeout: float = 90000): + """Navigate to a specific URL.""" + if self.page is None: + await self.start() + async with self.reporter as reporter: + await reporter.async_report(url, "url") + await self.page.goto(url, timeout=timeout) + self.is_empty_page = False + return await self._wait_page() + + async def go_back(self): + """Navigate to the previously viewed page.""" + await self.page.go_back() + return await self._wait_page() + + async def go_forward(self): + """Navigate to the next page (if a previous 'go_back' action was performed).""" + await self.page.go_forward() + return await self._wait_page() + + async def tab_focus(self, page_number: int): + """Open a new, empty browser tab.""" + page = self.browser_ctx.pages[page_number] + await page.bring_to_front() + return await self._wait_page() + + async def close_tab(self): + """Close the currently active tab.""" + await self.page.close() + if len(self.browser_ctx.pages) > 0: + self.page = self.browser_ctx.pages[-1] + else: + self.page = await self.browser_ctx.new_page() + self.is_empty_page = True + return await self._wait_page() + + async def _wait_page(self): + page = self.page + await self._wait_until_page_idle(page) + self.accessibility_tree = await get_accessibility_tree(page) + await self.reporter.async_report(page, "page") + return f"SUCCESS, URL: {page.url} have been loaded." + + def _register_page_event(self, page: Page): + page.last_busy_time = time.time() + page.requests = set() + page.on("domcontentloaded", self._update_page_last_busy_time) + page.on("load", self._update_page_last_busy_time) + page.on("request", self._on_page_request) + page.on("requestfailed", self._on_page_requestfinished) + page.on("requestfinished", self._on_page_requestfinished) + page.on("frameattached", self._on_frame_change) + page.on("framenavigated", self._on_frame_change) + + async def _wait_until_page_idle(self, page) -> None: + if not hasattr(page, "last_busy_time"): + self._register_page_event(page) + else: + page.last_busy_time = time.time() + while time.time() - page.last_busy_time < 0.5: + await page.wait_for_timeout(100) + + async def _update_page_last_busy_time(self, page: Page): + page.last_busy_time = time.time() + + async def _on_page_request(self, request: Request): + page = request.frame.page + page.requests.add(request) + await self._update_page_last_busy_time(page) + + async def _on_page_requestfinished(self, request: Request): + request.frame.page.requests.discard(request) + + async def _on_frame_change(self, frame: Frame): + await self._update_page_last_busy_time(frame.page) + + async def view(self): + observation = parse_accessibility_tree(self.accessibility_tree) + return f"Current Browser Viewer\n URL: {self.page.url}\nOBSERVATION:\n{observation[0]}\n" + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, *args, **kwargs): + await self.stop() diff --git a/metagpt/tools/libs/cr.py b/metagpt/tools/libs/cr.py new file mode 100644 index 0000000000..0a53dd1945 --- /dev/null +++ b/metagpt/tools/libs/cr.py @@ -0,0 +1,102 @@ +import difflib +import json +from pathlib import Path +from typing import Optional + +import aiofiles +from bs4 import BeautifulSoup +from unidiff import PatchSet + +import metagpt.ext.cr +from metagpt.ext.cr.actions.code_review import CodeReview as CodeReview_ +from metagpt.ext.cr.actions.modify_code import ModifyCode +from metagpt.ext.cr.utils.schema import Point +from metagpt.tools.libs.browser import Browser +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.report import EditorReporter + + +@register_tool(tags=["codereview"], include_functions=["review", "fix"]) +class CodeReview: + """Review and fix the patch content from the pull request URL or a file.""" + + async def review( + self, + patch_path: str, + output_file: str, + point_file: Optional[str] = None, + ) -> str: + """Review a PR and save code review comments. + + Notes: + If the user does not specify an output path, saved it using a relative path in the current working directory. + + Args: + patch_path: The local path of the patch file or the URL of the pull request. + output_file: Output file path where code review comments will be saved. + point_file: File path for specifying code review points. If not specified, this parameter does not need to be passed. + + Examples: + + >>> cr = CodeReview() + >>> await cr.review(patch_path="https://github.com/geekan/MetaGPT/pull/136", output_file="cr/MetaGPT_136.json") + >>> await cr.review(patch_path="/data/uploads/dev-master.diff", output_file="cr/dev-master.json") + >>> await cr.review(patch_path="/data/uploads/main.py", output_file="cr/main.json") + """ + patch = await self._get_patch_content(patch_path) + point_file = point_file if point_file else Path(metagpt.ext.cr.__file__).parent / "points.json" + await EditorReporter().async_report(str(point_file), "path") + async with aiofiles.open(point_file, "rb") as f: + cr_point_content = await f.read() + cr_points = [Point(**i) for i in json.loads(cr_point_content)] + try: + comments = await CodeReview_().run(patch, cr_points, output_file) + except ValueError as e: + return str(e) + return f"The number of defects: {len(comments)}, the comments are stored in {output_file}, and the checkpoints are stored in {str(point_file)}" + + async def fix( + self, + patch_path: str, + cr_file: str, + output_dir: str, + ) -> str: + """Fix the patch content based on code review comments. + + Args: + patch_path: The local path of the patch file or the url of the pull request. + cr_file: File path where code review comments are stored. + output_dir: File path where code review comments are stored. + """ + patch = await self._get_patch_content(patch_path) + async with aiofiles.open(cr_file, "r", encoding="utf-8") as f: + comments = json.loads(await f.read()) + await ModifyCode(pr="").run(patch, comments, output_dir) + return f"The fixed patch files store in {output_dir}" + + async def _get_patch_content(self, patch_path): + if patch_path.startswith(("https://", "http://")): + # async with aiohttp.ClientSession(trust_env=True) as client: + # async with client.get(f"{patch_path}.diff", ) as resp: + # patch_file_content = await resp.text() + async with Browser() as browser: + await browser.goto(f"{patch_path}.diff") + patch_file_content = await browser.page.content() + if patch_file_content.startswith(""): + soup = BeautifulSoup(patch_file_content, "html.parser") + pre = soup.find("pre") + if pre: + patch_file_content = pre.text + else: + async with aiofiles.open(patch_path, encoding="utf-8") as f: + patch_file_content = await f.read() + await EditorReporter().async_report(patch_path) + if not patch_path.endswith((".diff", ".patch")): + name = Path(patch_path).name + patch_file_content = "".join( + difflib.unified_diff([], patch_file_content.splitlines(keepends=True), "/dev/null", f"b/{name}"), + ) + patch_file_content = f"diff --git a/{name} b/{name}\n{patch_file_content}" + + patch: PatchSet = PatchSet(patch_file_content) + return patch diff --git a/metagpt/tools/libs/deployer.py b/metagpt/tools/libs/deployer.py new file mode 100644 index 0000000000..2b4d996d17 --- /dev/null +++ b/metagpt/tools/libs/deployer.py @@ -0,0 +1,26 @@ +from metagpt.tools.tool_registry import register_tool + + +# An un-implemented tool reserved for deploying a local service to public +@register_tool( + include_functions=[ + "deploy_to_public", + ] +) +class Deployer: + """Deploy a local service to public. Used only for final deployment, you should NOT use it for development and testing.""" + + async def static_server(self, src_path: str) -> str: + """This function will be implemented in the remote service.""" + return "http://127.0.0.1:8000/index.html" + + async def deploy_to_public(self, dist_dir: str): + """ + Deploy a web project to public. + Args: + dist_dir (str): The dist directory of the web project after run build. + >>> + deployer = Deployer("2048_game/dist") + """ + url = await self.static_server(dist_dir) + return "The Project is deployed to: " + url + "\n Deployment successed!" diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py new file mode 100644 index 0000000000..223216d4f3 --- /dev/null +++ b/metagpt/tools/libs/editor.py @@ -0,0 +1,1135 @@ +""" +This file is borrowed from OpenDevin +You can find the original repository here: +https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py +""" +import os +import re +import shutil +import tempfile +from pathlib import Path +from typing import List, Optional, Union + +import tiktoken +from pydantic import BaseModel, ConfigDict + +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, IndexRepo +from metagpt.tools.libs.linter import Linter +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import awrite +from metagpt.utils.file import File +from metagpt.utils.report import EditorReporter + +# This is also used in unit tests! +LINTER_ERROR_MSG = "[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n" + + +INDENTATION_INFO = """ +The previous line is: +"{pre_line}" +The indentation has {pre_line_indent} spaces. + +The error line is: +"{insert_line}" +The indentation has {insert_line_indent} spaces. + +Please check the indentation of the code to ensure that it is not causing any errors. +Try using indentation with either {sub_4_space} or {add_4_space} spaces. +""" + +ERROR_GUIDANCE = """ +{linter_error_msg} + +[This is how your edit would have looked if applied] +------------------------------------------------- +{window_after_applied} +------------------------------------------------- + +[This is the original code before your edit] +------------------------------------------------- +{window_before_applied} +------------------------------------------------- + +Your changes have NOT been applied. Please fix your edit command and try again +{guidance_message} + +""" + +LINE_NUMBER_AND_CONTENT_MISMATCH = """Error: The `{position}_replaced_line_number` does not match the `{position}_replaced_line_content`. Please correct the parameters. +The `{position}_replaced_line_number` is {line_number} and the corresponding content is "{true_content}". +But the `{position}_replaced_line_content ` is "{fake_content}". +The content around the specified line is: +{context} +Pay attention to the new content. Ensure that it aligns with the new parameters. +""" +SUCCESS_EDIT_INFO = """ +[File: {file_name} ({n_total_lines} lines total after edit)] +{window_after_applied} +[File updated (edited at line {line_number})]. +""" +# Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary. + + +class FileBlock(BaseModel): + """A block of content in a file""" + + file_path: str + block_content: str + + +class LineNumberError(Exception): + pass + + +@register_tool( + include_functions=[ + "write", + "read", + "open_file", + "goto_line", + "scroll_down", + "scroll_up", + "create_file", + "edit_file_by_replace", + "insert_content_at_line", + "append_file", + "search_dir", + "search_file", + "find_file", + "similarity_search", + ] +) +class Editor(BaseModel): + """ + A tool for reading, understanding, writing, and editing files. + Support local file including text-based files (txt, md, json, py, html, js, css, etc.), pdf, docx, excluding images, csv, excel, or online links + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + resource: EditorReporter = EditorReporter() + current_file: Optional[Path] = None + current_line: int = 1 + window: int = 200 + enable_auto_lint: bool = False + working_dir: Path = DEFAULT_WORKSPACE_ROOT + + def write(self, path: str, content: str): + """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" + + path = self._try_fix_path(path) + + if "\n" not in content and "\\n" in content: + # A very raw rule to correct the content: If 'content' lacks actual newlines ('\n') but includes '\\n', consider + # replacing them with '\n' to potentially correct mistaken representations of newline characters. + content = content.replace("\\n", "\n") + directory = os.path.dirname(path) + if directory and not os.path.exists(directory): + os.makedirs(directory) + with open(path, "w", encoding="utf-8") as f: + f.write(content) + # self.resource.report(path, "path") + return f"The writing/coding the of the file {os.path.basename(path)}' is now completed. The file '{os.path.basename(path)}' has been successfully created." + + async def read(self, path: str) -> FileBlock: + """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + + path = self._try_fix_path(path) + + error = FileBlock( + file_path=str(path), + block_content="The file is too large to read. Use `Editor.similarity_search` to read the file instead.", + ) + path = Path(path) + if path.stat().st_size > 5 * DEFAULT_MIN_TOKEN_COUNT: + return error + content = await File.read_text_file(path) + if not content: + return FileBlock(file_path=str(path), block_content="") + if self.is_large_file(content=content): + return error + self.resource.report(str(path), "path") + + lines = content.splitlines(keepends=True) + lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] + result = FileBlock( + file_path=str(path), + block_content="".join(lines_with_num), + ) + return result + + @staticmethod + def _is_valid_filename(file_name: str) -> bool: + if not file_name or not file_name.strip(): + return False + invalid_chars = '<>:"/\\|?*' + if os.name == "nt": # Windows + invalid_chars = '<>:"/\\|?*' + elif os.name == "posix": # Unix-like systems + invalid_chars = "\0" + + for char in invalid_chars: + if char in file_name: + return False + return True + + @staticmethod + def _is_valid_path(path: Path) -> bool: + try: + return path.exists() + except PermissionError: + return False + + @staticmethod + def _create_paths(file_path: Path) -> bool: + try: + if file_path.parent: + file_path.parent.mkdir(parents=True, exist_ok=True) + return True + except PermissionError: + return False + + def _check_current_file(self, file_path: Optional[Path] = None) -> bool: + if file_path is None: + file_path = self.current_file + if not file_path or not file_path.is_file(): + raise ValueError("No file open. Use the open_file function first.") + return True + + @staticmethod + def _clamp(value, min_value, max_value): + return max(min_value, min(value, max_value)) + + def _lint_file(self, file_path: Path) -> tuple[Optional[str], Optional[int]]: + """Lint the file at the given path and return a tuple with a boolean indicating if there are errors, + and the line number of the first error, if any. + + Returns: + tuple[str | None, int | None]: (lint_error, first_error_line_number) + """ + + linter = Linter(root=self.working_dir) + lint_error = linter.lint(str(file_path)) + if not lint_error: + # Linting successful. No issues found. + return None, None + return "ERRORS:\n" + lint_error.text, lint_error.lines[0] + + def _print_window(self, file_path: Path, targeted_line: int, window: int): + self._check_current_file(file_path) + with file_path.open() as file: + content = file.read() + + # Ensure the content ends with a newline character + if not content.endswith("\n"): + content += "\n" + + lines = content.splitlines(True) # Keep all line ending characters + total_lines = len(lines) + + # cover edge cases + self.current_line = self._clamp(targeted_line, 1, total_lines) + half_window = max(1, window // 2) + + # Ensure at least one line above and below the targeted line + start = max(1, self.current_line - half_window) + end = min(total_lines, self.current_line + half_window) + + # Adjust start and end to ensure at least one line above and below + if start == 1: + end = min(total_lines, start + window - 1) + if end == total_lines: + start = max(1, end - window + 1) + + output = "" + + # only display this when there's at least one line above + if start > 1: + output += f"({start - 1} more lines above)\n" + else: + output += "(this is the beginning of the file)\n" + for i in range(start, end + 1): + _new_line = f"{i:03d}|{lines[i - 1]}" + if not _new_line.endswith("\n"): + _new_line += "\n" + output += _new_line + if end < total_lines: + output += f"({total_lines - end} more lines below)\n" + else: + output += "(this is the end of the file)\n" + output = output.rstrip() + + return output + + @staticmethod + def _cur_file_header(current_file: Path, total_lines: int) -> str: + if not current_file: + return "" + return f"[File: {current_file.resolve()} ({total_lines} lines total)]\n" + + def _set_workdir(self, path: str) -> None: + """ + Sets the working directory to the given path. eg: repo directory. + You MUST to set it up before open the file. + + Args: + path: str: The path to set as the working directory. + """ + self.working_dir = Path(path) + + def open_file( + self, path: Union[Path, str], line_number: Optional[int] = 1, context_lines: Optional[int] = None + ) -> str: + """Opens the file at the given path in the editor. If line_number is provided, the window will be moved to include that line. + It only shows the first 100 lines by default! Max `context_lines` supported is 2000, use `scroll up/down` + to view the file if you want to see more. + + Args: + path: str: The path to the file to open, preferred absolute path. + line_number: int | None = 1: The line number to move to. Defaults to 1. + context_lines: int | None = 100: Only shows this number of lines in the context window (usually from line 1), with line_number as the center (if possible). Defaults to 100. + """ + if context_lines is None: + context_lines = self.window + + path = self._try_fix_path(path) + + if not path.is_file(): + raise FileNotFoundError(f"File {path} not found") + + self.current_file = path + with path.open() as file: + total_lines = max(1, sum(1 for _ in file)) + + if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines: + raise ValueError(f"Line number must be between 1 and {total_lines}") + self.current_line = line_number + + # Override WINDOW with context_lines + if context_lines is None or context_lines < 1: + context_lines = self.window + + output = self._cur_file_header(path, total_lines) + output += self._print_window(path, self.current_line, self._clamp(context_lines, 1, 2000)) + self.resource.report(path, "path") + return output + + def goto_line(self, line_number: int) -> str: + """Moves the window to show the specified line number. + + Args: + line_number: int: The line number to move to. + """ + self._check_current_file() + + with self.current_file.open() as file: + total_lines = max(1, sum(1 for _ in file)) + if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines: + raise ValueError(f"Line number must be between 1 and {total_lines}") + + self.current_line = self._clamp(line_number, 1, total_lines) + + output = self._cur_file_header(self.current_file, total_lines) + output += self._print_window(self.current_file, self.current_line, self.window) + return output + + def scroll_down(self) -> str: + """Moves the window down by 100 lines.""" + self._check_current_file() + + with self.current_file.open() as file: + total_lines = max(1, sum(1 for _ in file)) + self.current_line = self._clamp(self.current_line + self.window, 1, total_lines) + output = self._cur_file_header(self.current_file, total_lines) + output += self._print_window(self.current_file, self.current_line, self.window) + return output + + def scroll_up(self) -> str: + """Moves the window up by 100 lines.""" + self._check_current_file() + + with self.current_file.open() as file: + total_lines = max(1, sum(1 for _ in file)) + self.current_line = self._clamp(self.current_line - self.window, 1, total_lines) + output = self._cur_file_header(self.current_file, total_lines) + output += self._print_window(self.current_file, self.current_line, self.window) + return output + + async def create_file(self, filename: str) -> str: + """Creates and opens a new file with the given name. + + Args: + filename: str: The name of the file to create. If the parent directory does not exist, it will be created. + """ + filename = self._try_fix_path(filename) + + if filename.exists(): + raise FileExistsError(f"File '{filename}' already exists.") + await awrite(filename, "\n") + + self.open_file(filename) + return f"[File {filename} created.]" + + @staticmethod + def _append_impl(lines, content): + """Internal method to handle appending to a file. + + Args: + lines: list[str]: The lines in the original file. + content: str: The content to append to the file. + + Returns: + content: str: The new content of the file. + n_added_lines: int: The number of lines added to the file. + """ + content_lines = content.splitlines(keepends=True) + n_added_lines = len(content_lines) + if lines and not (len(lines) == 1 and lines[0].strip() == ""): + # file is not empty + if not lines[-1].endswith("\n"): + lines[-1] += "\n" + new_lines = lines + content_lines + content = "".join(new_lines) + else: + # file is empty + content = "".join(content_lines) + + return content, n_added_lines + + @staticmethod + def _insert_impl(lines, start, content): + """Internal method to handle inserting to a file. + + Args: + lines: list[str]: The lines in the original file. + start: int: The start line number for inserting. + content: str: The content to insert to the file. + + Returns: + content: str: The new content of the file. + n_added_lines: int: The number of lines added to the file. + + Raises: + LineNumberError: If the start line number is invalid. + """ + inserted_lines = [content + "\n" if not content.endswith("\n") else content] + if len(lines) == 0: + new_lines = inserted_lines + elif start is not None: + if len(lines) == 1 and lines[0].strip() == "": + # if the file with only 1 line and that line is empty + lines = [] + + if len(lines) == 0: + new_lines = inserted_lines + else: + new_lines = lines[: start - 1] + inserted_lines + lines[start - 1 :] + else: + raise LineNumberError( + f"Invalid line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive)." + ) + + content = "".join(new_lines) + n_added_lines = len(inserted_lines) + return content, n_added_lines + + @staticmethod + def _edit_impl(lines, start, end, content): + """Internal method to handle editing a file. + + REQUIRES (should be checked by caller): + start <= end + start and end are between 1 and len(lines) (inclusive) + content ends with a newline + + Args: + lines: list[str]: The lines in the original file. + start: int: The start line number for editing. + end: int: The end line number for editing. + content: str: The content to replace the lines with. + + Returns: + content: str: The new content of the file. + n_added_lines: int: The number of lines added to the file. + """ + # Handle cases where start or end are None + if start is None: + start = 1 # Default to the beginning + if end is None: + end = len(lines) # Default to the end + # Check arguments + if not (1 <= start <= len(lines)): + raise LineNumberError( + f"Invalid start line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive)." + ) + if not (1 <= end <= len(lines)): + raise LineNumberError( + f"Invalid end line number: {end}. Line numbers must be between 1 and {len(lines)} (inclusive)." + ) + if start > end: + raise LineNumberError(f"Invalid line range: {start}-{end}. Start must be less than or equal to end.") + + # Split content into lines and ensure it ends with a newline + if not content.endswith("\n"): + content += "\n" + content_lines = content.splitlines(True) + + # Calculate the number of lines to be added + n_added_lines = len(content_lines) + + # Remove the specified range of lines and insert the new content + new_lines = lines[: start - 1] + content_lines + lines[end:] + + # Handle the case where the original lines are empty + if len(lines) == 0: + new_lines = content_lines + + # Join the lines to create the new content + content = "".join(new_lines) + return content, n_added_lines + + def _get_indentation_info(self, content, first_line): + """ + The indentation of the first insert line and the previous line, along with guidance for the next attempt. + """ + content_lines = content.split("\n") + pre_line = content_lines[first_line - 2] if first_line - 2 >= 0 else "" + pre_line_indent = len(pre_line) - len(pre_line.lstrip()) + insert_line = content_lines[first_line - 1] + insert_line_indent = len(insert_line) - len(insert_line.lstrip()) + ret_str = INDENTATION_INFO.format( + pre_line=pre_line, + pre_line_indent=pre_line_indent, + insert_line=insert_line, + insert_line_indent=insert_line_indent, + sub_4_space=max(insert_line_indent - 4, 0), + add_4_space=insert_line_indent + 4, + ) + return ret_str + + def _edit_file_impl( + self, + file_name: Path, + start: Optional[int] = None, + end: Optional[int] = None, + content: str = "", + is_insert: bool = False, + is_append: bool = False, + ) -> str: + """Internal method to handle common logic for edit_/append_file methods. + + Args: + file_name: Path: The name of the file to edit or append to. + start: int | None = None: The start line number for editing. Ignored if is_append is True. + end: int | None = None: The end line number for editing. Ignored if is_append is True. + content: str: The content to replace the lines with or to append. + is_insert: bool = False: Whether to insert content at the given line number instead of editing. + is_append: bool = False: Whether to append content to the file instead of editing. + """ + + ERROR_MSG = f"[Error editing file {file_name}. Please confirm the file is correct.]" + ERROR_MSG_SUFFIX = ( + "Your changes have NOT been applied. Please fix your edit command and try again.\n" + "You either need to 1) Open the correct file and try again or 2) Specify the correct line number arguments.\n" + "DO NOT re-run the same failed edit command. Running it again will lead to the same error." + ) + + if not self._is_valid_filename(file_name.name): + raise FileNotFoundError("Invalid file name.") + + if not self._is_valid_path(file_name): + raise FileNotFoundError("Invalid path or file name.") + + if not self._create_paths(file_name): + raise PermissionError("Could not access or create directories.") + + if not file_name.is_file(): + raise FileNotFoundError(f"File {file_name} not found.") + + if is_insert and is_append: + raise ValueError("Cannot insert and append at the same time.") + + # Use a temporary file to write changes + content = str(content or "") + temp_file_path = "" + src_abs_path = file_name.resolve() + first_error_line = None + # The file to store previous content and will be removed automatically. + temp_backup_file = tempfile.NamedTemporaryFile("w", delete=True) + + try: + # lint the original file + # enable_auto_lint = os.getenv("ENABLE_AUTO_LINT", "false").lower() == "true" + if self.enable_auto_lint: + original_lint_error, _ = self._lint_file(file_name) + + # Create a temporary file + with tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + temp_file_path = temp_file.name + + # Read the original file and check if empty and for a trailing newline + with file_name.open() as original_file: + lines = original_file.readlines() + + if is_append: + content, n_added_lines = self._append_impl(lines, content) + elif is_insert: + try: + content, n_added_lines = self._insert_impl(lines, start, content) + except LineNumberError as e: + return (f"{ERROR_MSG}\n" f"{e}\n" f"{ERROR_MSG_SUFFIX}") + "\n" + else: + try: + content, n_added_lines = self._edit_impl(lines, start, end, content) + except LineNumberError as e: + return (f"{ERROR_MSG}\n" f"{e}\n" f"{ERROR_MSG_SUFFIX}") + "\n" + + if not content.endswith("\n"): + content += "\n" + + # Write the new content to the temporary file + temp_file.write(content) + + # Replace the original file with the temporary file atomically + shutil.move(temp_file_path, src_abs_path) + + # Handle linting + # NOTE: we need to get env var inside this function + # because the env var will be set AFTER the agentskills is imported + if self.enable_auto_lint: + # BACKUP the original file + temp_backup_file.writelines(lines) + temp_backup_file.flush() + lint_error, first_error_line = self._lint_file(file_name) + + # Select the errors caused by the modification + def extract_last_part(line): + parts = line.split(":") + if len(parts) > 1: + return parts[-1].strip() + return line.strip() + + def subtract_strings(str1, str2) -> str: + lines1 = str1.splitlines() + lines2 = str2.splitlines() + + last_parts1 = [extract_last_part(line) for line in lines1] + + remaining_lines = [line for line in lines2 if extract_last_part(line) not in last_parts1] + + result = "\n".join(remaining_lines) + return result + + if original_lint_error and lint_error: + lint_error = subtract_strings(original_lint_error, lint_error) + if lint_error == "": + lint_error = None + first_error_line = None + + if lint_error is not None: + # if first_error_line is not None: + # show_line = int(first_error_line) + + # show the first insert line. + if is_append: + # original end-of-file + show_line = len(lines) + # insert OR edit WILL provide meaningful line numbers + elif start is not None and end is not None: + show_line = int((start + end) / 2) + else: + raise ValueError("Invalid state. This should never happen.") + + guidance_message = self._get_indentation_info(content, start or len(lines)) + guidance_message += ( + "You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n" + "DO NOT re-run the same failed edit command. Running it again will lead to the same error." + ) + lint_error_info = ERROR_GUIDANCE.format( + linter_error_msg=LINTER_ERROR_MSG + lint_error, + window_after_applied=self._print_window(file_name, show_line, n_added_lines + 20), + window_before_applied=self._print_window( + Path(temp_backup_file.name), show_line, n_added_lines + 20 + ), + guidance_message=guidance_message, + ).strip() + + # recover the original file + shutil.move(temp_backup_file.name, src_abs_path) + return lint_error_info + + except FileNotFoundError as e: + return f"File not found: {e}\n" + except IOError as e: + return f"An error occurred while handling the file: {e}\n" + except ValueError as e: + return f"Invalid input: {e}\n" + except Exception as e: + guidance_message = self._get_indentation_info(content, start or len(lines)) + guidance_message += ( + "You either need to 1) Specify the correct start/end line arguments or 2) Enlarge the range of original code.\n" + "DO NOT re-run the same failed edit command. Running it again will lead to the same error." + ) + error_info = ERROR_GUIDANCE.format( + linter_error_msg=LINTER_ERROR_MSG + str(e), + window_after_applied=self._print_window(file_name, start or len(lines), 100), + window_before_applied=self._print_window(Path(temp_backup_file.name), start or len(lines), 100), + guidance_message=guidance_message, + ).strip() + # Clean up the temporary file if an error occurs + shutil.move(temp_backup_file.name, src_abs_path) + if temp_file_path and Path(temp_file_path).exists(): + Path(temp_file_path).unlink() + + # logger.warning(f"An unexpected error occurred: {e}") + raise Exception(f"{error_info}") from e + # Update the file information and print the updated content + with file_name.open("r", encoding="utf-8") as file: + n_total_lines = max(1, len(file.readlines())) + if first_error_line is not None and int(first_error_line) > 0: + self.current_line = first_error_line + else: + if is_append: + self.current_line = max(1, len(lines)) # end of original file + else: + self.current_line = start or n_total_lines or 1 + success_edit_info = SUCCESS_EDIT_INFO.format( + file_name=file_name.resolve(), + n_total_lines=n_total_lines, + window_after_applied=self._print_window(file_name, self.current_line, self.window), + line_number=self.current_line, + ).strip() + return success_edit_info + + def edit_file_by_replace( + self, + file_name: str, + first_replaced_line_number: int, + first_replaced_line_content: str, + last_replaced_line_number: int, + last_replaced_line_content: str, + new_content: str, + ) -> str: + """ + Line numbers start from 1. Replace lines from start_line to end_line (inclusive) with the new_content in the open file. + All of the new_content will be entered, so makesure your indentation is formatted properly. + The new_content must be a complete block of code. + + Example 1: + Given a file "/workspace/example.txt" with the following content: + ``` + 001|contain f + 002|contain g + 003|contain h + 004|contain i + ``` + + EDITING: If you want to replace line 2 and line 3 + + edit_file_by_replace( + "/workspace/example.txt", + first_replaced_line_number=2, + first_replaced_line_content="contain g", + last_replaced_line_number=3, + last_replaced_line_content="contain h", + new_content="new content", + ) + This will replace the second line 2 and line 3 with "new content". + + The resulting file will be: + ``` + 001|contain f + 002|new content + 003|contain i + ``` + Example 2: + Given a file "/workspace/example.txt" with the following content: + ``` + 001|contain f + 002|contain g + 003|contain h + 004|contain i + ``` + EDITING: If you want to remove the line 2 and line 3. + edit_file_by_replace( + "/workspace/example.txt", + first_replaced_line_number=2, + first_replaced_line_content="contain g", + last_replaced_line_number=3, + last_replaced_line_content="contain h", + new_content="", + ) + This will remove line 2 and line 3. + The resulting file will be: + ``` + 001|contain f + 002| + 003|contain i + ``` + Args: + file_name (str): The name of the file to edit. + first_replaced_line_number (int): The line number to start the edit at, starting from 1. + first_replaced_line_content (str): The content of the start replace line, according to the first_replaced_line_number. + last_replaced_line_number (int): The line number to end the edit at (inclusive), starting from 1. + last_replaced_line_content (str): The content of the end replace line, according to the last_replaced_line_number. + new_content (str): The text to replace the current selection with, must conform to PEP8 standards. The content in the start line and end line will also be replaced. + + """ + + file_name = self._try_fix_path(file_name) + + # Check if the first_replaced_line_number and last_replaced_line_number correspond to the appropriate content. + mismatch_error = "" + with file_name.open() as file: + content = file.read() + # Ensure the content ends with a newline character + if not content.endswith("\n"): + content += "\n" + lines = content.splitlines(True) + total_lines = len(lines) + check_list = [ + ("first", first_replaced_line_number, first_replaced_line_content), + ("last", last_replaced_line_number, last_replaced_line_content), + ] + for position, line_number, line_content in check_list: + if line_number > len(lines) or lines[line_number - 1].rstrip() != line_content: + start = max(1, line_number - 3) + end = min(total_lines, line_number + 3) + context = "\n".join( + [ + f'The {cur_line_number:03d} line is "{lines[cur_line_number-1].rstrip()}"' + for cur_line_number in range(start, end + 1) + ] + ) + mismatch_error += LINE_NUMBER_AND_CONTENT_MISMATCH.format( + position=position, + line_number=line_number, + true_content=lines[line_number - 1].rstrip() + if line_number - 1 < len(lines) + else "OUT OF FILE RANGE!", + fake_content=line_content.replace("\n", "\\n"), + context=context.strip(), + ) + if mismatch_error: + raise ValueError(mismatch_error) + ret_str = self._edit_file_impl( + file_name, + start=first_replaced_line_number, + end=last_replaced_line_number, + content=new_content, + ) + # TODO: automatically tries to fix linter error (maybe involve some static analysis tools on the location near the edit to figure out indentation) + self.resource.report(file_name, "path") + return ret_str + + def _edit_file_by_replace(self, file_name: str, to_replace: str, new_content: str) -> str: + """Edit a file. This will search for `to_replace` in the given file and replace it with `new_content`. + + Every *to_replace* must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc. + + Include enough lines to make code in `to_replace` unique. `to_replace` should NOT be empty. + + For example, given a file "/workspace/example.txt" with the following content: + ``` + line 1 + line 2 + line 2 + line 3 + ``` + + EDITING: If you want to replace the second occurrence of "line 2", you can make `to_replace` unique: + + edit_file_by_replace( + '/workspace/example.txt', + to_replace='line 2\nline 3', + new_content='new line\nline 3', + ) + + This will replace only the second "line 2" with "new line". The first "line 2" will remain unchanged. + + The resulting file will be: + ``` + line 1 + line 2 + new line + line 3 + ``` + + REMOVAL: If you want to remove "line 2" and "line 3", you can set `new_content` to an empty string: + + edit_file_by_replace( + '/workspace/example.txt', + to_replace='line 2\nline 3', + new_content='', + ) + + Args: + file_name: (str): The name of the file to edit. + to_replace: (str): The content to search for and replace. + new_content: (str): The new content to replace the old content with. + NOTE: + This tool is exclusive. If you use this tool, you cannot use any other commands in the current response. + If you need to use it multiple times, wait for the next turn. + """ + # FIXME: support replacing *all* occurrences + + if to_replace == new_content: + raise ValueError("`to_replace` and `new_content` must be different.") + + # search for `to_replace` in the file + # if found, replace it with `new_content` + # if not found, perform a fuzzy search to find the closest match and replace it with `new_content` + file_name = self._try_fix_path(file_name) + with file_name.open("r") as file: + file_content = file.read() + + if to_replace.strip() == "": + if file_content.strip() == "": + raise ValueError(f"The file '{file_name}' is empty. Please use the append method to add content.") + raise ValueError("`to_replace` must not be empty.") + + if file_content.count(to_replace) > 1: + raise ValueError( + "`to_replace` appears more than once, please include enough lines to make code in `to_replace` unique." + ) + start = file_content.find(to_replace) + if start != -1: + # Convert start from index to line number + start_line_number = file_content[:start].count("\n") + 1 + end_line_number = start_line_number + len(to_replace.splitlines()) - 1 + else: + + def _fuzzy_transform(s: str) -> str: + # remove all space except newline + return re.sub(r"[^\S\n]+", "", s) + + # perform a fuzzy search (remove all spaces except newlines) + to_replace_fuzzy = _fuzzy_transform(to_replace) + file_content_fuzzy = _fuzzy_transform(file_content) + # find the closest match + start = file_content_fuzzy.find(to_replace_fuzzy) + if start == -1: + return f"[No exact match found in {file_name} for\n```\n{to_replace}\n```\n]" + # Convert start from index to line number for fuzzy match + start_line_number = file_content_fuzzy[:start].count("\n") + 1 + end_line_number = start_line_number + len(to_replace.splitlines()) - 1 + + ret_str = self._edit_file_impl( + file_name, + start=start_line_number, + end=end_line_number, + content=new_content, + is_insert=False, + ) + # lint_error = bool(LINTER_ERROR_MSG in ret_str) + # TODO: automatically tries to fix linter error (maybe involve some static analysis tools on the location near the edit to figure out indentation) + self.resource.report(file_name, "path") + return ret_str + + def insert_content_at_line(self, file_name: str, line_number: int, insert_content: str) -> str: + """Insert a complete block of code before the given line number in a file. That is, the new content will start at the beginning of the specified line, and the existing content of that line will be moved down. + This operation will NOT modify the content of the lines before or after the given line number. + This function can not insert content the end of the file. Please use append_file instead, + For example, if the file has the following content: + ``` + 001|contain g + 002|contain h + 003|contain i + 004|contain j + ``` + and you call + insert_content_at_line( + file_name='file.txt', + line_number=2, + insert_content='new line' + ) + the file will be updated to: + ``` + 001|contain g + 002|new line + 003|contain h + 004|contain i + 005|contain j + ``` + + Args: + file_name: (str): The name of the file to edit. + line_number (int): The line number (starting from 1) to insert the content after. The insert content will be add between the line of line_number-1 and line_number + insert_content (str): The content to insert betweed the previous_line_content and current_line_content.The insert_content must be a complete block of code at. + + NOTE: + This tool is exclusive. If you use this tool, you cannot use any other commands in the current response. + If you need to use it multiple times, wait for the next turn. + """ + file_name = self._try_fix_path(file_name) + ret_str = self._edit_file_impl( + file_name, + start=line_number, + end=line_number, + content=insert_content, + is_insert=True, + is_append=False, + ) + self.resource.report(file_name, "path") + return ret_str + + def append_file(self, file_name: str, content: str) -> str: + """Append content to the given file. + It appends text `content` to the end of the specified file. + + Args: + file_name: str: The name of the file to edit. + content: str: The content to insert. + NOTE: + This tool is exclusive. If you use this tool, you cannot use any other commands in the current response. + If you need to use it multiple times, wait for the next turn. + """ + file_name = self._try_fix_path(file_name) + ret_str = self._edit_file_impl( + file_name, + start=None, + end=None, + content=content, + is_insert=False, + is_append=True, + ) + self.resource.report(file_name, "path") + return ret_str + + def search_dir(self, search_term: str, dir_path: str = "./") -> str: + """Searches for search_term in all files in dir. If dir is not provided, searches in the current directory. + + Args: + search_term: str: The term to search for. + dir_path: str: The path to the directory to search. + """ + dir_path = self._try_fix_path(dir_path) + if not dir_path.is_dir(): + raise FileNotFoundError(f"Directory {dir_path} not found") + matches = [] + for root, _, files in os.walk(dir_path): + for file in files: + if file.startswith("."): + continue + file_path = Path(root) / file + with file_path.open("r", errors="ignore") as f: + for line_num, line in enumerate(f, 1): + if search_term in line: + matches.append((file_path, line_num, line.strip())) + + if not matches: + return f'No matches found for "{search_term}" in {dir_path}' + + num_matches = len(matches) + num_files = len(set(match[0] for match in matches)) + + if num_files > 100: + return f'More than {num_files} files matched for "{search_term}" in {dir_path}. Please narrow your search.' + + res_list = [f'[Found {num_matches} matches for "{search_term}" in {dir_path}]'] + for file_path, line_num, line in matches: + res_list.append(f"{file_path} (Line {line_num}): {line}") + res_list.append(f'[End of matches for "{search_term}" in {dir_path}]') + return "\n".join(res_list) + + def search_file(self, search_term: str, file_path: Optional[str] = None) -> str: + """Searches for search_term in file. If file is not provided, searches in the current open file. + + Args: + search_term: str: The term to search for. + file_path: str | None: The path to the file to search. + """ + if file_path is None: + file_path = self.current_file + else: + file_path = self._try_fix_path(file_path) + if file_path is None: + raise FileNotFoundError("No file specified or open. Use the open_file function first.") + if not file_path.is_file(): + raise FileNotFoundError(f"File {file_path} not found") + + matches = [] + with file_path.open() as file: + for i, line in enumerate(file, 1): + if search_term in line: + matches.append((i, line.strip())) + res_list = [] + if matches: + res_list.append(f'[Found {len(matches)} matches for "{search_term}" in {file_path}]') + for match in matches: + res_list.append(f"Line {match[0]}: {match[1]}") + res_list.append(f'[End of matches for "{search_term}" in {file_path}]') + else: + res_list.append(f'[No matches found for "{search_term}" in {file_path}]') + + extra = {"type": "search", "symbol": search_term, "lines": [i[0] - 1 for i in matches]} if matches else None + self.resource.report(file_path, "path", extra=extra) + return "\n".join(res_list) + + def find_file(self, file_name: str, dir_path: str = "./") -> str: + """Finds all files with the given name in the specified directory. + + Args: + file_name: str: The name of the file to find. + dir_path: str: The path to the directory to search. + """ + file_name = self._try_fix_path(file_name) + dir_path = self._try_fix_path(dir_path) + if not dir_path.is_dir(): + raise FileNotFoundError(f"Directory {dir_path} not found") + + matches = [] + for root, _, files in os.walk(dir_path): + for file in files: + if str(file_name) in file: + matches.append(Path(root) / file) + + res_list = [] + if matches: + res_list.append(f'[Found {len(matches)} matches for "{file_name}" in {dir_path}]') + for match in matches: + res_list.append(f"{match}") + res_list.append(f'[End of matches for "{file_name}" in {dir_path}]') + else: + res_list.append(f'[No matches found for "{file_name}" in {dir_path}]') + return "\n".join(res_list) + + def _try_fix_path(self, path: Union[Path, str]) -> Path: + """Tries to fix the path if it is not absolute.""" + if not isinstance(path, Path): + path = Path(path) + if not path.is_absolute(): + path = self.working_dir / path + return path + + @staticmethod + async def similarity_search(query: str, path: Union[str, Path]) -> List[str]: + """Given a filename or a pathname, performs a similarity search for a given query across the specified file or path. + + This method searches the index repository for the provided query, classifying the specified + files or paths. It performs a search on each cluster of files and handles non-indexed files + separately, merging results from structured indices with any direct results from non-indexed files. + This function call does not depend on other functions. + + Args: + query (str): The search query string to look for in the indexed files. + path (Union[str, Path]): A pathname or filename to search within. + + Returns: + List[str]: A list of results as strings, containing the text from the merged results + and any direct results from non-indexed files. + + Example: + >>> query = "The problem to be analyzed from the document" + >>> file_or_path = "The pathname or filename you want to search within" + >>> texts: List[str] = await Editor.similarity_search(query=query, path=file_or_path) + >>> print(texts) + """ + return await IndexRepo.cross_repo_search(query=query, file_or_path=path) + + @staticmethod + def is_large_file(content: str, mix_token_count: int = 0) -> bool: + encoding = tiktoken.get_encoding("cl100k_base") + token_count = len(encoding.encode(content)) + mix_token_count = mix_token_count or DEFAULT_MIN_TOKEN_COUNT + return token_count >= mix_token_count diff --git a/metagpt/tools/libs/env.py b/metagpt/tools/libs/env.py new file mode 100644 index 0000000000..c1757c5f9c --- /dev/null +++ b/metagpt/tools/libs/env.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/4/25 +@Author : mashenquan +@File : env.py +@Desc: Implement `get_env`. RFC 216 2.4.2.4.2. +""" +import os +from typing import Dict, Optional + + +class EnvKeyNotFoundError(Exception): + def __init__(self, info): + super().__init__(info) + + +def to_app_key(key: str, app_name: str = None) -> str: + return f"{app_name}-{key}" if app_name else key + + +def split_app_key(app_key: str) -> (str, str): + if "-" not in app_key: + return "", app_key + app_name, key = app_key.split("-", 1) + return app_name, key + + +async def default_get_env(key: str, app_name: str = None) -> str: + app_key = to_app_key(key=key, app_name=app_name) + if app_key in os.environ: + return os.environ[app_key] + + env_app_key = app_key.replace("-", "_") # "-" is not supported by linux environment variable + if env_app_key in os.environ: + return os.environ[env_app_key] + + from metagpt.context import Context + + context = Context() + val = context.kwargs.get(app_key, None) + if val is not None: + return val + + raise EnvKeyNotFoundError(f"EnvKeyNotFoundError: {key}, app_name:{app_name or ''}") + + +async def default_get_env_description() -> Dict[str, str]: + result = {} + for k in os.environ.keys(): + app_name, key = split_app_key(k) + call = f'await get_env(key="{key}", app_name="{app_name}")' + result[call] = f"Return the value of environment variable `{k}`." + + from metagpt.context import Context + + context = Context() + for k in context.kwargs.__dict__.keys(): + app_name, key = split_app_key(k) + call = f'await get_env(key="{key}", app_name="{app_name}")' + result[call] = f"Get the value of environment variable `{k}`." + return result + + +_get_env_entry = default_get_env +_get_env_description_entry = default_get_env_description + + +async def get_env(key: str, app_name: str = None) -> str: + """ + Retrieve the value of the environment variable for the specified key. + + Args: + key (str): The key of the environment variable. + app_name (str, optional): The name of the application. Defaults to None. + + Returns: + str: The value corresponding to the given key in the environment variables. + If no value is found for the given key, an empty string is returned. + + Example: + This function can be used to retrieve environment variables asynchronously. + It should be called using `await`. + + >>> from metagpt.tools.libs.env import get_env + >>> api_key = await get_env("API_KEY") + >>> print(api_key) + + + >>> from metagpt.tools.libs.env import get_env + >>> api_key = await get_env(key="API_KEY", app_name="GITHUB") + >>> print(api_key) + + + Note: + This is an asynchronous function and must be called using `await`. + """ + global _get_env_entry + if _get_env_entry: + return await _get_env_entry(key=key, app_name=app_name) + + return await default_get_env(key=key, app_name=app_name) + + +async def get_env_default(key: str, app_name: str = None, default_value: str = None) -> Optional[str]: + """ + Retrieves the value for the specified environment variable key. If the key is not found, + returns the default value. + + Args: + key (str): The name of the environment variable to retrieve. + app_name (str, optional): The name of the application or component to associate with the environment variable. + default_value (str, optional): The default value to return if the environment variable is not found. + + Returns: + str or None: The value of the environment variable if found, otherwise the default value. + + Example: + >>> from metagpt.tools.libs.env import get_env + >>> api_key = await get_env_default(key="NOT_EXISTS_API_KEY", default_value="") + >>> print(api_key) + + + >>> from metagpt.tools.libs.env import get_env + >>> api_key = await get_env_default(key="NOT_EXISTS_API_KEY", app_name="GITHUB", default_value="") + >>> print(api_key) + + + """ + try: + return await get_env(key=key, app_name=app_name) + except EnvKeyNotFoundError: + return default_value + + +async def get_env_description() -> Dict[str, str]: + global _get_env_description_entry + + if _get_env_description_entry: + return await _get_env_description_entry() + + return await default_get_env_description() + + +def set_get_env_entry(value, description): + """Modify `get_env` entry and `get_description` entry. + + Args: + value (function): New function entry. + description (str): Description of the function. + + This function modifies the `get_env` entry by updating the function + to the provided `value` and its description to the provided `description`. + """ + global _get_env_entry + global _get_env_description_entry + _get_env_entry = value + _get_env_description_entry = description diff --git a/metagpt/tools/libs/git.py b/metagpt/tools/libs/git.py new file mode 100644 index 0000000000..8a3e464f0c --- /dev/null +++ b/metagpt/tools/libs/git.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import annotations + +import urllib +from pathlib import Path +from typing import Optional + +from github.Issue import Issue +from github.PullRequest import PullRequest + +from metagpt.tools.tool_registry import register_tool + + +@register_tool(tags=["software development", "git", "create a git pull request or merge request"]) +async def git_create_pull( + base: str, + head: str, + app_name: str, + base_repo_name: str, + head_repo_name: str = None, + title: Optional[str] = None, + body: Optional[str] = None, + issue: Optional[Issue] = None, +) -> PullRequest: + """ + Creates a pull request on a Git repository. Use this tool in priority over Browser to create a pull request. + + Args: + base (str): The name of the base branch where the pull request will be merged. + head (str): The name of the branch that contains the changes for the pull request. + app_name (str): The name of the platform hosting the repository (e.g., "github", "gitlab", "bitbucket"). + base_repo_name (str): The full name of the target repository (in the format "user/repo") where the pull request will be created. + head_repo_name (Optional[str]): The full name of the source repository (in the format "user/repo") from which the changes will be pulled. + title (Optional[str]): The title of the pull request. Defaults to None. + body (Optional[str]): The description or body content of the pull request. Defaults to None. + issue (Optional[Issue]): An optional issue related to the pull request. Defaults to None. + + Example: + >>> # create pull request + >>> base_repo_name = "geekan/MetaGPT" + >>> head_repo_name = "ioris/MetaGPT" + >>> base = "master" + >>> head = "feature/http" + >>> title = "feat: modify http lib", + >>> body = "Change HTTP library used to send requests" + >>> app_name = "github" + >>> pr = await git_create_pull( + >>> base_repo_name=base_repo_name, + >>> head_repo_name=head_repo_name, + >>> base=base, + >>> head=head, + >>> title=title, + >>> body=body, + >>> app_name=app_name, + >>> ) + >>> if isinstance(pr, PullRequest): + >>> print(pr) + PullRequest("feat: modify http lib") + >>> if isinstance(pr, str): + >>> print(f"Visit this url to create a new pull request: '{pr}'") + Visit this url to create a new pull request: 'https://github.com/geekan/MetaGPT/compare/master...iorisa:MetaGPT:feature/http' + + Returns: + PullRequest: The created pull request. + """ + from metagpt.utils.git_repository import GitRepository + + git_credentials_path = Path.home() / ".git-credentials" + with open(git_credentials_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + for line in lines: + line = line.strip() + if not line: + continue + parsed_url = urllib.parse.urlparse(line) + if app_name in parsed_url.hostname: + colon_index = parsed_url.netloc.find(":") + at_index = parsed_url.netloc.find("@") + access_token = parsed_url.netloc[colon_index + 1 : at_index] + break + return await GitRepository.create_pull( + base=base, + head=head, + base_repo_name=base_repo_name, + head_repo_name=head_repo_name, + title=title, + body=body, + issue=issue, + access_token=access_token, + ) + + +@register_tool(tags=["software development", "create a git issue"]) +async def git_create_issue( + repo_name: str, + title: str, + access_token: str, + body: Optional[str] = None, +) -> Issue: + """ + Creates an issue on a Git repository. + + Args: + repo_name (str): The name of the repository. + title (str): The title of the issue. + access_token (str): The access token for authentication. Use `get_env` to get access token. + body (Optional[str], optional): The body of the issue. Defaults to None. + + Example: + >>> repo_name = "geekan/MetaGPT" + >>> title = "This is a new issue" + >>> from metagpt.tools.libs import get_env + >>> access_token = await get_env(key="access_token", app_name="github") + >>> body = "This is the issue body." + >>> issue = await git_create_issue( + >>> repo_name=repo_name, + >>> title=title, + >>> access_token=access_token, + >>> body=body, + >>> ) + >>> print(issue) + Issue("This is a new issue") + + Returns: + Issue: The created issue. + """ + from metagpt.utils.git_repository import GitRepository + + return await GitRepository.create_issue(repo_name=repo_name, title=title, body=body, access_token=access_token) diff --git a/metagpt/tools/libs/gpt_v_generator.py b/metagpt/tools/libs/gpt_v_generator.py index 4eba3d5eec..62c36b2f88 100644 --- a/metagpt/tools/libs/gpt_v_generator.py +++ b/metagpt/tools/libs/gpt_v_generator.py @@ -7,8 +7,9 @@ """ import re from pathlib import Path +from typing import Optional -from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.config2 import Config from metagpt.logs import logger from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, encode_image @@ -36,11 +37,11 @@ class GPTvGenerator: It utilizes a vision model to analyze the layout from an image and generate webpage codes accordingly. """ - def __init__(self): + def __init__(self, config: Optional[Config]): """Initialize GPTvGenerator class with default values from the configuration.""" - from metagpt.config2 import config from metagpt.llm import LLM + config = config if config else Config.default() self.llm = LLM(llm_config=config.get_openai_llm()) self.llm.model = "gpt-4-vision-preview" @@ -84,12 +85,12 @@ def save_webpages(webpages: str, save_folder_name: str = "example") -> Path: Path: The path of the saved webpages. """ # Create a folder called webpages in the workspace directory to store HTML, CSS, and JavaScript files - webpages_path = DEFAULT_WORKSPACE_ROOT / "webpages" / save_folder_name + webpages_path = Config.default().workspace.path / "webpages" / save_folder_name logger.info(f"code will be saved at {webpages_path}") webpages_path.mkdir(parents=True, exist_ok=True) index_path = webpages_path / "index.html" - index_path.write_text(CodeParser.parse_code(block=None, text=webpages, lang="html")) + index_path.write_text(CodeParser.parse_code(text=webpages, lang="html")) extract_and_save_code(folder=webpages_path, text=webpages, pattern="styles?.css", language="css") @@ -102,5 +103,5 @@ def extract_and_save_code(folder, text, pattern, language): word = re.search(pattern, text) if word: path = folder / word.group(0) - code = CodeParser.parse_code(block=None, text=text, lang=language) + code = CodeParser.parse_code(text=text, lang=language) path.write_text(code, encoding="utf-8") diff --git a/metagpt/tools/libs/image_getter.py b/metagpt/tools/libs/image_getter.py new file mode 100644 index 0000000000..ecbaaf5100 --- /dev/null +++ b/metagpt/tools/libs/image_getter.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import Optional + +from playwright.async_api import Browser as Browser_ +from playwright.async_api import BrowserContext, Page, Playwright, async_playwright +from pydantic import BaseModel, ConfigDict, Field + +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import decode_image +from metagpt.utils.proxy_env import get_proxy_from_env +from metagpt.utils.report import BrowserReporter + +DOWNLOAD_PICTURE_JAVASCRIPT = """ +async () => {{ + var img = document.querySelector('{img_element_selector}'); + if (img && img.src) {{ + const response = await fetch(img.src); + if (response.ok) {{ + const blob = await response.blob(); + return await new Promise(resolve => {{ + const reader = new FileReader(); + reader.onloadend = () => resolve(reader.result); + reader.readAsDataURL(blob); + }}); + }} + }} + return null; +}} +""" + + +@register_tool(include_functions=["get_image"]) +class ImageGetter(BaseModel): + """ + A tool to get images. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + playwright: Optional[Playwright] = Field(default=None, exclude=True) + browser_instance: Optional[Browser_] = Field(default=None, exclude=True) + browser_ctx: Optional[BrowserContext] = Field(default=None, exclude=True) + page: Optional[Page] = Field(default=None, exclude=True) + headless: bool = Field(default=True) + proxy: Optional[dict] = Field(default_factory=get_proxy_from_env) + reporter: BrowserReporter = Field(default_factory=BrowserReporter) + url: str = "https://unsplash.com/s/photos/{search_term}/" + img_element_selector: str = ".zNNw1 > div > img:nth-of-type(2)" + + async def start(self) -> None: + """Starts Playwright and launches a browser""" + if self.playwright is None: + self.playwright = playwright = await async_playwright().start() + browser = self.browser_instance = await playwright.chromium.launch(headless=self.headless, proxy=self.proxy) + browser_ctx = self.browser_ctx = await browser.new_context() + self.page = await browser_ctx.new_page() + + async def get_image(self, search_term, image_save_path): + """ + Get an image related to the search term. + + Args: + search_term (str): The term to search for the image. The search term must be in English. Using any other language may lead to a mismatch. + image_save_path (str): The file path where the image will be saved. + """ + # Search for images from https://unsplash.com/s/photos/ + + if self.page is None: + await self.start() + await self.page.goto(self.url.format(search_term=search_term), wait_until="domcontentloaded") + # Wait until the image element is loaded + try: + await self.page.wait_for_selector(self.img_element_selector) + except TimeoutError: + return f"{search_term} not found. Please broaden the search term." + # Get the base64 code of the first retrieved image + image_base64 = await self.page.evaluate( + DOWNLOAD_PICTURE_JAVASCRIPT.format(img_element_selector=self.img_element_selector) + ) + if image_base64: + image = decode_image(image_base64) + image.save(image_save_path) + return f"{search_term} found. The image is saved in {image_save_path}." + return f"{search_term} not found. Please broaden the search term." diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py new file mode 100644 index 0000000000..4c4e6c59b4 --- /dev/null +++ b/metagpt/tools/libs/index_repo.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import asyncio +import json +import re +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple, Union + +import tiktoken +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.schema import NodeWithScore +from pydantic import BaseModel, Field, model_validator + +from metagpt.config2 import Config +from metagpt.context import Context +from metagpt.logs import logger +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.factories.embedding import RAGEmbeddingFactory +from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig +from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files +from metagpt.utils.file import File +from metagpt.utils.report import EditorReporter + +UPLOADS_INDEX_ROOT = "/data/.index/uploads" +DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT +UPLOAD_ROOT = "/data/uploads" +DEFAULT_ROOT = UPLOAD_ROOT +CHATS_INDEX_ROOT = "/data/.index/chats" +CHATS_ROOT = "/data/chats/" +OTHER_TYPE = "other" + +DEFAULT_MIN_TOKEN_COUNT = 10000 +DEFAULT_MAX_TOKEN_COUNT = 100000000 + + +class IndexRepoMeta(BaseModel): + min_token_count: int + max_token_count: int + + +class TextScore(BaseModel): + filename: str + text: str + score: Optional[float] = None + + +class IndexRepo(BaseModel): + persist_path: str = DEFAULT_INDEX_ROOT # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` + root_path: str = ( + DEFAULT_ROOT # `/data/uploads` or r`/data/chats/[a-z0-9]+`, the root path of files indexed by the index repo. + ) + fingerprint_filename: str = "fingerprint.json" + meta_filename: str = "meta.json" + model: Optional[str] = None + min_token_count: int = DEFAULT_MIN_TOKEN_COUNT + max_token_count: int = DEFAULT_MAX_TOKEN_COUNT + recall_count: int = 5 + embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True) + fingerprints: Dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def _update_fingerprints(self) -> "IndexRepo": + """Load fingerprints from the fingerprint file if not already loaded. + + Returns: + IndexRepo: The updated IndexRepo instance. + """ + if not self.fingerprints: + filename = Path(self.persist_path) / self.fingerprint_filename + if not filename.exists(): + return self + with open(str(filename), "r") as reader: + self.fingerprints = json.load(reader) + return self + + async def search( + self, query: str, filenames: Optional[List[Path]] = None + ) -> Optional[List[Union[NodeWithScore, TextScore]]]: + """Search for documents related to the given query. + + Args: + query (str): The search query. + filenames (Optional[List[Path]]): A list of filenames to filter the search. + + Returns: + Optional[List[Union[NodeWithScore, TextScore]]]: A list of search results containing NodeWithScore or TextScore. + """ + encoding = tiktoken.get_encoding("cl100k_base") + result: List[Union[NodeWithScore, TextScore]] = [] + filenames, excludes = await self._filter(filenames) + if not filenames: + raise ValueError(f"Unsupported file types: {[str(i) for i in excludes]}") + resource = EditorReporter() + for i in filenames: + await resource.async_report(str(i), "path") + filter_filenames = set() + meta = await self._read_meta() + new_files = {} + for i in filenames: + if Path(i).suffix.lower() in {".pdf", ".doc", ".docx"}: + if str(i) not in self.fingerprints: + new_files[i] = "" + logger.warning(f'file: "{i}" not indexed') + filter_filenames.add(str(i)) + continue + content = await File.read_text_file(i) + token_count = len(encoding.encode(content)) + if not self._is_buildable( + token_count, min_token_count=meta.min_token_count, max_token_count=meta.max_token_count + ): + result.append(TextScore(filename=str(i), text=content)) + continue + file_fingerprint = generate_fingerprint(content) + if str(i) not in self.fingerprints or (self.fingerprints.get(str(i)) != file_fingerprint): + new_files[i] = content + logger.warning(f'file: "{i}" changed but not indexed') + continue + filter_filenames.add(str(i)) + if new_files: + added, others = await self.add(paths=list(new_files.keys()), file_datas=new_files) + filter_filenames.update([str(i) for i in added]) + for i in others: + result.append(TextScore(filename=str(i), text=new_files.get(i))) + filter_filenames.discard(str(i)) + nodes = await self._search(query=query, filters=filter_filenames) + return result + nodes + + async def merge( + self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]] + ) -> List[Union[NodeWithScore, TextScore]]: + """Merge results from multiple indices based on the query. + + Args: + query (str): The search query. + indices_list (List[List[Union[NodeWithScore, TextScore]]]): A list of result lists from different indices. + + Returns: + List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. + """ + flat_nodes = [node for indices in indices_list if indices for node in indices if node] + if len(flat_nodes) <= self.recall_count: + return flat_nodes + + if not self.embedding: + config = Config.default() + if self.model: + config.embedding.model = self.model + factory = RAGEmbeddingFactory(config) + self.embedding = factory.get_rag_embedding() + + scores = [] + query_embedding = await self.embedding.aget_text_embedding(query) + for i in flat_nodes: + try: + text_embedding = await self.embedding.aget_text_embedding(i.text) + except Exception as e: # 超过最大长度 + tenth = int(len(i.text) / 10) # DEFAULT_MIN_TOKEN_COUNT = 10000 + logger.warning( + f"{e}, tenth len={tenth}, pre_part_len={len(i.text[: tenth * 6])}, post_part_len={len(i.text[tenth * 4:])}" + ) + pre_win_part = await self.embedding.aget_text_embedding(i.text[: tenth * 6]) + post_win_part = await self.embedding.aget_text_embedding(i.text[tenth * 4 :]) + similarity = max( + self.embedding.similarity(query_embedding, pre_win_part), + self.embedding.similarity(query_embedding, post_win_part), + ) + scores.append((similarity, i)) + continue + similarity = self.embedding.similarity(query_embedding, text_embedding) + scores.append((similarity, i)) + scores.sort(key=lambda x: x[0], reverse=True) + return [i[1] for i in scores][: self.recall_count] + + async def add( + self, paths: List[Path], file_datas: Dict[Union[str, Path], str] = None + ) -> Tuple[List[str], List[str]]: + """Add new documents to the index. + + Args: + paths (List[Path]): A list of paths to the documents to be added. + file_datas (Dict[Union[str, Path], str]): A list of file content. + + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists: + 1. The list of filenames that were successfully added to the index. + 2. The list of filenames that were not added to the index because they were not buildable. + """ + encoding = tiktoken.get_encoding("cl100k_base") + filenames, _ = await self._filter(paths) + filter_filenames = [] + delete_filenames = [] + file_datas = file_datas or {} + for i in filenames: + content = file_datas.get(i) or await File.read_text_file(i) + file_datas[i] = content + if not self._is_fingerprint_changed(filename=i, content=content): + continue + token_count = len(encoding.encode(content)) + if self._is_buildable(token_count): + filter_filenames.append(i) + logger.debug(f"{i} is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") + else: + delete_filenames.append(i) + logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") + await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames, file_datas=file_datas) + return filter_filenames, delete_filenames + + async def _add_batch( + self, + filenames: List[Union[str, Path]], + delete_filenames: List[Union[str, Path]], + file_datas: Dict[Union[str, Path], str], + ): + """Add and remove documents in a batch operation. + + Args: + filenames (List[Union[str, Path]]): List of filenames to add. + delete_filenames (List[Union[str, Path]]): List of filenames to delete. + """ + if not filenames: + return + logger.info(f"update index repo, add {filenames}, remove {delete_filenames}") + engine = None + Context() + if Path(self.persist_path).exists(): + logger.debug(f"load index from {self.persist_path}") + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.persist_path), + retriever_configs=[FAISSRetrieverConfig()], + ) + try: + engine.delete_docs(filenames + delete_filenames) + logger.info(f"delete docs {filenames + delete_filenames}") + engine.add_docs(input_files=filenames) + logger.info(f"add docs {filenames}") + except NotImplementedError as e: + logger.debug(f"{e}") + filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys()))) + engine = None + logger.info(f"{e}. Rebuild all.") + if not engine: + engine = SimpleEngine.from_docs( + input_files=[str(i) for i in filenames], + retriever_configs=[FAISSRetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + logger.info(f"add docs {filenames}") + engine.persist(persist_dir=self.persist_path) + for i in filenames: + content = file_datas.get(i) or await File.read_text_file(i) + fp = generate_fingerprint(content) + self.fingerprints[str(i)] = fp + await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + await self._save_meta() + + def __str__(self): + """Return a string representation of the IndexRepo. + + Returns: + str: The filename of the index repository. + """ + return f"{self.persist_path}" + + def _is_buildable(self, token_count: int, min_token_count: int = -1, max_token_count=-1) -> bool: + """Check if the token count is within the buildable range. + + Args: + token_count (int): The number of tokens in the content. + + Returns: + bool: True if buildable, False otherwise. + """ + min_token_count = min_token_count if min_token_count >= 0 else self.min_token_count + max_token_count = max_token_count if max_token_count >= 0 else self.max_token_count + if token_count < min_token_count or token_count > max_token_count: + return False + return True + + async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]): + """Filter the provided filenames to only include valid text files. + + Args: + filenames (Optional[List[Union[str, Path]]]): List of filenames to filter. + + Returns: + Tuple[List[Path], List[Path]]: A tuple containing a list of valid pathnames and a list of excluded paths. + """ + root_path = Path(self.root_path).absolute() + if not filenames: + filenames = [root_path] + pathnames = [] + excludes = [] + for i in filenames: + path = Path(i).absolute() + if not path.is_relative_to(root_path): + excludes.append(path) + logger.debug(f"{path} not is_relative_to {root_path})") + continue + if not path.is_dir(): + is_text = await File.is_textual_file(path) + if is_text: + pathnames.append(path) + continue + subfiles = list_files(path) + for j in subfiles: + is_text = await File.is_textual_file(j) + if is_text: + pathnames.append(j) + + logger.debug(f"{pathnames}, excludes:{excludes})") + return pathnames, excludes + + async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]: + """Perform a search for the given query using the index. + + Args: + query (str): The search query. + filters (Set[str]): A set of filenames to filter the search results. + + Returns: + List[NodeWithScore]: A list of nodes with scores matching the query. + """ + if not filters: + return [] + if not Path(self.persist_path).exists(): + raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.") + Context() + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.persist_path), + retriever_configs=[FAISSRetrieverConfig()], + ) + rsp = await engine.aretrieve(query) + return [i for i in rsp if i.metadata.get("file_path") in filters] + + def _is_fingerprint_changed(self, filename: Union[str, Path], content: str) -> bool: + """Check if the fingerprint of the given document content has changed. + + Args: + filename (Union[str, Path]): The filename of the document. + content (str): The content of the document. + + Returns: + bool: True if the fingerprint has changed, False otherwise. + """ + old_fp = self.fingerprints.get(str(filename)) + if not old_fp: + return True + fp = generate_fingerprint(content) + return old_fp != fp + + @staticmethod + def find_index_repo_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + """Map the file path to the corresponding index repo. + + Args: + files (List[Union[str, Path]]): A list of file paths or Path objects to be classified. + + Returns: + Tuple[Dict[str, Set[Path]], Dict[str, str]]: + - A dictionary mapping the index repo path to the files. + - A dictionary mapping the index repo path to their corresponding root directories. + """ + mappings = { + UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"), + CHATS_INDEX_ROOT: re.compile(r"^/data/chats/[a-z0-9]+($|/.*)"), + } + + clusters = {} + roots = {} + for i in files: + path = Path(i).absolute() + path_type = OTHER_TYPE + for type_, pattern in mappings.items(): + if re.match(pattern, str(i)): + path_type = type_ + break + if path_type == CHATS_INDEX_ROOT: + chat_id = path.parts[3] + path_type = str(Path(path_type) / chat_id) + roots[path_type] = str(Path(CHATS_ROOT) / chat_id) + elif path_type == UPLOADS_INDEX_ROOT: + roots[path_type] = UPLOAD_ROOT + + if path_type in clusters: + clusters[path_type].add(path) + else: + clusters[path_type] = {path} + + return clusters, roots + + async def _save_meta(self): + meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + await awrite(filename=Path(self.persist_path) / self.meta_filename, data=meta.model_dump_json()) + + async def _read_meta(self) -> IndexRepoMeta: + default_meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + + filename = Path(self.persist_path) / self.meta_filename + if not filename.exists(): + return default_meta + meta_data = await aread(filename=filename) + try: + meta = IndexRepoMeta.model_validate_json(meta_data) + return meta + except Exception as e: + logger.warning(f"Load meta error: {e}") + return default_meta + + @staticmethod + async def cross_repo_search(query: str, file_or_path: Union[str, Path]) -> List[str]: + """Search for a query across multiple repositories. + + This asynchronous function searches for the specified query in files + located at the given path or file. + + Args: + query (str): The search term to look for in the files. + file_or_path (Union[str, Path]): The path to the file or directory + where the search should be conducted. This can be a string path + or a Path object. + + Returns: + List[str]: A list of strings containing the paths of files that + contain the query results. + + Raises: + ValueError: If the query string is empty. + """ + if not file_or_path or not Path(file_or_path).exists(): + raise ValueError(f'"{str(file_or_path)}" not exists') + files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path) + clusters, roots = IndexRepo.find_index_repo_path(files) + futures = [] + others = set() + for persist_path, filenames in clusters.items(): + if persist_path == OTHER_TYPE: + others.update(filenames) + continue + root = roots[persist_path] + repo = IndexRepo(persist_path=persist_path, root_path=root) + futures.append(repo.search(query=query, filenames=list(filenames))) + + for i in others: + futures.append(File.read_text_file(i)) + + futures_results = [] + if futures: + futures_results = await asyncio.gather(*futures) + + result = [] + v_result = [] + for i in futures_results: + if not i: + continue + if isinstance(i, str): + result.append(i) + else: + v_result.append(i) + + repo = IndexRepo() + merged = await repo.merge(query=query, indices_list=v_result) + return [i.text for i in merged] + result diff --git a/metagpt/tools/libs/linter.py b/metagpt/tools/libs/linter.py new file mode 100644 index 0000000000..0497e49c06 --- /dev/null +++ b/metagpt/tools/libs/linter.py @@ -0,0 +1,233 @@ +""" +This file is borrowed from OpenDevin +You can find the original repository here: +https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/utils/aider/linter.py +""" +import os +import subprocess +import sys +import traceback +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from grep_ast import TreeContext, filename_to_lang +from tree_sitter_languages import get_parser # noqa: E402 + +# tree_sitter is throwing a FutureWarning +warnings.simplefilter("ignore", category=FutureWarning) + + +@dataclass +class LintResult: + text: str + lines: list + + +class Linter: + def __init__(self, encoding="utf-8", root=None): + self.encoding = encoding + self.root = root + + self.languages = dict( + python=self.py_lint, + sql=self.fake_lint, # base_lint lacks support for full SQL syntax. Use fake_lint to bypass the validation. + css=self.fake_lint, # base_lint lacks support for css syntax. Use fake_lint to bypass the validation. + js=self.fake_lint, # base_lint lacks support for javascipt syntax. Use fake_lint to bypass the validation. + javascript=self.fake_lint, + ) + self.all_lint_cmd = None + + def set_linter(self, lang, cmd): + if lang: + self.languages[lang] = cmd + return + + self.all_lint_cmd = cmd + + def get_rel_fname(self, fname): + if self.root: + return os.path.relpath(fname, self.root) + else: + return fname + + def run_cmd(self, cmd, rel_fname, code): + cmd += " " + rel_fname + cmd = cmd.split() + process = subprocess.Popen(cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout, _ = process.communicate() + errors = stdout.decode().strip() + self.returncode = process.returncode + if self.returncode == 0: + return # zero exit status + + cmd = " ".join(cmd) + res = "" + res += errors + line_num = extract_error_line_from(res) + return LintResult(text=res, lines=[line_num]) + + def get_abs_fname(self, fname): + if os.path.isabs(fname): + return fname + elif os.path.isfile(fname): + rel_fname = self.get_rel_fname(fname) + return os.path.abspath(rel_fname) + else: # if a temp file + return self.get_rel_fname(fname) + + def lint(self, fname, cmd=None) -> Optional[LintResult]: + code = Path(fname).read_text(self.encoding) + absolute_fname = self.get_abs_fname(fname) + if cmd: + cmd = cmd.strip() + if not cmd: + lang = filename_to_lang(fname) + if not lang: + return None + if self.all_lint_cmd: + cmd = self.all_lint_cmd + else: + cmd = self.languages.get(lang) + if callable(cmd): + linkres = cmd(fname, absolute_fname, code) + elif cmd: + linkres = self.run_cmd(cmd, absolute_fname, code) + else: + linkres = basic_lint(absolute_fname, code) + return linkres + + def flake_lint(self, rel_fname, code): + fatal = "F821,F822,F831,E112,E113,E999,E902" + flake8 = f"flake8 --select={fatal} --isolated" + + try: + flake_res = self.run_cmd(flake8, rel_fname, code) + except FileNotFoundError: + flake_res = None + return flake_res + + def py_lint(self, fname, rel_fname, code): + error = self.flake_lint(rel_fname, code) + if not error: + error = lint_python_compile(fname, code) + if not error: + error = basic_lint(rel_fname, code) + return error + + def fake_lint(self, fname, rel_fname, code): + return None + + +def lint_python_compile(fname, code): + try: + compile(code, fname, "exec") # USE TRACEBACK BELOW HERE + return + except IndentationError as err: + end_lineno = getattr(err, "end_lineno", err.lineno) + if isinstance(end_lineno, int): + line_numbers = list(range(end_lineno - 1, end_lineno)) + else: + line_numbers = [] + + tb_lines = traceback.format_exception(type(err), err, err.__traceback__) + last_file_i = 0 + + target = "# USE TRACEBACK" + target += " BELOW HERE" + for i in range(len(tb_lines)): + if target in tb_lines[i]: + last_file_i = i + break + tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :] + + res = "".join(tb_lines) + return LintResult(text=res, lines=line_numbers) + + +def basic_lint(fname, code): + """ + Use tree-sitter to look for syntax errors, display them with tree context. + """ + + lang = filename_to_lang(fname) + if not lang: + return + + parser = get_parser(lang) + tree = parser.parse(bytes(code, "utf-8")) + + errors = traverse_tree(tree.root_node) + if not errors: + return + return LintResult(text=f"{fname}:{errors[0]}", lines=errors) + + +def extract_error_line_from(lint_error): + # moved from openhands.agentskills#_lint_file + for line in lint_error.splitlines(True): + if line.strip(): + # The format of the error message is: ::: + parts = line.split(":") + if len(parts) >= 2: + try: + first_error_line = int(parts[1]) + break + except ValueError: + continue + return first_error_line + + +def tree_context(fname, code, line_nums): + context = TreeContext( + fname, + code, + color=False, + line_number=True, + child_context=False, + last_line=False, + margin=0, + mark_lois=True, + loi_pad=3, + # header_max=30, + show_top_of_file_parent_scope=False, + ) + line_nums = set(line_nums) + context.add_lines_of_interest(line_nums) + context.add_context() + output = context.format() + + return output + + +# Traverse the tree to find errors +def traverse_tree(node): + errors = [] + if node.type == "ERROR" or node.is_missing: + line_no = node.start_point[0] + 1 + errors.append(line_no) + + for child in node.children: + errors += traverse_tree(child) + + return errors + + +def main(): + """ + Main function to parse files provided as command line arguments. + """ + if len(sys.argv) < 2: + print("Usage: python linter.py ...") + sys.exit(1) + + linter = Linter(root=os.getcwd()) + for file_path in sys.argv[1:]: + errors = linter.lint(file_path) + if errors: + print(errors) + + +if __name__ == "__main__": + main() diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index b62e39db89..4cf7d23109 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -14,7 +14,7 @@ from aiohttp import ClientSession from PIL import Image, PngImagePlugin -from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT +from metagpt.const import SD_OUTPUT_FILE_REPO, SD_URL, SOURCE_ROOT from metagpt.logs import logger from metagpt.tools.tool_registry import register_tool @@ -68,7 +68,7 @@ def __init__(self, sd_url=""): Args: sd_url (str, optional): URL of the stable diffusion service. Defaults to "". """ - self.sd_url = sd_url + self.sd_url = SD_URL if not sd_url else sd_url self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img" # Define default payload settings for SD API self.payload = payload @@ -76,12 +76,12 @@ def __init__(self, sd_url=""): def construct_payload( self, - prompt, - negtive_prompt=default_negative_prompt, - width=512, - height=512, - sd_model="galaxytimemachinesGTM_photoV20", - ): + prompt: object, + negtive_prompt: object = default_negative_prompt, + width: object = 512, + height: object = 512, + sd_model: object = "galaxytimemachinesGTM_photoV20", + ) -> object: """Modify and set the API parameters for image generation. Args: diff --git a/metagpt/tools/libs/shell.py b/metagpt/tools/libs/shell.py new file mode 100644 index 0000000000..046830070b --- /dev/null +++ b/metagpt/tools/libs/shell.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import annotations + +import subprocess +from pathlib import Path +from typing import Dict, List, Tuple, Union + + +async def shell_execute( + command: Union[List[str], str], cwd: str | Path = None, env: Dict = None, timeout: int = 600 +) -> Tuple[str, str, int]: + """ + Execute a command asynchronously and return its standard output and standard error. + + Args: + command (Union[List[str], str]): The command to execute and its arguments. It can be provided either as a list + of strings or as a single string. + cwd (str | Path, optional): The current working directory for the command. Defaults to None. + env (Dict, optional): Environment variables to set for the command. Defaults to None. + timeout (int, optional): Timeout for the command execution in seconds. Defaults to 600. + + Returns: + Tuple[str, str, int]: A tuple containing the string type standard output and string type standard error of the executed command and int type return code. + + Raises: + ValueError: If the command times out, this error is raised. The error message contains both standard output and + standard error of the timed-out process. + + Example: + >>> # command is a list + >>> stdout, stderr, returncode = await shell_execute(command=["ls", "-l"], cwd="/home/user", env={"PATH": "/usr/bin"}) + >>> print(stdout) + total 8 + -rw-r--r-- 1 user user 0 Mar 22 10:00 file1.txt + -rw-r--r-- 1 user user 0 Mar 22 10:00 file2.txt + ... + + >>> # command is a string of shell script + >>> stdout, stderr, returncode = await shell_execute(command="ls -l", cwd="/home/user", env={"PATH": "/usr/bin"}) + >>> print(stdout) + total 8 + -rw-r--r-- 1 user user 0 Mar 22 10:00 file1.txt + -rw-r--r-- 1 user user 0 Mar 22 10:00 file2.txt + ... + + References: + This function uses `subprocess.Popen` for executing shell commands asynchronously. + """ + cwd = str(cwd) if cwd else None + shell = True if isinstance(command, str) else False + result = subprocess.run(command, cwd=cwd, capture_output=True, text=True, env=env, timeout=timeout, shell=shell) + return result.stdout, result.stderr, result.returncode diff --git a/metagpt/tools/libs/software_development.py b/metagpt/tools/libs/software_development.py new file mode 100644 index 0000000000..0955faa7ad --- /dev/null +++ b/metagpt/tools/libs/software_development.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import annotations + +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional + +from metagpt.actions.requirement_analysis.framework import ( + EvaluateFramework, + WriteFramework, + save_framework, +) +from metagpt.actions.requirement_analysis.trd import ( + CompressExternalInterfaces, + DetectInteraction, + EvaluateTRD, + WriteTRD, +) +from metagpt.const import ASSISTANT_ALIAS, TEST_DATA_PATH +from metagpt.context import Context +from metagpt.logs import ToolLogItem, log_tool_output, logger +from metagpt.utils.common import aread +from metagpt.utils.cost_manager import CostManager + + +async def import_git_repo(url: str) -> Path: + """ + Imports a project from a Git website and formats it to MetaGPT project format to enable incremental appending requirements. + + Args: + url (str): The Git project URL, such as "https://github.com/geekan/MetaGPT.git". + + Returns: + Path: The path of the formatted project. + + Example: + # The Git project URL to input + >>> git_url = "https://github.com/geekan/MetaGPT.git" + + # Import the Git repository and get the formatted project path + >>> formatted_project_path = await import_git_repo(git_url) + >>> print("Formatted project path:", formatted_project_path) + /PATH/TO/THE/FORMMATTED/PROJECT + """ + from metagpt.actions.import_repo import ImportRepo + from metagpt.context import Context + + log_tool_output( + output=[ToolLogItem(name=ASSISTANT_ALIAS, value=import_git_repo.__name__)], tool_name=import_git_repo.__name__ + ) + + ctx = Context() + action = ImportRepo(repo_path=url, context=ctx) + await action.run() + + outputs = [ToolLogItem(name="MetaGPT Project", value=str(ctx.repo.workdir))] + log_tool_output(output=outputs, tool_name=import_git_repo.__name__) + + return ctx.repo.workdir + + +async def extract_external_interfaces(acknowledge: str) -> str: + """ + Extracts and compresses information about external system interfaces from a given acknowledgement text. + + Args: + acknowledge (str): A natural text of acknowledgement containing details about external system interfaces. + + Returns: + str: A compressed version of the information about external system interfaces. + + Example: + >>> acknowledge = "## Interfaces\\n..." + >>> external_interfaces = await extract_external_interfaces(acknowledge=acknowledge) + >>> print(external_interfaces) + ```json\n[\n{\n"id": 1,\n"inputs": {... + """ + compress_acknowledge = CompressExternalInterfaces() + return await compress_acknowledge.run(acknowledge=acknowledge) + + +async def mock_asearch_acknowledgement(use_case_actors: str): + return await aread(filename=TEST_DATA_PATH / "requirements/1.acknowledge.md") + + +async def write_trd( + use_case_actors: str, + user_requirements: str, + investment: float = 10, + context: Optional[Context] = None, +) -> str: + """ + Handles the writing of a Technical Requirements Document (TRD) based on user requirements. + + Args: + user_requirements (str): The new/incremental user requirements. + use_case_actors (str): Description of the actors involved in the use case. + investment (float): Budget. Automatically stops optimizing TRD when the budget is overdrawn. + context (Context, optional): The context configuration. Default is None. + Returns: + str: The newly created TRD. + + Example: + >>> # Given a new user requirements, write out a new TRD. + >>> user_requirements = "Write a 'snake game' TRD." + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> investment = 10.0 + >>> trd = await write_trd( + >>> user_requirements=user_requirements, + >>> use_case_actors=use_case_actors, + >>> investment=investment, + >>> ) + >>> print(trd) + ## Technical Requirements Document\n ... + """ + context = context or Context(cost_manager=CostManager(max_budget=investment)) + compress_acknowledge = CompressExternalInterfaces() + acknowledgement = await mock_asearch_acknowledgement(use_case_actors) # Replaced by acknowledgement_repo later. + external_interfaces = await compress_acknowledge.run(acknowledge=acknowledgement) + detect_interaction = DetectInteraction(context=context) + w_trd = WriteTRD(context=context) + evaluate_trd = EvaluateTRD(context=context) + is_pass = False + evaluation_conclusion = "" + interaction_events = "" + trd = "" + while not is_pass and (context.cost_manager.total_cost < context.cost_manager.max_budget): + interaction_events = await detect_interaction.run( + user_requirements=user_requirements, + use_case_actors=use_case_actors, + legacy_interaction_events=interaction_events, + evaluation_conclusion=evaluation_conclusion, + ) + trd = await w_trd.run( + user_requirements=user_requirements, + use_case_actors=use_case_actors, + available_external_interfaces=external_interfaces, + evaluation_conclusion=evaluation_conclusion, + interaction_events=interaction_events, + previous_version_trd=trd, + ) + evaluation = await evaluate_trd.run( + user_requirements=user_requirements, + use_case_actors=use_case_actors, + trd=trd, + interaction_events=interaction_events, + ) + is_pass = evaluation.is_pass + evaluation_conclusion = evaluation.conclusion + + return trd + + +async def write_framework( + use_case_actors: str, + trd: str, + additional_technical_requirements: str, + output_dir: Optional[str] = "", + investment: float = 20.0, + context: Optional[Context] = None, + max_loop: int = 20, +) -> str: + """ + Run the action to generate a software framework based on the provided TRD and related information. + + Args: + use_case_actors (str): Description of the use case actors involved. + trd (str): Technical Requirements Document detailing the requirements. + additional_technical_requirements (str): Any additional technical requirements. + output_dir (str, optional): Path to save the software framework files. Default is en empty string. + investment (float): Budget. Automatically stops optimizing TRD when the budget is overdrawn. + context (Context, optional): The context configuration. Default is None. + max_loop(int, optional): Acts as a safety exit valve when cost statistics fail. Default is 20. + + Returns: + str: The generated software framework as a string of pathnames. + + Example: + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> trd = "## TRD\\n..." + >>> additional_technical_requirements = "Using Java language, ..." + >>> investment = 15.0 + >>> framework = await write_framework( + >>> use_case_actors=use_case_actors, + >>> trd=trd, + >>> additional_technical_requirements=constraint, + >>> investment=investment, + >>> ) + >>> print(framework) + [{"path":"balabala", "filename":"...", ... + """ + context = context or Context(cost_manager=CostManager(max_budget=investment)) + write_framework = WriteFramework(context=context) + evaluate_framework = EvaluateFramework(context=context) + is_pass = False + framework = "" + evaluation_conclusion = "" + acknowledgement = await mock_asearch_acknowledgement(use_case_actors) # Replaced by acknowledgement_repo later. + loop_count = 0 + output_dir = ( + Path(output_dir) + if output_dir + else context.config.workspace.path / (datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:8]) + ) + file_list = [] + while not is_pass and (context.cost_manager.total_cost < context.cost_manager.max_budget): + try: + framework = await write_framework.run( + use_case_actors=use_case_actors, + trd=trd, + acknowledge=acknowledgement, + legacy_output=framework, + evaluation_conclusion=evaluation_conclusion, + additional_technical_requirements=additional_technical_requirements, + ) + except Exception as e: + logger.info(f"{e}") + break + evaluation = await evaluate_framework.run( + use_case_actors=use_case_actors, + trd=trd, + acknowledge=acknowledgement, + legacy_output=framework, + additional_technical_requirements=additional_technical_requirements, + ) + is_pass = evaluation.is_pass + evaluation_conclusion = evaluation.conclusion + loop_count += 1 + logger.info(f"Loop {loop_count}") + if context.cost_manager.total_cost < 1 and loop_count > max_loop: + break + file_list = await save_framework(dir_data=framework, trd=trd, output_dir=output_dir) + logger.info(f"Output:\n{file_list}") + + return "## Software Framework" + "".join([f"\n- {i}" for i in file_list]) + + +async def write_trd_and_framework( + use_case_actors: str, + user_requirements: str, + additional_technical_requirements: str, + investment: float = 50.0, + output_dir: Optional[str] = "", + context: Optional[Context] = None, +) -> str: + context = context or Context(cost_manager=CostManager(max_budget=investment)) + trd = await write_trd(use_case_actors=use_case_actors, user_requirements=user_requirements, context=context) + return await write_framework( + use_case_actors=use_case_actors, + trd=trd, + additional_technical_requirements=additional_technical_requirements, + output_dir=output_dir, + context=context, + ) diff --git a/metagpt/tools/libs/terminal.py b/metagpt/tools/libs/terminal.py new file mode 100644 index 0000000000..9f35450526 --- /dev/null +++ b/metagpt/tools/libs/terminal.py @@ -0,0 +1,269 @@ +import asyncio +import os +import re +from asyncio import Queue +from asyncio.subprocess import PIPE, STDOUT +from typing import Optional + +from metagpt.config2 import Config +from metagpt.const import DEFAULT_WORKSPACE_ROOT, SWE_SETUP_PATH +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.report import END_MARKER_VALUE, TerminalReporter + + +@register_tool() +class Terminal: + """ + A tool for running terminal commands. + Don't initialize a new instance of this class if one already exists. + For commands that need to be executed within a Conda environment, it is recommended + to use the `execute_in_conda_env` method. + """ + + def __init__(self): + self.shell_command = ["bash"] # FIXME: should consider windows support later + self.command_terminator = "\n" + self.stdout_queue = Queue(maxsize=1000) + self.observer = TerminalReporter() + self.process: Optional[asyncio.subprocess.Process] = None + # The cmd in forbidden_terminal_commands will be replace by pass ana return the advise. example:{"cmd":"forbidden_reason/advice"} + self.forbidden_commands = { + "run dev": "Use Deployer.deploy_to_public instead.", + # serve cmd have a space behind it, + "serve ": "Use Deployer.deploy_to_public instead.", + } + + async def _start_process(self): + # Start a persistent shell process + self.process = await asyncio.create_subprocess_exec( + *self.shell_command, + stdin=PIPE, + stdout=PIPE, + stderr=STDOUT, + executable="bash", + env=os.environ.copy(), + cwd=DEFAULT_WORKSPACE_ROOT.absolute(), + ) + await self._check_state() + + async def _check_state(self): + """ + Check the state of the terminal, e.g. the current directory of the terminal process. Useful for agent to understand. + """ + output = await self.run_command("pwd") + logger.info("The terminal is at:", output) + + async def run_command(self, cmd: str, daemon=False) -> str: + """ + Executes a specified command in the terminal and streams the output back in real time. + This command maintains state across executions, such as the current directory, + allowing for sequential commands to be contextually aware. + + Args: + cmd (str): The command to execute in the terminal. + daemon (bool): If True, executes the command in an asynchronous task, allowing + the main program to continue execution. + Returns: + str: The command's output or an empty string if `daemon` is True. Remember that + when `daemon` is True, use the `get_stdout_output` method to get the output. + """ + if self.process is None: + await self._start_process() + + output = "" + # Remove forbidden commands + commands = re.split(r"\s*&&\s*", cmd) + for cmd_name, reason in self.forbidden_commands.items(): + # "true" is a pass command in linux terminal. + for index, command in enumerate(commands): + if cmd_name in command: + output += f"Failed to execut {command}. {reason}\n" + commands[index] = "true" + cmd = " && ".join(commands) + + # Send the command + self.process.stdin.write((cmd + self.command_terminator).encode()) + self.process.stdin.write( + f'echo "{END_MARKER_VALUE}"{self.command_terminator}'.encode() # write EOF + ) # Unique marker to signal command end + await self.process.stdin.drain() + if daemon: + asyncio.create_task(self._read_and_process_output(cmd)) + else: + output += await self._read_and_process_output(cmd) + + return output + + async def execute_in_conda_env(self, cmd: str, env, daemon=False) -> str: + """ + Executes a given command within a specified Conda environment automatically without + the need for manual activation. Users just need to provide the name of the Conda + environment and the command to execute. + + Args: + cmd (str): The command to execute within the Conda environment. + env (str, optional): The name of the Conda environment to activate before executing the command. + If not specified, the command will run in the current active environment. + daemon (bool): If True, the command is run in an asynchronous task, similar to `run_command`, + affecting error logging and handling in the same manner. + + Returns: + str: The command's output, or an empty string if `daemon` is True, with output processed + asynchronously in that case. + + Note: + This function wraps `run_command`, prepending the necessary Conda activation commands + to ensure the specified environment is active for the command's execution. + """ + cmd = f"conda run -n {env} {cmd}" + return await self.run_command(cmd, daemon=daemon) + + async def get_stdout_output(self) -> str: + """ + Retrieves all collected output from background running commands and returns it as a string. + + Returns: + str: The collected output from background running commands, returned as a string. + """ + output_lines = [] + while not self.stdout_queue.empty(): + line = await self.stdout_queue.get() + output_lines.append(line) + return "\n".join(output_lines) + + async def _read_and_process_output(self, cmd, daemon=False) -> str: + async with self.observer as observer: + cmd_output = [] + await observer.async_report(cmd + self.command_terminator, "cmd") + # report the command + # Read the output until the unique marker is found. + # We read bytes directly from stdout instead of text because when reading text, + # '\r' is changed to '\n', resulting in excessive output. + tmp = b"" + while True: + output = tmp + await self.process.stdout.read(1) + if not output: + continue + *lines, tmp = output.splitlines(True) + for line in lines: + line = line.decode() + ix = line.rfind(END_MARKER_VALUE) + if ix >= 0: + line = line[0:ix] + if line: + await observer.async_report(line, "output") + # report stdout in real-time + cmd_output.append(line) + return "".join(cmd_output) + # log stdout in real-time + await observer.async_report(line, "output") + cmd_output.append(line) + if daemon: + await self.stdout_queue.put(line) + + async def close(self): + """Close the persistent shell process.""" + self.process.stdin.close() + await self.process.wait() + + +@register_tool(include_functions=["run"]) +class Bash(Terminal): + """ + A class to run bash commands directly and provides custom shell functions. + All custom functions in this class can ONLY be called via the `Bash.run` method. + """ + + def __init__(self): + """init""" + os.environ["SWE_CMD_WORK_DIR"] = str(Config.default().workspace.path) + super().__init__() + self.start_flag = False + + async def start(self): + await self.run_command(f"cd {Config.default().workspace.path}") + await self.run_command(f"source {SWE_SETUP_PATH}") + + async def run(self, cmd) -> str: + """ + Executes a bash command. + + Args: + cmd (str): The bash command to execute. + + Returns: + str: The output of the command. + + This method allows for executing standard bash commands as well as + utilizing several custom shell functions defined in the environment. + + Custom Shell Functions: + + - open [] + Opens the file at the given path in the editor. If line_number is provided, + the window will move to include that line. + Arguments: + path (str): The path to the file to open. + line_number (int, optional): The line number to move the window to. + If not provided, the window will start at the top of the file. + + - goto + Moves the window to show . + Arguments: + line_number (int): The line number to move the window to. + + - scroll_down + Moves the window down {WINDOW} lines. + + - scroll_up + Moves the window up {WINDOW} lines. + + - create + Creates and opens a new file with the given name. + Arguments: + filename (str): The name of the file to create. + + - search_dir_and_preview [] + Searches for search_term in all files in dir and gives their code preview + with line numbers. If dir is not provided, searches in the current directory. + Arguments: + search_term (str): The term to search for. + dir (str, optional): The directory to search in. Defaults to the current directory. + + - search_file [] + Searches for search_term in file. If file is not provided, searches in the current open file. + Arguments: + search_term (str): The term to search for. + file (str, optional): The file to search in. Defaults to the current open file. + + - find_file [] + Finds all files with the given name in dir. If dir is not provided, searches in the current directory. + Arguments: + file_name (str): The name of the file to search for. + dir (str, optional): The directory to search in. Defaults to the current directory. + + - edit : < + EOF + Line numbers start from 1. Replaces lines through (inclusive) with the given text in the open file. + The replacement text is terminated by a line with only EOF on it. All of the will be entered, so make + sure your indentation is formatted properly. Python files will be checked for syntax errors after the edit. If the system + detects a syntax error, the edit will not be executed. Simply try to edit the file again, but make sure to read the error + message and modify the edit command you issue accordingly. Issuing the same command a second time will just lead to the same + error message again. All code modifications made via the 'edit' command must strictly follow the PEP8 standard. + Arguments: + start_line (int): The line number to start the edit at, starting from 1. + end_line (int): The line number to end the edit at (inclusive), starting from 1. + replacement_text (str): The text to replace the current selection with, must conform to PEP8 standards. + + - submit + Submits your current code locally. it can only be executed once, the last action before the `end`. + + Note: Make sure to use these functions as per their defined arguments and behaviors. + """ + if not self.start_flag: + await self.start() + self.start_flag = True + + return await self.run_command(cmd) diff --git a/metagpt/tools/libs/web_scraping.py b/metagpt/tools/libs/web_scraping.py index bc34b13063..9e7a8041c5 100644 --- a/metagpt/tools/libs/web_scraping.py +++ b/metagpt/tools/libs/web_scraping.py @@ -1,20 +1,52 @@ +import contextlib +from uuid import uuid4 + +from metagpt.tools.libs.browser import Browser from metagpt.tools.tool_registry import register_tool -from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper +from metagpt.utils.file import MemoryFileSystem +from metagpt.utils.parse_html import simplify_html -@register_tool(tags=["web scraping", "web"]) -async def scrape_web_playwright(url): - """ - Asynchronously Scrape and save the HTML structure and inner text content of a web page using Playwright. +@register_tool(tags=["web scraping"]) +async def view_page_element_to_scrape(url: str, requirement: str, keep_links: bool = False) -> str: + """view the HTML content of current page to understand the structure. Args: - url (str): The main URL to fetch inner text from. - + url (str): The URL of the web page to scrape. + requirement (str): Providing a clear and detailed requirement helps in focusing the inspection on the desired elements. + keep_links (bool): Whether to keep the hyperlinks in the HTML content. Set to True if links are required Returns: - dict: The inner text content and html structure of the web page, keys are 'inner_text', 'html'. + str: The HTML content of the page. """ - # Create a PlaywrightWrapper instance for the Chromium browser - web = await PlaywrightWrapper().run(url) + async with Browser() as browser: + await browser.goto(url) + page = browser.page + html = await page.content() + html = simplify_html(html, url=page.url, keep_links=keep_links) + mem_fs = MemoryFileSystem() + filename = f"{uuid4().hex}.html" + with mem_fs.open(filename, "w") as f: + f.write(html) + + # Since RAG is an optional optimization, if it fails, the simplified HTML can be used as a fallback. + with contextlib.suppress(Exception): + from metagpt.rag.engines import SimpleEngine # avoid circular import + + # TODO make `from_docs` asynchronous + engine = SimpleEngine.from_docs(input_files=[filename], fs=mem_fs) + nodes = await engine.aretrieve(requirement) + html = "\n".join(i.text for i in nodes) + + mem_fs.rm_file(filename) + return html + - # Return the inner text content of the web page - return {"inner_text": web.inner_text.strip(), "html": web.html.strip()} +# async def get_elements_outerhtml(self, element_ids: list[int]): +# """Inspect the outer HTML of the elements in Current Browser Viewer. +# """ +# page = self.page +# data = [] +# for element_id in element_ids: +# html = await get_element_outer_html(page, get_backend_node_id(element_id, self.accessibility_tree)) +# data.append(html) +# return "\n".join(f"[{element_id}]. {html}" for element_id, html in zip(element_ids, data)) diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py index 66b5ba9503..2756a24c5b 100644 --- a/metagpt/tools/search_engine_googleapi.py +++ b/metagpt/tools/search_engine_googleapi.py @@ -26,6 +26,8 @@ class GoogleAPIWrapper(BaseModel): api_key: str cse_id: str + discovery_service_url: Optional[str] = None + loop: Optional[asyncio.AbstractEventLoop] = None executor: Optional[futures.Executor] = None proxy: Optional[str] = None @@ -56,7 +58,7 @@ def validate_google(cls, values: dict) -> dict: @property def google_api_client(self): - build_kwargs = {"developerKey": self.api_key} + build_kwargs = {"developerKey": self.api_key, "discoveryServiceUrl": self.discovery_service_url} if self.proxy: parse_result = urlparse(self.proxy) proxy_type = parse_result.scheme diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 5744b1b621..b3ccb06495 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -6,7 +6,7 @@ @File : search_engine_serpapi.py """ import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import aiohttp from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -24,6 +24,7 @@ class SerpAPIWrapper(BaseModel): "hl": "en", } ) + url: str = "https://serpapi.com/search" aiosession: Optional[aiohttp.ClientSession] = None proxy: Optional[str] = None @@ -49,22 +50,18 @@ async def run(self, query, max_results: int = 8, as_string: bool = True, **kwarg async def results(self, query: str, max_results: int) -> dict: """Use aiohttp to run query through SerpAPI and return the results async.""" - def construct_url_and_params() -> Tuple[str, Dict[str, str]]: - params = self.get_params(query) - params["source"] = "python" - params["num"] = max_results - params["output"] = "json" - url = "https://serpapi.com/search" - return url, params + params = self.get_params(query) + params["source"] = "python" + params["num"] = max_results + params["output"] = "json" - url, params = construct_url_and_params() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.get(url, params=params, proxy=self.proxy) as response: + async with session.get(self.url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get(url, params=params, proxy=self.proxy) as response: + async with self.aiosession.get(self.url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index ba2fb4f93d..932f2eb44b 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -7,7 +7,7 @@ """ import json import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import aiohttp from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -17,6 +17,7 @@ class SerperWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) api_key: str + url: str = "https://google.serper.dev/search" payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10}) aiosession: Optional[aiohttp.ClientSession] = None proxy: Optional[str] = None @@ -33,6 +34,7 @@ def validate_serper(cls, values: dict) -> dict: "To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain " "an API key from https://serper.dev/." ) + return values async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: @@ -46,20 +48,16 @@ async def run(self, query: str, max_results: int = 8, as_string: bool = True, ** async def results(self, queries: list[str], max_results: int = 8) -> dict: """Use aiohttp to run query through Serper and return the results async.""" - def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]: - payloads = self.get_payloads(queries, max_results) - url = "https://google.serper.dev/search" - headers = self.get_headers() - return url, payloads, headers + payloads = self.get_payloads(queries, max_results) + headers = self.get_headers() - url, payloads, headers = construct_url_and_payload_and_headers() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: + async with session.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: + async with self.aiosession.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() diff --git a/metagpt/tools/swe_agent_commands/__init__.py b/metagpt/tools/swe_agent_commands/__init__.py new file mode 100644 index 0000000000..c0d3e2a602 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/__init__.py @@ -0,0 +1,7 @@ +""" +This folder is borrowed from princeton-nlp/SWE-agent +You can find the original repository here: +https://github.com/princeton-nlp/SWE-agent/tree/main/config/commands +We are using a modified version from OpenDevin: +https://github.com/OpenDevin/OpenDevin/tree/main/opendevin/runtime/plugins/swe_agent_commands +""" diff --git a/metagpt/tools/swe_agent_commands/_setup_default_env.sh b/metagpt/tools/swe_agent_commands/_setup_default_env.sh new file mode 100644 index 0000000000..8fb4a379e2 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/_setup_default_env.sh @@ -0,0 +1,20 @@ +# _setup_default_env.sh +# Default Mode from SWE-Bench +# https://github.com/princeton-nlp/SWE-agent/blob/ca54d5556b9db4f4f2be21f09530ce69a72c0305/config/configs/default_sys-env_window100-detailed_cmd_format-last_5_history-1_demos.yaml + +export WINDOW=100 +export OVERLAP=2 +export CURRENT_LINE=0 +export CURRENT_FILE='' +export SEARCH_RESULTS=() +export SEARCH_FILES=() +export SEARCH_INDEX=0 + +state() { + local working_dir="$PWD" + if [ ! -e "$CURRENT_FILE" ]; then + echo '{"open_file": "n/a", "working_dir": "'$working_dir'"}' + else + echo '{"open_file": "'$(realpath "$CURRENT_FILE")'", "working_dir": "'$working_dir'"}' + fi +} diff --git a/metagpt/tools/swe_agent_commands/_split_string b/metagpt/tools/swe_agent_commands/_split_string new file mode 100755 index 0000000000..ecc363e718 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/_split_string @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import sys + + +def print_flake8_output(input_string, show_line_numbers=False): + for value in input_string.split("\n"): + parts = value.split() + if not show_line_numbers: + print(f"- {' '.join(parts[1:])}") + else: + line_nums = ":".join(parts[0].split(":")[1:]) + print(f"- {line_nums} {' '.join(parts[1:])}") + + +if __name__ == "__main__": + lint_output = sys.argv[1] + print_flake8_output(lint_output) diff --git a/metagpt/tools/swe_agent_commands/_split_string.py b/metagpt/tools/swe_agent_commands/_split_string.py new file mode 100755 index 0000000000..ecc363e718 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/_split_string.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import sys + + +def print_flake8_output(input_string, show_line_numbers=False): + for value in input_string.split("\n"): + parts = value.split() + if not show_line_numbers: + print(f"- {' '.join(parts[1:])}") + else: + line_nums = ":".join(parts[0].split(":")[1:]) + print(f"- {line_nums} {' '.join(parts[1:])}") + + +if __name__ == "__main__": + lint_output = sys.argv[1] + print_flake8_output(lint_output) diff --git a/metagpt/tools/swe_agent_commands/defaults.sh b/metagpt/tools/swe_agent_commands/defaults.sh new file mode 100644 index 0000000000..d416dcbf54 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/defaults.sh @@ -0,0 +1,192 @@ +_print() { + local total_lines=$(awk 'END {print NR}' $CURRENT_FILE) + echo "[File: $(realpath $CURRENT_FILE) ($total_lines lines total)]" + lines_above=$(jq -n "$CURRENT_LINE - $WINDOW/2" | jq '[0, .] | max | floor') + lines_below=$(jq -n "$total_lines - $CURRENT_LINE - $WINDOW/2" | jq '[0, .] | max | round') + if [ $lines_above -gt 0 ]; then + echo "($lines_above more lines above)" + fi + cat $CURRENT_FILE | grep -n $ | head -n $(jq -n "[$CURRENT_LINE + $WINDOW/2, $WINDOW/2] | max | floor") | tail -n $(jq -n "$WINDOW") + if [ $lines_below -gt 0 ]; then + echo "($lines_below more lines below)" + fi +} + +_constrain_line() { + if [ -z "$CURRENT_FILE" ] + then + echo "No file open. Use the open command first." + return + fi + local max_line=$(awk 'END {print NR}' $CURRENT_FILE) + local half_window=$(jq -n "$WINDOW/2" | jq 'floor') + export CURRENT_LINE=$(jq -n "[$CURRENT_LINE, $max_line - $half_window] | min") + export CURRENT_LINE=$(jq -n "[$CURRENT_LINE, $half_window] | max") +} + +# @yaml +# signature: open [] +# docstring: opens the file at the given path in the editor. If line_number is provided, the window will be move to include that line +# arguments: +# path: +# type: string +# description: the path to the file to open +# required: true +# line_number: +# type: integer +# description: the line number to move the window to (if not provided, the window will start at the top of the file) +# required: false +open() { + if [ -z "$1" ] + then + echo "Usage: open " + return + fi + # Check if the second argument is provided + if [ -n "$2" ]; then + # Check if the provided argument is a valid number + if ! [[ $2 =~ ^[0-9]+$ ]]; then + echo "Usage: open []" + echo "Error: must be a number" + return # Exit if the line number is not valid + fi + local max_line=$(awk 'END {print NR}' $1) + if [ $2 -gt $max_line ]; then + echo "Warning: ($2) is greater than the number of lines in the file ($max_line)" + echo "Warning: Setting to $max_line" + local line_number=$(jq -n "$max_line") # Set line number to max if greater than max + elif [ $2 -lt 1 ]; then + echo "Warning: ($2) is less than 1" + echo "Warning: Setting to 1" + local line_number=$(jq -n "1") # Set line number to 1 if less than 1 + else + local OFFSET=$(jq -n "$WINDOW/6" | jq 'floor') + local line_number=$(jq -n "[$2 + $WINDOW/2 - $OFFSET, 1] | max | floor") + fi + else + local line_number=$(jq -n "$WINDOW/2") # Set default line number if not provided + fi + + if [ -f "$1" ]; then + export CURRENT_FILE=$(realpath $1) + export CURRENT_LINE=$line_number + _constrain_line + _print + elif [ -d "$1" ]; then + echo "Error: $1 is a directory. You can only open files. Use cd or ls to navigate directories." + else + echo "File $1 not found" + fi +} + +# @yaml +# signature: goto +# docstring: moves the window to show +# arguments: +# line_number: +# type: integer +# description: the line number to move the window to +# required: true +goto() { + if [ $# -gt 1 ]; then + echo "goto allows only one line number at a time." + return + fi + if [ -z "$CURRENT_FILE" ] + then + echo "No file open. Use the open command first." + return + fi + if [ -z "$1" ] + then + echo "Usage: goto " + return + fi + if ! [[ $1 =~ ^[0-9]+$ ]] + then + echo "Usage: goto " + echo "Error: must be a number" + return + fi + local max_line=$(awk 'END {print NR}' $CURRENT_FILE) + if [ $1 -gt $max_line ] + then + echo "Error: must be less than or equal to $max_line" + return + fi + local OFFSET=$(jq -n "$WINDOW/6" | jq 'floor') + export CURRENT_LINE=$(jq -n "[$1 + $WINDOW/2 - $OFFSET, 1] | max | floor") + _constrain_line + _print +} + +# @yaml +# signature: scroll_down +# docstring: moves the window down {WINDOW} lines +scroll_down() { + if [ -z "$CURRENT_FILE" ] + then + echo "No file open. Use the open command first." + return + fi + export CURRENT_LINE=$(jq -n "$CURRENT_LINE + $WINDOW - $OVERLAP") + _constrain_line + _print +} + +# @yaml +# signature: scroll_up +# docstring: moves the window down {WINDOW} lines +scroll_up() { + if [ -z "$CURRENT_FILE" ] + then + echo "No file open. Use the open command first." + return + fi + export CURRENT_LINE=$(jq -n "$CURRENT_LINE - $WINDOW + $OVERLAP") + _constrain_line + _print +} + +# @yaml +# signature: create +# docstring: creates and opens a new file with the given name +# arguments: +# filename: +# type: string +# description: the name of the file to create +# required: true +create() { + if [ -z "$1" ]; then + echo "Usage: create " + return + fi + + # Check if the file already exists + if [ -e "$1" ]; then + echo "Error: File '$1' already exists." + open "$1" + return + fi + + # Create the file an empty new line + printf "\n" > "$1" + # Use the existing open command to open the created file + open "$1" +} + +# @yaml +# signature: submit +# docstring: submits your current code. the last action before the `end`, it can only be executed once. +submit() { + # Check if the patch file exists and is non-empty + if [ -s "$SWE_CMD_WORK_DIR/test.patch" ]; then + # Apply the patch in reverse + git apply -R < "$SWE_CMD_WORK_DIR/test.patch" + fi + + git add -A + echo "<>" +} diff --git a/metagpt/tools/swe_agent_commands/edit_linting.sh b/metagpt/tools/swe_agent_commands/edit_linting.sh new file mode 100644 index 0000000000..e6d675ada0 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/edit_linting.sh @@ -0,0 +1,165 @@ +# @yaml +# signature: |- +# edit : < +# EOF +# docstring: Line numbers start from 1. Replaces lines through (inclusive) with the given text in the open file. The replacement text is terminated by a line with only EOF on it. All of the will be entered, so make sure your indentation is formatted properly. Python files will be checked for syntax errors after the edit. If the system detects a syntax error, the edit will not be executed. Simply try to edit the file again, but make sure to read the error message and modify the edit command you issue accordingly. Issuing the same command a second time will just lead to the same error message again. All code modifications made via the 'edit' command must strictly follow the PEP8 standard. +# end_name: EOF +# arguments: +# start_line: +# type: integer +# description: the line number to start the edit at, start from 1. +# required: true +# end_line: +# type: integer +# description: the line number to end the edit at (inclusive), start from 1. +# required: true +# replacement_text: +# type: string +# description: the text to replace the current selection with must conform to PEP8 standards. +# required: true +edit() { + if [ -z "$CURRENT_FILE" ] + then + echo 'No file open. Use the `open` command first.' + return + fi + + local start_line="$(echo $1: | cut -d: -f1)" + local end_line="$(echo $1: | cut -d: -f2)" + + if [ -z "$start_line" ] || [ -z "$end_line" ] + then + echo "Usage: edit :" + return + fi + + local re='^[0-9]+$' + if ! [[ $start_line =~ $re ]]; then + echo "Usage: edit :" + echo "Error: start_line must be a number" + return + fi + if ! [[ $end_line =~ $re ]]; then + echo "Usage: edit :" + echo "Error: end_line must be a number" + return + fi + + # Run linter for original file + if [[ $CURRENT_FILE == *.py ]]; then + original_lint_output=$(flake8 --isolated --select=F821,F822,F831,E112,E113,E999,E902 "$CURRENT_FILE" 2>&1) + else + # do nothing + original_lint_output="" + fi + + + # Bash array starts at 0, so let's adjust + local start_line=$((start_line - 1)) + local end_line=$((end_line)) + + local line_count=0 + local replacement=() + while IFS= read -r line + do + replacement+=("$line") + ((line_count++)) + done + + # Create a backup of the current file + cp "$CURRENT_FILE" "$SWE_CMD_WORK_DIR/$(basename "$CURRENT_FILE")_backup" + + # Read the file line by line into an array + mapfile -t lines < "$CURRENT_FILE" + local new_lines=("${lines[@]:0:$start_line}" "${replacement[@]}" "${lines[@]:$((end_line))}") + # Write the new stuff directly back into the original file + printf "%s\n" "${new_lines[@]}" >| "$CURRENT_FILE" + + # Run linter + if [[ $CURRENT_FILE == *.py ]]; then + lint_output=$(flake8 --isolated --select=F821,F822,F831,E112,E113,E999,E902 "$CURRENT_FILE" 2>&1) + else + # do nothing + lint_output="" + fi + + # Create temporary files + temp_original=$(mktemp) + temp_modified=$(mktemp) + + # Remove line numbers and save cleaned outputs to temporary files + echo "$original_lint_output" | sed 's/:[0-9]\+:[0-9]\+:/:LINE:COL:/g' > "$temp_original" + echo "$lint_output" | sed 's/:[0-9]\+:[0-9]\+:/:LINE:COL:/g' > "$temp_modified" + + + # Compare the temporary files + if cmp -s "$temp_original" "$temp_modified"; then + lint_output="" + else + echo "Linter output for the original file:" + cat "$temp_original" + # print linter result + echo "Linter output for the modified file:" + cat "$temp_modified" + fi + + # Clean up temporary files + rm "$temp_original" "$temp_modified" + + # if there is no output, then the file is good + if [ -z "$lint_output" ]; then + export CURRENT_LINE=$start_line + _constrain_line + _print + + echo "File updated. Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary." + else + echo "Your proposed edit has introduced new syntax error(s). Please understand the fixes and retry your edit command." + echo "" + echo "ERRORS:" + _split_string "$lint_output" + echo "" + + # Save original values + original_current_line=$CURRENT_LINE + original_window=$WINDOW + + # Update values + export CURRENT_LINE=$(( (line_count / 2) + start_line )) # Set to "center" of edit + export WINDOW=$((line_count + 10)) # Show +/- 5 lines around edit + + echo "This is how your edit would have looked if applied" + echo "-------------------------------------------------" + _constrain_line + _print + echo "-------------------------------------------------" + echo "" + + + # Restoring CURRENT_FILE to original contents. + cp "$SWE_CMD_WORK_DIR/$(basename "$CURRENT_FILE")_backup" "$CURRENT_FILE" + + export CURRENT_LINE=$(( ((end_line - start_line + 1) / 2) + start_line )) + export WINDOW=$((end_line - start_line + 10)) + + echo "This is the original code before your edit" + echo "-------------------------------------------------" + _constrain_line + _print + echo "-------------------------------------------------" +# + + # Restore original values + export CURRENT_LINE=$original_current_line + export WINDOW=$original_window + + echo "Your changes have NOT been applied. Please fix your edit command and try again." + echo "You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code." + echo "DO NOT re-run the same failed edit command. Running it again will lead to the same error." + fi + + + # Remove backup file + rm -f "$SWE_CMD_WORK_DIR/$(basename "$CURRENT_FILE")_backup" +} diff --git a/metagpt/tools/swe_agent_commands/search.sh b/metagpt/tools/swe_agent_commands/search.sh new file mode 100644 index 0000000000..b973b2d12b --- /dev/null +++ b/metagpt/tools/swe_agent_commands/search.sh @@ -0,0 +1,245 @@ +# @yaml +# signature: search_dir_and_preview [] +# docstring: searches for search_term in all files in dir and give their code preview with line number if you think need a first look. The output will vary depending on the length of the search results, but the file path, line number & corresponding code or number of occurrences will always be output. If dir is not provided, searches in the current directory +# arguments: +# search_term: +# type: string +# description: the term to search for +# required: true +# dir: +# type: string +# description: the directory to search in (if not provided, searches in the current directory) +# required: false +search_dir_and_preview() { + if [ $# -eq 1 ]; then + local search_term="$1" + local dir="./" + elif [ $# -eq 2 ]; then + local search_term="$1" + if [ -d "$2" ]; then + local dir="$2" + else + echo "Directory $2 not found" + return + fi + else + echo "Usage: search_dir_and_preview []" + return + fi + dir=$(realpath "$dir") + local matches=$(find "$dir" -type f -path '*.py' -exec grep -nIH -- "$search_term" {} + | cut -d: -f1 | sort | uniq -c) +< 100, print an error + if [ $num_files -gt 100 ]; then + echo "More than $num_files files matched for \"$search_term\" in $dir. Please narrow your search." + return + fi + + match_with_cnt=$(echo "$matches" | awk '{$2=$2; gsub(/^\.+\/+/, "./", $2); print $2 " ("$1" matches)"}') +< [] +# docstring: searches for search_term in file. If file is not provided, searches in the current open file +# arguments: +# search_term: +# type: string +# description: the term to search for +# required: true +# file: +# type: string +# description: the file to search in (if not provided, searches in the current open file) +# required: false +search_file() { + # Check if the first argument is provided + if [ -z "$1" ]; then + echo "Usage: search_file []" + return + fi + # Check if the second argument is provided + if [ -n "$2" ]; then + # Check if the provided argument is a valid file + if [ -f "$2" ]; then + local file="$2" # Set file if valid + else + echo "Usage: search_file []" + echo "Error: File name $2 not found. Please provide a valid file name." + return # Exit if the file is not valid + fi + else + # Check if a file is open + if [ -z "$CURRENT_FILE" ]; then + echo "No file open. Use the open command first." + return # Exit if no file is open + fi + local file="$CURRENT_FILE" # Set file to the current open file + fi + local search_term="$1" + file=$(realpath "$file") + # Use grep to directly get the desired formatted output + local matches=$(grep -nH -- "$search_term" "$file") + # Check if no matches were found + if [ -z "$matches" ]; then + echo "No matches found for \"$search_term\" in $file" + return + fi + # Calculate total number of matches + local num_matches=$(echo "$matches" | wc -l | awk '{$1=$1; print $0}') + + # calculate total number of lines matched + local num_lines=$(echo "$matches" | cut -d: -f1 | sort | uniq | wc -l | awk '{$1=$1; print $0}') + # if num_lines is > 100, print an error + if [ $num_lines -gt 100 ]; then + echo "More than $num_lines lines matched for \"$search_term\" in $file. Please narrow your search." + return + fi + + # Print the total number of matches and the matches themselves + echo "Found $num_matches matches for \"$search_term\" in $file:" + echo "$matches" | cut -d: -f1-2 | sort -u -t: -k2,2n | while IFS=: read -r filename line_number; do + echo "Line $line_number:$(sed -n "${line_number}p" "$file")" + done + echo "End of matches for \"$search_term\" in $file" +} + +# @yaml +# signature: find_file [] +# docstring: finds all files with the given name in dir. If dir is not provided, searches in the current directory +# arguments: +# file_name: +# type: string +# description: the name of the file to search for +# required: true +# dir: +# type: string +# description: the directory to search in (if not provided, searches in the current directory) +# required: false +find_file() { + if [ $# -eq 1 ]; then + local file_name="$1" + local dir="./" + elif [ $# -eq 2 ]; then + local file_name="$1" + if [ -d "$2" ]; then + local dir="$2" + else + echo "Directory $2 not found" + return + fi + else + echo "Usage: find_file []" + return + fi + + dir=$(realpath "$dir") + local matches=$(find "$dir" -type f -name "$file_name") + # if no matches, return + if [ -z "$matches" ]; then + echo "No matches found for \"$file_name\" in $dir" + return + fi + # Calculate total number of matches + local num_matches=$(echo "$matches" | wc -l | awk '{$1=$1; print $0}') + echo "Found $num_matches matches for \"$file_name\" in $dir:" + echo "$matches" | awk '{print $0}' +} diff --git a/metagpt/tools/swe_agent_commands/setup_default.sh b/metagpt/tools/swe_agent_commands/setup_default.sh new file mode 100644 index 0000000000..2656500012 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/setup_default.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +pip install flake8 + +# Default Mode from SWE-Bench +# https://github.com/princeton-nlp/SWE-agent/blob/ca54d5556b9db4f4f2be21f09530ce69a72c0305/config/configs/default_sys-env_window100-detailed_cmd_format-last_5_history-1_demos.yaml#L103-L106 +SCRIPT_PATH="${BASH_SOURCE[0]}" # use BASH_SOURCE to avoid the influence of `source *.sh which cause CUR_DIR=/bin` +CUR_DIR=$(dirname $(readlink -f $SCRIPT_PATH)) +REPO_ROOT_DIR=$CUR_DIR"/../../.." +source $REPO_ROOT_DIR/metagpt/tools/swe_agent_commands/_setup_default_env.sh + +# make _split_string (py) available +export PATH=$PATH:$REPO_ROOT_DIR/metagpt/tools/swe_agent_commands + +source $REPO_ROOT_DIR/metagpt/tools/swe_agent_commands/defaults.sh +source $REPO_ROOT_DIR/metagpt/tools/swe_agent_commands/search.sh +source $REPO_ROOT_DIR/metagpt/tools/swe_agent_commands/edit_linting.sh + +echo "SWE_CMD_WORK_DIR: $SWE_CMD_WORK_DIR" diff --git a/metagpt/tools/swe_agent_commands/swe_agent_utils.py b/metagpt/tools/swe_agent_commands/swe_agent_utils.py new file mode 100644 index 0000000000..9e293f4d28 --- /dev/null +++ b/metagpt/tools/swe_agent_commands/swe_agent_utils.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import numpy as np +from datasets import load_dataset, load_from_disk + + +def extract_patch(command_output): + patch_lines = [] + recording = False + for line in command_output.split("\n"): + if line.startswith("diff --git"): + recording = True + if recording: + patch_lines.append(line) + return "\n".join(patch_lines) + + +def load_hf_dataset(dataset_name_or_path: str, cache_dir, split: str = "test", existing_ids: list = []): + data_dir = cache_dir / dataset_name_or_path + if Path(data_dir).exists(): + dataset = load_from_disk(data_dir) + else: + dataset = load_dataset(dataset_name_or_path) + dataset.save_to_disk(data_dir) + print(dataset) + if split not in dataset: + raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}") + dataset = dataset[split] + np.array(list(map(len, dataset["instance_id"]))) + + if existing_ids: + dataset = dataset.filter( + lambda x: x["instance_id"] not in existing_ids, + desc="Filtering out existing ids", + load_from_cache_file=False, + ) + + return dataset diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py index 42c65b9e77..a84cbeea07 100644 --- a/metagpt/tools/tool_convert.py +++ b/metagpt/tools/tool_convert.py @@ -1,3 +1,4 @@ +import ast import inspect from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces @@ -5,9 +6,10 @@ PARSER = GoogleDocstringParser -def convert_code_to_tool_schema(obj, include: list[str] = None): +def convert_code_to_tool_schema(obj, include: list[str] = None) -> dict: + """Converts an object (function or class) to a tool schema by inspecting the object""" docstring = inspect.getdoc(obj) - assert docstring, "no docstring found for the objects, skip registering" + # assert docstring, "no docstring found for the objects, skip registering" if inspect.isclass(obj): schema = {"type": "class", "description": remove_spaces(docstring), "methods": {}} @@ -18,8 +20,7 @@ def convert_code_to_tool_schema(obj, include: list[str] = None): continue # method_doc = inspect.getdoc(method) method_doc = get_class_method_docstring(obj, name) - if method_doc: - schema["methods"][name] = function_docstring_to_schema(method, method_doc) + schema["methods"][name] = function_docstring_to_schema(method, method_doc) elif inspect.isfunction(obj): schema = function_docstring_to_schema(obj, docstring) @@ -27,7 +28,17 @@ def convert_code_to_tool_schema(obj, include: list[str] = None): return schema -def function_docstring_to_schema(fn_obj, docstring) -> dict: +def convert_code_to_tool_schema_ast(code: str) -> list[dict]: + """Converts a code string to a list of tool schemas by parsing the code with AST""" + + visitor = CodeVisitor(code) + parsed_code = ast.parse(code) + visitor.visit(parsed_code) + + return visitor.get_tool_schemas() + + +def function_docstring_to_schema(fn_obj, docstring="") -> dict: """ Converts a function's docstring into a schema dictionary. @@ -62,3 +73,67 @@ def get_class_method_docstring(cls, method_name): if method.__doc__: return method.__doc__ return None # No docstring found in the class hierarchy + + +class CodeVisitor(ast.NodeVisitor): + """Visit and convert the AST nodes within a code file to tool schemas""" + + def __init__(self, source_code: str): + self.tool_schemas = {} # {tool_name: tool_schema} + self.source_code = source_code + + def visit_ClassDef(self, node): + class_schemas = {"type": "class", "description": remove_spaces(ast.get_docstring(node)), "methods": {}} + for body_node in node.body: + if isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and ( + not body_node.name.startswith("_") or body_node.name == "__init__" + ): + func_schemas = self._get_function_schemas(body_node) + class_schemas["methods"].update({body_node.name: func_schemas}) + class_schemas["code"] = ast.get_source_segment(self.source_code, node) + self.tool_schemas[node.name] = class_schemas + + def visit_FunctionDef(self, node): + self._visit_function(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_function(node) + + def _visit_function(self, node): + if node.name.startswith("_"): + return + function_schemas = self._get_function_schemas(node) + function_schemas["code"] = ast.get_source_segment(self.source_code, node) + self.tool_schemas[node.name] = function_schemas + + def _get_function_schemas(self, node): + docstring = remove_spaces(ast.get_docstring(node)) + overall_desc, param_desc = PARSER.parse(docstring) + return { + "type": "async_function" if isinstance(node, ast.AsyncFunctionDef) else "function", + "description": overall_desc, + "signature": self._get_function_signature(node), + "parameters": param_desc, + } + + def _get_function_signature(self, node): + args = [] + defaults = dict(zip([arg.arg for arg in node.args.args][-len(node.args.defaults) :], node.args.defaults)) + for arg in node.args.args: + arg_str = arg.arg + if arg.annotation: + annotation = ast.unparse(arg.annotation) + arg_str += f": {annotation}" + if arg.arg in defaults: + default_value = ast.unparse(defaults[arg.arg]) + arg_str += f" = {default_value}" + args.append(arg_str) + + return_annotation = "" + if node.returns: + return_annotation = f" -> {ast.unparse(node.returns)}" + + return f"({', '.join(args)}){return_annotation}" + + def get_tool_schemas(self): + return self.tool_schemas diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index 69b9a4b5d3..25f403c772 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -1,20 +1,22 @@ from __future__ import annotations import json +import traceback from typing import Any -import jieba import numpy as np from pydantic import BaseModel, field_validator from rank_bm25 import BM25Okapi from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.prompts.di.role_zero import JSON_REPAIR_PROMPT from metagpt.schema import Plan from metagpt.tools import TOOL_REGISTRY from metagpt.tools.tool_data_type import Tool from metagpt.tools.tool_registry import validate_tool_names from metagpt.utils.common import CodeParser +from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output TOOL_INFO_PROMPT = """ ## Capabilities @@ -62,6 +64,10 @@ class ToolRecommender(BaseModel): @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v: list[str]) -> dict[str, Tool]: + # If `v` is already a dictionary (e.g., during deserialization), return it as is. + if isinstance(v, dict): + return v + # One can use special symbol [""] to indicate use of all registered tools if v == [""]: return TOOL_REGISTRY.get_all_tools() @@ -102,11 +108,13 @@ async def recommend_tools( return ranked_tools - async def get_recommended_tool_info(self, **kwargs) -> str: + async def get_recommended_tool_info(self, fixed: list[str] = None, **kwargs) -> str: """ Wrap recommended tools with their info in a string, which can be used directly in a prompt. """ recommended_tools = await self.recommend_tools(**kwargs) + if fixed: + recommended_tools.extend([self.tools[tool_name] for tool_name in fixed if tool_name in self.tools]) if not recommended_tools: return "" tool_schemas = {tool.name: tool.schemas for tool in recommended_tools} @@ -132,9 +140,30 @@ async def rank_tools( available_tools=available_tools, topk=topk, ) - rsp = await LLM().aask(prompt) - rsp = CodeParser.parse_code(block=None, text=rsp) - ranked_tools = json.loads(rsp) + rsp = await LLM().aask(prompt, stream=False) + + # 临时方案,待role zero的版本完成可将本注释内的代码直接替换掉 + # -------------开始--------------- + try: + ranked_tools = CodeParser.parse_code(block=None, lang="json", text=rsp) + ranked_tools = json.loads( + repair_llm_raw_output(output=ranked_tools, req_keys=[None], repair_type=RepairType.JSON) + ) + except json.JSONDecodeError: + ranked_tools = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=rsp)) + ranked_tools = json.loads(CodeParser.parse_code(block=None, lang="json", text=ranked_tools)) + except Exception: + tb = traceback.format_exc() + print(tb) + + # 为了对LLM不按格式生成进行容错 + if isinstance(ranked_tools, dict): + ranked_tools = list(ranked_tools.values())[0] + # -------------结束--------------- + + if not isinstance(ranked_tools, list): + logger.warning(f"Invalid rank result: {ranked_tools}, will use the recalled tools instead.") + ranked_tools = list(available_tools.keys()) valid_tools = validate_tool_names(ranked_tools) @@ -182,7 +211,7 @@ def _init_corpus(self): self.bm25 = BM25Okapi(tokenized_corpus) def _tokenize(self, text): - return jieba.lcut(text) # FIXME: needs more sophisticated tokenization + return text.split() # FIXME: needs more sophisticated tokenization async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]: query = plan.current_task.instruction if plan else context @@ -193,7 +222,7 @@ async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 2 recalled_tools = [list(self.tools.values())[index] for index in top_indexes] logger.info( - f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}" + f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[np.round(doc_scores[index], 4) for index in top_indexes]}" ) return recalled_tools diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 11269cb0fc..49820b458e 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -7,17 +7,20 @@ """ from __future__ import annotations +import contextlib import inspect import os from collections import defaultdict -from typing import Union +from pathlib import Path -import yaml from pydantic import BaseModel from metagpt.const import TOOL_SCHEMA_PATH from metagpt.logs import logger -from metagpt.tools.tool_convert import convert_code_to_tool_schema +from metagpt.tools.tool_convert import ( + convert_code_to_tool_schema, + convert_code_to_tool_schema_ast, +) from metagpt.tools.tool_data_type import Tool, ToolSchema @@ -27,21 +30,23 @@ class ToolRegistry(BaseModel): def register_tool( self, - tool_name, - tool_path, - schema_path="", - tool_code="", - tags=None, - tool_source_object=None, - include_functions=None, - verbose=False, + tool_name: str, + tool_path: str, + schemas: dict = None, + schema_path: str = "", + tool_code: str = "", + tags: list[str] = None, + tool_source_object=None, # can be any classes or functions + include_functions: list[str] = None, + verbose: bool = False, ): if self.has_tool(tool_name): return schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml" - schemas = make_schema(tool_source_object, include_functions, schema_path) + if not schemas: + schemas = make_schema(tool_source_object, include_functions, schema_path) if not schemas: return @@ -95,7 +100,9 @@ def decorator(cls): if "metagpt" in file_path: # split to handle ../metagpt/metagpt/tools/... where only metapgt/tools/... is needed file_path = "metagpt" + file_path.split("metagpt")[-1] - source_code = inspect.getsource(cls) + source_code = "" + with contextlib.suppress(OSError): + source_code = inspect.getsource(cls) TOOL_REGISTRY.register_tool( tool_name=cls.__name__, @@ -112,14 +119,8 @@ def decorator(cls): def make_schema(tool_source_object, include, path): - os.makedirs(os.path.dirname(path), exist_ok=True) # Create the necessary directories try: schema = convert_code_to_tool_schema(tool_source_object, include=include) - with open(path, "w", encoding="utf-8") as f: - yaml.dump(schema, f, sort_keys=False) - # import json - # with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f: - # json.dump(schema, f, ensure_ascii=False, indent=4) except Exception as e: schema = {} logger.error(f"Fail to make schema: {e}") @@ -127,15 +128,67 @@ def make_schema(tool_source_object, include, path): return schema -def validate_tool_names(tools: Union[list[str], str]) -> str: +def validate_tool_names(tools: list[str]) -> dict[str, Tool]: assert isinstance(tools, list), "tools must be a list of str" valid_tools = {} for key in tools: - # one can define either tool names or tool type names, take union to get the whole set - if TOOL_REGISTRY.has_tool(key): - valid_tools.update({key: TOOL_REGISTRY.get_tool(key)}) + # one can define either tool names OR tool tags OR tool path, take union to get the whole set + # if tool paths are provided, they will be registered on the fly + if os.path.isdir(key) or os.path.isfile(key): + valid_tools.update(register_tools_from_path(key)) + elif TOOL_REGISTRY.has_tool(key.split(":")[0]): + if ":" in key: + # handle class tools with methods specified, such as Editor:read,write + class_tool_name = key.split(":")[0] + method_names = key.split(":")[1].split(",") + class_tool = TOOL_REGISTRY.get_tool(class_tool_name) + + methods_filtered = {} + for method_name in method_names: + if method_name in class_tool.schemas["methods"]: + methods_filtered[method_name] = class_tool.schemas["methods"][method_name] + else: + logger.warning(f"invalid method {method_name} under tool {class_tool_name}, skipped") + class_tool_filtered = class_tool.model_copy(deep=True) + class_tool_filtered.schemas["methods"] = methods_filtered + + valid_tools.update({class_tool_name: class_tool_filtered}) + + else: + valid_tools.update({key: TOOL_REGISTRY.get_tool(key)}) elif TOOL_REGISTRY.has_tool_tag(key): valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key)) else: logger.warning(f"invalid tool name or tool type name: {key}, skipped") return valid_tools + + +def register_tools_from_file(file_path) -> dict[str, Tool]: + file_name = Path(file_path).name + if not file_name.endswith(".py") or file_name == "setup.py" or file_name.startswith("test"): + return {} + registered_tools = {} + code = Path(file_path).read_text(encoding="utf-8") + tool_schemas = convert_code_to_tool_schema_ast(code) + for name, schemas in tool_schemas.items(): + tool_code = schemas.pop("code", "") + TOOL_REGISTRY.register_tool( + tool_name=name, + tool_path=file_path, + schemas=schemas, + tool_code=tool_code, + ) + registered_tools.update({name: TOOL_REGISTRY.get_tool(name)}) + return registered_tools + + +def register_tools_from_path(path) -> dict[str, Tool]: + tools_registered = {} + if os.path.isfile(path): + tools_registered.update(register_tools_from_file(path)) + elif os.path.isdir(path): + for root, _, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + tools_registered.update(register_tools_from_file(file_path)) + return tools_registered diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 243871aff3..9e67a35858 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,7 +4,7 @@ import json from pathlib import Path -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.provider.openai_api import OpenAILLM as GPTAPI from metagpt.utils.common import awrite @@ -282,6 +282,7 @@ async def gpt_msgs_to_code(self, messages: list) -> str: """Choose based on different calling methods""" result = "" if self.chatgpt_method == "API": + config = Config.default() result = await GPTAPI(config.get_openai_llm()).aask_code(messages=messages) return result diff --git a/metagpt/tools/web_browser_engine.py b/metagpt/tools/web_browser_engine.py index 01339e51a4..a65bf29bc6 100644 --- a/metagpt/tools/web_browser_engine.py +++ b/metagpt/tools/web_browser_engine.py @@ -92,14 +92,14 @@ def from_browser_config(cls, config: BrowserConfig, **kwargs): return cls(**data, **kwargs) @overload - async def run(self, url: str) -> WebPage: + async def run(self, url: str, per_page_timeout: float = None) -> WebPage: ... @overload - async def run(self, url: str, *urls: str) -> list[WebPage]: + async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> list[WebPage]: ... - async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: + async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> WebPage | list[WebPage]: """Runs the browser engine to load one or more web pages. This method is the implementation of the overloaded run signatures. It delegates the task @@ -108,8 +108,9 @@ async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: Args: url: The URL of the first web page to load. *urls: Additional URLs of web pages to load, if any. + per_page_timeout: The maximum time for fetching a single page in seconds. Returns: A WebPage object if a single URL is provided, or a list of WebPage objects if multiple URLs are provided. """ - return await self.run_func(url, *urls) + return await self.run_func(url, *urls, per_page_timeout=per_page_timeout) diff --git a/metagpt/tools/web_browser_engine_playwright.py b/metagpt/tools/web_browser_engine_playwright.py index 2df288b1a9..f38a3b296e 100644 --- a/metagpt/tools/web_browser_engine_playwright.py +++ b/metagpt/tools/web_browser_engine_playwright.py @@ -39,10 +39,11 @@ def __init__(self, **kwargs): if not any(str.startswith(i, "--proxy-server=") for i in args): launch_kwargs["proxy"] = {"server": self.proxy} - if "ignore_https_errors" in kwargs: - self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"] + for key in ["ignore_https_errors", "java_script_enabled", "extra_http_headers", "user_agent"]: + if key in kwargs: + self.context_kwargs[key] = kwargs[key] - async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: + async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> WebPage | list[WebPage]: async with async_playwright() as ap: browser_type = getattr(ap, self.browser_type) await self._run_precheck(browser_type) @@ -50,11 +51,17 @@ async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: _scrape = self._scrape if urls: - return await asyncio.gather(_scrape(browser, url), *(_scrape(browser, i) for i in urls)) - return await _scrape(browser, url) + return await asyncio.gather( + _scrape(browser, url, per_page_timeout), *(_scrape(browser, i, per_page_timeout) for i in urls) + ) + return await _scrape(browser, url, per_page_timeout) - async def _scrape(self, browser, url): + async def _scrape(self, browser, url, timeout: float = None): context = await browser.new_context(**self.context_kwargs) + + if timeout is not None: + context.set_default_timeout(timeout * 1000) # playwright uses milliseconds. + page = await context.new_page() async with page: try: diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 7867154618..3217a78c78 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -54,14 +54,16 @@ def launch_args(self): def executable_path(self): return self.launch_kwargs.get("executable_path") - async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: + async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> WebPage | list[WebPage]: await self._run_precheck() - _scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url) + _scrape = lambda url, per_page_timeout: self.loop.run_in_executor( + self.executor, self._scrape_website, url, per_page_timeout + ) if urls: - return await asyncio.gather(_scrape(url), *(_scrape(i) for i in urls)) - return await _scrape(url) + return await asyncio.gather(_scrape(url, per_page_timeout), *(_scrape(i, per_page_timeout) for i in urls)) + return await _scrape(url, per_page_timeout) async def _run_precheck(self): if self._has_run_precheck: @@ -75,11 +77,11 @@ async def _run_precheck(self): ) self._has_run_precheck = True - def _scrape_website(self, url): + def _scrape_website(self, url, timeout: float = None): with self._get_driver() as driver: try: driver.get(url) - WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) + WebDriverWait(driver, timeout or 30).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) inner_text = driver.execute_script("return document.body.innerText;") html = driver.page_source except Exception as e: diff --git a/metagpt/utils/__init__.py b/metagpt/utils/__init__.py index f13175cf88..26042eb0e2 100644 --- a/metagpt/utils/__init__.py +++ b/metagpt/utils/__init__.py @@ -19,6 +19,7 @@ "read_docx", "Singleton", "TOKEN_COSTS", + "new_transaction_id", "count_message_tokens", "count_string_tokens", ] diff --git a/metagpt/utils/a11y_tree.py b/metagpt/utils/a11y_tree.py new file mode 100644 index 0000000000..133c4f63a4 --- /dev/null +++ b/metagpt/utils/a11y_tree.py @@ -0,0 +1,312 @@ +"""See https://github.com/web-arena-x/webarena +""" +from __future__ import annotations + +import re + +from playwright.async_api import BrowserContext, Page + + +async def get_accessibility_tree(page: Page): + cdp_session = await get_page_cdp_session(page) + resp = await cdp_session.send("Accessibility.getFullAXTree") + + seen_ids = set() + accessibility_tree = [] + for node in resp["nodes"]: + if node["nodeId"] not in seen_ids: + accessibility_tree.append(node) + seen_ids.add(node["nodeId"]) + return accessibility_tree + + +async def execute_step(step: str, page: Page, browser_ctx: BrowserContext, accessibility_tree: list): + step = step.strip() + func = step.split("[")[0].strip() if "[" in step else step.split()[0].strip() + if func == "None": + return "" + elif func == "click": + match = re.search(r"click ?\[(\d+)\]", step) + if not match: + raise ValueError(f"Invalid click action {step}") + element_id = match.group(1) + await click_element(page, get_backend_node_id(element_id, accessibility_tree)) + elif func == "hover": + match = re.search(r"hover ?\[(\d+)\]", step) + if not match: + raise ValueError(f"Invalid hover action {step}") + element_id = match.group(1) + await hover_element(page, get_backend_node_id(element_id, accessibility_tree)) + elif func == "type": + # add default enter flag + if not (step.endswith("[0]") or step.endswith("[1]")): + step += " [1]" + + match = re.search(r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", step) + if not match: + raise ValueError(f"Invalid type action {step}") + element_id, text, enter_flag = ( + match.group(1), + match.group(2), + match.group(3), + ) + if enter_flag == "1": + text += "\n" + await click_element(page, get_backend_node_id(element_id, accessibility_tree)) + await type_text(page, text) + elif func == "press": + match = re.search(r"press ?\[(.+)\]", step) + if not match: + raise ValueError(f"Invalid press action {step}") + key = match.group(1) + await key_press(page, key) + elif func == "scroll": + # up or down + match = re.search(r"scroll ?\[?(up|down)\]?", step) + if not match: + raise ValueError(f"Invalid scroll action {step}") + direction = match.group(1) + await scroll_page(page, direction) + elif func == "goto": + match = re.search(r"goto ?\[(.+)\]", step) + if not match: + raise ValueError(f"Invalid goto action {step}") + url = match.group(1) + await page.goto(url) + elif func == "new_tab": + page = await browser_ctx.new_page() + elif func == "go_back": + await page.go_back() + elif func == "go_forward": + await page.go_forward() + elif func == "tab_focus": + match = re.search(r"tab_focus ?\[(\d+)\]", step) + if not match: + raise ValueError(f"Invalid tab_focus action {step}") + page_number = int(match.group(1)) + page = browser_ctx.pages[page_number] + await page.bring_to_front() + elif func == "close_tab": + await page.close() + if len(browser_ctx.pages) > 0: + page = browser_ctx.pages[-1] + else: + page = await browser_ctx.new_page() + elif func == "stop": + match = re.search(r'stop\(?"(.+)?"\)', step) + answer = match.group(1) if match else "" + return answer + else: + raise ValueError + await page.wait_for_load_state("domcontentloaded") + return page + + +async def type_text(page: Page, text: str): + await page.keyboard.type(text) + + +async def click_element(page: Page, backend_node_id: int): + cdp_session = await get_page_cdp_session(page) + resp = await get_bounding_rect(cdp_session, backend_node_id) + node_info = resp["result"]["value"] + x, y = await get_element_center(node_info) + # Move to the location of the element + await page.evaluate(f"window.scrollTo({x}- window.innerWidth/2,{y} - window.innerHeight/2);") + # Refresh the relative location of the element + resp = await get_bounding_rect(cdp_session, backend_node_id) + node_info = resp["result"]["value"] + x, y = await get_element_center(node_info) + await page.mouse.click(x, y) + + +async def hover_element(page: Page, backend_node_id: int) -> None: + cdp_session = await get_page_cdp_session(page) + resp = await get_bounding_rect(cdp_session, backend_node_id) + node_info = resp["result"]["value"] + x, y = await get_element_center(node_info) + await page.mouse.move(x, y) + + +async def scroll_page(page: Page, direction: str) -> None: + # perform the action + # code from natbot + if direction == "up": + await page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" + ) + elif direction == "down": + await page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;" + ) + + +async def key_press(page: Page, key: str) -> None: + """Press a key.""" + if "Meta" in key and "Mac" not in await page.evaluate("navigator.platform"): + key = key.replace("Meta", "Control") + await page.keyboard.press(key) + + +async def get_element_outer_html(page: Page, backend_node_id: int): + cdp_session = await get_page_cdp_session(page) + try: + outer_html = await cdp_session.send("DOM.getOuterHTML", {"backendNodeId": int(backend_node_id)}) + return outer_html["outerHTML"] + except Exception as e: + raise ValueError("Element not found") from e + + +async def get_element_center(node_info): + x, y, width, height = node_info["x"], node_info["y"], node_info["width"], node_info["height"] + center_x = x + width / 2 + center_y = y + height / 2 + return center_x, center_y + + +def extract_step(response: str, action_splitter: str = "```") -> str: + # find the first occurence of action + pattern = rf"{action_splitter}((.|\n)*?){action_splitter}" + match = re.search(pattern, response) + if match: + return match.group(1).strip() + else: + raise ValueError(f'Cannot find the answer phrase "{response}"') + + +async def get_bounding_rect(cdp_session, backend_node_id: str): + try: + remote_object = await cdp_session.send("DOM.resolveNode", {"backendNodeId": int(backend_node_id)}) + remote_object_id = remote_object["object"]["objectId"] + response = await cdp_session.send( + "Runtime.callFunctionOn", + { + "objectId": remote_object_id, + "functionDeclaration": """ + function() { + if (this.nodeType == 3) { + var range = document.createRange(); + range.selectNode(this); + var rect = range.getBoundingClientRect().toJSON(); + range.detach(); + return rect; + } else { + return this.getBoundingClientRect().toJSON(); + } + } + """, + "returnByValue": True, + }, + ) + return response + except Exception as e: + raise ValueError("Element not found") from e + + +IGNORED_ACTREE_PROPERTIES = ( + "focusable", + "editable", + "readonly", + "level", + "settable", + "multiline", + "invalid", +) + + +def parse_accessibility_tree(accessibility_tree): + """Parse the accessibility tree into a string text""" + node_id_to_idx = {} + for idx, node in enumerate(accessibility_tree): + node_id_to_idx[node["nodeId"]] = idx + + obs_nodes_info = {} + + def dfs(idx: int, obs_node_id: str, depth: int) -> str: + tree_str = "" + node = accessibility_tree[idx] + indent = "\t" * depth + valid_node = True + try: + role = node["role"]["value"] + name = node["name"]["value"] + node_str = f"[{obs_node_id}] {role} {repr(name)}" + properties = [] + for property in node.get("properties", []): + try: + if property["name"] in IGNORED_ACTREE_PROPERTIES: + continue + properties.append(f'{property["name"]}: {property["value"]["value"]}') + except KeyError: + pass + + if properties: + node_str += " " + " ".join(properties) + + # check valid + if not node_str.strip(): + valid_node = False + + # empty generic node + if not name.strip(): + if not properties: + if role in [ + "generic", + "img", + "list", + "strong", + "paragraph", + "banner", + "navigation", + "Section", + "LabelText", + "Legend", + "listitem", + ]: + valid_node = False + elif role in ["listitem"]: + valid_node = False + + if valid_node: + tree_str += f"{indent}{node_str}" + obs_nodes_info[obs_node_id] = { + "backend_id": node["backendDOMNodeId"], + "union_bound": node["union_bound"], + "text": node_str, + } + + except Exception: + valid_node = False + + for _, child_node_id in enumerate(node["childIds"]): + if child_node_id not in node_id_to_idx: + continue + # mark this to save some tokens + child_depth = depth + 1 if valid_node else depth + child_str = dfs(node_id_to_idx[child_node_id], child_node_id, child_depth) + if child_str.strip(): + if tree_str.strip(): + tree_str += "\n" + tree_str += child_str + + return tree_str + + tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0) + return tree_str, obs_nodes_info + + +async def get_page_cdp_session(page): + if hasattr(page, "cdp_session"): + return page.cdp_session + + cdp_session = await page.context.new_cdp_session(page) + page.cdp_session = cdp_session + return cdp_session + + +def get_backend_node_id(element_id, accessibility_tree): + element_id = str(element_id) + for i in accessibility_tree: + if i["nodeId"] == element_id: + return i.get("backendDOMNodeId") + raise ValueError(f"Element {element_id} not found") diff --git a/metagpt/utils/async_helper.py b/metagpt/utils/async_helper.py index ee440ef447..cecb20c5dd 100644 --- a/metagpt/utils/async_helper.py +++ b/metagpt/utils/async_helper.py @@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any: new_loop.call_soon_threadsafe(new_loop.stop) t.join() new_loop.close() + + +class NestAsyncio: + """Make asyncio event loop reentrant.""" + + is_applied = False + + @classmethod + def apply_once(cls): + """Ensures `nest_asyncio.apply()` is called only once.""" + if not cls.is_applied: + import nest_asyncio + + nest_asyncio.apply() + cls.is_applied = True diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e443c34664..90f13da23e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -15,6 +15,8 @@ import base64 import contextlib import csv +import functools +import hashlib import importlib import inspect import json @@ -23,13 +25,19 @@ import platform import re import sys +import time import traceback +import uuid +from asyncio import iscoroutinefunction +from datetime import datetime +from functools import partial from io import BytesIO from pathlib import Path -from typing import Any, Callable, List, Literal, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from urllib.parse import quote, unquote import aiofiles +import aiohttp import chardet import loguru import requests @@ -37,9 +45,10 @@ from pydantic_core import to_jsonable_python from tenacity import RetryCallState, RetryError, _utils -from metagpt.const import MESSAGE_ROUTE_TO_ALL +from metagpt.const import MARKDOWN_TITLE_PREFIX, MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger from metagpt.utils.exceptions import handle_exception +from metagpt.utils.json_to_markdown import json_to_markdown def check_cmd_exists(command) -> int: @@ -65,7 +74,7 @@ class OutputParser: @classmethod def parse_blocks(cls, text: str): # 首先根据"##"将文本分割成不同的block - blocks = text.split("##") + blocks = text.split(MARKDOWN_TITLE_PREFIX) # 创建一个字典,用于存储每个block的标题和内容 block_dict = {} @@ -271,10 +280,10 @@ def parse_blocks(cls, text: str): return block_dict @classmethod - def parse_code(cls, block: str, text: str, lang: str = "") -> str: + def parse_code(cls, text: str, lang: str = "", block: Optional[str] = None) -> str: if block: text = cls.parse_block(block, text) - pattern = rf"```{lang}.*?\s+(.*?)```" + pattern = rf"```{lang}.*?\s+(.*?)\n```" match = re.search(pattern, text, re.DOTALL) if match: code = match.group(1) @@ -287,7 +296,7 @@ def parse_code(cls, block: str, text: str, lang: str = "") -> str: @classmethod def parse_str(cls, block: str, text: str, lang: str = ""): - code = cls.parse_code(block, text, lang) + code = cls.parse_code(block=block, text=text, lang=lang) code = code.split("=")[-1] code = code.strip().strip("'").strip('"') return code @@ -295,7 +304,7 @@ def parse_str(cls, block: str, text: str, lang: str = ""): @classmethod def parse_file_list(cls, block: str, text: str, lang: str = "") -> list[str]: # Regular expression pattern to find the tasks list. - code = cls.parse_code(block, text, lang) + code = cls.parse_code(block=block, text=text, lang=lang) # print(code) pattern = r"\s*(.*=.*)?(\[.*\])" @@ -560,7 +569,7 @@ def log_it(retry_state: "RetryCallState") -> None: return log_it -def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: +def read_json_file(json_file: str, encoding: str = "utf-8") -> list[Any]: if not Path(json_file).exists(): raise FileNotFoundError(f"json_file: {json_file} not exist, return []") @@ -572,13 +581,32 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data -def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): +def handle_unknown_serialization(x: Any) -> str: + """For `to_jsonable_python` debug, get more detail about the x.""" + + if inspect.ismethod(x): + tip = f"Cannot serialize method '{x.__func__.__name__}' of class '{x.__self__.__class__.__name__}'" + elif inspect.isfunction(x): + tip = f"Cannot serialize function '{x.__name__}'" + elif hasattr(x, "__class__"): + tip = f"Cannot serialize instance of '{x.__class__.__name__}'" + elif hasattr(x, "__name__"): + tip = f"Cannot serialize class or module '{x.__name__}'" + else: + tip = f"Cannot serialize object of type '{type(x).__name__}'" + + raise TypeError(tip) + + +def write_json_file(json_file: str, data: Any, encoding: str = "utf-8", indent: int = 4, use_fallback: bool = False): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) + custom_default = partial(to_jsonable_python, fallback=handle_unknown_serialization if use_fallback else None) + with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default) def read_csv_to_list(curr_file: str, header=False, strip_trail=True): @@ -646,7 +674,7 @@ async def wrapper(self, *args, **kwargs): raise Exception(format_trackback_info(limit=None)) except Exception as e: if self.latest_observed_msg: - logger.warning( + logger.exception( "There is a exception in role's execution, in order to resume, " "we delete the newest role communication message in the role's memory." ) @@ -659,7 +687,7 @@ async def wrapper(self, *args, **kwargs): if re.match(r"^openai\.", name) or re.match(r"^httpx\.", name): raise last_error - raise Exception(format_trackback_info(limit=None)) + raise Exception(format_trackback_info(limit=None)) from e return wrapper @@ -667,6 +695,8 @@ async def wrapper(self, *args, **kwargs): @handle_exception async def aread(filename: str | Path, encoding="utf-8") -> str: """Read file asynchronously.""" + if not filename or not Path(filename).exists(): + return "" try: async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader: content = await reader.read() @@ -783,13 +813,15 @@ def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) return skills -def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str: +def encode_image(image_path_or_pil: Union[Path, Image, str], encoding: str = "utf-8") -> str: """encode image from file or PIL.Image into base64""" if isinstance(image_path_or_pil, Image.Image): buffer = BytesIO() image_path_or_pil.save(buffer, format="JPEG") bytes_data = buffer.getvalue() else: + if isinstance(image_path_or_pil, str): + image_path_or_pil = Path(image_path_or_pil) if not image_path_or_pil.exists(): raise FileNotFoundError(f"{image_path_or_pil} not exists") with open(str(image_path_or_pil), "rb") as image_file: @@ -811,6 +843,21 @@ def decode_image(img_url_or_b64: str) -> Image: return img +def extract_image_paths(content: str) -> bool: + # We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx" + pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff|PNG|JPE?G|GIF|BMP|TIFF)" + image_paths = re.findall(pattern, content) + return image_paths + + +def extract_and_encode_images(content: str) -> list[str]: + images = [] + for path in extract_image_paths(content): + if os.path.exists(path): + images.append(encode_image(path)) + return images + + def log_and_reraise(retry_state: RetryCallState): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( @@ -822,19 +869,330 @@ def log_and_reraise(retry_state: RetryCallState): raise retry_state.outcome.exception() -def get_markdown_codeblock_type(filename: str) -> str: +async def get_mime_type(filename: str | Path, force_read: bool = False) -> str: + guess_mime_type, _ = mimetypes.guess_type(filename.name) + if not guess_mime_type: + ext_mappings = {".yml": "text/yaml", ".yaml": "text/yaml"} + guess_mime_type = ext_mappings.get(filename.suffix) + if not force_read and guess_mime_type: + return guess_mime_type + + from metagpt.tools.libs.shell import shell_execute # avoid circular import + + text_set = { + "application/json", + "application/vnd.chipnuts.karaoke-mmd", + "application/javascript", + "application/xml", + "application/x-sh", + "application/sql", + "text/yaml", + } + + try: + stdout, stderr, _ = await shell_execute(f"file --mime-type '{str(filename)}'") + if stderr: + logger.debug(f"file:{filename}, error:{stderr}") + return guess_mime_type + ix = stdout.rfind(" ") + mime_type = stdout[ix:].strip() + if mime_type == "text/plain" and guess_mime_type in text_set: + return guess_mime_type + return mime_type + except Exception as e: + logger.debug(f"file:{filename}, error:{e}") + return "unknown" + + +def get_markdown_codeblock_type(filename: str = None, mime_type: str = None) -> str: """Return the markdown code-block type corresponding to the file extension.""" - mime_type, _ = mimetypes.guess_type(filename) + if not filename and not mime_type: + raise ValueError("Either filename or mime_type must be valid.") + + if not mime_type: + mime_type, _ = mimetypes.guess_type(filename) mappings = { "text/x-shellscript": "bash", "text/x-c++src": "cpp", "text/css": "css", "text/html": "html", "text/x-java": "java", - "application/javascript": "javascript", - "application/json": "json", "text/x-python": "python", "text/x-ruby": "ruby", + "text/x-c": "cpp", + "text/yaml": "yaml", + "application/javascript": "javascript", + "application/json": "json", "application/sql": "sql", + "application/vnd.chipnuts.karaoke-mmd": "mermaid", + "application/x-sh": "bash", + "application/xml": "xml", } return mappings.get(mime_type, "text") + + +def get_project_srcs_path(workdir: str | Path) -> Path: + src_workdir_path = workdir / ".src_workspace" + if src_workdir_path.exists(): + with open(src_workdir_path, "r") as file: + src_name = file.read() + else: + src_name = Path(workdir).name + return Path(workdir) / src_name + + +async def init_python_folder(workdir: str | Path): + if not workdir: + return + workdir = Path(workdir) + if not workdir.exists(): + return + init_filename = Path(workdir) / "__init__.py" + if init_filename.exists(): + return + async with aiofiles.open(init_filename, "a"): + os.utime(init_filename, None) + + +def get_markdown_code_block_type(filename: str) -> str: + if not filename: + return "" + ext = Path(filename).suffix + types = { + ".py": "python", + ".js": "javascript", + ".java": "java", + ".cpp": "cpp", + ".c": "c", + ".html": "html", + ".css": "css", + ".xml": "xml", + ".json": "json", + ".yaml": "yaml", + ".md": "markdown", + ".sql": "sql", + ".rb": "ruby", + ".php": "php", + ".sh": "bash", + ".swift": "swift", + ".go": "go", + ".rs": "rust", + ".pl": "perl", + ".asm": "assembly", + ".r": "r", + ".scss": "scss", + ".sass": "sass", + ".lua": "lua", + ".ts": "typescript", + ".tsx": "tsx", + ".jsx": "jsx", + ".yml": "yaml", + ".ini": "ini", + ".toml": "toml", + ".svg": "xml", # SVG can often be treated as XML + # Add more file extensions and corresponding code block types as needed + } + return types.get(ext, "") + + +def to_markdown_code_block(val: str, type_: str = "") -> str: + """ + Convert a string to a Markdown code block. + + This function takes a string and wraps it in a Markdown code block. + If a type is provided, it adds it as a language identifier for syntax highlighting. + + Args: + val (str): The string to be converted to a Markdown code block. + type_ (str, optional): The language identifier for syntax highlighting. + Defaults to an empty string. + + Returns: + str: The input string wrapped in a Markdown code block. + If the input string is empty, it returns an empty string. + + Examples: + >>> to_markdown_code_block("print('Hello, World!')", "python") + \n```python\nprint('Hello, World!')\n```\n + + >>> to_markdown_code_block("Some text") + \n```\nSome text\n```\n + """ + if not val: + return val or "" + val = val.replace("```", "\\`\\`\\`") + return f"\n```{type_}\n{val}\n```\n" + + +async def save_json_to_markdown(content: str, output_filename: str | Path): + """ + Saves the provided JSON content as a Markdown file. + + This function takes a JSON string, converts it to Markdown format, + and writes it to the specified output file. + + Args: + content (str): The JSON content to be converted. + output_filename (str or Path): The path where the output Markdown file will be saved. + + Returns: + None + + Raises: + None: Any exceptions are logged and the function returns without raising them. + + Examples: + >>> await save_json_to_markdown('{"key": "value"}', Path("/path/to/output.md")) + This will save the Markdown converted JSON to the specified file. + + Notes: + - This function handles `json.JSONDecodeError` specifically for JSON parsing errors. + - Any other exceptions during the process are also logged and handled gracefully. + """ + try: + m = json.loads(content) + except json.JSONDecodeError as e: + logger.warning(f"Failed to decode JSON content: {e}") + return + except Exception as e: + logger.warning(f"An unexpected error occurred: {e}") + return + await awrite(filename=output_filename, data=json_to_markdown(m)) + + +def tool2name(cls, methods: List[str], entry) -> Dict[str, Any]: + """ + Generates a mapping of class methods to a given entry with class name as a prefix. + + Args: + cls: The class from which the methods are derived. + methods (List[str]): A list of method names as strings. + entry (Any): The entry to be mapped to each method. + + Returns: + Dict[str, Any]: A dictionary where keys are method names prefixed with the class name and + values are the given entry. If the number of methods is less than 2, + the dictionary will contain a single entry with the class name as the key. + + Example: + >>> class MyClass: + >>> pass + >>> + >>> tool2name(MyClass, ['method1', 'method2'], 'some_entry') + {'MyClass.method1': 'some_entry', 'MyClass.method2': 'some_entry'} + + >>> tool2name(MyClass, ['method1'], 'some_entry') + {'MyClass': 'some_entry', 'MyClass.method1': 'some_entry'} + """ + class_name = cls.__name__ + mappings = {f"{class_name}.{i}": entry for i in methods} + if len(mappings) < 2: + mappings[class_name] = entry + return mappings + + +def new_transaction_id(postfix_len=8) -> str: + """ + Generates a new unique transaction ID based on current timestamp and a random UUID. + + Args: + postfix_len (int): Length of the random UUID postfix to include in the transaction ID. Default is 8. + + Returns: + str: A unique transaction ID composed of timestamp and a random UUID. + """ + return datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:postfix_len] + + +def log_time(method): + """A time-consuming decorator for printing execution duration.""" + + def before_call(): + start_time, cpu_start_time = time.perf_counter(), time.process_time() + logger.info(f"[{method.__name__}] started at: " f"{datetime.now().strftime('%Y-%m-%d %H:%m:%S')}") + return start_time, cpu_start_time + + def after_call(start_time, cpu_start_time): + end_time, cpu_end_time = time.perf_counter(), time.process_time() + logger.info( + f"[{method.__name__}] ended. " + f"Time elapsed: {end_time - start_time:.4} sec, CPU elapsed: {cpu_end_time - cpu_start_time:.4} sec" + ) + + @functools.wraps(method) + def timeit_wrapper(*args, **kwargs): + start_time, cpu_start_time = before_call() + result = method(*args, **kwargs) + after_call(start_time, cpu_start_time) + return result + + @functools.wraps(method) + async def timeit_wrapper_async(*args, **kwargs): + start_time, cpu_start_time = before_call() + result = await method(*args, **kwargs) + after_call(start_time, cpu_start_time) + return result + + return timeit_wrapper_async if iscoroutinefunction(method) else timeit_wrapper + + +async def check_http_endpoint(url: str, timeout: int = 3) -> bool: + """ + Checks the status of an HTTP endpoint. + + Args: + url (str): The URL of the HTTP endpoint to check. + timeout (int, optional): The timeout in seconds for the HTTP request. Defaults to 3. + + Returns: + bool: True if the endpoint is online and responding with a 200 status code, False otherwise. + """ + async with aiohttp.ClientSession() as session: + try: + async with session.get(url, timeout=timeout) as response: + return response.status == 200 + except Exception as e: + print(f"Error accessing the endpoint {url}: {e}") + return False + + +def rectify_pathname(path: Union[str, Path], default_filename: str) -> Path: + """ + Rectifies the given path to ensure a valid output file path. + + If the given `path` is a directory, it creates the directory (if it doesn't exist) and appends the `default_filename` to it. If the `path` is a file path, it creates the parent directory (if it doesn't exist) and returns the `path`. + + Args: + path (Union[str, Path]): The input path, which can be a string or a `Path` object. + default_filename (str): The default filename to use if the `path` is a directory. + + Returns: + Path: The rectified output path. + """ + output_pathname = Path(path) + if output_pathname.is_dir(): + output_pathname.mkdir(parents=True, exist_ok=True) + output_pathname = output_pathname / default_filename + else: + output_pathname.parent.mkdir(parents=True, exist_ok=True) + return output_pathname + + +def generate_fingerprint(text: str) -> str: + """ + Generate a fingerprint for the given text + + Args: + text (str): The text for which the fingerprint needs to be generated + + Returns: + str: The fingerprint value of the text + """ + text_bytes = text.encode("utf-8") + + # calculate SHA-256 hash + sha256 = hashlib.sha256() + sha256.update(text_bytes) + fingerprint = sha256.hexdigest() + + return fingerprint diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py index fee706ece2..f8fabfbdc6 100644 --- a/metagpt/utils/di_graph_repository.py +++ b/metagpt/utils/di_graph_repository.py @@ -23,8 +23,8 @@ class DiGraphRepository(GraphRepository): """Graph repository based on DiGraph.""" - def __init__(self, name: str, **kwargs): - super().__init__(name=name, **kwargs) + def __init__(self, name: str | Path, **kwargs): + super().__init__(name=str(name), **kwargs) self._repo = networkx.DiGraph() async def insert(self, subject: str, predicate: str, object_: str): @@ -112,8 +112,14 @@ async def save(self, path: str | Path = None): async def load(self, pathname: str | Path): """Load a directed graph repository from a JSON file.""" data = await aread(filename=pathname, encoding="utf-8") - m = json.loads(data) + self.load_json(data) + + def load_json(self, val: str): + if not val: + return self + m = json.loads(val) self._repo = networkx.node_link_graph(m) + return self @staticmethod async def load_from(pathname: str | Path) -> GraphRepository: @@ -126,9 +132,7 @@ async def load_from(pathname: str | Path) -> GraphRepository: GraphRepository: A new instance of the graph repository loaded from the specified JSON file. """ pathname = Path(pathname) - name = pathname.with_suffix("").name - root = pathname.parent - graph = DiGraphRepository(name=name, root=root) + graph = DiGraphRepository(name=pathname.stem, root=pathname.parent) if pathname.exists(): await graph.load(pathname=pathname) return graph diff --git a/metagpt/utils/embedding.py b/metagpt/utils/embedding.py index 3d53a314ce..3fcf1f25b5 100644 --- a/metagpt/utils/embedding.py +++ b/metagpt/utils/embedding.py @@ -7,10 +7,11 @@ """ from llama_index.embeddings.openai import OpenAIEmbedding -from metagpt.config2 import config +from metagpt.config2 import Config def get_embedding() -> OpenAIEmbedding: + config = Config.default() llm = config.get_openai_llm() if llm is None: raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.") diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index f62b44eb86..a3f612bccc 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -6,12 +6,19 @@ @File : file.py @Describe : General file operations. """ +import base64 from pathlib import Path +from typing import Optional, Tuple, Union import aiofiles +from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem +from metagpt.config2 import Config from metagpt.logs import logger +from metagpt.utils import read_docx +from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint from metagpt.utils.exceptions import handle_exception +from metagpt.utils.repo_to_markdown import is_text_file class File: @@ -68,3 +75,128 @@ async def read(cls, file_path: Path, chunk_size: int = None) -> bytes: content = b"".join(chunks) logger.debug(f"Successfully read file, the path of file: {file_path}") return content + + @staticmethod + async def is_textual_file(filename: Union[str, Path]) -> bool: + """Determines if a given file is a textual file. + + A file is considered a textual file if it is plain text or has a + specific set of MIME types associated with textual formats, + including PDF and Microsoft Word documents. + + Args: + filename (Union[str, Path]): The path to the file to be checked. + + Returns: + bool: True if the file is a textual file, False otherwise. + """ + is_text, mime_type = await is_text_file(filename) + if is_text: + return True + if mime_type == "application/pdf": + return True + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return True + return False + + @staticmethod + async def read_text_file(filename: Union[str, Path]) -> Optional[str]: + """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + is_text, mime_type = await is_text_file(filename) + if is_text: + return await File._read_text(filename) + if mime_type == "application/pdf": + return await File._read_pdf(filename) + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return await File._read_docx(filename) + return None + + @staticmethod + async def _read_text(path: Union[str, Path]) -> str: + return await aread(path) + + @staticmethod + async def _read_pdf(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + + from llama_index.readers.file import PDFReader + + reader = PDFReader() + lines = reader.load_data(file=Path(path)) + return "\n".join([i.text for i in lines]) + + @staticmethod + async def _read_docx(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + return "\n".join(read_docx(str(path))) + + @staticmethod + async def _omniparse_read_file(path: Union[str, Path], auto_save_image: bool = False) -> Optional[str]: + from metagpt.tools.libs import get_env_default + from metagpt.utils.omniparse_client import OmniParseClient + + env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") + env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") + conf_base_url, conf_timeout = await File._read_omniparse_config() + + base_url = env_base_url or conf_base_url + if not base_url: + return None + api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") + timeout = env_timeout or conf_timeout or 600 + try: + timeout = int(timeout) + except ValueError: + timeout = 600 + + try: + if not await check_http_endpoint(url=base_url): + logger.warning(f"{base_url}: NOT AVAILABLE") + return None + client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) + file_data = await aread_bin(filename=path) + ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) + except (ValueError, Exception) as e: + logger.exception(f"{path}: {e}") + return None + if not ret.images or not auto_save_image: + return ret.text + + result = [ret.text] + img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") + img_dir.mkdir(parents=True, exist_ok=True) + for i in ret.images: + byte_data = base64.b64decode(i.image) + filename = img_dir / i.image_name + await awrite_bin(filename=filename, data=byte_data) + result.append(f"![{i.image_name}]({str(filename)})") + return "\n".join(result) + + @staticmethod + async def _read_omniparse_config() -> Tuple[str, int]: + config = Config.default() + if config.omniparse and config.omniparse.url: + return config.omniparse.url, config.omniparse.timeout + return "", 0 + + +class MemoryFileSystem(_MemoryFileSystem): + @classmethod + def _strip_protocol(cls, path): + return super()._strip_protocol(str(path)) diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index d19f2b7052..dd6c0709f2 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -198,8 +198,9 @@ async def save_doc(self, doc: Document, dependencies: List[str] = None): :type dependencies: List[str], optional """ - await self.save(filename=doc.filename, content=doc.content, dependencies=dependencies) + doc = await self.save(filename=doc.filename, content=doc.content, dependencies=dependencies) logger.debug(f"File Saved: {str(doc.filename)}") + return doc async def save_pdf(self, doc: Document, with_suffix: str = ".md", dependencies: List[str] = None): """Save a Document instance as a PDF file. @@ -216,8 +217,9 @@ async def save_pdf(self, doc: Document, with_suffix: str = ".md", dependencies: """ m = json.loads(doc.content) filename = Path(doc.filename).with_suffix(with_suffix) if with_suffix is not None else Path(doc.filename) - await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies) + doc = await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies) logger.debug(f"File Saved: {str(filename)}") + return doc async def delete(self, filename: Path | str): """Delete a file from the file repository. diff --git a/metagpt/utils/git_repository.py b/metagpt/utils/git_repository.py index 16f675175a..f3d6350bdc 100644 --- a/metagpt/utils/git_repository.py +++ b/metagpt/utils/git_repository.py @@ -8,16 +8,30 @@ """ from __future__ import annotations +import re import shutil +import uuid from enum import Enum from pathlib import Path -from typing import Dict, List +from subprocess import TimeoutExpired +from typing import Dict, List, Optional, Union +from urllib.parse import quote from git.repo import Repo from git.repo.fun import is_git_dir +from github import Auth, BadCredentialsException, Github +from github.GithubObject import NotSet +from github.Issue import Issue +from github.Label import Label +from github.Milestone import Milestone +from github.NamedUser import NamedUser +from github.PullRequest import PullRequest from gitignore_parser import parse_gitignore +from pydantic import BaseModel +from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.logs import logger +from metagpt.tools.libs.shell import shell_execute from metagpt.utils.dependency_file import DependencyFile from metagpt.utils.file_repository import FileRepository @@ -32,6 +46,18 @@ class ChangeType(Enum): UNTRACTED = "U" # File is untracked (not added to version control) +class RateLimitError(Exception): + def __init__(self, message="Rate limit exceeded"): + self.message = message + super().__init__(self.message) + + +class GitBranch(BaseModel): + head: str + base: str + repo_name: str + + class GitRepository: """A class representing a Git repository. @@ -52,7 +78,7 @@ def __init__(self, local_path=None, auto_init=True): self._dependency = None self._gitignore_rules = None if local_path: - self.open(local_path=local_path, auto_init=auto_init) + self.open(local_path=Path(local_path), auto_init=auto_init) def open(self, local_path: Path, auto_init=False): """Open an existing Git repository or initialize a new one if auto_init is True. @@ -68,7 +94,7 @@ def open(self, local_path: Path, auto_init=False): if not auto_init: return local_path.mkdir(parents=True, exist_ok=True) - return self._init(local_path) + self._init(local_path) def _init(self, local_path: Path): """Initialize a new Git repository at the specified path. @@ -130,6 +156,8 @@ def is_git_dir(local_path): :param local_path: The local path to check. :return: True if the directory is a Git repository, False otherwise. """ + if not local_path: + return False git_dir = Path(local_path) / ".git" if git_dir.exists() and is_git_dir(git_dir): return True @@ -160,15 +188,114 @@ def workdir(self) -> Path | None: return None return Path(self._repository.working_dir) + @property + def current_branch(self) -> str: + """ + Returns the name of the current active branch. + + Returns: + str: The name of the current active branch. + """ + return self._repository.active_branch.name + + @property + def remote_url(self) -> str: + try: + return self._repository.remotes.origin.url + except AttributeError: + return "" + + @property + def repo_name(self) -> str: + if self.remote_url: + # This assumes a standard HTTPS or SSH format URL + # HTTPS format example: https://github.com/username/repo_name.git + # SSH format example: git@github.com:username/repo_name.git + if self.remote_url.startswith("https://"): + return self.remote_url.split("/", maxsplit=3)[-1].replace(".git", "") + elif self.remote_url.startswith("git@"): + return self.remote_url.split(":")[-1].replace(".git", "") + return "" + + def new_branch(self, branch_name: str) -> str: + """ + Creates a new branch with the given name. + + Args: + branch_name (str): The name of the new branch to create. + + Returns: + str: The name of the newly created branch. + If the provided branch_name is empty, returns the name of the current active branch. + """ + if not branch_name: + return self.current_branch + new_branch = self._repository.create_head(branch_name) + new_branch.checkout() + return new_branch.name + def archive(self, comments="Archive"): """Archive the current state of the Git repository. :param comments: Comments for the archive commit. """ logger.info(f"Archive: {list(self.changed_files.keys())}") + if not self.changed_files: + return self.add_change(self.changed_files) self.commit(comments) + async def push( + self, new_branch: str, comments="Archive", access_token: Optional[str] = None, auth: Optional[Auth] = None + ) -> GitBranch: + """ + Pushes changes to the remote repository. + + Args: + new_branch (str): The name of the new branch to be pushed. + comments (str, optional): Comments to be associated with the push. Defaults to "Archive". + access_token (str, optional): Access token for authentication. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`, `https://github.com/PyGithub/PyGithub/blob/main/doc/examples/Authentication.rst`. + auth (Auth, optional): Optional authentication object. Defaults to None. + + Returns: + GitBranch: The pushed branch object. + + Raises: + ValueError: If neither `auth` nor `access_token` is provided. + BadCredentialsException: If authentication fails due to bad credentials or timeout. + + Note: + This function assumes that `self.current_branch`, `self.new_branch()`, `self.archive()`, + `ctx.config.proxy`, `ctx.config`, `self.remote_url`, `shell_execute()`, and `logger` are + defined and accessible within the scope of this function. + """ + if not auth and not access_token: + raise ValueError('`access_token` is invalid. Visit: "https://github.com/settings/tokens"') + from metagpt.context import Context + + base = self.current_branch + head = base if not new_branch else self.new_branch(new_branch) + self.archive(comments) # will skip committing if no changes + ctx = Context() + env = ctx.new_environ() + proxy = ["-c", f"http.proxy={ctx.config.proxy}"] if ctx.config.proxy else [] + token = access_token or auth.token + remote_url = f"https://{token}@" + self.remote_url.removeprefix("https://") + command = ["git"] + proxy + ["push", remote_url] + logger.info(" ".join(command).replace(token, "")) + try: + stdout, stderr, return_code = await shell_execute( + command=command, cwd=str(self.workdir), env=env, timeout=15 + ) + except TimeoutExpired as e: + info = str(e).replace(token, "") + raise BadCredentialsException(status=401, message=info) + info = f"{stdout}\n{stderr}\nexit: {return_code}\n" + info = info.replace(token, "") + print(info) + + return GitBranch(base=base, head=head, repo_name=self.repo_name) + def new_file_repository(self, relative_path: Path | str = ".") -> FileRepository: """Create a new instance of FileRepository associated with this Git repository. @@ -248,6 +375,8 @@ def get_files(self, relative_path: Path | str, root_relative_path: Path | str = if not directory_path.exists(): return [] for file_path in directory_path.iterdir(): + if not file_path.is_relative_to(root_relative_path): + continue if file_path.is_file(): rpath = file_path.relative_to(root_relative_path) files.append(str(rpath)) @@ -283,3 +412,222 @@ def filter_gitignore(self, filenames: List[str], root_relative_path: Path | str continue files.append(filename) return files + + @classmethod + @retry(wait=wait_random_exponential(min=1, max=15), stop=stop_after_attempt(3)) + async def clone_from(cls, url: str | Path, output_dir: str | Path = None) -> "GitRepository": + from metagpt.context import Context + + to_path = Path(output_dir or Path(__file__).parent / f"../../workspace/downloads/{uuid.uuid4().hex}").resolve() + to_path.mkdir(parents=True, exist_ok=True) + repo_dir = to_path / Path(url).stem + if repo_dir.exists(): + shutil.rmtree(repo_dir, ignore_errors=True) + ctx = Context() + env = ctx.new_environ() + proxy = ["-c", f"http.proxy={ctx.config.proxy}"] if ctx.config.proxy else [] + command = ["git", "clone"] + proxy + [str(url)] + logger.info(" ".join(command)) + + stdout, stderr, return_code = await shell_execute(command=command, cwd=str(to_path), env=env, timeout=600) + info = f"{stdout}\n{stderr}\nexit: {return_code}\n" + logger.info(info) + dir_name = Path(url).stem + to_path = to_path / dir_name + if not cls.is_git_dir(to_path): + raise ValueError(info) + logger.info(f"git clone to {to_path}") + return GitRepository(local_path=to_path, auto_init=False) + + async def checkout(self, commit_id: str): + self._repository.git.checkout(commit_id) + logger.info(f"git checkout {commit_id}") + + def log(self) -> str: + """Return git log""" + return self._repository.git.log() + + @staticmethod + async def create_pull( + base: str, + head: str, + base_repo_name: str, + head_repo_name: Optional[str] = None, + *, + title: Optional[str] = None, + body: Optional[str] = None, + maintainer_can_modify: Optional[bool] = None, + draft: Optional[bool] = None, + issue: Optional[Issue] = None, + access_token: Optional[str] = None, + auth: Optional[Auth] = None, + ) -> Union[PullRequest, str]: + """ + Creates a pull request in the specified repository. + + Args: + base (str): The name of the base branch. + head (str): The name of the head branch. + base_repo_name (str): The full repository name (user/repo) where the pull request will be created. + head_repo_name (Optional[str], optional): The full repository name (user/repo) where the pull request will merge from. Defaults to None. + title (Optional[str], optional): The title of the pull request. Defaults to None. + body (Optional[str], optional): The body of the pull request. Defaults to None. + maintainer_can_modify (Optional[bool], optional): Whether maintainers can modify the pull request. Defaults to None. + draft (Optional[bool], optional): Whether the pull request is a draft. Defaults to None. + issue (Optional[Issue], optional): The issue linked to the pull request. Defaults to None. + access_token (Optional[str], optional): The access token for authentication. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`, `https://github.com/PyGithub/PyGithub/blob/main/doc/examples/Authentication.rst`. + auth (Optional[Auth], optional): The authentication method. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html` + + Returns: + PullRequest: The created pull request object. + """ + title = title or NotSet + body = body or NotSet + maintainer_can_modify = maintainer_can_modify or NotSet + draft = draft or NotSet + issue = issue or NotSet + if not auth and not access_token: + raise ValueError('`access_token` is invalid. Visit: "https://github.com/settings/tokens"') + clone_url = f"https://github.com/{base_repo_name}.git" + try: + auth = auth or Auth.Token(access_token) + g = Github(auth=auth) + base_repo = g.get_repo(base_repo_name) + clone_url = base_repo.clone_url + head_repo = g.get_repo(head_repo_name) if head_repo_name and head_repo_name != base_repo_name else None + if head_repo: + user = head_repo.full_name.split("/")[0] + head = f"{user}:{head}" + pr = base_repo.create_pull( + base=base, + head=head, + title=title, + body=body, + maintainer_can_modify=maintainer_can_modify, + draft=draft, + issue=issue, + ) + except Exception as e: + logger.warning(f"Pull Request Error: {e}") + return GitRepository.create_github_pull_url( + clone_url=clone_url, + base=base, + head=head, + head_repo_name=head_repo_name, + ) + return pr + + @staticmethod + async def create_issue( + repo_name: str, + title: str, + body: Optional[str] = None, + assignee: NamedUser | Optional[str] = None, + milestone: Optional[Milestone] = None, + labels: list[Label] | Optional[list[str]] = None, + assignees: Optional[list[str]] | list[NamedUser] = None, + access_token: Optional[str] = None, + auth: Optional[Auth] = None, + ) -> Issue: + """ + Creates an issue in the specified repository. + + Args: + repo_name (str): The full repository name (user/repo) where the issue will be created. + title (str): The title of the issue. + body (Optional[str], optional): The body of the issue. Defaults to None. + assignee (Union[NamedUser, str], optional): The assignee for the issue, either as a NamedUser object or their username. Defaults to None. + milestone (Optional[Milestone], optional): The milestone to associate with the issue. Defaults to None. + labels (Union[list[Label], list[str]], optional): The labels to associate with the issue, either as Label objects or their names. Defaults to None. + assignees (Union[list[str], list[NamedUser]], optional): The list of usernames or NamedUser objects to assign to the issue. Defaults to None. + access_token (Optional[str], optional): The access token for authentication. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`, `https://github.com/PyGithub/PyGithub/blob/main/doc/examples/Authentication.rst`. + auth (Optional[Auth], optional): The authentication method. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html` + + Returns: + Issue: The created issue object. + """ + body = body or NotSet + assignee = assignee or NotSet + milestone = milestone or NotSet + labels = labels or NotSet + assignees = assignees or NotSet + if not auth and not access_token: + raise ValueError('`access_token` is invalid. Visit: "https://github.com/settings/tokens"') + auth = auth or Auth.Token(access_token) + g = Github(auth=auth) + + repo = g.get_repo(repo_name) + x_ratelimit_remaining = repo.raw_headers.get("x-ratelimit-remaining") + if ( + x_ratelimit_remaining + and bool(re.match(r"^-?\d+$", x_ratelimit_remaining)) + and int(x_ratelimit_remaining) <= 0 + ): + raise RateLimitError() + issue = repo.create_issue( + title=title, + body=body, + assignee=assignee, + milestone=milestone, + labels=labels, + assignees=assignees, + ) + return issue + + @staticmethod + async def get_repos(access_token: Optional[str] = None, auth: Optional[Auth] = None) -> List[str]: + """ + Fetches a list of public repositories belonging to the authenticated user. + + Args: + access_token (Optional[str], optional): The access token for authentication. Defaults to None. + Visit `https://github.com/settings/tokens` for obtaining a personal access token. + auth (Optional[Auth], optional): The authentication method. Defaults to None. + Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html` for more information. + + Returns: + List[str]: A list of full names of the public repositories belonging to the user. + """ + auth = auth or Auth.Token(access_token) + git = Github(auth=auth) + user = git.get_user() + v = user.get_repos(visibility="public") + return [i.full_name for i in v] + + @staticmethod + def create_github_pull_url(clone_url: str, base: str, head: str, head_repo_name: Optional[str] = None) -> str: + """ + Create a URL for comparing changes between branches or repositories on GitHub. + + Args: + clone_url (str): The URL used for cloning the repository, ending with '.git'. + base (str): The base branch or commit. + head (str): The head branch or commit. + head_repo_name (str, optional): The name of the repository for the head branch. If not provided, assumes the same repository. + + Returns: + str: The URL for comparing changes between the specified branches or commits. + """ + url = clone_url.removesuffix(".git") + f"/compare/{base}..." + if head_repo_name: + url += head_repo_name.replace("/", ":") + url += ":" + head + return url + + @staticmethod + def create_gitlab_merge_request_url(clone_url: str, head: str) -> str: + """ + Create a URL for creating a new merge request on GitLab. + + Args: + clone_url (str): The URL used for cloning the repository, ending with '.git'. + head (str): The name of the branch to be merged. + + Returns: + str: The URL for creating a new merge request for the specified branch. + """ + return ( + clone_url.removesuffix(".git") + + "/-/merge_requests/new?merge_request%5Bsource_branch%5D=" + + quote(head, safe="") + ) diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py index eb1fc5e120..f4219fac35 100644 --- a/metagpt/utils/graph_repository.py +++ b/metagpt/utils/graph_repository.py @@ -49,6 +49,10 @@ class GraphKeyword: IS_COMPOSITE_OF = "is_composite_of" IS_AGGREGATE_OF = "is_aggregate_of" HAS_PARTICIPANT = "has_participant" + HAS_SUMMARY = "has_summary" + HAS_INSTALL = "has_install" + HAS_CONFIG = "has_config" + HAS_USAGE = "has_usage" class SPO(BaseModel): diff --git a/metagpt/utils/make_sk_kernel.py b/metagpt/utils/make_sk_kernel.py index 283a682d64..f0c55b07c9 100644 --- a/metagpt/utils/make_sk_kernel.py +++ b/metagpt/utils/make_sk_kernel.py @@ -13,10 +13,11 @@ OpenAIChatCompletion, ) -from metagpt.config2 import config +from metagpt.config2 import Config def make_sk_kernel(): + config = Config.default() kernel = sk.Kernel() if llm := config.get_azure_llm(): kernel.add_chat_service( diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py index e1d140e849..88e58ae444 100644 --- a/metagpt/utils/mermaid.py +++ b/metagpt/utils/mermaid.py @@ -7,23 +7,44 @@ """ import asyncio import os +import re from pathlib import Path +from typing import List, Optional -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger from metagpt.utils.common import awrite, check_cmd_exists -async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: - """suffix: png/svg/pdf +async def mermaid_to_file( + engine, + mermaid_code, + output_file_without_suffix, + width=2048, + height=2048, + config=None, + suffixes: Optional[List[str]] = None, +) -> int: + """Convert Mermaid code to various file formats. - :param mermaid_code: mermaid code - :param output_file_without_suffix: output filename - :param width: - :param height: - :return: 0 if succeed, -1 if failed + Args: + engine (str): The engine to use for conversion. Supported engines are "nodejs", "playwright", "pyppeteer", "ink", and "none". + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. + + Returns: + int: 0 if the conversion is successful, -1 if the conversion fails. """ + file_head = "%%{init: {'theme': 'default', 'themeVariables': { 'fontFamily': 'Inter' }}}%%\n" + if not re.match(r"^%%\{.+", mermaid_code): + mermaid_code = file_head + mermaid_code + suffixes = suffixes or ["svg"] # Write the Mermaid code to a temporary file + config = config if config else Config.default() dir_name = os.path.dirname(output_file_without_suffix) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name) @@ -38,7 +59,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt ) return -1 - for suffix in ["pdf", "svg", "png"]: + for suffix in suffixes: output_file = f"{output_file_without_suffix}.{suffix}" # Call the `mmdc` command to convert the Mermaid code to a PNG logger.info(f"Generating {output_file}..") @@ -72,15 +93,17 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt if engine == "playwright": from metagpt.utils.mmdc_playwright import mermaid_to_file - return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height) + return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes) elif engine == "pyppeteer": from metagpt.utils.mmdc_pyppeteer import mermaid_to_file - return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height) + return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes) elif engine == "ink": from metagpt.utils.mmdc_ink import mermaid_to_file - return await mermaid_to_file(mermaid_code, output_file_without_suffix) + return await mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes=suffixes) + elif engine == "none": + return 0 else: logger.warning(f"Unsupported mermaid engine: {engine}") return 0 diff --git a/metagpt/utils/mmdc_ink.py b/metagpt/utils/mmdc_ink.py index d594adb300..15d6d6083a 100644 --- a/metagpt/utils/mmdc_ink.py +++ b/metagpt/utils/mmdc_ink.py @@ -6,21 +6,29 @@ @File : mermaid.py """ import base64 +from typing import List, Optional from aiohttp import ClientError, ClientSession from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix): - """suffix: png/svg - :param mermaid_code: mermaid code - :param output_file_without_suffix: output filename without suffix - :return: 0 if succeed, -1 if failed +async def mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes: Optional[List[str]] = None): + """Convert Mermaid code to various file formats. + + Args: + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. + + Returns: + int: 0 if the conversion is successful, -1 if the conversion fails. """ encoded_string = base64.b64encode(mermaid_code.encode()).decode() - - for suffix in ["svg", "png"]: + suffixes = suffixes or ["png"] + for suffix in suffixes: output_file = f"{output_file_without_suffix}.{suffix}" path_type = "svg" if suffix == "svg" else "img" url = f"https://mermaid.ink/{path_type}/{encoded_string}" diff --git a/metagpt/utils/mmdc_playwright.py b/metagpt/utils/mmdc_playwright.py index 5d455e1c50..cf846a7e92 100644 --- a/metagpt/utils/mmdc_playwright.py +++ b/metagpt/utils/mmdc_playwright.py @@ -7,6 +7,7 @@ """ import os +from typing import List, Optional from urllib.parse import urljoin from playwright.async_api import async_playwright @@ -14,20 +15,22 @@ from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: - """ - Converts the given Mermaid code to various output formats and saves them to files. +async def mermaid_to_file( + mermaid_code, output_file_without_suffix, width=2048, height=2048, suffixes: Optional[List[str]] = None +) -> int: + """Convert Mermaid code to various file formats. Args: - mermaid_code (str): The Mermaid code to convert. - output_file_without_suffix (str): The output file name without the file extension. - width (int, optional): The width of the output image in pixels. Defaults to 2048. - height (int, optional): The height of the output image in pixels. Defaults to 2048. + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. Returns: - int: Returns 1 if the conversion and saving were successful, -1 otherwise. + int: 0 if the conversion is successful, -1 if the conversion fails. """ - suffixes = ["png", "svg", "pdf"] + suffixes = suffixes or ["png"] __dirname = os.path.dirname(os.path.abspath(__file__)) async with async_playwright() as p: diff --git a/metagpt/utils/mmdc_pyppeteer.py b/metagpt/utils/mmdc_pyppeteer.py index f029325f18..36b77b5b26 100644 --- a/metagpt/utils/mmdc_pyppeteer.py +++ b/metagpt/utils/mmdc_pyppeteer.py @@ -6,28 +6,33 @@ @File : mmdc_pyppeteer.py """ import os +from typing import List, Optional from urllib.parse import urljoin from pyppeteer import launch -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: - """ - Converts the given Mermaid code to various output formats and saves them to files. +async def mermaid_to_file( + mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None, suffixes: Optional[List[str]] = None +) -> int: + """Convert Mermaid code to various file formats. Args: - mermaid_code (str): The Mermaid code to convert. - output_file_without_suffix (str): The output file name without the file extension. - width (int, optional): The width of the output image in pixels. Defaults to 2048. - height (int, optional): The height of the output image in pixels. Defaults to 2048. + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. Returns: - int: Returns 1 if the conversion and saving were successful, -1 otherwise. + int: 0 if the conversion is successful, -1 if the conversion fails. """ - suffixes = ["png", "svg", "pdf"] + config = config if config else Config.default() + suffixes = suffixes or ["png"] __dirname = os.path.dirname(os.path.abspath(__file__)) if config.mermaid.pyppeteer_path: diff --git a/metagpt/utils/omniparse_client.py b/metagpt/utils/omniparse_client.py new file mode 100644 index 0000000000..361e84fd15 --- /dev/null +++ b/metagpt/utils/omniparse_client.py @@ -0,0 +1,238 @@ +import mimetypes +from pathlib import Path +from typing import Union + +import httpx + +from metagpt.rag.schema import OmniParsedResult +from metagpt.utils.common import aread_bin + + +class OmniParseClient: + """ + OmniParse Server Client + This client interacts with the OmniParse server to parse different types of media, documents. + + OmniParse API Documentation: https://docs.cognitivelab.in/api + + Attributes: + ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions. + ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions. + ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions. + """ + + ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"} + ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"} + ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"} + + def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120): + """ + Args: + api_key: Default None, can be used for authentication later. + base_url: Base URL for the API. + max_timeout: Maximum request timeout in seconds. + """ + self.api_key = api_key + self.base_url = base_url + self.max_timeout = max_timeout + + self.parse_media_endpoint = "/parse_media" + self.parse_website_endpoint = "/parse_website" + self.parse_document_endpoint = "/parse_document" + + async def _request_parse( + self, + endpoint: str, + method: str = "POST", + files: dict = None, + params: dict = None, + data: dict = None, + json: dict = None, + headers: dict = None, + **kwargs, + ) -> dict: + """ + Request OmniParse API to parse a document. + + Args: + endpoint (str): API endpoint. + method (str, optional): HTTP method to use. Default is "POST". + files (dict, optional): Files to include in the request. + params (dict, optional): Query string parameters. + data (dict, optional): Form data to include in the request body. + json (dict, optional): JSON data to include in the request body. + headers (dict, optional): HTTP headers to include in the request. + **kwargs: Additional keyword arguments for httpx.AsyncClient.request() + + Returns: + dict: JSON response data. + """ + url = f"{self.base_url}{endpoint}" + method = method.upper() + headers = headers or {} + _headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + headers.update(**_headers) + async with httpx.AsyncClient() as client: + response = await client.request( + url=url, + method=method, + files=files, + params=params, + json=json, + data=data, + headers=headers, + timeout=self.max_timeout, + **kwargs, + ) + response.raise_for_status() + return response.json() + + async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: + """ + Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the document parsing. + """ + self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult: + """ + Parse pdf document. + + Args: + file_input: File path or file byte data. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the pdf parsing. + """ + self.verify_file_ext(file_input, {".pdf"}) + # parse_pdf supports parsing by accepting only the byte data of the file. + file_info = await self.get_file_info(file_input, only_bytes=True) + endpoint = f"{self.parse_document_endpoint}/pdf" + resp = await self._request_parse(endpoint=endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info}) + + async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse audio-type data (supports ".mp3", ".wav", ".aac"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info}) + + @staticmethod + def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): + """ + Verify the file extension. + + Args: + file_input: File path or file byte data. + allowed_file_extensions: Set of allowed file extensions. + bytes_filename: Filename to use for verification when `file_input` is byte data. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + """ + verify_file_path = None + if isinstance(file_input, (str, Path)): + verify_file_path = str(file_input) + elif isinstance(file_input, bytes) and bytes_filename: + verify_file_path = bytes_filename + + if not verify_file_path: + # Do not verify if only byte data is provided + return + + file_ext = Path(verify_file_path).suffix.lower() + if file_ext not in allowed_file_extensions: + raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}") + + @staticmethod + async def get_file_info( + file_input: Union[str, bytes, Path], + bytes_filename: str = None, + only_bytes: bool = False, + ) -> Union[bytes, tuple]: + """ + Get file information. + + Args: + file_input: File path or file byte data. + bytes_filename: Filename to use when uploading byte data, useful for determining MIME type. + only_bytes: Whether to return only byte data. Default is False, which returns a tuple. + + Raises: + ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type. + + Notes: + Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types, + the MIME type of the file must be specified when uploading. + + Returns: [bytes, tuple] + Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type). + """ + if isinstance(file_input, (str, Path)): + filename = Path(file_input).name + file_bytes = await aread_bin(file_input) + + if only_bytes: + return file_bytes + + mime_type = mimetypes.guess_type(file_input)[0] + return filename, file_bytes, mime_type + elif isinstance(file_input, bytes): + if only_bytes: + return file_input + if not bytes_filename: + raise ValueError("bytes_filename must be set when passing bytes") + + mime_type = mimetypes.guess_type(bytes_filename)[0] + return bytes_filename, file_input, mime_type + else: + raise ValueError("file_input must be a string (file path) or bytes.") diff --git a/metagpt/utils/parse_docstring.py b/metagpt/utils/parse_docstring.py index 63c0e68909..5df4d66712 100644 --- a/metagpt/utils/parse_docstring.py +++ b/metagpt/utils/parse_docstring.py @@ -3,7 +3,7 @@ def remove_spaces(text): - return re.sub(r"\s+", " ", text).strip() + return re.sub(r"\s+", " ", text).strip() if text else "" class DocstringParser: diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py index 65aa3f2369..985e54d962 100644 --- a/metagpt/utils/parse_html.py +++ b/metagpt/utils/parse_html.py @@ -4,6 +4,7 @@ from typing import Generator, Optional from urllib.parse import urljoin, urlparse +import htmlmin from bs4 import BeautifulSoup from pydantic import BaseModel, PrivateAttr @@ -38,6 +39,22 @@ def get_links(self) -> Generator[str, None, None]: elif url.startswith(("http://", "https://")): yield urljoin(self.url, url) + def get_slim_soup(self, keep_links: bool = False): + soup = _get_soup(self.html) + keep_attrs = ["class", "id"] + if keep_links: + keep_attrs.append("href") + + for i in soup.find_all(True): + for name in list(i.attrs): + if i[name] and name not in keep_attrs: + del i[name] + + for i in soup.find_all(["svg", "img", "video", "audio"]): + i.decompose() + + return soup + def get_html_content(page: str, base: str): soup = _get_soup(page) @@ -48,7 +65,12 @@ def get_html_content(page: str, base: str): def _get_soup(page: str): soup = BeautifulSoup(page, "html.parser") # https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup - for s in soup(["style", "script", "[document]", "head", "title"]): + for s in soup(["style", "script", "[document]", "head", "title", "footer"]): s.extract() return soup + + +def simplify_html(html: str, url: str, keep_links: bool = False): + html = WebPage(inner_text="", html=html, url=url).get_slim_soup(keep_links).decode() + return htmlmin.minify(html, remove_comments=True, remove_empty_space=True) diff --git a/metagpt/utils/project_repo.py b/metagpt/utils/project_repo.py index bb18b520c2..5761c0188f 100644 --- a/metagpt/utils/project_repo.py +++ b/metagpt/utils/project_repo.py @@ -10,6 +10,7 @@ from __future__ import annotations from pathlib import Path +from typing import Optional from metagpt.const import ( CLASS_VIEW_FILE_REPO, @@ -35,6 +36,7 @@ TEST_OUTPUTS_FILE_REPO, VISUAL_GRAPH_REPO_FILE_REPO, ) +from metagpt.utils.common import get_project_srcs_path from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository @@ -129,22 +131,33 @@ def srcs(self) -> FileRepository: return self._git_repo.new_file_repository(self._srcs_path) def code_files_exists(self) -> bool: - git_workdir = self.git_repo.workdir - src_workdir = git_workdir / git_workdir.name + src_workdir = get_project_srcs_path(self.git_repo.workdir) if not src_workdir.exists(): return False - code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files + code_files = self.with_src_path(path=src_workdir).srcs.all_files if not code_files: return False return bool(code_files) def with_src_path(self, path: str | Path) -> ProjectRepo: - try: - self._srcs_path = Path(path).relative_to(self.workdir) - except ValueError: - self._srcs_path = Path(path) + path = Path(path) + if path.is_relative_to(self.workdir): + self._srcs_path = path.relative_to(self.workdir) + else: + self._srcs_path = path return self @property def src_relative_path(self) -> Path | None: return self._srcs_path + + @staticmethod + def search_project_path(filename: str | Path) -> Optional[Path]: + root = Path(filename).parent if Path(filename).is_file() else Path(filename) + root = root.resolve() + while str(root) != "/": + git_repo = root / ".git" + if git_repo.exists(): + return root + root = root.parent + return None diff --git a/metagpt/utils/proxy_env.py b/metagpt/utils/proxy_env.py new file mode 100644 index 0000000000..bcb5c84f5f --- /dev/null +++ b/metagpt/utils/proxy_env.py @@ -0,0 +1,19 @@ +import os + + +def get_proxy_from_env(): + proxy_config = {} + server = None + for i in ("ALL_PROXY", "all_proxy", "HTTPS_PROXY", "https_proxy", "HTTP_PROXY", "http_proxy"): + if os.environ.get(i): + server = os.environ.get(i) + if server: + proxy_config["server"] = server + no_proxy = os.environ.get("NO_PROXY") or os.environ.get("no_proxy") + if no_proxy: + proxy_config["bypass"] = no_proxy + + if not proxy_config: + proxy_config = None + + return proxy_config diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 17e095c5f5..5c57693f74 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -4,12 +4,12 @@ import copy from enum import Enum -from typing import Callable, Union +from typing import Callable, Optional, Union import regex as re from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.logs import logger from metagpt.utils.custom_decoder import CustomDecoder @@ -154,7 +154,9 @@ def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = return output -def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: +def repair_llm_raw_output( + output: str, req_keys: list[str], repair_type: RepairType = None, config: Optional[Config] = None +) -> str: """ in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, so here we try to repair it and use all repair methods by default. @@ -169,6 +171,7 @@ def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairT target: { xxx } output: { xxx }] """ + config = config if config else Config.default() if not config.repair_llm_output: return output @@ -256,6 +259,7 @@ def run_and_passon(retry_state: RetryCallState) -> None: "next_action":"None" } """ + config = Config.default() if retry_state.outcome.failed: if retry_state.args: # # can't be used as args=retry_state.args @@ -276,8 +280,12 @@ def run_and_passon(retry_state: RetryCallState) -> None: return run_and_passon +def repair_stop_after_attempt(retry_state): + return stop_after_attempt(3 if Config.default().repair_llm_output else 0)(retry_state) + + @retry( - stop=stop_after_attempt(3 if config.repair_llm_output else 0), + stop=repair_stop_after_attempt, wait=wait_fixed(1), after=run_after_exp_and_passon_next_retry(logger), ) @@ -347,3 +355,44 @@ def extract_state_value_from_output(content: str) -> str: matches = list(set(matches)) state = matches[0] if len(matches) > 0 else "-1" return state + + +def repair_escape_error(commands): + """ + Repaires escape errors in command responses. + When RoleZero parses a command, the command may contain unknown escape characters. + + This function has two steps: + 1. Transform unescaped substrings like "\d" and "\(" to "\\\\d" and "\\\\(". + 2. Transform escaped characters like '\f' to substrings like "\\\\f". + + Example: + When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ", + The "content" will be parsed correctly to "\( \frac{1}{2} \)". + + However, if the original JSON string is " {"content":"\( \frac{1}{2} \)"}" directly. + It will cause a parsing error. + + To repair the wrong JSON string, the following transformations will be used: + "\(" ---> "\\\\(" + '\f' ---> "\\\\f" + "\)" ---> "\\\\)" + + """ + escape_repair_map = { + "\a": "\\\\a", + "\b": "\\\\b", + "\f": "\\\\f", + "\r": "\\\\r", + "\t": "\\\\t", + "\v": "\\\\v", + } + new_command = "" + for index, ch in enumerate(commands): + if ch == "\\" and index + 1 < len(commands): + if commands[index + 1] not in ["n", '"', " "]: + new_command += "\\" + elif ch in escape_repair_map: + ch = escape_repair_map[ch] + new_command += ch + return new_command diff --git a/metagpt/utils/repo_to_markdown.py b/metagpt/utils/repo_to_markdown.py index 76dfe1b829..a5bffffe1d 100644 --- a/metagpt/utils/repo_to_markdown.py +++ b/metagpt/utils/repo_to_markdown.py @@ -5,17 +5,24 @@ """ from __future__ import annotations -import mimetypes +import re from pathlib import Path +from typing import Tuple, Union from gitignore_parser import parse_gitignore from metagpt.logs import logger -from metagpt.utils.common import aread, awrite, get_markdown_codeblock_type, list_files +from metagpt.utils.common import ( + aread, + awrite, + get_markdown_codeblock_type, + get_mime_type, + list_files, +) from metagpt.utils.tree import tree -async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str: +async def repo_to_markdown(repo_path: str | Path, output: str | Path = None) -> str: """ Convert a local repository into a markdown representation. @@ -25,56 +32,118 @@ async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, git Args: repo_path (str | Path): The path to the local repository. output (str | Path, optional): The path to save the generated markdown file. Defaults to None. - gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None. Returns: str: The markdown representation of the repository. """ - repo_path = Path(repo_path) - gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve() + repo_path = Path(repo_path).resolve() + gitignore_file = repo_path / ".gitignore" - markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore) + markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore_file) - gitignore_rules = parse_gitignore(full_path=str(gitignore)) + gitignore_rules = parse_gitignore(full_path=str(gitignore_file)) if gitignore_file.exists() else None markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules) if output: - await awrite(filename=str(output), data=markdown, encoding="utf-8") + output_file = Path(output).resolve() + output_file.parent.mkdir(parents=True, exist_ok=True) + await awrite(filename=str(output_file), data=markdown, encoding="utf-8") + logger.info(f"save: {output_file}") return markdown async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str: try: - content = tree(repo_path, gitignore, run_command=True) + content = await tree(repo_path, gitignore, run_command=True) except Exception as e: logger.info(f"{e}, using safe mode.") - content = tree(repo_path, gitignore, run_command=False) + content = await tree(repo_path, gitignore, run_command=False) doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n" return doc -async def _write_files(repo_path, gitignore_rules) -> str: +async def _write_files(repo_path, gitignore_rules=None) -> str: filenames = list_files(repo_path) markdown = "" + pattern = r"^\..*" # Hidden folders/files for filename in filenames: - if gitignore_rules(str(filename)): + if gitignore_rules and gitignore_rules(str(filename)): + continue + ignore = False + for i in filename.parts: + if re.match(pattern, i): + ignore = True + break + if ignore: continue markdown += await _write_file(filename=filename, repo_path=repo_path) return markdown async def _write_file(filename: Path, repo_path: Path) -> str: - relative_path = filename.relative_to(repo_path) - markdown = f"## {relative_path}\n" - - mime_type, _ = mimetypes.guess_type(filename.name) - if "text/" not in mime_type: + is_text, mime_type = await is_text_file(filename) + if not is_text: logger.info(f"Ignore content: {filename}") - markdown += "\n---\n\n" + return "" + + try: + relative_path = filename.relative_to(repo_path) + markdown = f"## {relative_path}\n" + content = await aread(filename, encoding="utf-8") + content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-") + code_block_type = get_markdown_codeblock_type(filename.name) + markdown += f"```{code_block_type}\n{content}\n```\n---\n\n" return markdown - content = await aread(filename, encoding="utf-8") - content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-") - code_block_type = get_markdown_codeblock_type(filename.name) - markdown += f"```{code_block_type}\n{content}\n```\n---\n\n" - return markdown + except Exception as e: + logger.error(e) + return "" + + +async def is_text_file(filename: Union[str, Path]) -> Tuple[bool, str]: + """ + Determines if the specified file is a text file based on its MIME type. + + Args: + filename (Union[str, Path]): The path to the file. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates if the file is a text file + (True for text file, False otherwise), and the second element is the MIME type of the file. + """ + pass_set = { + "application/json", + "application/vnd.chipnuts.karaoke-mmd", + "application/javascript", + "application/xml", + "application/x-sh", + "application/sql", + } + denied_set = { + "application/zlib", + "application/octet-stream", + "image/svg+xml", + "application/pdf", + "application/msword", + "application/vnd.ms-excel", + "audio/x-wav", + "application/x-git", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/zip", + "image/jpeg", + "audio/mpeg", + "video/mp2t", + "inode/x-empty", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "image/png", + "image/vnd.microsoft.icon", + "video/mp4", + } + mime_type = await get_mime_type(Path(filename), force_read=True) + v = "text/" in mime_type or mime_type in pass_set + if v: + return True, mime_type + + if mime_type not in denied_set: + logger.info(mime_type) + return False, mime_type diff --git a/metagpt/utils/report.py b/metagpt/utils/report.py new file mode 100644 index 0000000000..0ae19beba8 --- /dev/null +++ b/metagpt/utils/report.py @@ -0,0 +1,330 @@ +import asyncio +import os +import typing +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Literal, Optional, Union +from urllib.parse import unquote, urlparse, urlunparse +from uuid import UUID, uuid4 + +from aiohttp import ClientSession, UnixConnector +from playwright.async_api import Page as AsyncPage +from playwright.sync_api import Page as SyncPage +from pydantic import BaseModel, Field, PrivateAttr + +from metagpt.const import METAGPT_REPORTER_DEFAULT_URL +from metagpt.logs import create_llm_stream_queue, get_llm_stream_queue + +if typing.TYPE_CHECKING: + from metagpt.roles.role import Role + +try: + import requests_unixsocket as requests +except ImportError: + import requests + +from contextvars import ContextVar + +CURRENT_ROLE: ContextVar["Role"] = ContextVar("role") + + +class BlockType(str, Enum): + """Enumeration for different types of blocks.""" + + TERMINAL = "Terminal" + TASK = "Task" + BROWSER = "Browser" + BROWSER_RT = "Browser-RT" + EDITOR = "Editor" + GALLERY = "Gallery" + NOTEBOOK = "Notebook" + DOCS = "Docs" + THOUGHT = "Thought" + + +END_MARKER_NAME = "end_marker" +END_MARKER_VALUE = "\x18\x19\x1B\x18\n" + + +class ResourceReporter(BaseModel): + """Base class for resource reporting.""" + + block: BlockType = Field(description="The type of block that is reporting the resource") + uuid: UUID = Field(default_factory=uuid4, description="The unique identifier for the resource") + enable_llm_stream: bool = Field(False, description="Indicates whether to connect to an LLM stream for reporting") + callback_url: str = Field(METAGPT_REPORTER_DEFAULT_URL, description="The URL to which the report should be sent") + _llm_task: Optional[asyncio.Task] = PrivateAttr(None) + + def report(self, value: Any, name: str, extra: Optional[dict] = None): + """Synchronously report resource observation data. + + Args: + value: The data to report. + name: The type name of the data. + """ + return self._report(value, name, extra) + + async def async_report(self, value: Any, name: str, extra: Optional[dict] = None): + """Asynchronously report resource observation data. + + Args: + value: The data to report. + name: The type name of the data. + """ + return await self._async_report(value, name, extra) + + @classmethod + def set_report_fn(cls, fn: Callable): + """Set the synchronous report function. + + Args: + fn: A callable function used for synchronous reporting. For example: + + >>> def _report(self, value: Any, name: str): + ... print(value, name) + + """ + cls._report = fn + + @classmethod + def set_async_report_fn(cls, fn: Callable): + """Set the asynchronous report function. + + Args: + fn: A callable function used for asynchronous reporting. For example: + + ```python + >>> async def _report(self, value: Any, name: str): + ... print(value, name) + ``` + """ + cls._async_report = fn + + def _report(self, value: Any, name: str, extra: Optional[dict] = None): + if not self.callback_url: + return + + data = self._format_data(value, name, extra) + resp = requests.post(self.callback_url, json=data) + resp.raise_for_status() + return resp.text + + async def _async_report(self, value: Any, name: str, extra: Optional[dict] = None): + if not self.callback_url: + return + + data = self._format_data(value, name, extra) + url = self.callback_url + _result = urlparse(url) + sessiion_kwargs = {} + if _result.scheme.endswith("+unix"): + parsed_list = list(_result) + parsed_list[0] = parsed_list[0][:-5] + parsed_list[1] = "fake.org" + url = urlunparse(parsed_list) + sessiion_kwargs["connector"] = UnixConnector(path=unquote(_result.netloc)) + + async with ClientSession(**sessiion_kwargs) as client: + async with client.post(url, json=data) as resp: + resp.raise_for_status() + return await resp.text() + + def _format_data(self, value, name, extra): + data = self.model_dump(mode="json", exclude=("callback_url", "llm_stream")) + if isinstance(value, BaseModel): + value = value.model_dump(mode="json") + elif isinstance(value, Path): + value = str(value) + + if name == "path": + value = os.path.abspath(value) + data["value"] = value + data["name"] = name + role = CURRENT_ROLE.get(None) + if role: + role_name = role.name + else: + role_name = os.environ.get("METAGPT_ROLE") + data["role"] = role_name + if extra: + data["extra"] = extra + return data + + def __enter__(self): + """Enter the synchronous streaming callback context.""" + return self + + def __exit__(self, *args, **kwargs): + """Exit the synchronous streaming callback context.""" + self.report(None, END_MARKER_NAME) + + async def __aenter__(self): + """Enter the asynchronous streaming callback context.""" + if self.enable_llm_stream: + queue = create_llm_stream_queue() + self._llm_task = asyncio.create_task(self._llm_stream_report(queue)) + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + """Exit the asynchronous streaming callback context.""" + if self.enable_llm_stream and exc_type != asyncio.CancelledError: + await get_llm_stream_queue().put(None) + await self._llm_task + self._llm_task = None + await self.async_report(None, END_MARKER_NAME) + + async def _llm_stream_report(self, queue: asyncio.Queue): + while True: + data = await queue.get() + if data is None: + return + await self.async_report(data, "content") + + async def wait_llm_stream_report(self): + """Wait for the LLM stream report to complete.""" + queue = get_llm_stream_queue() + while self._llm_task: + if queue.empty(): + break + await asyncio.sleep(0.01) + + +class TerminalReporter(ResourceReporter): + """Terminal output callback for streaming reporting of command and output. + + The terminal has state, and an agent can open multiple terminals and input different commands into them. + To correctly display these states, each terminal should have its own unique ID, so in practice, each terminal + should instantiate its own TerminalReporter object. + """ + + block: Literal[BlockType.TERMINAL] = BlockType.TERMINAL + + def report(self, value: str, name: Literal["cmd", "output"]): + """Report terminal command or output synchronously.""" + return super().report(value, name) + + async def async_report(self, value: str, name: Literal["cmd", "output"]): + """Report terminal command or output asynchronously.""" + return await super().async_report(value, name) + + +class BrowserReporter(ResourceReporter): + """Browser output callback for streaming reporting of requested URL and page content. + + The browser has state, so in practice, each browser should instantiate its own BrowserReporter object. + """ + + block: Literal[BlockType.BROWSER] = BlockType.BROWSER + + def report(self, value: Union[str, SyncPage], name: Literal["url", "page"]): + """Report browser URL or page content synchronously.""" + if name == "page": + value = {"page_url": value.url, "title": value.title(), "screenshot": str(value.screenshot())} + return super().report(value, name) + + async def async_report(self, value: Union[str, AsyncPage], name: Literal["url", "page"]): + """Report browser URL or page content asynchronously.""" + if name == "page": + value = {"page_url": value.url, "title": await value.title(), "screenshot": str(await value.screenshot())} + return await super().async_report(value, name) + + +class ServerReporter(ResourceReporter): + """Callback for server deployment reporting.""" + + block: Literal[BlockType.BROWSER_RT] = BlockType.BROWSER_RT + + def report(self, value: str, name: Literal["local_url"] = "local_url"): + """Report server deployment synchronously.""" + return super().report(value, name) + + async def async_report(self, value: str, name: Literal["local_url"] = "local_url"): + """Report server deployment asynchronously.""" + return await super().async_report(value, name) + + +class ObjectReporter(ResourceReporter): + """Callback for reporting complete object resources.""" + + def report(self, value: dict, name: Literal["object"] = "object"): + """Report object resource synchronously.""" + return super().report(value, name) + + async def async_report(self, value: dict, name: Literal["object"] = "object"): + """Report object resource asynchronously.""" + return await super().async_report(value, name) + + +class TaskReporter(ObjectReporter): + """Reporter for object resources to Task Block.""" + + block: Literal[BlockType.TASK] = BlockType.TASK + + +class ThoughtReporter(ObjectReporter): + """Reporter for object resources to Task Block.""" + + block: Literal[BlockType.THOUGHT] = BlockType.THOUGHT + + +class FileReporter(ResourceReporter): + """File resource callback for reporting complete file paths. + + There are two scenarios: if the file needs to be output in its entirety at once, use non-streaming callback; + if the file can be partially output for display first, use streaming callback. + """ + + def report( + self, + value: Union[Path, dict, Any], + name: Literal["path", "meta", "content"] = "path", + extra: Optional[dict] = None, + ): + """Report file resource synchronously.""" + return super().report(value, name, extra) + + async def async_report( + self, + value: Union[Path, dict, Any], + name: Literal["path", "meta", "content"] = "path", + extra: Optional[dict] = None, + ): + """Report file resource asynchronously.""" + return await super().async_report(value, name, extra) + + +class NotebookReporter(FileReporter): + """Equivalent to FileReporter(block=BlockType.NOTEBOOK).""" + + block: Literal[BlockType.NOTEBOOK] = BlockType.NOTEBOOK + + +class DocsReporter(FileReporter): + """Equivalent to FileReporter(block=BlockType.DOCS).""" + + block: Literal[BlockType.DOCS] = BlockType.DOCS + + +class EditorReporter(FileReporter): + """Equivalent to FileReporter(block=BlockType.EDITOR).""" + + block: Literal[BlockType.EDITOR] = BlockType.EDITOR + + +class GalleryReporter(FileReporter): + """Image resource callback for reporting complete file paths. + + Since images need to be complete before display, each callback is a complete file path. However, the Gallery + needs to display the type of image and prompt, so if there is meta information, it should be reported in a + streaming manner. + """ + + block: Literal[BlockType.GALLERY] = BlockType.GALLERY + + def report(self, value: Union[dict, Path], name: Literal["meta", "path"] = "path"): + """Report image resource synchronously.""" + return super().report(value, name) + + async def async_report(self, value: Union[dict, Path], name: Literal["meta", "path"] = "path"): + """Report image resource asynchronously.""" + return await super().async_report(value, name) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 0ba2daa893..b235ceb7ba 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -13,6 +13,7 @@ import tiktoken TOKEN_COSTS = { + "anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015}, "gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002}, "gpt-3.5-turbo-0301": {"prompt": 0.0015, "completion": 0.002}, "gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002}, @@ -32,6 +33,10 @@ "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4o": {"prompt": 0.005, "completion": 0.015}, + "gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015}, + "gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006}, + "gpt-4o-mini-2024-07-18": {"prompt": 0.00015, "completion": 0.0006}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens "glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens @@ -51,6 +56,18 @@ "claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075}, "yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003}, "yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017}, + "openai/gpt-4": {"prompt": 0.03, "completion": 0.06}, # start, for openrouter + "openai/gpt-4-turbo": {"prompt": 0.01, "completion": 0.03}, + "openai/gpt-4o": {"prompt": 0.005, "completion": 0.015}, + "openai/gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015}, + "openai/gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006}, + "openai/gpt-4o-mini-2024-07-18": {"prompt": 0.00015, "completion": 0.0006}, + "google/gemini-pro-1.5": {"prompt": 0.0025, "completion": 0.0075}, + "google/gemini-flash-1.5": {"prompt": 0.00025, "completion": 0.00075}, + "deepseek/deepseek-coder": {"prompt": 0.00014, "completion": 0.00028}, + "deepseek/deepseek-chat": {"prompt": 0.00014, "completion": 0.00028}, # end, for openrouter + "deepseek-chat": {"prompt": 0.00014, "completion": 0.00028}, + "deepseek-coder": {"prompt": 0.00014, "completion": 0.00028}, } @@ -145,11 +162,16 @@ # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo TOKEN_MAX = { + "gpt-4o-2024-05-13": 128000, + "gpt-4o": 128000, + "gpt-4o-mini": 128000, + "gpt-4o-mini-2024-07-18": 128000, "gpt-4-0125-preview": 128000, "gpt-4-turbo-preview": 128000, "gpt-4-1106-preview": 128000, "gpt-4-vision-preview": 128000, "gpt-4-1106-vision-preview": 128000, + "gpt-4-turbo": 128000, "gpt-4": 8192, "gpt-4-0613": 8192, "gpt-4-32k": 32768, @@ -180,6 +202,20 @@ "claude-3-opus-20240229": 200000, "yi-34b-chat-0205": 4000, "yi-34b-chat-200k": 200000, + "openai/gpt-4": 8192, # start, for openrouter + "openai/gpt-4-turbo": 128000, + "openai/gpt-4o": 128000, + "openai/gpt-4o-2024-05-13": 128000, + "openai/gpt-4o-mini": 128000, + "openai/gpt-4o-mini-2024-07-18": 128000, + "anthropic/claude-3.5-sonnet": 200000, + "google/gemini-pro-1.5": 2800000, + "google/gemini-flash-1.5": 2800000, + "deepseek/deepseek-coder": 128000, + "deepseek/deepseek-chat": 128000, # end, for openrouter + "deepseek-chat": 128000, + "deepseek-coder": 128000, + "deepseek-ai/DeepSeek-Coder-V2-Instruct": 32000, # siliconflow } @@ -207,6 +243,10 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4-1106-vision-preview", + "gpt-4o-2024-05-13", + "gpt-4o", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", }: tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> tokens_per_name = 1 @@ -280,4 +320,4 @@ def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> """ if model not in TOKEN_MAX: return default - return TOKEN_MAX[model] - count_message_tokens(messages) - 1 + return TOKEN_MAX[model] - count_message_tokens(messages, model) - 1 diff --git a/metagpt/utils/tree.py b/metagpt/utils/tree.py index bd79222901..2fcbb50220 100644 --- a/metagpt/utils/tree.py +++ b/metagpt/utils/tree.py @@ -27,14 +27,15 @@ """ from __future__ import annotations -import subprocess from pathlib import Path from typing import Callable, Dict, List from gitignore_parser import parse_gitignore +from metagpt.tools.libs.shell import shell_execute -def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str: + +async def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str: """ Recursively traverses the directory structure and prints it out in a tree-like format. @@ -80,7 +81,7 @@ def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = Fal """ root = Path(root).resolve() if run_command: - return _execute_tree(root, gitignore) + return await _execute_tree(root, gitignore) git_ignore_rules = parse_gitignore(gitignore) if gitignore else None dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)} @@ -129,12 +130,7 @@ def _add_line(rows: List[str]) -> List[str]: return rows -def _execute_tree(root: Path, gitignore: str | Path) -> str: +async def _execute_tree(root: Path, gitignore: str | Path) -> str: args = ["--gitfile", str(gitignore)] if gitignore else [] - try: - result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True) - if result.returncode != 0: - raise ValueError(f"tree exits with code {result.returncode}") - return result.stdout - except subprocess.CalledProcessError as e: - raise e + stdout, _, _ = await shell_execute(["tree"] + args + [str(root)]) + return stdout diff --git a/requirements.txt b/requirements.txt index d150d61f3e..ed8965b462 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ PyYAML==6.0.1 # sentence_transformers==2.2.2 setuptools==65.6.3 tenacity==8.2.3 -tiktoken==0.6.0 +tiktoken==0.7.0 tqdm==4.66.2 #unstructured[local-inference] # selenium>4 @@ -69,5 +69,10 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation -gymnasium==0.29.1 \ No newline at end of file +gymnasium==0.29.1 +pylint~=3.0.3 +pygithub~=2.3 +htmlmin +fsspec +grep-ast~=0.3.3 # linter +tree-sitter~=0.21.3 # linter \ No newline at end of file diff --git a/setup.py b/setup.py index 382e13a47d..c8e705bfb3 100644 --- a/setup.py +++ b/setup.py @@ -32,12 +32,15 @@ def run(self): "llama-index-core==0.10.15", "llama-index-embeddings-azure-openai==0.1.6", "llama-index-embeddings-openai==0.1.5", + "llama-index-embeddings-gemini==0.1.6", + "llama-index-embeddings-ollama==0.1.2", "llama-index-llms-azure-openai==0.1.4", "llama-index-readers-file==0.1.4", "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", "llama-index-vector-stores-chroma==0.1.6", + "docx2txt==0.8", ], } @@ -65,28 +68,29 @@ def run(self): extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pre-commit~=3.6.0"],) -setup( - name="metagpt", - version="0.8.0", - description="The Multi-Agent Framework", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/geekan/MetaGPT", - author="Alexander Wu", - author_email="alexanderwu@deepwisdom.ai", - license="MIT", - keywords="metagpt multi-agent multi-role programming gpt llm metaprogramming", - packages=find_packages(exclude=["contrib", "docs", "examples", "tests*"]), - python_requires=">=3.9", - install_requires=requirements, - extras_require=extras_require, - cmdclass={ - "install_mermaid": InstallMermaidCLI, - }, - entry_points={ - "console_scripts": [ - "metagpt=metagpt.software_company:app", - ], - }, - include_package_data=True, -) +if __name__ == "__main__": + setup( + name="metagpt", + version="0.8.0", + description="The Multi-Agent Framework", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/geekan/MetaGPT", + author="Alexander Wu", + author_email="alexanderwu@deepwisdom.ai", + license="MIT", + keywords="metagpt multi-agent multi-role programming gpt llm metaprogramming", + packages=find_packages(exclude=["contrib", "docs", "examples", "tests*"]), + python_requires=">=3.9", + install_requires=requirements, + extras_require=extras_require, + cmdclass={ + "install_mermaid": InstallMermaidCLI, + }, + entry_points={ + "console_scripts": [ + "metagpt=metagpt.software_company:app", + ], + }, + include_package_data=True, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 8603c752aa..1f6661f7c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ import os import re import uuid -from pathlib import Path from typing import Callable import aiohttp.web @@ -23,7 +22,6 @@ from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo from tests.mock.mock_aiohttp import MockAioResponse from tests.mock.mock_curl_cffi import MockCurlCffiResponse from tests.mock.mock_httplib2 import MockHttplib2Response @@ -149,13 +147,14 @@ def emit(self, record): @pytest.fixture(scope="function") def context(request): ctx = MetagptContext() - ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") - ctx.repo = ProjectRepo(ctx.git_repo) + repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") + ctx.config.project_path = str(repo.workdir) # Destroy git repo at the end of the test session. def fin(): - if ctx.git_repo: - ctx.git_repo.delete_repository() + if ctx.config.project_path: + git_repo = GitRepository(ctx.config.project_path) + git_repo.delete_repository() # Register the function for destroying the environment. request.addfinalizer(fin) @@ -247,14 +246,16 @@ def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, sear @pytest.fixture def http_server(): - async def handler(request): - return aiohttp.web.Response( - text=""" - MetaGPT

MetaGPT

""", - content_type="text/html", - ) - - async def start(): + async def start(handler=None): + if handler is None: + + async def handler(request): + return aiohttp.web.Response( + text=""" + MetaGPT

MetaGPT

""", + content_type="text/html", + ) + server = aiohttp.web.Server(handler) runner = aiohttp.web.ServerRunner(server) await runner.setup() @@ -277,6 +278,6 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache): @pytest.fixture def git_dir(): """Fixture to get the unittest directory.""" - git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}" + git_dir = DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}" git_dir.mkdir(parents=True, exist_ok=True) return git_dir diff --git a/tests/data/movie/trailer.mp4 b/tests/data/movie/trailer.mp4 new file mode 100644 index 0000000000..c9620136c8 Binary files /dev/null and b/tests/data/movie/trailer.mp4 differ diff --git a/tests/data/tools/test_script_for_file_manager.py b/tests/data/tools/test_script_for_file_manager.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/ui/1b.png.html b/tests/data/ui/1b.png.html new file mode 100644 index 0000000000..33e9fa442b --- /dev/null +++ b/tests/data/ui/1b.png.html @@ -0,0 +1,164 @@ + + + + + + 法务小超人 + + + +
+
+ + +
+
+

法律意见查询

+
+ + +
+
+ 已收录法律意见8394篇 +
+
+ + +
+ + \ No newline at end of file diff --git a/tests/metagpt/actions/di/test_ask_review.py b/tests/metagpt/actions/di/test_ask_review.py index 6bb1accf54..d49ad176a4 100644 --- a/tests/metagpt/actions/di/test_ask_review.py +++ b/tests/metagpt/actions/di/test_ask_review.py @@ -6,7 +6,7 @@ @pytest.mark.asyncio async def test_ask_review(mocker): mock_review_input = "confirm" - mocker.patch("builtins.input", return_value=mock_review_input) + mocker.patch("metagpt.actions.di.ask_review.get_human_input", return_value=mock_review_input) rsp, confirmed = await AskReview().run() assert rsp == mock_review_input assert confirmed diff --git a/tests/metagpt/actions/requirement_analysis/__init__.py b/tests/metagpt/actions/requirement_analysis/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metagpt/actions/requirement_analysis/requirement/__init__.py b/tests/metagpt/actions/requirement_analysis/requirement/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metagpt/actions/requirement_analysis/requirement/test_pic2txt.py b/tests/metagpt/actions/requirement_analysis/requirement/test_pic2txt.py new file mode 100644 index 0000000000..e5875b6aca --- /dev/null +++ b/tests/metagpt/actions/requirement_analysis/requirement/test_pic2txt.py @@ -0,0 +1,26 @@ +import pytest + +from metagpt.actions.requirement_analysis.requirement.pic2txt import Pic2Txt +from metagpt.const import TEST_DATA_PATH +from metagpt.utils.common import aread + + +@pytest.mark.asyncio +async def test_pic2txt(context): + images = [ + TEST_DATA_PATH / "requirements/pic/1.png", + TEST_DATA_PATH / "requirements/pic/2.png", + TEST_DATA_PATH / "requirements/pic/3.png", + ] + textual_user_requirements = await aread(filename=TEST_DATA_PATH / "requirements/1.original_requirement.txt") + + action = Pic2Txt(context=context) + rsp = await action.run( + image_paths=images, + textual_user_requirement=textual_user_requirements, + ) + assert rsp + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 989e2249cb..23779c9846 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -91,10 +91,10 @@ async def test_action_node_two_layer(): assert node_b in root.children.values() # FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST. - answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM()) + answer1 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM()) assert "579" in answer1.content - answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM()) + answer2 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM()) assert "579" in answer2.content @@ -112,7 +112,7 @@ async def test_action_node_review(): with pytest.raises(RuntimeError): _ = await node_a.review() - _ = await node_a.fill(context=None, llm=LLM()) + _ = await node_a.fill(req=None, llm=LLM()) setattr(node_a.instruct_content, key, "game snake") # wrong content to review review_comments = await node_a.review(review_mode=ReviewMode.AUTO) @@ -126,7 +126,7 @@ async def test_action_node_review(): with pytest.raises(RuntimeError): _ = await node.review() - _ = await node.fill(context=None, llm=LLM()) + _ = await node.fill(req=None, llm=LLM()) review_comments = await node.review(review_mode=ReviewMode.AUTO) assert len(review_comments) == 1 @@ -151,7 +151,7 @@ async def test_action_node_revise(): with pytest.raises(RuntimeError): _ = await node_a.review() - _ = await node_a.fill(context=None, llm=LLM()) + _ = await node_a.fill(req=None, llm=LLM()) setattr(node_a.instruct_content, key, "game snake") # wrong content to revise revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO) assert len(revise_contents) == 1 @@ -164,7 +164,7 @@ async def test_action_node_revise(): with pytest.raises(RuntimeError): _ = await node.revise() - _ = await node.fill(context=None, llm=LLM()) + _ = await node.fill(req=None, llm=LLM()) setattr(node.instruct_content, key, "game snake") revise_contents = await node.revise(revise_mode=ReviseMode.AUTO) assert len(revise_contents) == 1 @@ -257,7 +257,7 @@ def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict: invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png") img_base64 = encode_image(invoice_path) mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs) - node = await invoice.fill(context="", llm=LLM(), images=[img_base64]) + node = await invoice.fill(req="", llm=LLM(), images=[img_base64]) assert node.instruct_content.invoice @@ -303,5 +303,4 @@ def test_action_node_from_pydantic_and_print_everything(): if __name__ == "__main__": - test_create_model_class() - test_create_model_class_with_mapping() + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index 9924a2e847..314ff54e7a 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -6,37 +6,102 @@ @File : test_design_api.py @Modifiled By: mashenquan, 2023-12-6. According to RFC 135 """ +from pathlib import Path + import pytest from metagpt.actions.design_api import WriteDesign -from metagpt.llm import LLM +from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON @pytest.mark.asyncio -async def test_design_api(context): - inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE - for prd in inputs: - await context.repo.docs.prd.save(filename="new_prd.txt", content=prd) +async def test_design(context): + # Mock new design env + prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。" + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + filename = "prd.txt" + repo = ProjectRepo(context.kwargs.project_path) + await repo.docs.prd.save(filename=filename, content=prd) + kvs = { + "project_path": str(context.kwargs.project_path), + "changed_prd_filenames": [str(repo.docs.prd.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput") - design_api = WriteDesign(context=context) + design_api = WriteDesign(context=context) + result = await design_api.run([Message(content=prd, instruct_content=instruct_content)]) + logger.info(result) + assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files - result = await design_api.run(Message(content=prd, instruct_content=None)) - logger.info(result) + # Mock incremental design env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE) - assert result + result = await design_api.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "legacy_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) @pytest.mark.asyncio -async def test_refined_design_api(context): - await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE) - - design_api = WriteDesign(context=context, llm=LLM()) +async def test_design_api(context, user_requirement, prd_filename, legacy_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename + ) + assert isinstance(result, str) + assert result + assert str(DEFAULT_WORKSPACE_ROOT) in result - result = await design_api.run(Message(content="", instruct_content=None)) - logger.info(result) +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "legacy_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) +@pytest.mark.asyncio +async def test_design_api_dir(context, user_requirement, prd_filename, legacy_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, + prd_filename=prd_filename, + legacy_design_filename=legacy_design_filename, + output_pathname=str(Path(context.config.project_path) / "1.txt"), + ) + assert isinstance(result, str) assert result + assert str(context.config.project_path) in result + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_design_api_an.py b/tests/metagpt/actions/test_design_api_an.py index 3d11f200d0..4ed3cb3621 100644 --- a/tests/metagpt/actions/test_design_api_an.py +++ b/tests/metagpt/actions/test_design_api_an.py @@ -38,7 +38,7 @@ async def test_write_design_an(mocker): mocker.patch("metagpt.actions.design_api_an.REFINED_DESIGN_NODE.fill", return_value=root) prompt = NEW_REQ_TEMPLATE.format(old_design=DESIGN_SAMPLE, context=dict_to_markdown(REFINED_PRD_JSON)) - node = await REFINED_DESIGN_NODE.fill(prompt, llm) + node = await REFINED_DESIGN_NODE.fill(req=prompt, llm=llm) assert "Refined Implementation Approach" in node.instruct_content.model_dump() assert "Refined File list" in node.instruct_content.model_dump() diff --git a/tests/metagpt/actions/test_extract_readme.py b/tests/metagpt/actions/test_extract_readme.py new file mode 100644 index 0000000000..a3428d4d5e --- /dev/null +++ b/tests/metagpt/actions/test_extract_readme.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from pathlib import Path + +import pytest + +from metagpt.actions.extract_readme import ExtractReadMe +from metagpt.llm import LLM + + +@pytest.mark.asyncio +async def test_learn_readme(context): + action = ExtractReadMe( + name="RedBean", + i_context=str(Path(__file__).parent.parent.parent.parent), + llm=LLM(), + context=context, + ) + await action.run() + rows = await action.graph_db.select() + assert rows + assert context.repo.docs.graph_repo.changed_files + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_import_repo.py b/tests/metagpt/actions/test_import_repo.py new file mode 100644 index 0000000000..d498be0395 --- /dev/null +++ b/tests/metagpt/actions/test_import_repo.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import pytest + +from metagpt.actions.import_repo import ImportRepo +from metagpt.context import Context +from metagpt.utils.common import list_files + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "repo_path", + [ + "https://github.com/spec-first/connexion.git", + # "https://github.com/geekan/MetaGPT.git" + ], +) +@pytest.mark.skip +async def test_import_repo(repo_path): + context = Context() + action = ImportRepo(repo_path=repo_path, context=context) + await action.run() + assert context.repo + prd = list_files(context.repo.docs.prd.workdir) + assert prd + design = list_files(context.repo.docs.system_design.workdir) + assert design + assert prd[0].stem == design[0].stem + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_project_management.py b/tests/metagpt/actions/test_project_management.py index 5d0d11efb0..26699dea7d 100644 --- a/tests/metagpt/actions/test_project_management.py +++ b/tests/metagpt/actions/test_project_management.py @@ -5,13 +5,15 @@ @Author : alexanderwu @File : test_project_management.py """ +import json import pytest from metagpt.actions.project_management import WriteTasks -from metagpt.llm import LLM +from metagpt.const import METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import ( REFINED_DESIGN_JSON, REFINED_PRD_JSON, @@ -22,29 +24,46 @@ @pytest.mark.asyncio async def test_task(context): - await context.repo.docs.prd.save("1.txt", content=str(PRD)) - await context.repo.docs.system_design.save("1.txt", content=str(DESIGN)) - logger.info(context.git_repo) + # Mock write tasks env + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + repo = ProjectRepo(context.kwargs.project_path) + filename = "1.txt" + await repo.docs.prd.save(filename=filename, content=str(PRD)) + await repo.docs.system_design.save(filename=filename, content=str(DESIGN)) + kvs = { + "project_path": context.kwargs.project_path, + "changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput") action = WriteTasks(context=context) - - result = await action.run(Message(content="", instruct_content=None)) + result = await action.run([Message(content="", instruct_content=instruct_content)]) logger.info(result) - assert result + assert result.instruct_content.changed_task_filenames + # Mock incremental env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON)) + await repo.docs.task.save(filename=filename, content=TASK_SAMPLE) -@pytest.mark.asyncio -async def test_refined_task(context): - await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON)) - await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE) + result = await action.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert result.instruct_content.changed_task_filenames - logger.info(context.git_repo) - action = WriteTasks(context=context, llm=LLM()) +@pytest.mark.asyncio +async def test_task_api(context): + action = WriteTasks() + result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json")) + assert result + assert result.content + m = json.loads(result.content) + assert m - result = await action.run(Message(content="", instruct_content=None)) - logger.info(result) - assert result +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_project_management_an.py b/tests/metagpt/actions/test_project_management_an.py index 5a65e50c93..6d41109c93 100644 --- a/tests/metagpt/actions/test_project_management_an.py +++ b/tests/metagpt/actions/test_project_management_an.py @@ -42,7 +42,7 @@ async def test_project_management_an(mocker): root.instruct_content.model_dump = mock_task_json mocker.patch("metagpt.actions.project_management_an.PM_NODE.fill", return_value=root) - node = await PM_NODE.fill(dict_to_markdown(REFINED_DESIGN_JSON), llm) + node = await PM_NODE.fill(req=dict_to_markdown(REFINED_DESIGN_JSON), llm=llm) assert "Logic Analysis" in node.instruct_content.model_dump() assert "Task list" in node.instruct_content.model_dump() @@ -59,7 +59,7 @@ async def test_project_management_an_inc(mocker): mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root) prompt = NEW_REQ_TEMPLATE.format(old_task=TASK_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON)) - node = await REFINED_PM_NODE.fill(prompt, llm) + node = await REFINED_PM_NODE.fill(req=prompt, llm=llm) assert "Refined Logic Analysis" in node.instruct_content.model_dump() assert "Refined Task list" in node.instruct_content.model_dump() diff --git a/tests/metagpt/actions/test_rebuild_sequence_view.py b/tests/metagpt/actions/test_rebuild_sequence_view.py index 9be3e8a995..e2827c3342 100644 --- a/tests/metagpt/actions/test_rebuild_sequence_view.py +++ b/tests/metagpt/actions/test_rebuild_sequence_view.py @@ -61,7 +61,7 @@ async def test_rebuild(context, mocker): ], ) def test_get_full_filename(root, pathname, want): - res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname) + res = RebuildSequenceView.get_full_filename(root=root, pathname=pathname) assert res == want diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 1709e1f5b9..1c17720312 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -26,12 +26,7 @@ def setup_inc_workdir(context, inc: bool = False): """setup incremental workdir for testing""" - context.src_workspace = context.git_repo.workdir / "src" - if inc: - context.config.inc = inc - context.repo.old_workspace = context.repo.git_repo.workdir / "old" - context.config.project_path = "old" - + context.config.inc = inc return context @@ -110,7 +105,7 @@ async def test_write_refined_code(context, git_dir): # old_workspace contains the legacy code await context.repo.with_src_path(context.repo.old_workspace).srcs.save( - filename="game.py", content=CodeParser.parse_code(block="", text=REFINED_CODE_INPUT_SAMPLE) + filename="game.py", content=CodeParser.parse_code(text=REFINED_CODE_INPUT_SAMPLE) ) ccontext = CodingContext( diff --git a/tests/metagpt/actions/test_write_code_plan_and_change_an.py b/tests/metagpt/actions/test_write_code_plan_and_change_an.py index 5c262b4b71..5bc8604693 100644 --- a/tests/metagpt/actions/test_write_code_plan_and_change_an.py +++ b/tests/metagpt/actions/test_write_code_plan_and_change_an.py @@ -45,7 +45,7 @@ async def test_write_code_plan_and_change_an(mocker, context, git_dir): await context.repo.docs.task.save(filename="2.json", content=json.dumps(REFINED_TASK_JSON)) await context.repo.with_src_path(context.repo.old_workspace).srcs.save( - filename="game.py", content=CodeParser.parse_code(block="", text=REFINED_CODE_INPUT_SAMPLE) + filename="game.py", content=CodeParser.parse_code(text=REFINED_CODE_INPUT_SAMPLE) ) root = ActionNode.from_children( diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 43aa336b75..fcfa81931f 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -6,25 +6,26 @@ @File : test_write_prd.py @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`. """ +import uuid +from pathlib import Path import pytest from metagpt.actions import UserRequirement, WritePRD -from metagpt.const import REQUIREMENT_FILENAME +from metagpt.const import DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import RoleReactMode from metagpt.schema import Message from metagpt.utils.common import any_to_str -from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE -from tests.metagpt.actions.test_write_code import setup_inc_workdir +from metagpt.utils.project_repo import ProjectRepo +from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE @pytest.mark.asyncio async def test_write_prd(new_filename, context): product_manager = ProductManager(context=context) requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结" - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) product_manager.rc.react_mode = RoleReactMode.BY_ORDER prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement)) assert prd.cause_by == any_to_str(WritePRD) @@ -34,38 +35,39 @@ async def test_write_prd(new_filename, context): # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert product_manager.context.repo.docs.prd.changed_files + repo = ProjectRepo(context.kwargs.project_path) + assert repo.docs.prd.changed_files + repo.git_repo.archive() - -@pytest.mark.asyncio -async def test_write_prd_inc(new_filename, context, git_dir): - context = setup_inc_workdir(context, inc=True) - await context.repo.docs.prd.save("1.txt", PRD_SAMPLE) - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) + # Mock incremental requirement + context.config.inc = True + context.config.project_path = context.kwargs.project_path + repo = ProjectRepo(context.config.project_path) + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) action = WritePRD(context=context) - prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)) + prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)]) logger.info(NEW_REQUIREMENT_SAMPLE) logger.info(prd) # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert "Refined Requirements" in prd.content + assert repo.git_repo.changed_files @pytest.mark.asyncio async def test_fix_debug(new_filename, context, git_dir): - context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name + # Mock legacy project + context.kwargs.project_path = str(git_dir) + repo = ProjectRepo(context.kwargs.project_path) + repo.with_src_path(git_dir.name) + await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()') + requirements = "ValueError: undefined variable `st`." + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) - await context.repo.with_src_path(context.src_workspace).srcs.save( - filename="main.py", content='if __name__ == "__main__":\nmain()' - ) - requirements = "Please fix the bug in the code." - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) action = WritePRD(context=context) - - prd = await action.run(Message(content=requirements, instruct_content=None)) + prd = await action.run([Message(content=requirements, instruct_content=None)]) logger.info(prd) # Assert the prd is not None or empty @@ -73,5 +75,39 @@ async def test_fix_debug(new_filename, context, git_dir): assert prd.content != "" +@pytest.mark.asyncio +async def test_write_prd_api(context): + action = WritePRD() + result = await action.run(user_requirement="write a snake game.") + assert isinstance(result, str) + assert result + assert str(DEFAULT_WORKSPACE_ROOT) in result + + result = await action.run( + user_requirement="write a snake game.", + output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"), + ) + assert isinstance(result, str) + assert result + assert str(context.config.project_path) in result + + ix = result.find(":") + legacy_prd_filename = result[ix + 1 :].replace('"', "").strip() + + result = await action.run(user_requirement="Add moving enemy.", legacy_prd_filename=legacy_prd_filename) + assert isinstance(result, str) + assert result + assert str(DEFAULT_WORKSPACE_ROOT) in result + + result = await action.run( + user_requirement="Add moving enemy.", + output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"), + legacy_prd_filename=legacy_prd_filename, + ) + assert isinstance(result, str) + assert result + assert str(context.config.project_path) in result + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_write_prd_an.py b/tests/metagpt/actions/test_write_prd_an.py index 378ce42c31..b6e92d3d66 100644 --- a/tests/metagpt/actions/test_write_prd_an.py +++ b/tests/metagpt/actions/test_write_prd_an.py @@ -39,7 +39,7 @@ async def test_write_prd_an(mocker): requirements=NEW_REQUIREMENT_SAMPLE, old_prd=PRD_SAMPLE, ) - node = await REFINED_PRD_NODE.fill(prompt, llm) + node = await REFINED_PRD_NODE.fill(req=prompt, llm=llm) assert "Refined Requirements" in node.instruct_content.model_dump() assert "Refined Product Goals" in node.instruct_content.model_dump() diff --git a/tests/metagpt/environment/mgx_env/run_mgx_env.py b/tests/metagpt/environment/mgx_env/run_mgx_env.py new file mode 100644 index 0000000000..dd9e7c3e53 --- /dev/null +++ b/tests/metagpt/environment/mgx_env/run_mgx_env.py @@ -0,0 +1,170 @@ +import asyncio +import os +import re +import threading +import time + +from metagpt.environment.mgx.mgx_env import MGXEnv +from metagpt.roles import Architect, Engineer, ProductManager, ProjectManager +from metagpt.roles.di.data_analyst import DataAnalyst +from metagpt.roles.di.engineer2 import Engineer2 +from metagpt.roles.di.team_leader import TeamLeader +from metagpt.schema import Message + + +async def main(requirement="", enable_human_input=False, use_fixed_sop=False, allow_idle_time=30): + if use_fixed_sop: + engineer = Engineer(n_borg=5, use_code_review=False) + else: + engineer = Engineer2() + + env = MGXEnv() + env.add_roles( + [ + TeamLeader(), + ProductManager(use_fixed_sop=use_fixed_sop), + Architect(use_fixed_sop=use_fixed_sop), + ProjectManager(use_fixed_sop=use_fixed_sop), + engineer, + # QaEngineer(), + DataAnalyst(), + ] + ) + + if enable_human_input: + # simulate human sending messages in chatbox + stop_event = threading.Event() + human_input_thread = send_human_input(env, stop_event) + + if requirement: + env.publish_message(Message(content=requirement)) + # user_defined_recipient = "Alex" + # env.publish_message(Message(content=requirement, send_to={user_defined_recipient}), user_defined_recipient=user_defined_recipient) + + allow_idle_time = allow_idle_time if enable_human_input else 1 + start_time = time.time() + while time.time() - start_time < allow_idle_time: + if not env.is_idle: + await env.run() + start_time = time.time() # reset start time + + if enable_human_input: + print("No more human input, terminating, press ENTER for a full termination.") + stop_event.set() + human_input_thread.join() + + +def send_human_input(env, stop_event): + """ + Simulate sending message in chatbox + Note in local environment, the message is consumed only after current round of env.run is finished + """ + + def send_messages(): + while not stop_event.is_set(): + message = input("Enter a message any time: ") + user_defined_recipient = re.search(r"@(\w+)", message) + if user_defined_recipient: + recipient_name = user_defined_recipient.group(1) + print(f"{recipient_name} will receive the message") + env.publish_message( + Message(content=message, send_to={recipient_name}), user_defined_recipient=recipient_name + ) + else: + env.publish_message(Message(content=message)) + + # Start a thread for sending messages + send_thread = threading.Thread(target=send_messages, args=()) + send_thread.start() + return send_thread + + +GAME_REQ = "create a 2048 game" +GAME_REQ_ZH = "写一个贪吃蛇游戏" +WEB_GAME_REQ = "Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard." +WEB_GAME_REQ_DEPLOY = "Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard. When finished, deploy the game to public at port 8090." +TODO_APP_REQ = "Create a website widget for TODO list management. Users should be able to add, mark as complete, and delete tasks. Include features like prioritization, due dates, and categories. Make it visually appealing, responsive, and user-friendly. Use HTML, CSS, and JavaScript. Consider additional features like notifications or task export. Keep it simple and enjoyable for users.dont use vue or react.dont use third party library, use localstorage to save data." +FLAPPY_BIRD_REQ = "write a flappy bird game in pygame, code only" +SIMPLE_DATA_REQ = "load sklearn iris dataset and print a statistic summary" +WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, and train a model to predict wine class (20% as validation), and show validation accuracy." +PAPER_LIST_REQ = """ +Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, +and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables* +""" +ECOMMERCE_REQ = """ +Get products data from website https://scrapeme.live/shop/ and save it as a csv file. +**Notice: Firstly parse the web page encoding and the text HTML structure; +The first page product name, price, product URL, and image URL must be saved in the csv;** +""" +NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**; +下面是一个大致流程, 你会根据每一步的运行结果对当前计划中的任务做出适当调整: +1. 爬取并本地保存html结构; +2. 直接打印第7个*`快讯`*关键词后2000个字符的html内容, 作为*快讯的html内容示例*; +3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间; +4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。 +5. 将全部结果存在本地csv中 +**Notice: view the page element before writing scraping code** +""" +data_path = "data/titanic" +train_path = f"{data_path}/split_train.csv" +eval_path = f"{data_path}/split_eval.csv" +TITANIC_REQ = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{train_path}', eval data path: '{eval_path}'." +CALIFORNIA_HOUSING_REQ = """ +Analyze the 'Canifornia-housing-dataset' using https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_california_housing.html#sklearn.datasets.fetch_california_housing to predict the median house value. you need to perfrom data preprocessing, feature engineering and finally modeling to predict the target. Use machine learning techniques such as linear regression (including ridge regression and lasso regression), random forest, XGBoost. You also need to report the MSE on the test dataset +""" +STOCK_REQ = """Import NVIDIA Corporation (NVDA) stock price data from Yahoo Finance, focusing on historical closing prices from the past 5 years. +Summary statistics (mean, median, standard deviation, etc.) to understand the central tendency and dispersion of closingprices. Analyze the data for any noticeable trends, patterns, or anomalies over time, potentially using rolling averages or percentage changes. +Create a pot to visualize all the data analysis. Reserve 20% of the dataset for validaation. Train a predictive model on the training set. Report the modeel's validation accuracy, and visualize the result of prediction result. +""" +FIX_ISSUE1 = """ +Write a fix for this issue: https://github.com/langchain-ai/langchain/issues/20453, +you can fix it on this repo https://github.com/garylin2099/langchain, +checkout a branch named test-fix, commit your changes, push, and create a PR to the master branch of https://github.com/iorisa/langchain +""" +FIX_ISSUE2 = """ +Write a fix for this issue https://github.com/geekan/MetaGPT/issues/1275. +You can fix it on the v0.8-release branch of this repo https://github.com/garylin2099/MetaGPT, +during fixing, checkout a branch named test-fix-1275, commit your changes, push, and create a PR to the v0.8-release branch of https://github.com/garylin2099/MetaGPT +""" +FIX_ISSUE3 = """ +Write a fix for this issue https://github.com/geekan/MetaGPT/issues/1262. +You can fix it on this repo https://github.com/garylin2099/MetaGPT, +during fixing, checkout a branch named test-fix-1262, commit your changes, push, and create a PR to https://github.com/garylin2099/MetaGPT +""" +FIX_ISSUE_SIMPLE = """ +Write a fix for this issue: https://github.com/mannaandpoem/simple_calculator/issues/1, +you can fix it on this repo https://github.com/garylin2099/simple_calculator, +checkout a branch named test, commit your changes, push, and create a PR to the master branch of original repo. +""" +PUSH_PR_REQ = """ +clone https://github.com/garylin2099/simple_calculator, checkout a new branch named test-branch, add an empty file test_file.py to the repo. +Commit your changes and push, finally, create a PR to the master branch of https://github.com/mannaandpoem/simple_calculator. +""" +IMAGE2CODE_REQ = "Please write a frontend web page similar to this image /Users/gary/Files/temp/workspace/temp_img.png, I want the same title and color. code only" +DOC_QA_REQ1 = "Tell me what this paper is about /Users/gary/Files/temp/workspace/2308.09687.pdf" +DOC_QA_REQ2 = "Summarize this doc /Users/gary/Files/temp/workspace/2401.14295.pdf" +DOC_QA_REQ3 = "请总结/Users/gary/Files/temp/workspace/2309.04658.pdf里的关键点" +DOC_QA_REQ4 = "这份报表/Users/gary/Files/temp/workspace/9929550.md中,营业收入TOP3产品各自的收入占比是多少" + +TL_CHAT1 = """Summarize the paper for me""" # expecting clarification +TL_CHAT2 = """Solve the issue at this link""" # expecting clarification +TL_CHAT3 = """Who is the first man landing on Moon""" # expecting answering directly +TL_CHAT4 = """Find all zeros in the indicated finite field of the given polynomial with coefficients in that field. x^5 + 3x^3 + x^2 + 2x in Z_5""" # expecting answering directly +TL_CHAT5 = """Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.""" # expecting answering directly +TL_CHAT6 = """True or False? Statement 1 | A ring homomorphism is one to one if and only if the kernel is {{0}},. Statement 2 | Q is an ideal in R""" # expecting answering directly +TL_CHAT7 = """Jean has 30 lollipops. Jean eats 2 of the lollipops. With the remaining lollipops, Jean wants to package 2 lollipops in one bag. How many bags can Jean fill?""" # expecting answering directly +TL_CHAT9 = """What's your name?""" +TL_CHAT10 = "Hi" +TL_CHAT11 = "Tell me about your team" +TL_CHAT12 = "What can you do" +CODING_REQ1 = "写一个java的hello world程序" +CODING_REQ2 = "python里的装饰器是什么" +CODING_REQ3 = "python里的装饰器是怎么用的,给我个例子" + + +if __name__ == "__main__": + # NOTE: Add access_token to test github issue fixing + os.environ["access_token"] = "ghp_xxx" + # NOTE: Change the requirement to the one you want to test + # Set enable_human_input to True if you want to simulate sending messages in chatbox + asyncio.run(main(requirement=GAME_REQ, enable_human_input=False, use_fixed_sop=False)) diff --git a/tests/metagpt/environment/test_base_env.py b/tests/metagpt/environment/test_base_env.py index 404f1c2064..ecdc4e1329 100644 --- a/tests/metagpt/environment/test_base_env.py +++ b/tests/metagpt/environment/test_base_env.py @@ -6,6 +6,7 @@ import pytest +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.environment.api.env_api import EnvAPIAbstract from metagpt.environment.base_env import ( Environment, @@ -14,7 +15,6 @@ mark_as_readable, mark_as_writeable, ) -from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams class ForTestEnv(Environment): diff --git a/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py new file mode 100644 index 0000000000..0a160fb42f --- /dev/null +++ b/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py @@ -0,0 +1,32 @@ +import pytest + +from metagpt.exp_pool.context_builders.base import ( + EXP_TEMPLATE, + BaseContextBuilder, + Experience, +) +from metagpt.exp_pool.schema import Metric, Score + + +class TestBaseContextBuilder: + class ConcreteContextBuilder(BaseContextBuilder): + async def build(self, *args, **kwargs): + pass + + @pytest.fixture + def context_builder(self): + return self.ConcreteContextBuilder() + + def test_format_exps(self, context_builder): + exp1 = Experience(req="req1", resp="resp1", metric=Metric(score=Score(val=8))) + exp2 = Experience(req="req2", resp="resp2", metric=Metric(score=Score(val=9))) + context_builder.exps = [exp1, exp2] + + result = context_builder.format_exps() + expected = "\n".join( + [ + f"1. {EXP_TEMPLATE.format(req='req1', resp='resp1', score=8)}", + f"2. {EXP_TEMPLATE.format(req='req2', resp='resp2', score=9)}", + ] + ) + assert result == expected diff --git a/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py new file mode 100644 index 0000000000..82a3622a5e --- /dev/null +++ b/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py @@ -0,0 +1,50 @@ +import pytest + +from metagpt.const import EXPERIENCE_MASK +from metagpt.exp_pool.context_builders.base import BaseContextBuilder +from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder + + +class TestRoleZeroContextBuilder: + @pytest.fixture + def context_builder(self): + return RoleZeroContextBuilder() + + @pytest.mark.asyncio + async def test_build_empty_req(self, context_builder): + result = await context_builder.build(req=[]) + assert result == [] + + @pytest.mark.asyncio + async def test_build_no_experiences(self, context_builder, mocker): + mocker.patch.object(BaseContextBuilder, "format_exps", return_value="") + req = [{"content": "Original content"}] + result = await context_builder.build(req=req) + assert result == req + + @pytest.mark.asyncio + async def test_build_with_experiences(self, context_builder, mocker): + mocker.patch.object(BaseContextBuilder, "format_exps", return_value="Formatted experiences") + mocker.patch.object(RoleZeroContextBuilder, "replace_example_content", return_value="Updated content") + req = [{"content": "Original content 1"}] + result = await context_builder.build(req=req) + assert result == [{"content": "Updated content"}] + + def test_replace_example_content(self, context_builder, mocker): + mocker.patch.object(RoleZeroContextBuilder, "fill_experience", return_value="Replaced content") + result = context_builder.replace_example_content("Original text", "New example content") + assert result == "Replaced content" + context_builder.fill_experience.assert_called_once_with("Original text", "New example content") + + def test_fill_experience(self): + text = f"Start\n# Past Experience\n{EXPERIENCE_MASK}\n\n# Instruction\nEnd" + new_content = "New content" + result = RoleZeroContextBuilder.fill_experience(text, new_content) + expected = "Start\n# Past Experience\nNew content\n\n# Instruction\nEnd" + assert result == expected + + def test_fill_experience_no_match(self): + text = "Start\nNo markers\nEnd" + new_content = "New content" + result = RoleZeroContextBuilder.fill_experience(text, new_content) + assert result == text diff --git a/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py new file mode 100644 index 0000000000..cf1a42f270 --- /dev/null +++ b/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py @@ -0,0 +1,47 @@ +import pytest + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder +from metagpt.exp_pool.context_builders.simple import ( + SIMPLE_CONTEXT_TEMPLATE, + SimpleContextBuilder, +) + + +class TestSimpleContextBuilder: + @pytest.fixture + def context_builder(self): + return SimpleContextBuilder() + + @pytest.mark.asyncio + async def test_build_with_experiences(self, mocker, context_builder: SimpleContextBuilder): + # Mock the format_exps method + mock_exps = "Mocked experiences" + mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps) + + req = "Test request" + result = await context_builder.build(req=req) + + expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=mock_exps) + assert result == expected + + @pytest.mark.asyncio + async def test_build_without_experiences(self, mocker, context_builder: SimpleContextBuilder): + # Mock the format_exps method to return an empty string + mocker.patch.object(BaseContextBuilder, "format_exps", return_value="") + + req = "Test request" + result = await context_builder.build(req=req) + + expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps="") + assert result == expected + + @pytest.mark.asyncio + async def test_build_without_req(self, mocker, context_builder: SimpleContextBuilder): + # Mock the format_exps method + mock_exps = "Mocked experiences" + mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps) + + result = await context_builder.build(req="") + + expected = SIMPLE_CONTEXT_TEMPLATE.format(req="", exps=mock_exps) + assert result == expected diff --git a/tests/metagpt/exp_pool/test_decorator.py b/tests/metagpt/exp_pool/test_decorator.py new file mode 100644 index 0000000000..9d104fca4c --- /dev/null +++ b/tests/metagpt/exp_pool/test_decorator.py @@ -0,0 +1,206 @@ +import asyncio + +import pytest + +from metagpt.config2 import Config +from metagpt.configs.exp_pool_config import ExperiencePoolConfig +from metagpt.exp_pool.context_builders import SimpleContextBuilder +from metagpt.exp_pool.decorator import ExpCacheHandler, exp_cache +from metagpt.exp_pool.manager import ExperienceManager +from metagpt.exp_pool.perfect_judges import SimplePerfectJudge +from metagpt.exp_pool.schema import Experience, QueryType, Score +from metagpt.exp_pool.scorers import SimpleScorer +from metagpt.rag.engines import SimpleEngine + + +class TestExpCacheHandler: + @pytest.fixture + def mock_func(self, mocker): + return mocker.AsyncMock() + + @pytest.fixture + def mock_exp_manager(self, mocker): + manager = mocker.MagicMock(spec=ExperienceManager) + manager.storage = mocker.MagicMock(spec=SimpleEngine) + manager.config = mocker.MagicMock(spec=Config) + manager.config.exp_pool = ExperiencePoolConfig() + manager.query_exps = mocker.AsyncMock() + manager.create_exp = mocker.MagicMock() + return manager + + @pytest.fixture + def mock_scorer(self, mocker): + scorer = mocker.MagicMock(spec=SimpleScorer) + scorer.evaluate = mocker.AsyncMock() + return scorer + + @pytest.fixture + def mock_perfect_judge(self, mocker): + return mocker.MagicMock(spec=SimplePerfectJudge) + + @pytest.fixture + def mock_context_builder(self, mocker): + return mocker.MagicMock(spec=SimpleContextBuilder) + + @pytest.fixture + def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer, mock_perfect_judge, mock_context_builder): + return ExpCacheHandler( + func=mock_func, + args=(), + kwargs={"req": "test_req"}, + exp_manager=mock_exp_manager, + exp_scorer=mock_scorer, + exp_perfect_judge=mock_perfect_judge, + context_builder=mock_context_builder, + ) + + @pytest.mark.asyncio + async def test_fetch_experiences(self, exp_cache_handler, mock_exp_manager): + mock_exp_manager.query_exps.return_value = [Experience(req="test_req", resp="test_resp")] + await exp_cache_handler.fetch_experiences() + mock_exp_manager.query_exps.assert_called_once_with( + "test_req", query_type=QueryType.SEMANTIC, tag=exp_cache_handler.tag + ) + assert len(exp_cache_handler._exps) == 1 + + @pytest.mark.asyncio + async def test_get_one_perfect_exp(self, exp_cache_handler, mock_perfect_judge): + exp = Experience(req="test_req", resp="perfect_resp") + exp_cache_handler._exps = [exp] + mock_perfect_judge.is_perfect_exp.return_value = True + result = await exp_cache_handler.get_one_perfect_exp() + assert result == "perfect_resp" + + @pytest.mark.asyncio + async def test_execute_function(self, exp_cache_handler, mock_func, mock_context_builder): + mock_context_builder.build.return_value = "built_context" + mock_func.return_value = "function_result" + await exp_cache_handler.execute_function() + mock_context_builder.build.assert_called_once() + mock_func.assert_called_once_with(req="built_context") + assert exp_cache_handler._raw_resp == "function_result" + assert exp_cache_handler._resp == "function_result" + + @pytest.mark.asyncio + async def test_process_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager): + exp_cache_handler._resp = "test_resp" + mock_scorer.evaluate.return_value = Score(val=8) + await exp_cache_handler.process_experience() + mock_scorer.evaluate.assert_called_once() + mock_exp_manager.create_exp.assert_called_once() + + @pytest.mark.asyncio + async def test_evaluate_experience(self, exp_cache_handler, mock_scorer): + exp_cache_handler._resp = "test_resp" + mock_scorer.evaluate.return_value = Score(val=9) + await exp_cache_handler.evaluate_experience() + assert exp_cache_handler._score.val == 9 + + def test_save_experience(self, exp_cache_handler, mock_exp_manager): + exp_cache_handler._req = "test_req" + exp_cache_handler._resp = "test_resp" + exp_cache_handler._score = Score(val=7) + exp_cache_handler.save_experience() + mock_exp_manager.create_exp.assert_called_once() + + def test_choose_wrapper_async(self, mocker): + async def async_func(): + pass + + wrapper = ExpCacheHandler.choose_wrapper(async_func, mocker.AsyncMock()) + assert asyncio.iscoroutinefunction(wrapper) + + def test_choose_wrapper_sync(self, mocker): + def sync_func(): + pass + + wrapper = ExpCacheHandler.choose_wrapper(sync_func, mocker.AsyncMock()) + assert not asyncio.iscoroutinefunction(wrapper) + + def test_validate_params(self): + with pytest.raises(ValueError): + ExpCacheHandler(func=lambda x: x, args=(), kwargs={}) + + def test_generate_tag(self): + class TestClass: + def test_method(self): + pass + + handler = ExpCacheHandler(func=TestClass().test_method, args=(TestClass(),), kwargs={"req": "test"}) + assert handler._generate_tag() == "TestClass.test_method" + + handler = ExpCacheHandler(func=lambda x: x, args=(), kwargs={"req": "test"}) + assert handler._generate_tag() == "" + + +class TestExpCache: + @pytest.fixture + def mock_exp_manager(self, mocker, mock_config): + manager = mocker.MagicMock(spec=ExperienceManager) + manager.storage = mocker.MagicMock(spec=SimpleEngine) + manager.config = mock_config + manager.query_exps = mocker.AsyncMock() + manager.create_exp = mocker.MagicMock() + return manager + + @pytest.fixture + def mock_scorer(self, mocker): + scorer = mocker.MagicMock(spec=SimpleScorer) + scorer.evaluate = mocker.AsyncMock(return_value=Score()) + return scorer + + @pytest.fixture + def mock_perfect_judge(self, mocker): + return mocker.MagicMock(spec=SimplePerfectJudge) + + @pytest.fixture + def mock_config(self, mocker): + config = Config.default().model_copy(deep=True) + default = mocker.patch("metagpt.config2.Config.default") + default.return_value = config + return config + + @pytest.mark.asyncio + async def test_exp_cache_disabled(self, mock_config, mock_exp_manager): + mock_config.exp_pool.enabled = False + + @exp_cache(manager=mock_exp_manager) + async def test_func(req): + return "result" + + result = await test_func(req="test") + assert result == "result" + mock_exp_manager.query_exps.assert_not_called() + + @pytest.mark.asyncio + async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer): + mock_config.exp_pool.enabled = True + mock_config.exp_pool.enable_read = True + mock_config.exp_pool.enable_write = True + mock_exp_manager.query_exps.return_value = [] + + @exp_cache(manager=mock_exp_manager, scorer=mock_scorer) + async def test_func(req): + return "computed_result" + + result = await test_func(req="test") + assert result == "computed_result" + mock_exp_manager.query_exps.assert_called() + mock_exp_manager.create_exp.assert_called() + + @pytest.mark.asyncio + async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge): + mock_config.exp_pool.enabled = True + mock_config.exp_pool.enable_read = True + perfect_exp = Experience(req="test", resp="perfect_result") + mock_exp_manager.query_exps.return_value = [perfect_exp] + mock_perfect_judge.is_perfect_exp.return_value = True + + @exp_cache(manager=mock_exp_manager, perfect_judge=mock_perfect_judge) + async def test_func(req): + return "should_not_be_called" + + result = await test_func(req="test") + assert result == "perfect_result" + mock_exp_manager.query_exps.assert_called_once() + mock_exp_manager.create_exp.assert_not_called() diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py new file mode 100644 index 0000000000..b0e4e8537d --- /dev/null +++ b/tests/metagpt/exp_pool/test_manager.py @@ -0,0 +1,144 @@ +import pytest + +from metagpt.config2 import Config +from metagpt.configs.exp_pool_config import ( + ExperiencePoolConfig, + ExperiencePoolRetrievalType, +) +from metagpt.configs.llm_config import LLMConfig +from metagpt.exp_pool.manager import Experience, ExperienceManager +from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, QueryType + + +class TestExperienceManager: + @pytest.fixture + def mock_config(self): + return Config( + llm=LLMConfig(), + exp_pool=ExperiencePoolConfig( + enable_write=True, enable_read=True, enabled=True, retrieval_type=ExperiencePoolRetrievalType.BM25 + ), + ) + + @pytest.fixture + def mock_storage(self, mocker): + engine = mocker.MagicMock() + engine.add_objs = mocker.MagicMock() + engine.aretrieve = mocker.AsyncMock(return_value=[]) + engine.count = mocker.MagicMock(return_value=10) + return engine + + @pytest.fixture + def exp_manager(self, mock_config, mock_storage): + manager = ExperienceManager(config=mock_config) + manager._storage = mock_storage + return manager + + def test_storage_property(self, exp_manager, mock_storage): + assert exp_manager.storage == mock_storage + + def test_storage_property_initialization(self, mocker, mock_config): + mocker.patch.object(ExperienceManager, "_resolve_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + assert manager._storage is None + _ = manager.storage + assert manager._storage is not None + + def test_create_exp_write_disabled(self, exp_manager, mock_config): + mock_config.exp_pool.enable_write = False + exp = Experience(req="test", resp="response") + exp_manager.create_exp(exp) + exp_manager.storage.add_objs.assert_not_called() + + def test_create_exp_write_enabled(self, exp_manager): + exp = Experience(req="test", resp="response") + exp_manager.create_exp(exp) + exp_manager.storage.add_objs.assert_called_once_with([exp]) + exp_manager.storage.persist.assert_called_once_with(exp_manager.config.exp_pool.persist_path) + + @pytest.mark.asyncio + async def test_query_exps_read_disabled(self, exp_manager, mock_config): + mock_config.exp_pool.enable_read = False + result = await exp_manager.query_exps("query") + assert result == [] + + @pytest.mark.asyncio + async def test_query_exps_with_exact_match(self, exp_manager, mocker): + req = "exact query" + exp1 = Experience(req=req, resp="response1") + exp2 = Experience(req="different query", resp="response2") + + mock_node1 = mocker.MagicMock(metadata={"obj": exp1}) + mock_node2 = mocker.MagicMock(metadata={"obj": exp2}) + + exp_manager.storage.aretrieve.return_value = [mock_node1, mock_node2] + + result = await exp_manager.query_exps(req, query_type=QueryType.EXACT) + assert len(result) == 1 + assert result[0].req == req + + @pytest.mark.asyncio + async def test_query_exps_with_tag_filter(self, exp_manager, mocker): + tag = "test_tag" + exp1 = Experience(req="query1", resp="response1", tag=tag) + exp2 = Experience(req="query2", resp="response2", tag="other_tag") + + mock_node1 = mocker.MagicMock(metadata={"obj": exp1}) + mock_node2 = mocker.MagicMock(metadata={"obj": exp2}) + + exp_manager.storage.aretrieve.return_value = [mock_node1, mock_node2] + + result = await exp_manager.query_exps("query", tag=tag) + assert len(result) == 1 + assert result[0].tag == tag + + def test_get_exps_count(self, exp_manager): + assert exp_manager.get_exps_count() == 10 + + def test_resolve_storage_bm25(self, mocker, mock_config): + mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.BM25 + mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._resolve_storage() + manager._create_bm25_storage.assert_called_once() + assert storage is not None + + def test_resolve_storage_chroma(self, mocker, mock_config): + mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.CHROMA + mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._resolve_storage() + manager._create_chroma_storage.assert_called_once() + assert storage is not None + + def test_create_bm25_storage(self, mocker, mock_config): + mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.engines.SimpleEngine.from_index", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.engines.SimpleEngine.get_obj_nodes", return_value=[]) + mocker.patch("metagpt.rag.engines.SimpleEngine._resolve_embed_model", return_value=mocker.MagicMock()) + mocker.patch("llama_index.core.VectorStoreIndex", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.schema.BM25RetrieverConfig", return_value=mocker.MagicMock()) + mocker.patch("pathlib.Path.exists", return_value=False) + + manager = ExperienceManager(config=mock_config) + storage = manager._create_bm25_storage() + assert storage is not None + + def test_create_chroma_storage(self, mocker, mock_config): + mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._create_chroma_storage() + assert storage is not None + + def test_get_ranker_configs_use_llm_ranker_true(self, mock_config): + mock_config.exp_pool.use_llm_ranker = True + manager = ExperienceManager(config=mock_config) + ranker_configs = manager._get_ranker_configs() + assert len(ranker_configs) == 1 + assert ranker_configs[0].top_n == DEFAULT_SIMILARITY_TOP_K + + def test_get_ranker_configs_use_llm_ranker_false(self, mock_config): + mock_config.exp_pool.use_llm_ranker = False + manager = ExperienceManager(config=mock_config) + ranker_configs = manager._get_ranker_configs() + assert len(ranker_configs) == 0 diff --git a/tests/metagpt/exp_pool/test_perfect_judges/test_simple_perfect_judge.py b/tests/metagpt/exp_pool/test_perfect_judges/test_simple_perfect_judge.py new file mode 100644 index 0000000000..5abd04f0db --- /dev/null +++ b/tests/metagpt/exp_pool/test_perfect_judges/test_simple_perfect_judge.py @@ -0,0 +1,40 @@ +import pytest + +from metagpt.exp_pool.perfect_judges import SimplePerfectJudge +from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score + + +class TestSimplePerfectJudge: + @pytest.fixture + def simple_perfect_judge(self): + return SimplePerfectJudge() + + @pytest.mark.asyncio + async def test_is_perfect_exp_perfect_match(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))) + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is True + + @pytest.mark.asyncio + async def test_is_perfect_exp_imperfect_score(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE - 1))) + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is False + + @pytest.mark.asyncio + async def test_is_perfect_exp_mismatched_request(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))) + result = await simple_perfect_judge.is_perfect_exp(exp, "different_request") + assert result is False + + @pytest.mark.asyncio + async def test_is_perfect_exp_no_metric(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp") + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is False + + @pytest.mark.asyncio + async def test_is_perfect_exp_no_score(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric()) + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is False diff --git a/tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py b/tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py new file mode 100644 index 0000000000..e17edfca81 --- /dev/null +++ b/tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py @@ -0,0 +1,64 @@ +import json + +import pytest + +from metagpt.exp_pool.schema import Score +from metagpt.exp_pool.scorers.simple import SIMPLE_SCORER_TEMPLATE, SimpleScorer +from metagpt.llm import BaseLLM + + +class TestSimpleScorer: + @pytest.fixture + def mock_llm(self, mocker): + mock_llm = mocker.MagicMock(spec=BaseLLM) + return mock_llm + + @pytest.fixture + def simple_scorer(self, mock_llm): + return SimpleScorer(llm=mock_llm) + + def test_init(self, mock_llm): + scorer = SimpleScorer(llm=mock_llm) + assert isinstance(scorer.llm, BaseLLM) + + @pytest.mark.asyncio + async def test_evaluate(self, simple_scorer, mock_llm, mocker): + # Mock request and response + req = "What is the capital of France?" + resp = "The capital of France is Paris." + + # Mock LLM response + mock_llm_response = '{"val": 9, "reason": "Accurate and concise answer"}' + mock_llm.aask.return_value = f"```json\n{mock_llm_response}\n```" + + # Mock CodeParser.parse_code + mocker.patch("metagpt.utils.common.CodeParser.parse_code", return_value=mock_llm_response) + + # Test evaluate method + result = await simple_scorer.evaluate(req, resp) + + # Assert LLM was called with correct prompt + expected_prompt = SIMPLE_SCORER_TEMPLATE.format(req=req, resp=resp) + mock_llm.aask.assert_called_once_with(expected_prompt) + + # Assert the result is correct + assert isinstance(result, Score) + assert result.val == 9 + assert result.reason == "Accurate and concise answer" + + @pytest.mark.asyncio + async def test_evaluate_invalid_response(self, simple_scorer, mock_llm, mocker): + # Mock request and response + req = "What is the capital of France?" + resp = "The capital of France is Paris." + + # Mock LLM response with invalid JSON + mock_llm_response = "Invalid JSON" + mock_llm.aask.return_value = f"```json\n{mock_llm_response}\n```" + + # Mock CodeParser.parse_code + mocker.patch("metagpt.utils.common.CodeParser.parse_code", return_value=mock_llm_response) + + # Test evaluate method with invalid response + with pytest.raises(json.JSONDecodeError): + await simple_scorer.evaluate(req, resp) diff --git a/tests/metagpt/exp_pool/test_serializers/test_action_node.py b/tests/metagpt/exp_pool/test_serializers/test_action_node.py new file mode 100644 index 0000000000..e4ab4684d2 --- /dev/null +++ b/tests/metagpt/exp_pool/test_serializers/test_action_node.py @@ -0,0 +1,35 @@ +from typing import Type + +import pytest + +from metagpt.actions.action_node import ActionNode +from metagpt.exp_pool.serializers.action_node import ActionNodeSerializer + + +class TestActionNodeSerializer: + @pytest.fixture + def serializer(self): + return ActionNodeSerializer() + + @pytest.fixture + def action_node(self): + class InstructContent: + def __init__(self, json_data): + self.json_data = json_data + + def model_dump_json(self): + return self.json_data + + action_node = ActionNode(key="", expected_type=Type[str], instruction="", example="") + action_node.instruct_content = InstructContent('{"key": "value"}') + + return action_node + + def test_serialize_resp(self, serializer: ActionNodeSerializer, action_node: ActionNode): + serialized = serializer.serialize_resp(action_node) + assert serialized == '{"key": "value"}' + + def test_deserialize_resp(self, serializer: ActionNodeSerializer): + deserialized = serializer.deserialize_resp('{"key": "value"}') + assert isinstance(deserialized, ActionNode) + assert deserialized.instruct_content.model_dump_json() == '{"key": "value"}' diff --git a/tests/metagpt/exp_pool/test_serializers/test_role_zero.py b/tests/metagpt/exp_pool/test_serializers/test_role_zero.py new file mode 100644 index 0000000000..964443f292 --- /dev/null +++ b/tests/metagpt/exp_pool/test_serializers/test_role_zero.py @@ -0,0 +1,46 @@ +import json + +import pytest + +from metagpt.exp_pool.serializers import RoleZeroSerializer + + +class TestRoleZeroSerializer: + @pytest.fixture + def serializer(self) -> RoleZeroSerializer: + return RoleZeroSerializer() + + @pytest.fixture + def last_item(self) -> dict: + return { + "role": "user", + "content": "# Current Plan\nsome plan\n# Current Plan\nsome plan\n# Instruction\nsome instruction", + } + + @pytest.fixture + def sample_req(self): + return [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + + def test_serialize_req_empty_input(self, serializer: RoleZeroSerializer): + assert serializer.serialize_req(req=[]) == "" + + def test_serialize_req_with_content(self, serializer: RoleZeroSerializer, last_item: dict): + req = [ + {"role": "user", "content": "Command Editor.read executed: file_path=test.py"}, + {"role": "assistant", "content": "Some other content"}, + last_item, + ] + expected_output = json.dumps([{"role": "user", "content": "Command Editor.read executed: file_path=test.py"}]) + assert serializer.serialize_req(req=req) == expected_output + + def test_filter_req(self, serializer: RoleZeroSerializer): + req = [ + {"role": "user", "content": "Command Editor.read executed: file_path=test1.py"}, + {"role": "assistant", "content": "Some other content"}, + {"role": "user", "content": "Command Editor.read executed: file_path=test2.py"}, + {"role": "assistant", "content": "Final content"}, + ] + filtered_req = serializer._filter_req(req) + assert len(filtered_req) == 2 + assert filtered_req[0]["content"] == "Command Editor.read executed: file_path=test1.py" + assert filtered_req[1]["content"] == "Command Editor.read executed: file_path=test2.py" diff --git a/tests/metagpt/exp_pool/test_serializers/test_simple.py b/tests/metagpt/exp_pool/test_serializers/test_simple.py new file mode 100644 index 0000000000..2a6bf96e3f --- /dev/null +++ b/tests/metagpt/exp_pool/test_serializers/test_simple.py @@ -0,0 +1,44 @@ +import pytest + +from metagpt.exp_pool.serializers.simple import SimpleSerializer + + +class TestSimpleSerializer: + @pytest.fixture + def serializer(self): + return SimpleSerializer() + + def test_serialize_req(self, serializer: SimpleSerializer): + # Test with different types of input + assert serializer.serialize_req(req=123) == "123" + assert serializer.serialize_req(req="test") == "test" + assert serializer.serialize_req(req=[1, 2, 3]) == "[1, 2, 3]" + assert serializer.serialize_req(req={"a": 1}) == "{'a': 1}" + + def test_serialize_resp(self, serializer: SimpleSerializer): + # Test with different types of input + assert serializer.serialize_resp(456) == "456" + assert serializer.serialize_resp("response") == "response" + assert serializer.serialize_resp([4, 5, 6]) == "[4, 5, 6]" + assert serializer.serialize_resp({"b": 2}) == "{'b': 2}" + + def test_deserialize_resp(self, serializer: SimpleSerializer): + # Test with different types of input + assert serializer.deserialize_resp("789") == "789" + assert serializer.deserialize_resp("test_response") == "test_response" + assert serializer.deserialize_resp("[7, 8, 9]") == "[7, 8, 9]" + assert serializer.deserialize_resp("{'c': 3}") == "{'c': 3}" + + def test_roundtrip(self, serializer: SimpleSerializer): + # Test serialization and deserialization roundtrip + original = "test_roundtrip" + serialized = serializer.serialize_resp(original) + deserialized = serializer.deserialize_resp(serialized) + assert deserialized == original + + @pytest.mark.parametrize("input_value", [123, "test", [1, 2, 3], {"a": 1}, None]) + def test_serialize_req_types(self, serializer: SimpleSerializer, input_value): + # Test serialize_req with various input types + result = serializer.serialize_req(req=input_value) + assert isinstance(result, str) + assert result == str(input_value) diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 990017feec..cbd161dfa1 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -28,10 +28,6 @@ async def test_ltm_search(mocker): ) role_id = "UTUserLtm(Product Manager)" - from metagpt.environment import Environment - - Environment - RoleContext.model_rebuild() rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"}) ltm = LongTermMemory() ltm.recover_memory(role_id, rc) diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py new file mode 100644 index 0000000000..80eb58e493 --- /dev/null +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -0,0 +1,164 @@ +from datetime import datetime, timedelta + +import pytest + +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage + + +class TestRoleZeroLongTermMemory: + @pytest.fixture + def mock_memory(self, mocker) -> RoleZeroLongTermMemory: + memory = RoleZeroLongTermMemory() + memory._resolve_rag_engine = mocker.Mock() + return memory + + def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True) + mock_memory._transfer_to_longterm_memory = mocker.Mock() + + message = UserMessage(content="test") + mock_memory.add(message) + + assert mock_memory.storage[-1] == message + mock_memory._transfer_to_longterm_memory.assert_called_once() + + def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True) + mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query") + mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")]) + + mock_memory.storage = [Message(content="short-term")] + + result = mock_memory.get() + + assert len(result) == 2 + assert result[0].content == "long-term" + assert result[1].content == "short-term" + + def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mocker.patch.object(mock_memory, "storage", [None] * 201) + + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_add() == True + + mocker.patch.object(mock_memory, "storage", [None] * 199) + assert mock_memory._should_use_longterm_memory_for_add() == False + + @pytest.mark.parametrize( + "k,is_last_from_user,count,expected", + [ + (0, True, 201, False), + (1, False, 201, False), + (1, True, 199, False), + (1, True, 201, True), + ], + ) + def test_should_use_longterm_memory_for_get( + self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected + ): + mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user) + mocker.patch.object(mock_memory, "storage", [None] * count) + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_get(k) == expected + + def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_item = mocker.Mock() + mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item) + mock_memory._add_to_longterm_memory = mocker.Mock() + + mock_memory._transfer_to_longterm_memory() + + mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item) + + def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="test") + mock_memory.storage = [mock_message, mock_message] + mock_memory.memory_k = 1 + + result = mock_memory._get_longterm_memory_item() + + assert isinstance(result, LongTermMemoryItem) + assert result.message == mock_message + + def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory): + item = LongTermMemoryItem(message=Message(content="test")) + mock_memory._add_to_longterm_memory(item) + + mock_memory.rag_engine.add_objs.assert_called_once_with([item]) + + def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="query") + mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message) + + result = mock_memory._build_longterm_memory_query() + + assert result == "query" + + def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory): + mock_memory.storage = [Message(content="1"), Message(content="2")] + + result = mock_memory._get_the_last_message() + + assert result.content == "2" + + @pytest.mark.parametrize( + "message,expected", + [ + (UserMessage(content="test", cause_by=UserRequirement), True), + (UserMessage(content="test", sent_from=TEAMLEADER_NAME), True), + (UserMessage(content="test"), True), + (AIMessage(content="test"), False), + (None, False), + ], + ) + def test_is_last_message_from_user_requirement( + self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected + ): + mock_memory._get_the_last_message = mocker.Mock(return_value=message) + + assert mock_memory._is_last_message_from_user_requirement() == expected + + def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_nodes = [mocker.Mock(), mocker.Mock()] + mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes) + mock_items = [ + LongTermMemoryItem(message=UserMessage(content="user1")), + LongTermMemoryItem(message=AIMessage(content="ai1")), + ] + mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items) + + result = mock_memory._fetch_longterm_memories("query") + + assert len(result) == 2 + assert result[0].content == "user1" + assert result[1].content == "ai1" + + def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory): + now = datetime.now() + mock_nodes = [ + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp() + ) + } + ), + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp() + ) + } + ), + mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}), + ] + + result = mock_memory._get_items_from_nodes(mock_nodes) + + assert len(result) == 3 + assert [item.message.content for item in result] == ["1", "2", "3"] diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 40a9fda920..62083a769e 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -8,9 +8,11 @@ import pytest +from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig +from metagpt.const import IMAGES from metagpt.provider.base_llm import BaseLLM -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message, UserMessage from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( default_resp_cont, @@ -104,3 +106,99 @@ async def test_async_base_llm(): # resp = await base_llm.aask_code([prompt]) # assert resp == default_resp_cont + + +@pytest.mark.parametrize("compress_type", list(CompressType)) +def test_compress_messages_no_effect(compress_type): + base_llm = MockBaseLLM() + messages = [ + {"role": "system", "content": "first system msg"}, + {"role": "system", "content": "second system msg"}, + ] + for i in range(5): + messages.append({"role": "user", "content": f"u{i}"}) + messages.append({"role": "assistant", "content": f"a{i}"}) + compressed = base_llm.compress_messages(messages, compress_type=compress_type) + # should take no effect for short context + assert compressed == messages + + +@pytest.mark.parametrize("compress_type", CompressType.cut_types()) +def test_compress_messages_long(compress_type): + base_llm = MockBaseLLM() + base_llm.config.model = "test_llm" + max_token_limit = 100 + + messages = [ + {"role": "system", "content": "first system msg"}, + {"role": "system", "content": "second system msg"}, + ] + for i in range(100): + messages.append({"role": "user", "content": f"u{i}" * 10}) # ~2x10x0.5 = 10 tokens + messages.append({"role": "assistant", "content": f"a{i}" * 10}) + compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit) + + print(compressed) + print(len(compressed)) + assert 3 <= len(compressed) < len(messages) + assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system" + assert compressed[2]["role"] != "system" + + +def test_long_messages_no_compress(): + base_llm = MockBaseLLM() + messages = [{"role": "user", "content": "1" * 10000}] * 10000 + compressed = base_llm.compress_messages(messages) + assert len(compressed) == len(messages) + + +@pytest.mark.parametrize("compress_type", CompressType.cut_types()) +def test_compress_messages_long_no_sys_msg(compress_type): + base_llm = MockBaseLLM() + base_llm.config.model = "test_llm" + max_token_limit = 100 + + messages = [{"role": "user", "content": "1" * 10000}] + compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit) + + print(compressed) + assert compressed + assert len(compressed[0]["content"]) < len(messages[0]["content"]) + + +def test_format_msg(mocker): + base_llm = MockBaseLLM() + messages = [UserMessage(content="req"), AIMessage(content="rsp")] + formatted_msgs = base_llm.format_msg(messages) + assert formatted_msgs == [{"role": "user", "content": "req"}, {"role": "assistant", "content": "rsp"}] + + +def test_format_msg_w_images(mocker): + base_llm = MockBaseLLM() + base_llm.config.model = "gpt-4o" + msg_w_images = UserMessage(content="req1") + msg_w_images.add_metadata(IMAGES, ["base64 string 1", "base64 string 2"]) + msg_w_empty_images = UserMessage(content="req2") + msg_w_empty_images.add_metadata(IMAGES, []) + messages = [ + msg_w_images, # should be converted + AIMessage(content="rsp"), + msg_w_empty_images, # should not be converted + ] + formatted_msgs = base_llm.format_msg(messages) + assert formatted_msgs == [ + { + "role": "user", + "content": [ + {"type": "text", "text": "req1"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 1"}}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 2"}}, + ], + }, + {"role": "assistant", "content": "rsp"}, + {"role": "user", "content": "req2"}, + ] + + +if name == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 3ce38d2a5a..d292a82861 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -9,6 +9,7 @@ from openai.types.chat.chat_completion_message_tool_call import Function from PIL import Image +from metagpt.configs.compress_msg_config import CompressType from metagpt.const import TEST_DATA_PATH from metagpt.llm import LLM from metagpt.logs import logger @@ -164,3 +165,63 @@ async def test_openai_acompletion(mocker): assert resp.usage == usage await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont) + + +def test_count_tokens(): + llm = LLM() + llm.config.model = "gpt-4o" + messages = [ + llm._system_msg("some system msg"), + llm._system_msg("some system message 2"), + llm._user_msg("user 1"), + llm._assistant_msg("assistant 1"), + llm._user_msg("user 1"), + llm._assistant_msg("assistant 2"), + ] + cnt = llm.count_tokens(messages) + assert cnt == 47 + + +def test_count_tokens_long(): + llm = LLM() + llm.config.model = "gpt-4-0613" + test_msg_content = " ".join([str(i) for i in range(100000)]) + messages = [ + llm._system_msg("You are a helpful assistant"), + llm._user_msg(test_msg_content + " what's the first number you see?"), + ] + cnt = llm.count_tokens(messages) # 299023, ~300k + assert 290000 <= cnt <= 300000 + + llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens + cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference + assert 290000 <= cnt <= 300000 + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_aask_long(): + llm = LLM() + llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + llm.config.compress_type = CompressType.POST_CUT_BY_TOKEN + test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens + messages = [ + llm._system_msg("You are a helpful assistant"), + llm._user_msg(test_msg_content + " what's the first number you see?"), + ] + await llm.aask(messages) # should not fail with context truncated + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_aask_long_no_compress(): + llm = LLM() + llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + # Not specifying llm.config.compress_type will use default "", no compress + test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens + messages = [ + llm._system_msg("You are a helpful assistant"), + llm._user_msg(test_msg_content + " what's the first number you see?"), + ] + with pytest.raises(Exception): + await llm.aask(messages) # should fail diff --git a/tests/metagpt/rag/__init__.py b/tests/metagpt/rag/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 9262ccb07e..e0a174ed2e 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -25,10 +25,6 @@ def mock_embedding(self): def mock_simple_directory_reader(self, mocker): return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") - @pytest.fixture - def mock_vector_store_index(self, mocker): - return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") - @pytest.fixture def mock_get_retriever(self, mocker): return mocker.patch("metagpt.rag.engines.simple.get_retriever") @@ -45,7 +41,6 @@ def test_from_docs( self, mocker, mock_simple_directory_reader, - mock_vector_store_index, mock_get_retriever, mock_get_rankers, mock_get_response_synthesizer, @@ -80,12 +75,9 @@ def test_from_docs( ) # Assert - mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) - mock_vector_store_index.assert_called_once() - mock_get_retriever.assert_called_once_with( - configs=retriever_configs, index=mock_vector_store_index.return_value - ) - mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm) + mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None) + mock_get_retriever.assert_called_once() + mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) assert isinstance(engine, SimpleEngine) @@ -119,7 +111,7 @@ def model_dump_json(self): # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is not None + assert engine._transformations is not None def test_from_objs_with_bm25_config(self): # Setup @@ -137,6 +129,7 @@ def test_from_objs_with_bm25_config(self): def test_from_index(self, mocker, mock_llm, mock_embedding): # Mock mock_index = mocker.MagicMock(spec=VectorStoreIndex) + mock_index.as_retriever.return_value = "retriever" mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index") mock_get_index.return_value = mock_index @@ -149,7 +142,7 @@ def test_from_index(self, mocker, mock_llm, mock_embedding): # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is mock_index + assert engine._retriever == "retriever" @pytest.mark.asyncio async def test_asearch(self, mocker): @@ -200,14 +193,11 @@ def test_add_docs(self, mocker): mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever) - mock_index = mocker.MagicMock(spec=VectorStoreIndex) - mock_index._transformations = mocker.MagicMock() - mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations") mock_run_transformations.return_value = ["node1", "node2"] # Setup - engine = SimpleEngine(retriever=mock_retriever, index=mock_index) + engine = SimpleEngine(retriever=mock_retriever) input_files = ["test_file1", "test_file2"] # Exec @@ -230,7 +220,7 @@ def model_dump_json(self): return "" objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)] - engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock()) + engine = SimpleEngine(retriever=mock_retriever) # Exec engine.add_objs(objs=objs) diff --git a/tests/metagpt/rag/factories/test_base.py b/tests/metagpt/rag/factories/test_base.py index 1d41e18721..0b0a449761 100644 --- a/tests/metagpt/rag/factories/test_base.py +++ b/tests/metagpt/rag/factories/test_base.py @@ -97,6 +97,5 @@ def test_val_from_config_or_kwargs_fallback_to_kwargs(self): def test_val_from_config_or_kwargs_key_error(self): # Test KeyError when the key is not found in both config object and kwargs config = DummyConfig(name=None) - with pytest.raises(KeyError) as exc_info: - ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) - assert "The key 'missing_key' is required but not provided" in str(exc_info.value) + val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) + assert val is None diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1ded6b4a8d..03bdfab1d9 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -1,5 +1,7 @@ import pytest +from metagpt.config2 import Config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -10,30 +12,54 @@ def mock_embedding_factory(self): self.embedding_factory = RAGEmbeddingFactory() @pytest.fixture - def mock_openai_embedding(self, mocker): + def mock_config(self, mocker): + config = Config.default().model_copy(deep=True) + default = mocker.patch("metagpt.config2.Config.default") + default.return_value = config + return config + + @staticmethod + def mock_openai_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") - @pytest.fixture - def mock_azure_embedding(self, mocker): + @staticmethod + def mock_azure_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") - def test_get_rag_embedding_openai(self, mock_openai_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.OPENAI) + @staticmethod + def mock_gemini_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") - # Assert - mock_openai_embedding.assert_called_once() + @staticmethod + def mock_ollama_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") + + @pytest.mark.parametrize( + ("mock_func", "embedding_type"), + [ + (mock_openai_embedding, LLMType.OPENAI), + (mock_azure_embedding, LLMType.AZURE), + (mock_openai_embedding, EmbeddingType.OPENAI), + (mock_azure_embedding, EmbeddingType.AZURE), + (mock_gemini_embedding, EmbeddingType.GEMINI), + (mock_ollama_embedding, EmbeddingType.OLLAMA), + ], + ) + def test_get_rag_embedding(self, mock_func, embedding_type, mocker): + # Mock + mock = mock_func(mocker) - def test_get_rag_embedding_azure(self, mock_azure_embedding): # Exec - self.embedding_factory.get_rag_embedding(LLMType.AZURE) + self.embedding_factory.get_rag_embedding(embedding_type) # Assert - mock_azure_embedding.assert_called_once() + mock.assert_called_once() - def test_get_rag_embedding_default(self, mocker, mock_openai_embedding): + def test_get_rag_embedding_default(self, mocker, mock_config): # Mock - mock_config = mocker.patch("metagpt.rag.factories.embedding.config") + mock_openai_embedding = self.mock_openai_embedding(mocker) + + mock_config.embedding.api_type = None mock_config.llm.api_type = LLMType.OPENAI # Exec @@ -41,3 +67,44 @@ def test_get_rag_embedding_default(self, mocker, mock_openai_embedding): # Assert mock_openai_embedding.assert_called_once() + + @pytest.mark.parametrize( + "model, embed_batch_size, expected_params", + [("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})], + ) + def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params): + # Mock + mock_config.embedding.model = model + mock_config.embedding.embed_batch_size = embed_batch_size + + # Setup + test_params = {} + + # Exec + self.embedding_factory._try_set_model_and_batch_size(test_params) + + # Assert + assert test_params == expected_params + + def test_resolve_embedding_type(self, mock_config): + # Mock + mock_config.embedding.api_type = EmbeddingType.OPENAI + + # Exec + embedding_type = self.embedding_factory._resolve_embedding_type() + + # Assert + assert embedding_type == EmbeddingType.OPENAI + + def test_resolve_embedding_type_exception(self, mock_config): + # Mock + mock_config.embedding.api_type = None + mock_config.llm.api_type = LLMType.GEMINI + + # Assert + with pytest.raises(TypeError): + self.embedding_factory._resolve_embedding_type() + + def test_raise_for_key(self): + with pytest.raises(ValueError): + self.embedding_factory._raise_for_key("key") diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index ef1cef7e00..cd55a32db6 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -1,6 +1,8 @@ import faiss import pytest from llama_index.core import VectorStoreIndex +from llama_index.core.embeddings import MockEmbedding +from llama_index.core.schema import TextNode from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -43,6 +45,14 @@ def mock_chroma_vector_store(self, mocker): def mock_es_vector_store(self, mocker): return mocker.MagicMock(spec=ElasticsearchStore) + @pytest.fixture + def mock_nodes(self, mocker): + return [TextNode(text="msg")] + + @pytest.fixture + def mock_embedding(self): + return MockEmbedding(embed_dim=1) + def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index): mock_config = FAISSRetrieverConfig(dimensions=128) mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index) @@ -52,42 +62,40 @@ def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_ve assert isinstance(retriever, FAISSRetriever) - def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index): + def test_get_retriever_with_bm25_config(self, mocker, mock_nodes): mock_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes) assert isinstance(retriever, DynamicBM25Retriever) - def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index): - mock_faiss_config = FAISSRetrieverConfig(dimensions=128) + def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding): + mock_faiss_config = FAISSRetrieverConfig(dimensions=1) mock_bm25_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config]) + retriever = self.retriever_factory.get_retriever( + configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding + ) assert isinstance(retriever, SimpleHybridRetriever) - def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store): + def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding): mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection") mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient") mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock() mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ChromaRetriever) - def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store): + def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding): mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig()) mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ElasticsearchRetriever) @@ -111,3 +119,19 @@ def test_extract_index_from_kwargs(self, mock_vector_store_index): extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index) assert extracted_index == mock_vector_store_index + + def test_get_or_build_when_get(self, mocker): + want = "existing_index" + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want + + def test_get_or_build_when_build(self, mocker): + want = "call_build_es_index" + mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want diff --git a/tests/metagpt/rag/test_large_pdf.py b/tests/metagpt/rag/test_large_pdf.py new file mode 100644 index 0000000000..4f343aa874 --- /dev/null +++ b/tests/metagpt/rag/test_large_pdf.py @@ -0,0 +1,55 @@ +import pytest + +from metagpt.config2 import Config +from metagpt.const import TEST_DATA_PATH +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.factories.embedding import RAGEmbeddingFactory +from metagpt.utils.common import aread + + +@pytest.mark.skip +@pytest.mark.parametrize( + ("knowledge_filename", "query_filename", "answer_filename"), + [ + ( + TEST_DATA_PATH / "embedding/2.knowledge.md", + TEST_DATA_PATH / "embedding/2.query.md", + TEST_DATA_PATH / "embedding/2.answer.md", + ), + ( + TEST_DATA_PATH / "embedding/3.knowledge.md", + TEST_DATA_PATH / "embedding/3.query.md", + TEST_DATA_PATH / "embedding/3.answer.md", + ), + ], +) +@pytest.mark.asyncio +async def test_large_pdf(knowledge_filename, query_filename, answer_filename): + Config.default(reload=True) # `config.embedding.model = "text-embedding-ada-002"` changes the cache. + + engine = SimpleEngine.from_docs( + input_files=[knowledge_filename], + ) + + query = await aread(filename=query_filename) + rsp = await engine.aretrieve(query) + assert rsp + + config = Config.default() + config.embedding.model = "text-embedding-ada-002" + factory = RAGEmbeddingFactory(config) + embedding = factory.get_rag_embedding() + answer = await aread(filename=answer_filename) + answer_embedding = await embedding.aget_text_embedding(answer) + similarity = 0 + for i in rsp: + rsp_embedding = await embedding.aget_query_embedding(i.text) + v = embedding.similarity(answer_embedding, rsp_embedding) + similarity = max(similarity, v) + + print(similarity) + assert similarity > 0.9 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/roles/di/run_architect.py b/tests/metagpt/roles/di/run_architect.py new file mode 100644 index 0000000000..455b60d92b --- /dev/null +++ b/tests/metagpt/roles/di/run_architect.py @@ -0,0 +1,39 @@ +import asyncio +import os + +from metagpt.roles.architect import Architect +from metagpt.schema import Message + +DESIGN_DOC_SNAKE = """ +{ + "Implementation approach": "We will use the Pygame library to create the CLI-based snake game. Pygame is a set of Python modules designed for writing video games, which will help us handle graphics, sound, and input. The game will be structured into different modules to handle the main game loop, snake movement, food generation, collision detection, and user interface. We will ensure the game is engaging and responsive by optimizing the game loop and input handling. The score display and different speed levels will be implemented to enhance the user experience.", + "File list": [ + "main.py", + "game.py", + "snake.py", + "food.py", + "ui.py" + ], + "Data structures and interfaces": "\nclassDiagram\n class Main {\n +main() void\n }\n class Game {\n -Snake snake\n -Food food\n -int score\n -int speed\n +__init__(speed: int)\n +run() void\n +restart() void\n +update_score() void\n }\n class Snake {\n -list body\n -str direction\n +__init__()\n +move() void\n +change_direction(new_direction: str) void\n +check_collision() bool\n +grow() void\n }\n class Food {\n -tuple position\n +__init__()\n +generate_new_position() void\n }\n class UI {\n +display_score(score: int) void\n +display_game_over() void\n +display_game(snake: Snake, food: Food) void\n }\n Main --> Game\n Game --> Snake\n Game --> Food\n Game --> UI\n", + "Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant U as UI\n M->>G: __init__(speed)\n M->>G: run()\n G->>S: __init__()\n G->>F: __init__()\n loop Game Loop\n G->>S: move()\n G->>S: check_collision()\n alt Collision Detected\n G->>G: restart()\n G->>U: display_game_over()\n else No Collision\n G->>F: generate_new_position()\n G->>S: grow()\n G->>G: update_score()\n G->>U: display_score(score)\n end\n G->>U: display_game(snake, food)\n end\n", + "Anything UNCLEAR": "Currently, all aspects of the project are clear." +} +""" + +WRITE_SNAKE = """Write a system design for a cli snake game with pygame""" + +REWRITE_SNAKE = """Rewrite the system design at temp_design.json, add a web UI""" + +CASUAL_CHAT = """What's your name?""" + + +async def main(requirement): + with open("temp_design.json", "w") as f: + f.write(DESIGN_DOC_SNAKE) + architect = Architect() + await architect.run(Message(content=requirement, send_to="Bob")) + os.remove("temp_design.json") + + +if __name__ == "__main__": + asyncio.run(main(REWRITE_SNAKE)) diff --git a/tests/metagpt/roles/di/run_data_analyst.py b/tests/metagpt/roles/di/run_data_analyst.py new file mode 100644 index 0000000000..247bc78070 --- /dev/null +++ b/tests/metagpt/roles/di/run_data_analyst.py @@ -0,0 +1,54 @@ +from metagpt.roles.di.data_analyst import DataAnalyst + +HOUSE_PRICE_TRAIN_PATH = "/data/house-prices-advanced-regression-techniques/split_train.csv" +HOUSE_PRICE_EVAL_PATH = "/data/house-prices-advanced-regression-techniques/split_eval.csv" +HOUSE_PRICE_REQ = f""" +This is a house price dataset, your goal is to predict the sale price of a property based on its features. The target column is SalePrice. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report RMSE between the logarithm of the predicted value and the logarithm of the observed sales price on the eval data. Train data path: '{HOUSE_PRICE_TRAIN_PATH}', eval data path: '{HOUSE_PRICE_EVAL_PATH}'. +""" + +CALIFORNIA_HOUSING_REQ = """ +Analyze the 'Canifornia-housing-dataset' using https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_california_housing.html#sklearn.datasets.fetch_california_housing to predict the median house value. you need to perfrom data preprocessing, feature engineering and finally modeling to predict the target. Use machine learning techniques such as linear regression (including ridge regression and lasso regression), random forest, CatBoost, LightGBM, XGBoost or other appropriate method. You also need to report the MSE on the test dataset +""" + +# For web scraping task, please provide url begin with `https://` or `http://` +PAPER_LIST_REQ = """" +Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, +and save it to a csv file. paper title must include `multiagent` or `large language model`. +**Notice: view the page element before writing scraping code** +""" + +ECOMMERCE_REQ = """ +Get products data from website https://scrapeme.live/shop/ and save it as a csv file. +The first page product name, price, product URL, and image URL must be saved in the csv. +**Notice: view the page element before writing scraping code** +""" + +NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**; +下面是一个大致流程, 你会根据每一步的运行结果对当前计划中的任务做出适当调整: +1. 爬取并本地保存html结构; +2. 直接打印第7个*`快讯`*关键词后2000个字符的html内容, 作为*快讯的html内容示例*; +3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间; +4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。 +5. 将全部结果存在本地csv中 +**Notice: view the page element before writing scraping code** +""" + +WIKIPEDIA_SEARCH_REQ = """ +Search for `LLM` on https://www.wikipedia.org/ and print all the meaningful significances of the entry. +""" + +STACKOVERFLOW_CLICK_REQ = """ +Click the Questions tag in https://stackoverflow.com/ and scrap question name, votes, answers and views num to csv in the first result page. +""" + + +async def main(): + di = DataAnalyst() + await di.browser.start() + await di.run(STACKOVERFLOW_CLICK_REQ) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/tests/metagpt/roles/di/run_engineer2.py b/tests/metagpt/roles/di/run_engineer2.py new file mode 100644 index 0000000000..a5ceec93e6 --- /dev/null +++ b/tests/metagpt/roles/di/run_engineer2.py @@ -0,0 +1,167 @@ +import asyncio +import sys +import uuid +from pathlib import Path + +from metagpt.logs import logger +from metagpt.roles.di.engineer2 import Engineer2 + +DESIGN_DOC_2048 = '{"Implementation approach":"We will use the Pygame library to implement the 2048 game logic and user interface. Pygame is a set of Python modules designed for writing video games, which will help us create a responsive and visually appealing UI. For the mobile responsiveness, we will ensure that the game scales appropriately on different screen sizes. We will also use the Pygame GUI library to create buttons for restarting the game and choosing difficulty levels.","File list":["main.py","game.py","ui.py"],"Data structures and interfaces":"\\nclassDiagram\\n class Game {\\n -grid: list[list[int]]\\n -score: int\\n +__init__()\\n +move(direction: str) bool\\n +merge() bool\\n +spawn_tile() None\\n +is_game_over() bool\\n +reset() None\\n }\\n class UI {\\n -game: Game\\n +__init__(game: Game)\\n +draw_grid() None\\n +draw_score() None\\n +draw_buttons() None\\n +handle_input() None\\n }\\n class Main {\\n -ui: UI\\n +main() None\\n }\\n Main --> UI\\n UI --> Game\\n","Program call flow":"\\nsequenceDiagram\\n participant M as Main\\n participant U as UI\\n participant G as Game\\n M->>U: __init__(game)\\n U->>G: __init__()\\n M->>U: draw_grid()\\n U->>G: move(direction)\\n G-->>U: return bool\\n U->>G: merge()\\n G-->>U: return bool\\n U->>G: spawn_tile()\\n G-->>U: return None\\n U->>G: is_game_over()\\n G-->>U: return bool\\n U->>G: reset()\\n G-->>U: return None\\n M->>U: draw_score()\\n M->>U: draw_buttons()\\n M->>U: handle_input()\\n","Anything UNCLEAR":"Clarification needed on the specific design elements for the UI to ensure it meets the \'beautiful\' requirement. Additionally, we need to confirm the exact difficulty levels and how they should affect the game mechanics."}' +TASK_DOC_2048 = '{"Required Python packages":["pygame==2.0.1","pygame_gui==0.5.7"],"Required Other language third-party packages":["No third-party dependencies required"],"Logic Analysis":[["game.py","Contains Game class with methods: __init__, move, merge, spawn_tile, is_game_over, reset"],["ui.py","Contains UI class with methods: __init__, draw_grid, draw_score, draw_buttons, handle_input"],["main.py","Contains Main class with method: main, initializes UI and Game"]],"Task list":["game.py","ui.py","main.py"],"Full API spec":"","Shared Knowledge":"`game.py` contains core game logic and state management. `ui.py` handles all user interface elements and interactions. `main.py` serves as the entry point to initialize and run the game.","Anything UNCLEAR":"Clarification needed on the specific design elements for the UI to ensure it meets the \'beautiful\' requirement. Additionally, we need to confirm the exact difficulty levels and how they should affect the game mechanics."}' +DESIGN_DOC_SNAKE = """ +{ + "Implementation approach": "We will use the Pygame library to create the CLI-based snake game. Pygame is a set of Python modules designed for writing video games, which will help us handle graphics, sound, and input. The game will be structured into different modules to handle the main game loop, snake movement, food generation, collision detection, and user interface. We will ensure the game is engaging and responsive by optimizing the game loop and input handling. The score display and different speed levels will be implemented to enhance the user experience.", + "File list": [ + "main.py", + "game.py", + "snake.py", + "food.py", + "ui.py" + ], + "Data structures and interfaces": "\nclassDiagram\n class Main {\n +main() void\n }\n class Game {\n -Snake snake\n -Food food\n -int score\n -int speed\n +__init__(speed: int)\n +run() void\n +restart() void\n +update_score() void\n }\n class Snake {\n -list body\n -str direction\n +__init__()\n +move() void\n +change_direction(new_direction: str) void\n +check_collision() bool\n +grow() void\n }\n class Food {\n -tuple position\n +__init__()\n +generate_new_position() void\n }\n class UI {\n +display_score(score: int) void\n +display_game_over() void\n +display_game(snake: Snake, food: Food) void\n }\n Main --> Game\n Game --> Snake\n Game --> Food\n Game --> UI\n", + "Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant U as UI\n M->>G: __init__(speed)\n M->>G: run()\n G->>S: __init__()\n G->>F: __init__()\n loop Game Loop\n G->>S: move()\n G->>S: check_collision()\n alt Collision Detected\n G->>G: restart()\n G->>U: display_game_over()\n else No Collision\n G->>F: generate_new_position()\n G->>S: grow()\n G->>G: update_score()\n G->>U: display_score(score)\n end\n G->>U: display_game(snake, food)\n end\n", + "Anything UNCLEAR": "Currently, all aspects of the project are clear." +} +""" +TASK_DOC_SNAKE = """ +{ + "Required Python packages": [ + "pygame==2.0.1" + ], + "Required Other language third-party packages": [ + "No third-party dependencies required" + ], + "Logic Analysis": [ + [ + "main.py", + "Contains the main function to initialize and start the game. Imports Game from game.py." + ], + [ + "game.py", + "Contains the Game class which manages the game loop, score, and speed. Imports Snake from snake.py, Food from food.py, and UI from ui.py." + ], + [ + "snake.py", + "Contains the Snake class which handles snake movement, direction changes, collision detection, and growth." + ], + [ + "food.py", + "Contains the Food class which handles food position generation." + ], + [ + "ui.py", + "Contains the UI class which handles displaying the score, game over screen, and the game state." + ] + ], + "Task list": [ + "snake.py", + "food.py", + "ui.py", + "game.py", + "main.py" + ], + "Full API spec": "", + "Shared Knowledge": "`game.py` contains the main game loop and integrates all other modules (snake, food, UI).", + "Anything UNCLEAR": "Currently, all aspects of the project are clear." +} +""" + +GAME_REQ_2048 = f""" +Create a 2048 game, follow the design doc and task doc. Write your code under /Users/gary/Files/temp/workspace/2048_game/src. +After writing all codes, write a code review for the codes, make improvement or adjustment based on the review. +Notice: You MUST implement the full code, don't leave comment without implementation! +Design doc: +{DESIGN_DOC_2048} +Task doc: +{TASK_DOC_2048} +""" +GAME_REQ_SNAKE = f""" +Create a snake game, follow the design doc and task doc. Write your code under /Users/gary/Files/temp/workspace/snake_game/src. +After writing all codes, write a code review for the codes, make improvement or adjustment based on the review. +Notice: You MUST implement the full code, don't leave comment without implementation! +Design doc: +{DESIGN_DOC_SNAKE} +Task doc: +{TASK_DOC_SNAKE} +""" +GAME_REQ_2048_NO_DOC = """ +Create a 2048 game with pygame. Write your code under /Users/gary/Files/temp/workspace/2048_game/src. +Consider what files you will write, break down the requests to multiple tasks and write one file in each task. +After writing all codes, write a code review for the codes, make improvement or adjustment based on the review. +Notice: You MUST implement the full code, don't leave comment without implementation! +""" +GAME_INC_REQ_2048 = """ +I found an issue with the 2048 code: when tiles are merged, no new tiles pop up. +Write code review for the codes (game.py, main.py, ui.py) under under /Users/gary/Files/temp/workspace/2048_game_bugs/src. +Then correct any issues you find. You can review all code in one time, and solve issues in one time. +""" +GAME_INC_REQ_SNAKE = """ +Found this issue, TypeError: generate_new_position() missing 1 required positional argument: 'snake_body' +Write code review for the codes (food.py, game.py, main.py, snake.py, ui.py) under under /Users/gary/Files/temp/workspace/snake_game_bugs/src. +Then correct any issues you find. You can review all code in one time, and solve issues in one time. +""" +CASUAL_CHAT = """what's your name?""" + + +# increment development +INC_DEVELOPMENT_CASE1 = [ + "Complete the Snake game with the root directory at '/home/mgx/mgx/MetaGPT/workspace/snake_game'", + "Use the up button to control the snake to move down, the left button to move right, and so on", + "Place the restart/start button at the top", + "Add a pause button", + "Display the score and leaderboard in real-time on the page", +] + +INC_DEVELOPMENT_CASE2 = [ + "Develop a Snake game using Python in the '/home/mgx/mgx/MetaGPT/workspace/snake_game_py' folder", + "Change the title to 'Special Snake'", + "Use the up button to control the snake to move down, the left button to move right, and so on", + "Add a pause button", + "Display the score and leaderboard in real-time on the page", + "Design a more attractive style for the leaderboard", +] + +INC_DEVELOPMENT_CASE3 = [ + "Complete the 2048 game with the root directory at '/home/mgx/mgx/MetaGPT/workspace/2048_game'", + "Place the start button at the top", + "Display the score and leaderboard in real-time on the page", + "Design a more attractive style for the leaderboard", + "Add a restart button", +] + +INC_DEVELOPMENT_CASE4 = [ + "Develop a 2048 game using Python in the '/home/mgx/mgx/MetaGPT/workspace/2048_game_py' folder", + "Display the score and leaderboard in real-time on the page", + "Add a restart button", +] +INC_DEVELOPMENT_CASE5 = [ + "Root path is '/home/mgx/mgx/MetaGPT/workspace/to_list' Create a website widget for TODO list management. Users should be able to add, mark as complete, and delete tasks. Include features like prioritization, due dates, and categories. Make it visually appealing, responsive, and user-friendly. Use HTML, CSS, and JavaScript. Consider additional features like notifications or task export. Keep it simple and enjoyable for users.dont use vue or react.dont use third party library, use localstorage to save data.", + "Add a `clean all` buttonn", +] +INC_DEVELOPMENT_CASE6 = [ + '使用原生HTML开发一个塔罗牌角色介绍网站\n1. 主题是塔罗牌占卜的网站\n2. 超前的网页布局\n3. 页面需要时响应式的\n4. 页面需要美观大气 root path "”/home/mgx/mgx/MetaGPT/workspace/taro"', + "扩充更多的角色,添加3个自己想出来的角色", + "让每一个角色的描述更加清楚", + "将中文内容全部替换为英文包括js里面的内容", +] + + +async def increment_development(): + engineer2 = Engineer2(run_eval=True) + example = INC_DEVELOPMENT_CASE6 + logger.remove() + logger.add(sys.stderr, level="INFO") + logger.add(Path("logs") / f"{str(uuid.uuid4())[-12:]}.log", level="DEBUG") + logger.info("user requirement:\n" + "\n".join(example)) + try: + for user_requirement in example: + logger.info(f"input:{user_requirement}") + await engineer2.run(user_requirement) + except Exception as e: + print(e) + + +if __name__ == "__main__": + asyncio.run(increment_development()) + # engineer2 = Engineer2() + # asyncio.run(engineer2.run(GAME_REQ_2048_NO_DOC)) diff --git a/tests/metagpt/roles/di/run_product_manager.py b/tests/metagpt/roles/di/run_product_manager.py new file mode 100644 index 0000000000..bb230b7d99 --- /dev/null +++ b/tests/metagpt/roles/di/run_product_manager.py @@ -0,0 +1,87 @@ +import asyncio +import sys + +from metagpt.logs import logger +from metagpt.roles import ProductManager + +CASE0_WRITE_2048 = """Write a PRD for a cli 2048 game""" +CASE1_GREEDY_SNAKE = "设计一个贪吃蛇游戏" +CASE2_SMART_HOME = "搜索并分析米家、华为智能家居和海尔智家在智能家居市场中的功能、用户需求和市场定位" +CASE3_BEST_SELLING_REFRIGERATOR = "调研当前市场上最畅销的智能冰箱的五个关键特性" +OLD_PRD = """ +Language +en_us + +Programming Language +N/A + +Original Requirements +Write a PRD based on the current music streaming service. + +Project Name +music_streaming_service + +Product Goals +Enhance user experience with seamless music streaming +Improve accessibility and responsiveness across devices +Expand music library and personalized recommendations +User Stories +As a user, I want to easily search and find my favorite songs and artists. +As a user, I want to create and manage my own playlists. +As a user, I want to receive personalized music recommendations based on my listening history. +As a user, I want to stream music without interruptions or buffering. +As a user, I want to access the service on both desktop and mobile devices. +Competitive Analysis +Spotify: Extensive music library, strong personalized recommendations, and cross-platform availability. +Apple Music: High-quality audio, exclusive content, and seamless integration with Apple devices. +Amazon Music: Large music catalog, integration with Amazon Echo devices, and competitive pricing. +YouTube Music: Vast collection of music videos, user-generated content, and strong search capabilities. +Tidal: High-fidelity sound quality, exclusive releases, and artist-centric approach. +Competitive Quadrant Chart +quadrantChart title "Feature Richness vs. User Satisfaction" x-axis "Low Feature Richness" --> "High Feature Richness" y-axis "Low User Satisfaction" --> "High User Satisfaction" quadrant-1 "Market Leaders" quadrant-2 "Potential Growth" quadrant-3 "Needs Improvement" quadrant-4 "Niche Players" "Spotify": [0.9, 0.85] "Apple Music": [0.85, 0.8] "Amazon Music": [0.75, 0.7] "YouTube Music": [0.8, 0.75] "Tidal": [0.7, 0.65] "Our Target Product": [0.8, 0.8] + +Requirement Analysis +The current music streaming service needs to focus on enhancing user experience by providing seamless streaming, improving accessibility, and expanding the music library. Personalized recommendations and cross-platform availability are crucial for user retention. + +Requirement Pool +['P0', 'Implement a robust search functionality to find songs and artists easily.'] +['P0', 'Develop a feature for users to create and manage playlists.'] +['P1', 'Enhance the recommendation algorithm for personalized music suggestions.'] +['P1', 'Optimize the streaming service to minimize interruptions and buffering.'] +['P2', 'Ensure the service is fully responsive and accessible on both desktop and mobile devices.'] +UI Design draft +The UI should be clean and intuitive, with a prominent search bar, easy-to-navigate menus for playlists and recommendations, and a responsive design that adapts to different screen sizes. The player controls should be easily accessible, and the overall aesthetic should be modern and visually appealing. + +Anything UNCLEAR +Currently, all aspects of the project are clear. +""" +CASE4_MUSIC_STREAMING_MEDIA = f"""We have received feedback from users regarding the current music streaming service, stating that they need better personalized recommendations. Please readjust the content of PRD {OLD_PRD} based on these feedback.""" +CASE5_SMART_BIG_SCREEN = """分析2024年上半年中国家庭智能大屏行业的发展情况并输出市场分析报告""" +CASE6_ELECTRONIC_CIGARETTE = """我想要生产一个电子烟产品,请帮我完成市场调研分析报告""" + + +def main(): + cases = [ + # CASE0_WRITE_2048, + # CASE1_GREEDY_SNAKE, + # CASE2_SMART_HOME, + # CASE3_BEST_SELLING_REFRIGERATOR, + # CASE4_MUSIC_STREAMING_MEDIA, + CASE5_SMART_BIG_SCREEN, + # CASE6_ELECTRONIC_CIGARETTE, + ] + root_path = "/tmp" + logger.remove() + logger.add(sys.stderr, level="INFO") + for case in cases: + case += f"\nroot path: '{root_path}'" + logger.info(f"user requirement:\n{case}") + try: + product_manager = ProductManager() + asyncio.run(product_manager.run(case)) + except Exception as e: + print(e) + + +if __name__ == "__main__": + main() diff --git a/tests/metagpt/roles/di/run_project_manager.py b/tests/metagpt/roles/di/run_project_manager.py new file mode 100644 index 0000000000..30889c59c6 --- /dev/null +++ b/tests/metagpt/roles/di/run_project_manager.py @@ -0,0 +1,36 @@ +import asyncio +import os + +from metagpt.roles.project_manager import ProjectManager +from metagpt.schema import Message + +DESIGN_DOC_2048 = '{"Implementation approach":"We will use the Pygame library to implement the 2048 game logic and user interface. Pygame is a set of Python modules designed for writing video games, which will help us create a responsive and visually appealing UI. For the mobile responsiveness, we will ensure that the game scales appropriately on different screen sizes. We will also use the Pygame GUI library to create buttons for restarting the game and choosing difficulty levels.","File list":["main.py","game.py","ui.py"],"Data structures and interfaces":"\\nclassDiagram\\n class Game {\\n -grid: list[list[int]]\\n -score: int\\n +__init__()\\n +move(direction: str) bool\\n +merge() bool\\n +spawn_tile() None\\n +is_game_over() bool\\n +reset() None\\n }\\n class UI {\\n -game: Game\\n +__init__(game: Game)\\n +draw_grid() None\\n +draw_score() None\\n +draw_buttons() None\\n +handle_input() None\\n }\\n class Main {\\n -ui: UI\\n +main() None\\n }\\n Main --> UI\\n UI --> Game\\n","Program call flow":"\\nsequenceDiagram\\n participant M as Main\\n participant U as UI\\n participant G as Game\\n M->>U: __init__(game)\\n U->>G: __init__()\\n M->>U: draw_grid()\\n U->>G: move(direction)\\n G-->>U: return bool\\n U->>G: merge()\\n G-->>U: return bool\\n U->>G: spawn_tile()\\n G-->>U: return None\\n U->>G: is_game_over()\\n G-->>U: return bool\\n U->>G: reset()\\n G-->>U: return None\\n M->>U: draw_score()\\n M->>U: draw_buttons()\\n M->>U: handle_input()\\n","Anything UNCLEAR":"Clarification needed on the specific design elements for the UI to ensure it meets the \'beautiful\' requirement. Additionally, we need to confirm the exact difficulty levels and how they should affect the game mechanics."}' +DESIGN_DOC_SNAKE = """ +{ + "Implementation approach": "We will use the Pygame library to create the CLI-based snake game. Pygame is a set of Python modules designed for writing video games, which will help us handle graphics, sound, and input. The game will be structured into different modules to handle the main game loop, snake movement, food generation, collision detection, and user interface. We will ensure the game is engaging and responsive by optimizing the game loop and input handling. The score display and different speed levels will be implemented to enhance the user experience.", + "File list": [ + "main.py", + "game.py", + "snake.py", + "food.py", + "ui.py" + ], + "Data structures and interfaces": "\nclassDiagram\n class Main {\n +main() void\n }\n class Game {\n -Snake snake\n -Food food\n -int score\n -int speed\n +__init__(speed: int)\n +run() void\n +restart() void\n +update_score() void\n }\n class Snake {\n -list body\n -str direction\n +__init__()\n +move() void\n +change_direction(new_direction: str) void\n +check_collision() bool\n +grow() void\n }\n class Food {\n -tuple position\n +__init__()\n +generate_new_position() void\n }\n class UI {\n +display_score(score: int) void\n +display_game_over() void\n +display_game(snake: Snake, food: Food) void\n }\n Main --> Game\n Game --> Snake\n Game --> Food\n Game --> UI\n", + "Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant U as UI\n M->>G: __init__(speed)\n M->>G: run()\n G->>S: __init__()\n G->>F: __init__()\n loop Game Loop\n G->>S: move()\n G->>S: check_collision()\n alt Collision Detected\n G->>G: restart()\n G->>U: display_game_over()\n else No Collision\n G->>F: generate_new_position()\n G->>S: grow()\n G->>G: update_score()\n G->>U: display_score(score)\n end\n G->>U: display_game(snake, food)\n end\n", + "Anything UNCLEAR": "Currently, all aspects of the project are clear." +} +""" +REQ = """Write a project schedule based on the design at temp_design.json""" +CASUAL_CHAT = """what's your name?""" + + +async def main(requirement): + with open("temp_design.json", "w") as f: + f.write(DESIGN_DOC_2048) + project_manager = ProjectManager() + await project_manager.run(Message(content=requirement, send_to="Eve")) + os.remove("temp_design.json") + + +if __name__ == "__main__": + asyncio.run(main(REQ)) diff --git a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py new file mode 100644 index 0000000000..5ceba6dcc7 --- /dev/null +++ b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py @@ -0,0 +1,228 @@ +import argparse +import asyncio +import json +import os +import shutil +import sys +from datetime import datetime +from pathlib import Path + +from metagpt.config2 import Config +from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT +from metagpt.logs import logger +from metagpt.roles.di.engineer2 import Engineer2 +from metagpt.tools.libs.editor import Editor +from metagpt.tools.libs.terminal import Terminal +from metagpt.tools.swe_agent_commands.swe_agent_utils import load_hf_dataset + +config = Config.default() +# Specify by yourself +TEST_REPO_DIR = METAGPT_ROOT / "data" / "test_repo" +DATA_DIR = METAGPT_ROOT / "data/hugging_face" + +INSTANCE_TEMPLATE = """ +## User Requirement +Fix the bug in the repo. Because the environment is not available, you DO NOT need to run and modify any existing test case files or add new test case files to ensure that the bug is fixed. + +We're currently solving the following issue within our repository. You can use any bash commands or the special interface to help you. Here's the issue and hints text: +## ISSUE +{issue} + +## HINTS +hints text is the comment under issue: +{hints_text} + +The repository may already exist at the path `{repo_path}`. If it doesn't, please download the repository to this path. +Your first action must be to navigate to the repository path `{repo_path}`. +This issue occurred in version {version}, with the corresponding base commit being {base_commit}. You need to switch to the code version associated with this commit. +All subsequent actions must be performed within this repository path. Do not leave this directory to execute any actions at any time. + +# INSTRUCTIONS: +Now, you're going to solve this issue on your own from the perspective of a programmer. Your terminal session has started and you're in the repository's root directory. You can use any bash commands or the special interface to help you. Edit all the files you need. +Remember, YOU CAN ONLY ENTER ONE COMMAND AT A TIME. You should always wait for feedback after every command. +""" + + +def check_instance_status(instance, swe_result_dir): + output_file = swe_result_dir / "all_preds.jsonl" + res = True + # 先检查all_preds.jsonl文件是否存在 + if not output_file.exists(): + return res + with open(output_file, "r") as fp: + for line in fp: + existing_instance = json.loads(line.strip()) + if existing_instance["instance_id"] == instance["instance_id"]: + return False + return True + + +async def terminal_run_command(cmd, terminal): + cmd_output = await terminal.run_command(cmd) + logger.info(f"command:{cmd} output:\n {cmd_output}") + return cmd_output + + +async def refresh_repo(instance, test_repo_dir, reclone_existing_repo=False): + terminal = Terminal() + try: + repo_path = Path(test_repo_dir) / ( + instance["repo"].replace("-", "_").replace("/", "__") + "_" + instance["version"] + ) + repo_identifier = instance["repo"] + base_commit = instance["base_commit"] + if os.path.exists(repo_path) and reclone_existing_repo is True: + logger.info(f"remove exist repo path:{repo_path.absolute()}") + shutil.rmtree(repo_path) + if os.path.exists(repo_path): + logger.info(f"reset exist repo path:{repo_path.absolute()}") + for cmd in [ + f"cd {repo_path.absolute()}", + "git reset --hard && git clean -n -d && git clean -f -d", + "BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')", + 'git checkout "$BRANCH"', + "git branch", + "pwd", + ]: + await terminal_run_command(cmd, terminal) + else: + logger.info(f"clone repo to path:{repo_path}") + for cmd in [ + f"git clone 'https://github.com/{repo_identifier}.git' {repo_path.absolute()}", + f"cd {repo_path.absolute()}" + f" && git checkout -f {base_commit}" if base_commit else "", + "git branch", + "pwd", + ]: + await terminal_run_command(cmd, terminal) + except Exception as e: + logger.warning(e) + finally: + await terminal.close() + return repo_path + + +async def get_git_diff(instance, test_repo_dir): + git_diff = "" + terminal = Terminal() + try: + repo_path = Path(test_repo_dir) / ( + instance["repo"].replace("-", "_").replace("/", "__") + "_" + instance["version"] + ) + # ignore backup file and submit stage + for cmd in [f"cd {repo_path.absolute()} ", "echo '.backup.*' >> .gitignore", "git add -A"]: + await terminal_run_command(cmd, terminal) + git_diff = await terminal_run_command("git diff --cached", terminal) + except Exception as e: + logger.error(f"Error during submission: {e}") + finally: + await terminal.close() + return git_diff + + +async def run(instance, swe_result_dir, args): + if not check_instance_status(instance, swe_result_dir): + logger.info(f"Instance {instance['instance_id']} already exists, skipping execution.") + return + + # preparation for the repo + logger.info(f"**** Preparing to run {instance['instance_id']}****") + test_repo_dir = args.test_repo_dir + repo_path = await refresh_repo(instance, test_repo_dir, args.reclone_existing_repo) + + user_requirement_and_issue = INSTANCE_TEMPLATE.format( + issue=instance["problem_statement"], + hints_text=instance["hints_text"], + repo_path=repo_path.absolute(), + version=instance["version"], + base_commit=instance["base_commit"], + ) + + logger.info(f"**** Starting to run {instance['instance_id']}****") + logger.info("User Requirement:\n" + user_requirement_and_issue) + try: + editor = Editor(enable_auto_lint=True, working_dir=Path(repo_path)) + engineer = Engineer2(run_eval=True, editor=editor) + await asyncio.wait_for(engineer.run(user_requirement_and_issue), timeout=args.max_wait_time_per_case * 60) + except Exception as e: + logger.warning(f"**** exception lead to end: {instance['instance_id']}****\n\nerror:{e}") + # save the difference of repo + await save_predictions(engineer, instance, test_repo_dir, swe_result_dir) + logger.info(f"**** Finished running {instance['instance_id']}****") + + +async def save_predictions(engineer, instance, test_repo_dir, swe_result_dir): + output_file = swe_result_dir / "all_preds.jsonl" + instance["model_name_or_path"] = engineer.config.llm.model + instance["model_patch"] = await get_git_diff(instance, test_repo_dir) + logger.info(f"'model_patch':\n{instance['model_patch']}") + logger.info(f"Preparing to save predictions to {output_file}") + + # Save the predictions to a JSONL file + with open(output_file, "a+") as fp: + print(json.dumps(instance), file=fp, flush=True) + + logger.info(f"Saved prediction of {instance['instance_id']} to {output_file}") + + +async def async_main(args): + dataset_path = "manna-ai/SWE-bench_Nano" # "princeton-nlp/SWE-bench_Lite" #"manna-ai/SWE-bench_Nano" + dataset = load_hf_dataset(dataset_name_or_path=dataset_path, cache_dir=DATA_DIR, split="test") + swe_result_dir = Path(args.save_folder) + if swe_result_dir.exists(): + logger.info(f"{swe_result_dir} exists; resuming test from last checkpoint.") + swe_result_dir.mkdir(parents=True, exist_ok=True) + for index, instance in enumerate(dataset): + # switch to a new logger file + logger.remove() + logger.add(sys.stderr, level="INFO") + logger.add(swe_result_dir / "logs" / f"{index+1}_{instance['instance_id']}.log", level="DEBUG") + await run(instance, swe_result_dir, args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="the argument of scripts") + # 添加参数 + swe_result_dir = ( + DEFAULT_WORKSPACE_ROOT + / f"result_{config.llm.model.replace('/', '_')}_start_time_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S') }" + ) + test_repo_dir = TEST_REPO_DIR.absolute() + swe_result_dir = swe_result_dir.absolute() + parser.add_argument( + "-rw", "--test_repo_dir", default=test_repo_dir, help="The directory to save temporary repositories", type=str + ) + parser.add_argument("-s", "--save_folder", default=swe_result_dir, help="Folder to save results and logs", type=str) + parser.add_argument( + "-mwtc", + "--max_wait_time_per_case", + default=10, + help="Maximum wait time allowed per test case (in minutes)", + type=int, + ) + parser.add_argument( + "-o", + "--reclone_existing_repo", + action="store_true", + help="If set, the existing repository will be removed and recloned.", + ) + # 解析命令行参数 + args = parser.parse_args() + asyncio.run(async_main(args)) + + +""" +# +python tests/metagpt/roles/di/run_swe_agent_for_benchmark.py \ +--test_repo_dir "./data/test_repo" \ +--save_folder "./workspace/deepseek_coder_0907" \ +--max_wait_time_per_case 10 +""" + +""" +# 重新克隆仓库 +python tests/metagpt/roles/di/run_swe_agent_for_benchmark.py \ +--test_repo_dir "./data/test_repo" \ +--save_folder "./workspace/deepseek_coder_0907" \ +--max_wait_time_per_case 10 \ +--reclone_existing_repo +""" diff --git a/tests/metagpt/roles/di/run_swe_agent_open_source_issue.py b/tests/metagpt/roles/di/run_swe_agent_open_source_issue.py new file mode 100644 index 0000000000..ec87dd7e2c --- /dev/null +++ b/tests/metagpt/roles/di/run_swe_agent_open_source_issue.py @@ -0,0 +1,44 @@ +import asyncio + +from metagpt.logs import logger +from metagpt.roles.di.swe_agent import SWEAgent + +FIX_ISSUE1 = """ +Write a fix for this issue: https://github.com/langchain-ai/langchain/issues/20453, +you can fix it on this repo https://github.com/garylin2099/langchain +""" +# + "checkout a branch named test-fix, commit your changes, push, +# and create a PR to the master branch of https://github.com/iorisa/langchain" +# """ +FIX_ISSUE2 = """ +Write a fix for this issue https://github.com/geekan/MetaGPT/issues/1275. +You can fix it on the v0.8-release branch of this repo https://github.com/garylin2099/MetaGPT +""" +# + "during fixing, checkout a branch named test-fix-1275, commit your changes, push, +# and create a PR to the v0.8-release branch of https://github.com/garylin2099/MetaGPT" + +FIX_ISSUE3 = """ +Write a fix for this issue https://github.com/geekan/MetaGPT/issues/1262. +You can fix it on this repo https://github.com/garylin2099/MetaGPT +""" +# during fixing, checkout a branch named test-fix-1262, commit your changes, push, +# and create a PR to https://github.com/garylin2099/MetaGPT +# """ +FIX_ISSUE_SIMPLE = """ +Write a fix for this issue: https://github.com/mannaandpoem/simple_calculator/issues/1, +you can fix it on this repo https://github.com/garylin2099/simple_calculator +""" +# checkout a branch named test, commit your changes, push, and create a PR to the master branch of original repo. +# """ + + +NO_ENV_TIP = """ +Because the environment is not available, you DO NOT need to run and modify any existing test case files or +add new test case files to ensure that the bug is fixed. +""" +if __name__ == "__main__": + swe_agent = SWEAgent() + logger.info("**** Starting run ****") + user_requirement_and_issue = FIX_ISSUE1 + NO_ENV_TIP + asyncio.run(swe_agent.run(user_requirement_and_issue)) + logger.info("**** Finished running ****") diff --git a/tests/metagpt/roles/di/test_data_analyst.py b/tests/metagpt/roles/di/test_data_analyst.py new file mode 100644 index 0000000000..0f285ecd7b --- /dev/null +++ b/tests/metagpt/roles/di/test_data_analyst.py @@ -0,0 +1,21 @@ +import pytest + +from metagpt.const import TEST_DATA_PATH +from metagpt.roles.di.data_analyst import DataAnalyst + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "filename"), [("similarity search about '有哪些需求描述?' in document ", TEST_DATA_PATH / "requirements/2.pdf")] +) +async def test_similarity_search(query, filename): + di = DataAnalyst() + query += f"'{str(filename)}'" + + rsp = await di.run(query) + assert rsp + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/roles/di/test_routing.py b/tests/metagpt/roles/di/test_routing.py new file mode 100644 index 0000000000..0cd94e5716 --- /dev/null +++ b/tests/metagpt/roles/di/test_routing.py @@ -0,0 +1,154 @@ +import asyncio + +from metagpt.environment.mgx.mgx_env import MGXEnv +from metagpt.logs import logger +from metagpt.roles import Architect, ProductManager, ProjectManager +from metagpt.roles.di.data_analyst import DataAnalyst +from metagpt.roles.di.engineer2 import Engineer2 +from metagpt.roles.di.team_leader import TeamLeader +from metagpt.schema import Message + +NORMAL_QUESTION = [ + "create a 2048 game", + "write a snake game", + "Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard.", + "print statistic summary of sklearn iris dataset", + "Run data analysis on sklearn Wine recognition dataset, and train a model to predict wine class (20% as validation), and show validation accuracy.", + """ + Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, + and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables* + """, + """ + Get products data from website https://scrapeme.live/shop/ and save it as a csv file. + The first page product name, price, product URL, and image URL must be saved in the csv;** + """, + """ + Write a fix for this issue: https://github.com/langchain-ai/langchain/issues/20453, + you can fix it on this repo https://github.com/garylin2099/langchain, + checkout a branch named test-fix, commit your changes, push, and create a PR to the master branch of https://github.com/iorisa/langchain + """, + "Open this link and make a sumamry: https://github.com/geekan/MetaGPT", # should not confuse with searching + "请查看这个网页https://platform.openai.com/docs/models", # should not confuse with searching +] + + +SEARCH_QUESTION = [ + "今天的天气怎样?", + "全球智能手机市场份额排名是什么?前三名的品牌各占多少百分比?", + "中国股市上市公司数量是多少?", + "奥运会将在哪里举行?有哪些新增的比赛项目?", + "最近一周全球原油价格的走势如何?", + "当前全球碳排放量最大的三个国家是哪些?", + "当前全球碳排放量最大的三个国家各占多少比例", + "最新的全球教育质量排名中,前五名的国家是哪些?", + "当前全球最大的几家电动汽车制造商是哪些?", + "奥运会的开幕式是什么时候", + "Recommend some gyms near Shenzhen University", + "Which university tops QS ranking?", + "Which university tops QS ranking this year?", + "The stock price of Nvidia?", + # longer questions + "请为我查找位于深圳大学附近1000米范围内,价格适中(性价比最高),且晚上关门时间晚于22:00的健身房。", + "When is the Olympic football final this year, where will it be held, and where can I buy tickets? If possible, please provide me with a link to buy tickets", + "Help me search for Inter Miami CF home games in the next 2 months and give me the link to buy tickets", +] + + +QUICK_QUESTION = [ + ## general knowledge qa, logical, math ## + """Who is the first man landing on Moon""", + """In DNA adenine normally pairs with: A. cytosine. B. guanine. C. thymine. D. uracil. Answer:""", + """________________ occur(s) where there is no prior history of exchange and no future exchanges are expected between a buyer and seller. A. Relationship marketing. B. Service mix. C. Market exchanges. D. Service failure. Answer:""", + """Within American politics, the power to accord official recognition to other countries belongs to A. the Senate. B. the president. C. the Secretary of State. D. the chairman of the Joint Chiefs. Answer:""", + """Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.""", + """True or false? Statement 1 | A ring homomorphism is one to one if and only if the kernel is {{0}},. Statement 2 | Q is an ideal in R""", + """Jean has 30 lollipops. Jean eats 2 of the lollipops. With the remaining lollipops, Jean wants to package 2 lollipops in one bag. How many bags can Jean fill?""", + """Alisa biked 12 miles per hour for 4.5 hours. Stanley biked at 10 miles per hour for 2.5 hours. How many miles did Alisa and Stanley bike in total?""", + ## function filling (humaneval) ## + """ + def has_close_elements(numbers: List[float], threshold: float) -> bool: + ''' Check if in given list of numbers, are any two numbers closer to each other than + given threshold. + >>> has_close_elements([1.0, 2.0, 3.0], 0.5) + False + >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) + True + ''' + """, + """ + def is_palindrome(string: str) -> bool: + ''' Test if given string is a palindrome ''' + return string == string[::-1] + + + def make_palindrome(string: str) -> str: + ''' Find the shortest palindrome that begins with a supplied string. + Algorithm idea is simple: + - Find the longest postfix of supplied string that is a palindrome. + - Append to the end of the string reverse of a string prefix that comes before the palindromic suffix. + >>> make_palindrome('') + '' + >>> make_palindrome('cat') + 'catac' + >>> make_palindrome('cata') + 'catac' + ''' + """, + # casual chat + """What's your name?""", + "Who are you", + "What can you do", + "Hi", + "1+1", + # programming-related but not requiring software development SOP + "请写一个python入门教程", + "python里的装饰器是怎么用的,给我个例子", + "写一个java的hello world程序", +] + + +async def test_routing_acc(): + role = TeamLeader() + + env = MGXEnv() + env.add_roles( + [ + role, + ProductManager(), + Architect(), + ProjectManager(), + Engineer2(), + DataAnalyst(), + ] + ) + + for q in QUICK_QUESTION: + msg = Message(content=q) + role.put_message(msg) + await role._observe() + rsp, intent_result = await role._quick_think() + role.rc.memory.clear() + if "YES" not in intent_result: + logger.error(f"Quick question failed: {q}") + + for q in SEARCH_QUESTION: + msg = Message(content=q) + role.put_message(msg) + await role._observe() + rsp, intent_result = await role._quick_think() + role.rc.memory.clear() + if "SEARCH" not in intent_result: + logger.error(f"Search question failed: {q}") + + for q in NORMAL_QUESTION: + msg = Message(content=q) + role.put_message(msg) + await role._observe() + rsp, intent_result = await role._quick_think() + role.rc.memory.clear() + if "NO" not in intent_result: + logger.error(f"Normal question failed: {q}") + + +if __name__ == "__main__": + asyncio.run(test_routing_acc()) diff --git a/tests/metagpt/roles/di/test_team_leader.py b/tests/metagpt/roles/di/test_team_leader.py new file mode 100644 index 0000000000..1b33a6edc1 --- /dev/null +++ b/tests/metagpt/roles/di/test_team_leader.py @@ -0,0 +1,169 @@ +import pytest + +from metagpt.environment.mgx.mgx_env import MGXEnv +from metagpt.roles import ( + Architect, + Engineer, + ProductManager, + ProjectManager, + QaEngineer, +) +from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.roles.di.team_leader import TeamLeader +from metagpt.schema import Message + + +@pytest.fixture +def env(): + test_env = MGXEnv() + tl = TeamLeader() + da = DataInterpreter( + name="David", + profile="Data Analyst", + goal="Take on any data-related tasks, such as data analysis, machine learning, deep learning, web browsing, web scraping, web searching, web deployment, terminal operation, git operation, etc.", + react_mode="react", + ) + test_env.add_roles( + [ + tl, + ProductManager(), + Architect(), + ProjectManager(), + Engineer(n_borg=5, use_code_review=True), + QaEngineer(), + da, + ] + ) + return test_env + + +@pytest.mark.asyncio +async def test_plan_for_software_requirement(env): + requirement = "create a 2048 game" + + tl = env.get_role("Team Leader") + env.publish_message(Message(content=requirement, send_to=tl.name)) + await tl.run() + + # TL should assign tasks to 5 members first, then send message to the first assignee, 6 commands in total + assert len(tl.commands) == 6 + plan_cmd = tl.commands[:5] + route_cmd = tl.commands[5] + + task_assignment = [task["args"]["assignee"] for task in plan_cmd] + assert task_assignment == [ + ProductManager().name, + Architect().name, + ProjectManager().name, + Engineer().name, + QaEngineer().name, + ] + + assert route_cmd["command_name"] == "publish_message" + assert route_cmd["args"]["send_to"] == ProductManager().name + + +@pytest.mark.asyncio +async def test_plan_for_data_related_requirement(env): + requirement = "I want to use yolov5 for target detection, yolov5 all the information from the following link, please help me according to the content of the link (https://github.com/ultralytics/yolov5), set up the environment and download the model parameters, and finally provide a few pictures for inference, the inference results will be saved!" + + tl = env.get_role("Team Leader") + env.publish_message(Message(content=requirement, send_to=tl.name)) + await tl.run() + + # TL should assign 1 task to Data Analyst and send message to it + assert len(tl.commands) == 2 + plan_cmd = tl.commands[0] + route_cmd = tl.commands[-1] + + da = env.get_role("Data Analyst") + assert plan_cmd["command_name"] == "append_task" + assert plan_cmd["args"]["assignee"] == da.name + + assert route_cmd["command_name"] == "publish_message" + assert "https://github.com" in route_cmd["args"]["content"] # necessary info must be in the message + assert route_cmd["args"]["send_to"] == da.name + + +@pytest.mark.asyncio +async def test_plan_for_mixed_requirement(env): + requirement = "Search the web for the new game 2048X, then replicate it" + + tl = env.get_role("Team Leader") + env.publish_message(Message(content=requirement, send_to=tl.name)) + await tl.run() + + # TL should assign 6 tasks, first to Data Analyst to search the web, following by the software team sequence + # TL should send message to Data Analyst after task assignment + assert len(tl.commands) == 7 + plan_cmd = tl.commands[:6] + route_cmd = tl.commands[-1] + + task_assignment = [task["args"]["assignee"] for task in plan_cmd] + da = env.get_role("Data Analyst") + assert task_assignment == [ + da.name, + ProductManager().name, + Architect().name, + ProjectManager().name, + Engineer().name, + QaEngineer().name, + ] + + assert route_cmd["command_name"] == "publish_message" + assert route_cmd["args"]["send_to"] == da.name + + +PRD_MSG_CONTENT = """{'docs': {'20240424153821.json': {'root_path': 'docs/prd', 'filename': '20240424153821.json', 'content': '{"Language":"en_us","Programming Language":"Python","Original Requirements":"create a 2048 game","Project Name":"game_2048","Product Goals":["Develop an intuitive and addictive 2048 game variant","Ensure the game is accessible and performs well on various devices","Design a visually appealing and modern user interface"],"User Stories":["As a player, I want to be able to undo my last move so I can correct mistakes","As a player, I want to see my high scores to track my progress over time","As a player, I want to be able to play the game without any internet connection"],"Competitive Analysis":["2048 Original: Classic gameplay, minimalistic design, lacks social sharing features","2048 Hex: Unique hexagon board, but not mobile-friendly","2048 Multiplayer: Offers real-time competition, but overwhelming ads","2048 Bricks: Innovative gameplay with bricks, but poor performance on older devices","2048.io: Multiplayer battle royale mode, but complicated UI for new players","2048 Animated: Animated tiles add fun, but the game consumes a lot of battery","2048 3D: 3D version of the game, but has a steep learning curve"],"Competitive Quadrant Chart":"quadrantChart\\n title \\"User Experience and Feature Set of 2048 Games\\"\\n x-axis \\"Basic Features\\" --> \\"Rich Features\\"\\n y-axis \\"Poor Experience\\" --> \\"Great Experience\\"\\n quadrant-1 \\"Need Improvement\\"\\n quadrant-2 \\"Feature-Rich but Complex\\"\\n quadrant-3 \\"Simplicity with Poor UX\\"\\n quadrant-4 \\"Balanced\\"\\n \\"2048 Original\\": [0.2, 0.7]\\n \\"2048 Hex\\": [0.3, 0.4]\\n \\"2048 Multiplayer\\": [0.6, 0.5]\\n \\"2048 Bricks\\": [0.4, 0.3]\\n \\"2048.io\\": [0.7, 0.4]\\n \\"2048 Animated\\": [0.5, 0.6]\\n \\"2048 3D\\": [0.6, 0.3]\\n \\"Our Target Product\\": [0.8, 0.9]","Requirement Analysis":"The game must be engaging and retain players, which requires a balance of simplicity and challenge. Accessibility on various devices is crucial for a wider reach. A modern UI is needed to attract and retain the modern user. The ability to play offline is important for users on the go. High score tracking and the ability to undo moves are features that will enhance user experience.","Requirement Pool":[["P0","Implement core 2048 gameplay mechanics"],["P0","Design responsive UI for multiple devices"],["P1","Develop undo move feature"],["P1","Integrate high score tracking system"],["P2","Enable offline gameplay capability"]],"UI Design draft":"The UI will feature a clean and modern design with a minimalist color scheme. The game board will be center-aligned with smooth tile animations. Score and high score will be displayed at the top. Undo and restart buttons will be easily accessible. The design will be responsive to fit various screen sizes.","Anything UNCLEAR":"The monetization strategy for the game is not specified. Further clarification is needed on whether the game should include advertisements, in-app purchases, or be completely free."}'}}}""" +DESIGN_CONTENT = """{"docs":{"20240424214432.json":{"root_path":"docs/system_design","filename":"20240424214432.json","content":"{\\"Implementation approach\\":\\"We will develop the 2048 game using Python, leveraging the pygame library for rendering the game interface and handling user input. This library is suitable for creating games and is widely used in the open-source community. We will ensure that the game logic is separated from the UI code to maintain a clean architecture. The game will be designed to be responsive and accessible on both desktop and mobile devices using scalable dimensions and touch-friendly controls.\\",\\"File list\\":[\\"main.py\\",\\"game.py\\",\\"ui.py\\",\\"constants.py\\",\\"logic.py\\"],\\"Data structures and interfaces\\":\\"\\\\nclassDiagram\\\\n class Main {\\\\n +main() void\\\\n }\\\\n class Game {\\\\n -UI ui\\\\n -Logic logic\\\\n +start_game() void\\\\n +restart_game() void\\\\n }\\\\n class UI {\\\\n -current_score int\\\\n -high_score int\\\\n +draw_board(board: list) void\\\\n +update_score(score: int) void\\\\n +show_game_over() void\\\\n }\\\\n class Logic {\\\\n -board list\\\\n -score int\\\\n +move(direction: str) bool\\\\n +check_game_over() bool\\\\n +get_current_score() int\\\\n +get_high_score() int\\\\n +reset_game() void\\\\n }\\\\n class Constants {\\\\n +BOARD_SIZE int\\\\n +INITIAL_TILES int\\\\n }\\\\n Main --> Game\\\\n Game --> UI\\\\n Game --> Logic\\\\n\\",\\"Program call flow\\":\\"\\\\nsequenceDiagram\\\\n participant M as Main\\\\n participant G as Game\\\\n participant UI as UI\\\\n participant L as Logic\\\\n M->>G: start_game()\\\\n loop Game Loop\\\\n G->>UI: draw_board(board)\\\\n G->>L: move(direction)\\\\n alt if move successful\\\\n L-->>G: return true\\\\n G->>UI: update_score(score)\\\\n else if move not successful\\\\n L-->>G: return false\\\\n end\\\\n G->>L: check_game_over()\\\\n alt if game over\\\\n L-->>G: return true\\\\n G->>UI: show_game_over()\\\\n G->>G: restart_game()\\\\n else\\\\n L-->>G: return false\\\\n end\\\\n end\\\\n\\",\\"Anything UNCLEAR\\":\\"Clarification needed on the specific touch-friendly controls for mobile devices and how they will be implemented using pygame.\\"}"}}}""" + + +@pytest.mark.asyncio +async def test_plan_update_and_routing(env): + requirement = "create a 2048 game" + + tl = env.get_role("Team Leader") + env.publish_message(Message(content=requirement)) + await tl.run() + + # Assuming Product Manager finishes its task + env.publish_message(Message(content=PRD_MSG_CONTENT, role="Alice(Product Manager)", sent_from="Alice")) + await tl.run() + + # TL should mark current task as finished, and forward Product Manager's message to Architect + # Current task should be updated to the second task + plan_cmd = tl.commands[0] + route_cmd = tl.commands[-1] + assert plan_cmd["command_name"] == "finish_current_task" + assert route_cmd["command_name"] == "publish_message" + assert route_cmd["args"]["send_to"] == Architect().name + assert tl.planner.plan.current_task_id == "2" + + # Next step, assuming Architect finishes its task + env.publish_message(Message(content=DESIGN_CONTENT, role="Bob(Architect)", sent_from="Bob")) + await tl.run() + plan_cmd = tl.commands[0] + route_cmd = tl.commands[-1] + assert plan_cmd["command_name"] == "finish_current_task" + assert route_cmd["command_name"] == "publish_message" + assert route_cmd["args"]["send_to"] == ProjectManager().name + assert tl.planner.plan.current_task_id == "3" + + +@pytest.mark.asyncio +async def test_reply_to_human(env): + requirement = "create a 2048 game" + + tl = env.get_role("Team Leader") + env.publish_message(Message(content=requirement)) + await tl.run() + + # Assuming Product Manager finishes its task + env.publish_message(Message(content=PRD_MSG_CONTENT, role="Alice(Product Manager)", sent_from="Alice")) + await tl.run() + + # Human inquires about the progress + env.publish_message(Message(content="Who is working? How does the project go?")) + await tl.run() + + assert tl.commands[0]["command_name"] == "reply_to_human" diff --git a/tests/metagpt/roles/test_engineer.py b/tests/metagpt/roles/test_engineer.py index d263a8a2fe..d5eae662f8 100644 --- a/tests/metagpt/roles/test_engineer.py +++ b/tests/metagpt/roles/test_engineer.py @@ -91,7 +91,7 @@ def test_parse_file_list(): def test_parse_code(): - code = CodeParser.parse_code("Task list", TASKS, lang="python") + code = CodeParser.parse_code(block="Task list", text=TASKS, lang="python") logger.info(code) assert isinstance(code, str) assert target_code == code diff --git a/tests/metagpt/roles/test_product_manager.py b/tests/metagpt/roles/test_product_manager.py index 59b5aa81a5..9a28a32961 100644 --- a/tests/metagpt/roles/test_product_manager.py +++ b/tests/metagpt/roles/test_product_manager.py @@ -10,7 +10,6 @@ import pytest from metagpt.actions import WritePRD -from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.const import REQUIREMENT_FILENAME from metagpt.context import Context from metagpt.logs import logger @@ -30,12 +29,8 @@ async def test_product_manager(new_filename): rsp = await product_manager.run(MockMessages.req) assert context.git_repo assert context.repo - assert rsp.cause_by == any_to_str(PrepareDocuments) - assert REQUIREMENT_FILENAME in context.repo.docs.changed_files - - # write prd - rsp = await product_manager.run(rsp) assert rsp.cause_by == any_to_str(WritePRD) + assert REQUIREMENT_FILENAME in context.repo.docs.changed_files logger.info(rsp) assert len(rsp.content) > 0 doc = list(rsp.instruct_content.docs.values())[0] diff --git a/tests/metagpt/test_document.py b/tests/metagpt/test_document.py index 9c076f4e69..29393bb137 100644 --- a/tests/metagpt/test_document.py +++ b/tests/metagpt/test_document.py @@ -5,10 +5,12 @@ @Author : alexanderwu @File : test_document.py """ -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.document import Repo from metagpt.logs import logger +config = Config.default() + def set_existing_repo(path): repo1 = Repo.from_path(path) diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index 7559655d36..522013804e 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -11,17 +11,45 @@ import pytest from metagpt.actions import UserRequirement +from metagpt.actions.prepare_documents import PrepareDocuments +from metagpt.context import Context from metagpt.environment import Environment from metagpt.logs import logger -from metagpt.roles import Architect, ProductManager, Role -from metagpt.schema import Message +from metagpt.roles import ( + Architect, + Engineer, + ProductManager, + ProjectManager, + QaEngineer, + Role, +) +from metagpt.schema import Message, UserMessage +from metagpt.utils.common import any_to_str, is_send_to serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage") +class MockEnv(Environment): + def publish_message(self, message: Message, peekable: bool = True) -> bool: + logger.info(f"{message.metadata}:{message.content}") + consumers = [] + for role, addrs in self.member_addrs.items(): + if is_send_to(message, addrs): + role.put_message(message) + consumers.append(role) + if not consumers: + logger.warning(f"Message no recipients: {message.dump()}") + if message.cause_by in [any_to_str(UserRequirement), any_to_str(PrepareDocuments)]: + assert len(consumers) == 1 + + return True + + @pytest.fixture def env(): - return Environment() + context = Context() + context.kwargs.tag = __file__ + return MockEnv(context=context) def test_add_role(env: Environment): @@ -54,10 +82,57 @@ async def test_publish_and_process_message(env: Environment): env.add_roles([product_manager, architect]) - env.publish_message(Message(role="User", content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement)) + env.publish_message(UserMessage(content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement, send_to=product_manager)) await env.run(k=2) - logger.info(f"{env.history=}") - assert len(env.history) > 10 + logger.info(f"{env.history}") + assert len(env.history.storage) == 0 + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("content", "send_to"), + [ + ("snake game", any_to_str(ProductManager)), + ( + "Rewrite the PRD file of the project at '/Users/iorishinier/github/MetaGPT/workspace/snake_game', add 'moving enemy' to the original requirement", + any_to_str(ProductManager), + ), + ( + "Add 'random moving enemy, and dispears after 10 seconds' design to the project at '/Users/iorishinier/github/MetaGPT/workspace/snake_game'", + any_to_str(Architect), + ), + ( + 'Rewrite the tasks file of the project at "/Users/iorishinier/github/MetaGPT/workspace/snake_game"', + any_to_str(ProjectManager), + ), + ( + "src filename is 'game.py', Uncaught SyntaxError: Identifier 'Position' has already been declared (at game.js:1:1), the project at '/Users/iorishinier/github/bak/MetaGPT/workspace/snake_game'", + any_to_str(Engineer), + ), + ( + "Rewrite the unit test of 'main.py' at '/Users/iorishinier/github/MetaGPT/workspace/snake_game'", + any_to_str(QaEngineer), + ), + ], +) +async def test_env(content, send_to): + context = Context() + env = MockEnv(context=context) + env.add_roles( + [ + ProductManager(context=context), + Architect(context=context), + ProjectManager(context=context), + Engineer(n_borg=5, use_code_review=True, context=context), + QaEngineer(context=context, test_round_allowed=2), + ] + ) + msg = UserMessage(content=content, send_to=send_to) + env.publish_message(msg) + while not env.is_idle: + await env.run() + pass if __name__ == "__main__": diff --git a/tests/metagpt/test_reporter.py b/tests/metagpt/test_reporter.py new file mode 100644 index 0000000000..41d9634487 --- /dev/null +++ b/tests/metagpt/test_reporter.py @@ -0,0 +1,182 @@ +import ast +from contextlib import asynccontextmanager + +import aiohttp.web +import pytest + +from metagpt.logs import log_llm_stream +from metagpt.utils.report import ( + END_MARKER_NAME, + BlockType, + BrowserReporter, + DocsReporter, + EditorReporter, + NotebookReporter, + ServerReporter, + TaskReporter, + TerminalReporter, +) + + +class MockFileLLM: + def __init__(self, data: str): + self.data = data + + async def aask(self, *args, **kwargs) -> str: + for i in self.data.splitlines(keepends=True): + log_llm_stream(i) + log_llm_stream("\n") + return self.data + + +@asynccontextmanager +async def callback_server(http_server): + callback_data = [] + + async def handler(request): + callback_data.append(await request.json()) + return aiohttp.web.json_response({}) + + server, url = await http_server(handler) + yield url, callback_data + await server.stop() + + +@pytest.mark.asyncio +async def test_terminal_report(http_server): + async with callback_server(http_server) as (url, callback_data): + async with TerminalReporter(callback_url=url) as reporter: + await reporter.async_report("ls -a", "cmd") + await reporter.async_report("main.py\n", "output") + await reporter.async_report("setup.py\n", "output") + assert all(BlockType.TERMINAL is BlockType(i["block"]) for i in callback_data) + assert all(i["uuid"] == callback_data[0]["uuid"] for i in callback_data[1:]) + assert "".join(i["value"] for i in callback_data if i["name"] != END_MARKER_NAME) == "ls -amain.py\nsetup.py\n" + + +@pytest.mark.asyncio +async def test_browser_report(http_server): + img = b"\x89PNG\r\n\x1a\n\x00\x00" + web_url = "https://docs.deepwisdom.ai" + + class AsyncPage: + async def screenshot(self): + return img + + async with callback_server(http_server) as (url, callback_data): + async with BrowserReporter(callback_url=url) as reporter: + await reporter.async_report(web_url, "url") + await reporter.async_report(AsyncPage(), "page") + + assert all(BlockType.BROWSER is BlockType(i["block"]) for i in callback_data) + assert all(i["uuid"] == callback_data[0]["uuid"] for i in callback_data[1:]) + assert len(callback_data) == 3 + assert callback_data[-1]["name"] == END_MARKER_NAME + assert callback_data[0]["name"] == "url" + assert callback_data[0]["value"] == web_url + assert callback_data[1]["name"] == "page" + assert ast.literal_eval(callback_data[1]["value"]) == img + + +@pytest.mark.asyncio +async def test_server_reporter(http_server): + local_url = "http://127.0.0.1:8080/index.html" + async with callback_server(http_server) as (url, callback_data): + reporter = ServerReporter(callback_url=url) + await reporter.async_report(local_url) + assert all(BlockType.BROWSER_RT is BlockType(i["block"]) for i in callback_data) + assert len(callback_data) == 1 + assert callback_data[0]["name"] == "local_url" + assert callback_data[0]["value"] == local_url + assert not callback_data[0]["is_chunk"] + + +@pytest.mark.asyncio +async def test_task_reporter(http_server): + task = {"current_task_id": "", "tasks": []} + async with callback_server(http_server) as (url, callback_data): + reporter = TaskReporter(callback_url=url) + await reporter.async_report(task) + + assert all(BlockType.TASK is BlockType(i["block"]) for i in callback_data) + assert len(callback_data) == 1 + assert callback_data[0]["name"] == "object" + assert callback_data[0]["value"] == task + + +@pytest.mark.asyncio +async def test_notebook_reporter(http_server): + code = { + "cell_type": "code", + "execution_count": None, + "id": "e1841c44", + "metadata": {}, + "outputs": [], + "source": ["\n", "import time\n", "print('will sleep 1s.')\n", "time.sleep(1)\n", "print('end.')\n", ""], + } + output1 = {"name": "stdout", "output_type": "stream", "text": ["will sleep 1s.\n"]} + output2 = {"name": "stdout", "output_type": "stream", "text": ["will sleep 1s.\n"]} + code_path = "/data/main.ipynb" + async with callback_server(http_server) as (url, callback_data): + async with NotebookReporter(callback_url=url) as reporter: + await reporter.async_report(code, "content") + await reporter.async_report(output1, "content") + await reporter.async_report(output2, "content") + await reporter.async_report(code_path, "path") + + assert all(BlockType.NOTEBOOK is BlockType(i["block"]) for i in callback_data) + assert len(callback_data) == 5 + assert callback_data[-1]["name"] == END_MARKER_NAME + assert callback_data[-2]["name"] == "path" + assert callback_data[-2]["value"] == code_path + assert all(i["uuid"] == callback_data[0]["uuid"] for i in callback_data[1:]) + assert [i["value"] for i in callback_data if i["name"] == "content"] == [code, output1, output2] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("data", "file_path", "meta", "block", "report_cls"), + ( + ( + "## Language\n\nen_us\n\n## Programming Language\n\nPython\n\n## Original Requirements\n\nCreate a 2048 gam...", + "/data/prd.md", + {"type": "write_prd"}, + BlockType.DOCS, + DocsReporter, + ), + ( + "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nprint('Hello World')\n", + "/data/main.py", + {"type": "write_code"}, + BlockType.EDITOR, + EditorReporter, + ), + ), + ids=["test_docs_reporter", "test_editor_reporter"], +) +async def test_llm_stream_reporter(data, file_path, meta, block, report_cls, http_server): + async with callback_server(http_server) as (url, callback_data): + async with report_cls(callback_url=url, enable_llm_stream=True) as reporter: + await reporter.async_report(meta, "meta") + await MockFileLLM(data).aask("") + await reporter.wait_llm_stream_report() + await reporter.async_report(file_path, "path") + assert callback_data + assert all(block is BlockType(i["block"]) for i in callback_data) + assert all(i["uuid"] == callback_data[0]["uuid"] for i in callback_data[1:]) + chunks, names = [], set() + for i in callback_data: + name = i["name"] + names.add(name) + if name == "meta": + assert i["value"] == meta + elif name == "path": + assert i["value"] == file_path + elif name == END_MARKER_NAME: + pass + elif name == "content": + chunks.append(i["value"]) + else: + raise ValueError + assert "".join(chunks[:-1]) == data + assert names == {"meta", "path", "content", END_MARKER_NAME} diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 22f6ae9fbe..bc2bdd02a5 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -9,13 +9,15 @@ """ import json +from typing import Annotated import pytest +from pydantic import BaseModel, Field from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode -from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, CodeSummarizeContext, @@ -23,6 +25,7 @@ Message, MessageQueue, Plan, + SerializationMixin, SystemMessage, Task, UMLClassAttribute, @@ -350,5 +353,112 @@ def test_update_current_task(self): assert plan.current_task_id == "2" +@pytest.mark.parametrize( + ("content", "key_descriptions"), + [ + ( + """ +Traceback (most recent call last): + File "/Users/iorishinier/github/MetaGPT/workspace/game_2048_1/game_2048/main.py", line 38, in + Main().main() + File "/Users/iorishinier/github/MetaGPT/workspace/game_2048_1/game_2048/main.py", line 28, in main + self.user_interface.draw() + File "/Users/iorishinier/github/MetaGPT/workspace/game_2048_1/game_2048/user_interface.py", line 16, in draw + if grid[i][j] != 0: +TypeError: 'Grid' object is not subscriptable + """, + { + "filename": "the string type of the path name of the source code where the bug resides", + "line": "the integer type of the line error occurs", + "function_name": "the string type of the function name the error occurs in", + "code": "the string type of the codes where the error occurs at", + "info": "the string type of the error information", + }, + ), + ( + "将代码提交到github上的iorisa/repo1的branch1分支,发起pull request ,合并到master分支。", + { + "repo_name": "the string type of github repo to create pull", + "head": "the string type of github branch to be pushed", + "base": "the string type of github branch to merge the changes into", + }, + ), + ], +) +async def test_parse_resources(context, content: str, key_descriptions): + msg = Message(content=content) + llm = context.llm_with_cost_manager_from_llm_config(context.config.llm) + result = await msg.parse_resources(llm=llm, key_descriptions=key_descriptions) + assert result + assert result.get("resources") + for k in key_descriptions.keys(): + assert k in result + + +@pytest.mark.parametrize(("name", "value"), [("c1", {"age": 10, "name": "Alice"}), ("", {"path": __file__})]) +def test_create_instruct_value(name, value): + obj = Message.create_instruct_value(kvs=value, class_name=name) + assert obj.model_dump() == value + + +class TestUserModel(SerializationMixin, BaseModel): + name: str + value: int + + +class TestUserModelWithExclude(TestUserModel): + age: Annotated[int, Field(exclude=True)] + + +class TestSerializationMixin: + @pytest.fixture + def mock_write_json_file(self, mocker): + return mocker.patch("metagpt.schema.write_json_file") + + @pytest.fixture + def mock_read_json_file(self, mocker): + return mocker.patch("metagpt.schema.read_json_file") + + @pytest.fixture + def mock_user_model(self): + return TestUserModel(name="test", value=42) + + def test_serialize(self, mock_write_json_file, mock_user_model): + file_path = "test.json" + + mock_user_model.serialize(file_path) + + mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump()) + + def test_deserialize(self, mock_read_json_file): + file_path = "test.json" + data = {"name": "test", "value": 42} + mock_read_json_file.return_value = data + + model = TestUserModel.deserialize(file_path) + + mock_read_json_file.assert_called_once_with(file_path) + assert model == TestUserModel(**data) + + def test_serialize_with_exclude(self, mock_write_json_file): + model = TestUserModelWithExclude(name="test", value=42, age=10) + file_path = "test.json" + + model.serialize(file_path) + + expected_data = { + "name": "test", + "value": 42, + "__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude", + } + + mock_write_json_file.assert_called_once_with(file_path, expected_data) + + def test_get_serialization_path(self): + expected_path = str(SERDESER_PATH / "TestUserModel.json") + + assert TestUserModel.get_serialization_path() == expected_path + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_browser.py b/tests/metagpt/tools/libs/test_browser.py new file mode 100644 index 0000000000..ec0b5c848c --- /dev/null +++ b/tests/metagpt/tools/libs/test_browser.py @@ -0,0 +1,90 @@ +import pytest + +from metagpt.const import TEST_DATA_PATH +from metagpt.tools.libs.browser import Browser, get_scroll_position + +TEST_URL = "https://docs.deepwisdom.ai/main/en/guide/get_started/quickstart.html" + +TEST_SCREENSHOT_PATH = TEST_DATA_PATH / "screenshot.png" + + +@pytest.fixture(autouse=True) +def llm_mock(rsp_cache, mocker, request): + # An empty fixture to overwrite the global llm_mock fixture + # because in provider folder, we want to test the aask and aask functions for the specific models + pass + + +@pytest.fixture +def browser(): + browser_instance = Browser() + yield browser_instance + + +@pytest.mark.asyncio +async def test_open_and_switch_page(browser): + await browser.start() + + await browser.open_new_page("https://baidu.com") + await browser.open_new_page("https://tencent.com") + assert browser.current_page_url == "https://tencent.com" + await browser.switch_page("https://baidu.com") + assert browser.current_page_url == "https://baidu.com" + + await browser.close() + + +@pytest.mark.asyncio +async def test_search(browser): + await browser.start() + + # search all + await browser.open_new_page(TEST_URL) + search_term = "startup example" + search_results = await browser.search_content_all(search_term) + print(search_results) + # expected search result as of 20240410: + # [{'index': 0, 'content': {'text_block': 'Below is a breakdown of the software startup example. If you install MetaGPT with the git clone approach, simply run', 'links': [{'text': 'software startup example', 'href': 'https://github.com/geekan/MetaGPT/blob/main/metagpt/software_company.py'}]}, 'position': {'from_top': 640, 'from_left': 225}, 'element_obj': selector='text=startup example >> nth=0'>}] + first_result = search_results[0]["content"] + assert "software startup example" in first_result["text_block"] + assert first_result["links"] + assert first_result["links"][0]["href"] == "https://github.com/geekan/MetaGPT/blob/main/metagpt/software_company.py" + assert search_results[0]["position"] + + # scroll to search result + await browser.scroll_to_search_result(search_results, index=0) + + await browser.close() + + +# @pytest.mark.asyncio +# async def test_find_links(browser): +# await browser.start() + +# await browser.open_new_page(TEST_URL) +# link_info = await browser.find_links() +# assert link_info + +# await browser.close() + + +@pytest.mark.asyncio +async def test_scroll(browser): + await browser.start() + + await browser.open_new_page(TEST_URL) + + await browser.scroll_current_page(offset=-500) + assert await get_scroll_position(browser.current_page) == {"x": 0, "y": 0} # no change if you scrol up from top + initial_view = await browser._view() + + await browser.scroll_current_page(offset=500) # scroll down + assert await get_scroll_position(browser.current_page) == {"x": 0, "y": 500} + scrolled_view = await browser._view() + + assert initial_view != scrolled_view + + await browser.scroll_current_page(offset=-200) # scroll up + assert await get_scroll_position(browser.current_page) == {"x": 0, "y": 300} + + await browser.close() diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py new file mode 100644 index 0000000000..8d6e923afd --- /dev/null +++ b/tests/metagpt/tools/libs/test_editor.py @@ -0,0 +1,832 @@ +import os +import shutil +from pathlib import Path + +import pytest + +from metagpt.const import TEST_DATA_PATH +from metagpt.tools.libs.editor import Editor +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + CHATS_ROOT, + DEFAULT_MIN_TOKEN_COUNT, + UPLOAD_ROOT, + IndexRepo, +) +from metagpt.utils.common import list_files + +TEST_FILE_CONTENT = """ +# this is line one +def test_function_for_fm(): + "some docstring" + a = 1 + b = 2 + c = 3 + # this is the 7th line +""".strip() + +WINDOW = 200 + + +@pytest.fixture +def temp_file_path(tmp_path): + assert tmp_path is not None + temp_file_path = tmp_path / "a.txt" + yield temp_file_path + temp_file_path.unlink() + + +@pytest.fixture +def temp_py_file(tmp_path): + assert tmp_path is not None + temp_file_path = tmp_path / "test_script_for_editor.py" + temp_file_path.write_text(TEST_FILE_CONTENT) + yield temp_file_path + temp_file_path.unlink() + + +@pytest.fixture +def empty_file(tmp_path): + assert tmp_path is not None + temp_file_path = tmp_path / "test_script_empty_file_for_editor.py" + temp_file_path.write_text("") + yield temp_file_path + temp_file_path.unlink() + + +EXPECTED_CONTENT_AFTER_REPLACE = """ +# this is line one +def test_function_for_fm(): + # This is the new line A replacing lines 3 to 5. + # This is the new line B. + c = 3 + # this is the 7th line +""".strip() + + +def test_replace_content(temp_py_file): + editor = Editor() + editor._edit_file_impl( + file_name=temp_py_file, + start=3, + end=5, + content=" # This is the new line A replacing lines 3 to 5.\n # This is the new line B.", + is_insert=False, + is_append=False, + ) + with open(temp_py_file, "r") as f: + new_content = f.read() + assert new_content.strip() == EXPECTED_CONTENT_AFTER_REPLACE.strip() + + +EXPECTED_CONTENT_AFTER_DELETE = """ +# this is line one +def test_function_for_fm(): + + c = 3 + # this is the 7th line +""".strip() + + +def test_delete_content(temp_py_file): + editor = Editor() + editor._edit_file_impl( + file_name=temp_py_file, + start=3, + end=5, + content="", + is_insert=False, + is_append=False, + ) + with open(temp_py_file, "r") as f: + new_content = f.read() + assert new_content.strip() == EXPECTED_CONTENT_AFTER_DELETE.strip() + + +EXPECTED_CONTENT_AFTER_INSERT = """ +# this is line one +def test_function_for_fm(): + # This is the new line to be inserted, at line 3 + "some docstring" + a = 1 + b = 2 + c = 3 + # this is the 7th line +""".strip() + + +def test_insert_content(temp_py_file): + editor = Editor(enable_auto_lint=True) + editor.insert_content_at_line( + file_name=temp_py_file, + line_number=3, + insert_content=" # This is the new line to be inserted, at line 3", + ) + with open(temp_py_file, "r") as f: + new_content = f.read() + assert new_content.strip() == EXPECTED_CONTENT_AFTER_INSERT.strip() + + +@pytest.mark.parametrize( + "filename", + [ + TEST_DATA_PATH / "requirements/1.txt", + TEST_DATA_PATH / "requirements/1.json", + TEST_DATA_PATH / "requirements/1.constraint.md", + TEST_DATA_PATH / "requirements/pic/1.png", + TEST_DATA_PATH / "docx_for_test.docx", + TEST_DATA_PATH / "requirements/2.pdf", + TEST_DATA_PATH / "audio/hello.mp3", + TEST_DATA_PATH / "code/python/1.py", + TEST_DATA_PATH / "code/js/1.js", + TEST_DATA_PATH / "ui/1b.png.html", + TEST_DATA_PATH / "movie/trailer.mp4", + ], +) +@pytest.mark.asyncio +async def test_read_files(filename): + editor = Editor() + file_block = await editor.read(filename) + assert file_block + assert file_block.file_path + if filename.suffix not in [".png", ".mp3", ".mp4"]: + assert file_block.block_content + + +def _numbered_test_lines(start, end) -> str: + return ("\n".join(f"{i}|" for i in range(start, end + 1))) + "\n" + + +def _generate_test_file_with_lines(temp_path, num_lines) -> str: + file_path = temp_path / "test_file.py" + file_path.write_text("\n" * num_lines) + return file_path + + +def _generate_ruby_test_file_with_lines(temp_path, num_lines) -> str: + file_path = temp_path / "test_file.rb" + file_path.write_text("\n" * num_lines) + return file_path + + +def _calculate_window_bounds(current_line, total_lines, window_size): + half_window = window_size // 2 + if current_line - half_window < 0: + start = 1 + end = window_size + else: + start = current_line - half_window + end = current_line + half_window + return start, end + + +def test_open_file_unexist_path(): + editor = Editor() + with pytest.raises(FileNotFoundError): + editor.open_file("/unexist/path/a.txt") + + +def test_open_file(temp_file_path): + editor = Editor() + temp_file_path.write_text("Line 1\nLine 2\nLine 3\nLine 4\nLine 5") + + result = editor.open_file(str(temp_file_path)) + + assert result is not None + expected = ( + f"[File: {temp_file_path} (5 lines total)]\n" + "(this is the beginning of the file)\n" + "001|Line 1\n" + "002|Line 2\n" + "003|Line 3\n" + "004|Line 4\n" + "005|Line 5\n" + "(this is the end of the file)" + ) + assert result.split("\n") == expected.split("\n") + + +def test_open_file_with_indentation(temp_file_path): + editor = Editor() + temp_file_path.write_text("Line 1\n Line 2\nLine 3\nLine 4\nLine 5") + + result = editor.open_file(str(temp_file_path)) + assert result is not None + expected = ( + f"[File: {temp_file_path} (5 lines total)]\n" + "(this is the beginning of the file)\n" + "001|Line 1\n" + "002| Line 2\n" + "003|Line 3\n" + "004|Line 4\n" + "005|Line 5\n" + "(this is the end of the file)" + ) + assert result.split("\n") == expected.split("\n") + + +def test_open_file_long(temp_file_path): + editor = Editor() + content = "\n".join([f"Line {i}" for i in range(1, 1001)]) + temp_file_path.write_text(content) + + result = editor.open_file(str(temp_file_path), 1, 50) + assert result is not None + expected = f"[File: {temp_file_path} (1000 lines total)]\n" + expected += "(this is the beginning of the file)\n" + for i in range(1, 51): + expected += f"{i:03d}|Line {i}\n" + expected += "(950 more lines below)" + assert result.split("\n") == expected.split("\n") + + +def test_open_file_long_with_lineno(temp_file_path): + editor = Editor() + content = "\n".join([f"Line {i}" for i in range(1, 1001)]) + temp_file_path.write_text(content) + + cur_line = 300 + + result = editor.open_file(str(temp_file_path), cur_line) + assert result is not None + expected = f"[File: {temp_file_path} (1000 lines total)]\n" + start, end = _calculate_window_bounds(cur_line, 1000, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i:03d}|Line {i}\n" + if end == 1000: + expected += "(this is the end of the file)\n" + else: + expected += f"({1000 - end} more lines below)" + assert result.split("\n") == expected.split("\n") + + +def test_create_file_unexist_path(): + editor = Editor() + with pytest.raises(FileNotFoundError): + editor.create_file("/unexist/path/a.txt") + + +@pytest.mark.asyncio +async def test_create_file(temp_file_path): + editor = Editor() + result = await editor.create_file(str(temp_file_path)) + + expected = f"[File {temp_file_path} created.]" + assert result.split("\n") == expected.split("\n") + + +def test_goto_line(temp_file_path): + editor = Editor() + total_lines = 1000 + content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) + temp_file_path.write_text(content) + + result = editor.open_file(str(temp_file_path)) + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + expected += "(this is the beginning of the file)\n" + for i in range(1, WINDOW + 1): + expected += f"{i:03d}|Line {i}\n" + expected += f"({total_lines - WINDOW} more lines below)" + assert result.split("\n") == expected.split("\n") + + result = editor.goto_line(500) + + assert result is not None + + cur_line = 500 + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(cur_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i:03d}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)" + assert result.split("\n") == expected.split("\n") + + +def test_goto_line_negative(temp_file_path): + editor = Editor() + content = "\n".join([f"Line {i}" for i in range(1, 5)]) + temp_file_path.write_text(content) + + editor.open_file(str(temp_file_path)) + with pytest.raises(ValueError): + editor.goto_line(-1) + + +def test_goto_line_out_of_bound(temp_file_path): + editor = Editor() + content = "\n".join([f"Line {i}" for i in range(1, 5)]) + temp_file_path.write_text(content) + + editor.open_file(str(temp_file_path)) + with pytest.raises(ValueError): + editor.goto_line(100) + + +def test_scroll_down(temp_file_path): + editor = Editor() + total_lines = 1000 + content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) + temp_file_path.write_text(content) + result = editor.open_file(str(temp_file_path)) + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(1, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i:03d}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)" + else: + expected += f"({total_lines - end} more lines below)" + assert result.split("\n") == expected.split("\n") + + result = editor.scroll_down() + + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(WINDOW + 1, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i:03d}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)" + assert result.split("\n") == expected.split("\n") + + +def test_scroll_up(temp_file_path): + editor = Editor() + total_lines = 1000 + content = "\n".join([f"Line {i}" for i in range(1, total_lines + 1)]) + temp_file_path.write_text(content) + + cur_line = 500 + + result = editor.open_file(str(temp_file_path), cur_line) + assert result is not None + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(cur_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i:03d}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)" + + assert result.split("\n") == expected.split("\n") + result = editor.scroll_up() + assert result is not None + + cur_line = cur_line - WINDOW + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(cur_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i:03d}|Line {i}\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)" + print(result) + print(expected) + assert result.split("\n") == expected.split("\n") + + +def test_scroll_down_edge(temp_file_path): + editor = Editor() + content = "\n".join([f"Line {i}" for i in range(1, 10)]) + temp_file_path.write_text(content) + + result = editor.open_file(str(temp_file_path)) + assert result is not None + + expected = f"[File: {temp_file_path} (9 lines total)]\n" + expected += "(this is the beginning of the file)\n" + for i in range(1, 10): + expected += f"{i:03d}|Line {i}\n" + expected += "(this is the end of the file)" + + result = editor.scroll_down() + assert result is not None + + assert result.split("\n") == expected.split("\n") + + +def test_print_window_internal(temp_file_path): + editor = Editor() + editor.create_file(str(temp_file_path)) + with open(temp_file_path, "w") as file: + for i in range(1, 101): + file.write(f"Line `{i}`\n") + + current_line = 50 + window = 2 + + result = editor._print_window(temp_file_path, current_line, window) + expected = "(48 more lines above)\n" "049|Line `49`\n" "050|Line `50`\n" "051|Line `51`\n" "(49 more lines below)" + assert result == expected + + +def test_open_file_large_line_number(temp_file_path): + editor = Editor() + editor.create_file(str(temp_file_path)) + with open(temp_file_path, "w") as file: + for i in range(1, 1000): + file.write(f"Line `{i}`\n") + + current_line = 800 + window = 100 + + result = editor.open_file(str(temp_file_path), current_line, window) + + expected = f"[File: {temp_file_path} (999 lines total)]\n" + expected += "(749 more lines above)\n" + for i in range(750, 850 + 1): + expected += f"{i}|Line `{i}`\n" + expected += "(149 more lines below)" + assert result == expected + + +def test_open_file_large_line_number_consecutive_diff_window(temp_file_path): + editor = Editor() + editor.create_file(str(temp_file_path)) + total_lines = 1000 + with open(temp_file_path, "w") as file: + for i in range(1, total_lines + 1): + file.write(f"Line `{i}`\n") + + current_line = 800 + cur_window = 300 + + result = editor.open_file(str(temp_file_path), current_line, cur_window) + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(current_line, total_lines, cur_window) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(current_line - cur_window // 2, current_line + cur_window // 2 + 1): + expected += f"{i}|Line `{i}`\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)" + assert result == expected + + current_line = current_line - WINDOW + + result = editor.scroll_up() + + expected = f"[File: {temp_file_path} ({total_lines} lines total)]\n" + start, end = _calculate_window_bounds(current_line, total_lines, WINDOW) + if start == 1: + expected += "(this is the beginning of the file)\n" + else: + expected += f"({start - 1} more lines above)\n" + for i in range(start, end + 1): + expected += f"{i}|Line `{i}`\n" + if end == total_lines: + expected += "(this is the end of the file)\n" + else: + expected += f"({total_lines - end} more lines below)" + assert result.split("\n") == expected.split("\n") + + +EXPECTED_CONTENT_AFTER_REPLACE_TEXT = """ +# this is line one +def test_function_for_fm(): + "some docstring" + a = 1 + b = 9 + c = 3 + # this is the 7th line +""".strip() + + +def test_edit_file_by_replace(temp_py_file): + editor = Editor() + editor.edit_file_by_replace( + file_name=str(temp_py_file), + first_replaced_line_number=5, + first_replaced_line_content=" b = 2", + new_content=" b = 9", + last_replaced_line_number=5, + last_replaced_line_content=" b = 2", + ) + with open(temp_py_file, "r") as f: + new_content = f.read() + assert new_content.strip() == EXPECTED_CONTENT_AFTER_REPLACE_TEXT.strip() + + +MISMATCH_ERROR = """ +Error: The `first_replaced_line_number` does not match the `first_replaced_line_content`. Please correct the parameters. +The `first_replaced_line_number` is 5 and the corresponding content is " b = 2". +But the `first_replaced_line_content ` is "". +The content around the specified line is: +The 002 line is "def test_function_for_fm():" +The 003 line is " "some docstring"" +The 004 line is " a = 1" +The 005 line is " b = 2" +The 006 line is " c = 3" +The 007 line is " # this is the 7th line" +Pay attention to the new content. Ensure that it aligns with the new parameters. +Error: The `last_replaced_line_number` does not match the `last_replaced_line_content`. Please correct the parameters. +The `last_replaced_line_number` is 5 and the corresponding content is " b = 2". +But the `last_replaced_line_content ` is "". +The content around the specified line is: +The 002 line is "def test_function_for_fm():" +The 003 line is " "some docstring"" +The 004 line is " a = 1" +The 005 line is " b = 2" +The 006 line is " c = 3" +The 007 line is " # this is the 7th line" +Pay attention to the new content. Ensure that it aligns with the new parameters. +""".strip() + + +def test_edit_file_by_replace_mismatch(temp_py_file): + editor = Editor() + output = editor.edit_file_by_replace( + file_name=str(temp_py_file), + first_replaced_line_number=5, + first_replaced_line_content="", + new_content=" b = 9", + last_replaced_line_number=5, + last_replaced_line_content="", + ) + assert output.strip() == MISMATCH_ERROR.strip() + + +def test_append_file(temp_file_path): + editor = Editor() + # 写入初始内容 + initial_content = "Line 1\nLine 2\nLine 3\n" + temp_file_path.write_text(initial_content) + + # 追加内容到文件 + append_content = "Line 4\nLine 5\n" + + result = editor.append_file(str(temp_file_path), append_content) + + # 预期内容 + expected_content = initial_content + append_content + + # 读取文件并断言内容与预期一致 + with open(temp_file_path, "r") as f: + new_content = f.read() + assert new_content == expected_content + + # 输出的预期结果 + expected_output = ( + f"[File: {temp_file_path.resolve()} (5 lines total after edit)]\n" + "(this is the beginning of the file)\n" + "001|Line 1\n" + "002|Line 2\n" + "003|Line 3\n" + "004|Line 4\n" + "005|Line 5\n" + "(this is the end of the file)\n" + "[File updated (edited at line 3)]." + ) + + assert result.split("\n") == expected_output.split("\n") + + +def test_search_dir(tmp_path): + editor = Editor() + dir_path = tmp_path / "test_dir" + dir_path.mkdir() + + # Create some files with specific content + (dir_path / "file1.txt").write_text("This is a test file with some content.") + (dir_path / "file2.txt").write_text("Another file with different content.") + sub_dir = dir_path / "sub_dir" + sub_dir.mkdir() + (sub_dir / "file3.txt").write_text("This file is inside a sub directory with some content.") + + search_term = "some content" + + result = editor.search_dir(search_term, str(dir_path)) + + assert "file1.txt" in result + assert "file3.txt" in result + assert "Another file with different content." not in result + + +def test_search_dir_in_default_dir(tmp_path): + editor = Editor() + dir_path = editor.working_dir / "test_dir" + dir_path.mkdir(exist_ok=True) + + # Create some files with specific content + (dir_path / "file1.txt").write_text("This is a test file with some content.") + (dir_path / "file2.txt").write_text("Another file with different content.") + sub_dir = dir_path / "sub_dir" + sub_dir.mkdir(exist_ok=True) + (sub_dir / "file3.txt").write_text("This file is inside a sub directory with some content.") + + search_term = "some content" + + result = editor.search_dir(search_term) + + assert "file1.txt" in result + assert "file3.txt" in result + assert "Another file with different content." not in result + + +def test_search_file(temp_file_path): + editor = Editor() + file_path = temp_file_path + file_path.write_text("This is a test file with some content.\nAnother line with more content.") + + search_term = "some content" + + result = editor.search_file(search_term, str(file_path)) + + assert "Line 1: This is a test file with some content." in result + assert "Line 2: Another line with more content." not in result + + +def test_find_file(tmp_path): + editor = Editor() + dir_path = tmp_path / "test_dir" + dir_path.mkdir() + + # Create some files with specific names + (dir_path / "file1.txt").write_text("Content of file 1.") + (dir_path / "file2.txt").write_text("Content of file 2.") + sub_dir = dir_path / "sub_dir" + sub_dir.mkdir() + (sub_dir / "file3.txt").write_text("Content of file 3.") + + file_name = "file1.txt" + + result = editor.find_file(file_name, str(dir_path)) + + assert "file1.txt" in result + assert "file2.txt" not in result + assert "file3.txt" not in result + + +# Test data for _append_impl method +TEST_LINES = ["First line\n", "Second line\n", "Third line\n"] + +NEW_CONTENT = "Appended line\n" + +EXPECTED_APPEND_NON_EMPTY_FILE = ["First line\n", "Second line\n", "Third line\n", "Appended line\n"] + +EXPECTED_APPEND_EMPTY_FILE = ["Appended line\n"] + + +def test_append_non_empty_file(): + editor = Editor() + lines = TEST_LINES.copy() + content, n_added_lines = editor._append_impl(lines, NEW_CONTENT) + + assert content.splitlines(keepends=True) == EXPECTED_APPEND_NON_EMPTY_FILE + assert n_added_lines == 1 + + +def test_append_empty_file(): + editor = Editor() + lines = [] + content, n_added_lines = editor._append_impl(lines, NEW_CONTENT) + + assert content.splitlines(keepends=True) == EXPECTED_APPEND_EMPTY_FILE + assert n_added_lines == 1 + + +def test_append_to_single_empty_line_file(): + editor = Editor() + lines = [""] + content, n_added_lines = editor._append_impl(lines, NEW_CONTENT) + + assert content.splitlines(keepends=True) == EXPECTED_APPEND_EMPTY_FILE + assert n_added_lines == 1 + + +async def mock_index_repo(): + chat_id = "1" + chat_path = Path(CHATS_ROOT) / chat_id + chat_path.mkdir(parents=True, exist_ok=True) + src_path = TEST_DATA_PATH / "requirements" + command = f"cp -rf {str(src_path)} {str(chat_path)}" + os.system(command) + filenames = list_files(chat_path) + chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + chat_repo = IndexRepo( + persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 + ) + await chat_repo.add(chat_files) + assert chat_files + + Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True) + command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" + os.system(command) + filenames = list_files(UPLOAD_ROOT) + uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + assert uploads_files + + filenames = list_files(src_path) + other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + assert other_files + + return chat_path, UPLOAD_ROOT, src_path + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_index_repo(): + # mock data + chat_path, upload_path, src_path = await mock_index_repo() + + editor = Editor() + rsp = await editor.similarity_search(query="业务线", path=chat_path) + assert rsp + rsp = await editor.similarity_search(query="业务线", path=upload_path) + assert rsp + rsp = await editor.similarity_search(query="业务线", path=src_path) + assert rsp + + shutil.rmtree(CHATS_ROOT) + shutil.rmtree(UPLOAD_ROOT) + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "filename"), + [ + ( + "In this document, who are the legal representatives of both parties?", + TEST_DATA_PATH / "pdf/20210709逗你学云豆付费课程协议.pdf", + ), + ( + "What is the short name of the company in this document?", + TEST_DATA_PATH / "pdf/company_stock_code.pdf", + ), + ("平安创新推出中国版的什么模式,将差异化的医疗健康服务与作为支付方的金融业务无缝结合", TEST_DATA_PATH / "pdf/9112674.pdf"), + ( + "What principle is introduced by the author to explain the conditions necessary for the emergence of complexity?", + TEST_DATA_PATH / "pdf/9781444323498.ch2_1.pdf", + ), + ("行高的继承性的代码示例是?", TEST_DATA_PATH / "pdf/02-CSS.pdf"), + ], +) +async def test_similarity_search(query, filename): + filename = Path(filename) + save_to = Path(UPLOAD_ROOT) / filename.name + save_to.parent.mkdir(parents=True, exist_ok=True) + os.system(f"cp {str(filename)} {str(save_to)}") + + editor = Editor() + rsp = await editor.similarity_search(query=query, path=save_to) + assert rsp + + save_to.unlink(missing_ok=True) + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_read(): + editor = Editor() + filename = TEST_DATA_PATH / "pdf/9112674.pdf" + content = await editor.read(str(filename)) + size = filename.stat().st_size + assert "similarity_search" in content.block_content and size > 5 * DEFAULT_MIN_TOKEN_COUNT + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_git.py b/tests/metagpt/tools/libs/test_git.py new file mode 100644 index 0000000000..f200b900ed --- /dev/null +++ b/tests/metagpt/tools/libs/test_git.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import annotations + +import os +import uuid + +import pytest +from github import Auth, Github +from pydantic import BaseModel + +from metagpt.context import Context +from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.schema import UserMessage +from metagpt.tools.libs.git import git_checkout, git_clone +from metagpt.utils.common import awrite +from metagpt.utils.git_repository import GitRepository + + +class SWEBenchItem(BaseModel): + base_commit: str + repo: str + + +async def get_env(key: str, app_name: str = ""): + return os.environ.get(key) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ["url", "commit_id"], [("https://github.com/sqlfluff/sqlfluff.git", "d19de0ecd16d298f9e3bfb91da122734c40c01e5")] +) +@pytest.mark.skip +async def test_git(url: str, commit_id: str): + repo_dir = await git_clone(url) + assert repo_dir + + await git_checkout(repo_dir, commit_id) + + repo = GitRepository(repo_dir, auto_init=False) + repo.delete_repository() + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_login(): + auth = Auth.Login(await get_env("GITHUB_USER"), await get_env("GITHUB_PWD")) + g = Github(auth=auth) + repo = g.get_repo("geekan/MetaGPT") + topics = repo.get_topics() + assert topics + open_issues = repo.get_issues(state="open") + issues = [i for i in open_issues] + assert issues + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_new_issue(): + issue = await GitRepository.create_issue( + repo_name="iorisa/MetaGPT", + title="This is a new issue", + body="This is the issue body", + access_token=await get_env(key="access_token", app_name="github"), + ) + print(issue) + assert issue.number + pass + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_new_pr(): + body = """ + >>> SUMMARY + >>> Change HTTP library used to send requests + >>> + >>> TESTS + >>> - [x] Send 'GET' request + >>> - [x] Send 'POST' request with/without body + """ + pr = await GitRepository.create_pull( + base_repo_name="iorisa/MetaGPT", + base="send18", + head="fixbug/gbk", + title="Test pr", + body=body, + access_token=await get_env(key="access_token", app_name="github"), + ) + print(pr) + assert pr + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_new_pr1(): + body = """ + >>> SUMMARY + >>> Change HTTP library used to send requests + >>> + >>> TESTS + >>> - [x] Send 'GET' request + >>> - [x] Send 'POST' request with/without body + """ + pr = await GitRepository.create_pull( + head_repo_name="iorisa/MetaGPT", + head="fixbug/vscode", + base_repo_name="send18/MetaGPT", + base="dev", + title="Test pr", + body=body, + access_token=await get_env(key="access_token", app_name="github"), + ) + print(pr) + assert pr + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_auth(): + access_token = await get_env(key="access_token", app_name="github") + auth = Auth.Token(access_token) + g = Github(auth=auth) + u = g.get_user() + v = u.get_repos(visibility="public") + a = [i.full_name for i in v] + assert a + print(a) + pass + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_github(context): + repo = await GitRepository.clone_from(url="https://github.com/iorisa/snake-game.git") + content = uuid.uuid4().hex + await awrite(filename=repo.workdir / "README.md", data=content) + branch = await repo.push( + new_branch=f"feature/{content[0:8]}", access_token=await get_env(key="access_token", app_name="github") + ) + pr = await GitRepository.create_pull( + base=branch.base, + head=branch.head, + base_repo_name=branch.repo_name, + title=f"new pull {content[0:8]}", + access_token=await get_env(key="access_token", app_name="github"), + ) + assert pr + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize( + "content", + [ + # "create a new issue to github repo 'iorisa/snake-game' :'The snake did not grow longer after eating'", + "Resolve the issue #1 'Snake not growing longer after eating' in the GitHub repository https://github.com/iorisa/snake-game.git', and create a new pull request about the issue" + ], +) +async def test_git_create_issue(content: str): + context = Context() + di = DataInterpreter(context=context, tools=[""]) + + prerequisite = "from metagpt.tools.libs import get_env" + await di.execute_code.run(code=prerequisite, language="python") + di.put_message(UserMessage(content=content)) + while not di.is_idle: + await di.run() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py new file mode 100644 index 0000000000..aec1e3f5e6 --- /dev/null +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -0,0 +1,55 @@ +import shutil +from pathlib import Path + +import pytest + +from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + UPLOADS_INDEX_ROOT, + IndexRepo, +) + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) +async def test_index_repo(path, query): + index_path = DEFAULT_WORKSPACE_ROOT / ".index" + repo = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) + await repo.add([path]) + await repo.add([path]) + assert index_path.exists() + + rsp = await repo.search(query) + assert rsp + + repo2 = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) + rsp2 = await repo2.search(query) + assert rsp2 + + merged_rsp = await repo.merge(query=query, indices_list=[rsp, rsp2]) + assert merged_rsp + + shutil.rmtree(index_path) + + +@pytest.mark.parametrize( + ("paths", "path_type", "root"), + [ + (["/data/uploads"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/uploads/"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/chats/1/1.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/1/2.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/2/2.txt", "/data/chats/2/2.txt"], str(Path(CHATS_INDEX_ROOT) / "2"), "/data/chats/2"), + (["/data/chats.txt"], "other", ""), + ], +) +def test_classify_path(paths, path_type, root): + result, result_root = IndexRepo.classify_path(paths) + assert path_type in set(result.keys()) + assert root == result_root.get(path_type, "") + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_shell.py b/tests/metagpt/tools/libs/test_shell.py new file mode 100644 index 0000000000..ce25d49b0f --- /dev/null +++ b/tests/metagpt/tools/libs/test_shell.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import pytest + +from metagpt.tools.libs.shell import shell_execute + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ["command", "expect_stdout", "expect_stderr"], + [ + (["file", f"{__file__}"], "Python script text executable, ASCII text", ""), + (f"file {__file__}", "Python script text executable, ASCII text", ""), + ], +) +async def test_shell(command, expect_stdout, expect_stderr): + stdout, stderr = await shell_execute(command) + assert expect_stdout in stdout + assert stderr == expect_stderr + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_software_development.py b/tests/metagpt/tools/libs/test_software_development.py new file mode 100644 index 0000000000..c622583535 --- /dev/null +++ b/tests/metagpt/tools/libs/test_software_development.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from typing import Dict + +import pytest + +from metagpt.tools.libs.software_development import import_git_repo + + +async def get_env_description() -> Dict[str, str]: + return {'await get_env(key="access_token", app_name="github")': "get the access token for github authentication."} + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_import_repo(): + url = "https://github.com/spec-first/connexion.git" + path = await import_git_repo(url) + assert path + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_terminal.py b/tests/metagpt/tools/libs/test_terminal.py new file mode 100644 index 0000000000..9c64009aea --- /dev/null +++ b/tests/metagpt/tools/libs/test_terminal.py @@ -0,0 +1,22 @@ +import pytest + +from metagpt.const import DATA_PATH, METAGPT_ROOT +from metagpt.tools.libs.terminal import Terminal + + +@pytest.mark.asyncio +async def test_terminal(): + terminal = Terminal() + + await terminal.run_command(f"cd {METAGPT_ROOT}") + output = await terminal.run_command("pwd") + assert output.strip() == str(METAGPT_ROOT) + + # pwd now should be METAGPT_ROOT, cd data should land in DATA_PATH + await terminal.run_command("cd data") + output = await terminal.run_command("pwd") + assert output.strip() == str(DATA_PATH) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_azure_tts.py b/tests/metagpt/tools/test_azure_tts.py index f72b5663b1..ee55616d27 100644 --- a/tests/metagpt/tools/test_azure_tts.py +++ b/tests/metagpt/tools/test_azure_tts.py @@ -12,9 +12,11 @@ import pytest from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.tools.azure_tts import AzureTTS +config = Config.default() + @pytest.mark.asyncio async def test_azure_tts(mocker): diff --git a/tests/metagpt/tools/test_metagpt_text_to_image.py b/tests/metagpt/tools/test_metagpt_text_to_image.py index d3797a460e..bd0fcaf8bc 100644 --- a/tests/metagpt/tools/test_metagpt_text_to_image.py +++ b/tests/metagpt/tools/test_metagpt_text_to_image.py @@ -10,9 +10,11 @@ import pytest -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image +config = Config.default() + @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index 8dc9e9d5e7..0f921887fe 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -8,10 +8,12 @@ import pytest -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.llm import LLM from metagpt.tools.moderation import Moderation +config = Config.default() + @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py index 3f9169ddd4..4856342d15 100644 --- a/tests/metagpt/tools/test_openai_text_to_image.py +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -11,7 +11,7 @@ import pytest from pydantic import BaseModel -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.llm import LLM from metagpt.tools.openai_text_to_image import ( OpenAIText2Image, @@ -19,6 +19,8 @@ ) from metagpt.utils.s3 import S3 +config = Config.default() + @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_tool_convert.py b/tests/metagpt/tools/test_tool_convert.py index 061a619cee..5aa53ce4fe 100644 --- a/tests/metagpt/tools/test_tool_convert.py +++ b/tests/metagpt/tools/test_tool_convert.py @@ -2,7 +2,10 @@ import pandas as pd -from metagpt.tools.tool_convert import convert_code_to_tool_schema +from metagpt.tools.tool_convert import ( + convert_code_to_tool_schema, + convert_code_to_tool_schema_ast, +) class DummyClass: @@ -45,6 +48,14 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: pass +class DummySubClass(DummyClass): + """sub class docstring""" + + def sub_method(self, df: pd.DataFrame): + """sub method""" + pass + + def dummy_fn( df: pd.DataFrame, s: str, @@ -114,6 +125,18 @@ def test_convert_code_to_tool_schema_class(): assert schema == expected +def test_convert_code_to_tool_schema_subclass(): + schema = convert_code_to_tool_schema(DummySubClass) + assert "sub_method" in schema["methods"] # sub class method should be included + assert "fit" in schema["methods"] # parent class method should be included + + +def test_convert_code_to_tool_schema_include(): + schema = convert_code_to_tool_schema(DummyClass, include=["fit"]) + assert "fit" in schema["methods"] + assert "transform" not in schema["methods"] + + def test_convert_code_to_tool_schema_function(): expected = { "type": "function", @@ -128,3 +151,91 @@ def test_convert_code_to_tool_schema_function(): def test_convert_code_to_tool_schema_async_function(): schema = convert_code_to_tool_schema(dummy_async_fn) assert schema.get("type") == "async_function" + + +TEST_CODE_FILE_TEXT = ''' +import pandas as pd # imported obj should not be parsed +from some_module1 import some_imported_function, SomeImportedClass # imported obj should not be parsed +from ..some_module2 import some_imported_function2 # relative import should not result in an error + +class MyClass: + """This is a MyClass docstring.""" + def __init__(self, arg1): + """This is the constructor docstring.""" + self.arg1 = arg1 + + def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]: + """ + This is a method docstring. + + Args: + arg2 (Union[list[str], str]): A union of a list of strings and a string. + ... + + Returns: + Tuple[int, str]: A tuple of an integer and a string. + """ + return self.arg4 + arg5 + + async def my_async_method(self, some_arg) -> str: + return "hi" + + def _private_method(self): # private should not be parsed + return "private" + +def my_function(arg1, arg2) -> dict: + """This is a function docstring.""" + return arg1 + arg2 + +def my_async_function(arg1, arg2) -> dict: + return arg1 + arg2 + +def _private_function(): # private should not be parsed + return "private" +''' + + +def test_convert_code_to_tool_schema_ast(): + expected = { + "MyClass": { + "type": "class", + "description": "This is a MyClass docstring.", + "methods": { + "__init__": { + "type": "function", + "description": "This is the constructor docstring.", + "signature": "(self, arg1)", + "parameters": "", + }, + "my_method": { + "type": "function", + "description": "This is a method docstring. ", + "signature": "(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal['a', 'b', 'c'] = 'a') -> Tuple[int, str]", + "parameters": "Args: arg2 (Union[list[str], str]): A union of a list of strings and a string. ... Returns: Tuple[int, str]: A tuple of an integer and a string.", + }, + "my_async_method": { + "type": "async_function", + "description": "", + "signature": "(self, some_arg) -> str", + "parameters": "", + }, + }, + "code": 'class MyClass:\n """This is a MyClass docstring."""\n def __init__(self, arg1):\n """This is the constructor docstring."""\n self.arg1 = arg1\n\n def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:\n """\n This is a method docstring.\n \n Args:\n arg2 (Union[list[str], str]): A union of a list of strings and a string.\n ...\n \n Returns:\n Tuple[int, str]: A tuple of an integer and a string.\n """\n return self.arg4 + arg5\n \n async def my_async_method(self, some_arg) -> str:\n return "hi"\n \n def _private_method(self): # private should not be parsed\n return "private"', + }, + "my_function": { + "type": "function", + "description": "This is a function docstring.", + "signature": "(arg1, arg2) -> dict", + "parameters": "", + "code": 'def my_function(arg1, arg2) -> dict:\n """This is a function docstring."""\n return arg1 + arg2', + }, + "my_async_function": { + "type": "function", + "description": "", + "signature": "(arg1, arg2) -> dict", + "parameters": "", + "code": "def my_async_function(arg1, arg2) -> dict:\n return arg1 + arg2", + }, + } + schemas = convert_code_to_tool_schema_ast(TEST_CODE_FILE_TEXT) + assert schemas == expected diff --git a/tests/metagpt/tools/test_ut_writer.py b/tests/metagpt/tools/test_ut_writer.py index 3cc7e86bbb..3ebbe6d9d9 100644 --- a/tests/metagpt/tools/test_ut_writer.py +++ b/tests/metagpt/tools/test_ut_writer.py @@ -20,10 +20,12 @@ Function, ) -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.const import API_QUESTIONS_PATH, UT_PY_PATH from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator +config = Config.default() + class TestUTWriter: @pytest.mark.asyncio diff --git a/tests/metagpt/utils/test_code_parser.py b/tests/metagpt/utils/test_code_parser.py index 294324b8fa..f4d822f85f 100644 --- a/tests/metagpt/utils/test_code_parser.py +++ b/tests/metagpt/utils/test_code_parser.py @@ -119,7 +119,7 @@ def test_parse_block(self, parser, text): assert "game.py" in result def test_parse_code(self, parser, text): - result = parser.parse_code("Task list", text, "python") + result = parser.parse_code(block="Task list", text=text, lang="python") print(result) assert "game.py" in result diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 75e8ef4adb..b85fe229be 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -29,6 +29,8 @@ awrite, check_cmd_exists, concat_namespace, + extract_and_encode_images, + extract_image_paths, import_class_inst, parse_recipient, print_members, @@ -215,5 +217,23 @@ async def test_read_write_error_charset(self): assert data == content +def test_extract_image_paths(): + content = """ + Here are some image paths /home/user/images/photo1.jpg /home/user/images/photo2.png + # /absolute/path/to/image.gif""" + assert extract_image_paths(content) == [ + "/home/user/images/photo1.jpg", + "/home/user/images/photo2.png", + "/absolute/path/to/image.gif", + ] + + content = "no image path" + assert not extract_image_paths(content) + + +def test_extract_and_encode_images(): + assert not extract_and_encode_images("a non-existing.jpg") + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_di_graph_repository.py b/tests/metagpt/utils/test_di_graph_repository.py index 966aaf1b08..d2d0e2b3c8 100644 --- a/tests/metagpt/utils/test_di_graph_repository.py +++ b/tests/metagpt/utils/test_di_graph_repository.py @@ -62,6 +62,7 @@ class Input(BaseModel): @pytest.mark.asyncio +@pytest.mark.skip async def test_codes(): path = DEFAULT_WORKSPACE_ROOT / "snake_game" repo_parser = RepoParser(base_directory=path) @@ -81,5 +82,13 @@ async def test_codes(): print(data) +@pytest.mark.asyncio +async def test_graph_select(): + gdb_path = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json" + gdb = await DiGraphRepository.load_from(gdb_path) + rows = await gdb.select() + assert rows + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_mermaid.py b/tests/metagpt/utils/test_mermaid.py index 7367463dc5..1fbf060fe7 100644 --- a/tests/metagpt/utils/test_mermaid.py +++ b/tests/metagpt/utils/test_mermaid.py @@ -8,28 +8,32 @@ import pytest -from metagpt.utils.common import check_cmd_exists +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.utils.common import check_cmd_exists, new_transaction_id from metagpt.utils.mermaid import MMC1, mermaid_to_file @pytest.mark.asyncio -@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer -async def test_mermaid(engine, context, mermaid_mocker): +@pytest.mark.parametrize( + ("engine", "suffixes"), [("nodejs", None), ("nodejs", ["png", "svg", "pdf"]), ("ink", None)] +) # TODO: playwright and pyppeteer +async def test_mermaid(engine, suffixes, context, mermaid_mocker): # nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli # ink prerequisites: connected to internet # playwright prerequisites: playwright install --with-deps chromium assert check_cmd_exists("npm") == 0 - save_to = context.git_repo.workdir / f"{engine}/1" - await mermaid_to_file(engine, MMC1, save_to) + save_to = DEFAULT_WORKSPACE_ROOT / f"{new_transaction_id()}/{engine}/1" + await mermaid_to_file(engine, MMC1, save_to, suffixes=suffixes) # ink does not support pdf + exts = ["." + i for i in suffixes] if suffixes else [".png"] if engine == "ink": - for ext in [".svg", ".png"]: + for ext in exts: assert save_to.with_suffix(ext).exists() save_to.with_suffix(ext).unlink(missing_ok=True) else: - for ext in [".pdf", ".svg", ".png"]: + for ext in exts: assert save_to.with_suffix(ext).exists() save_to.with_suffix(ext).unlink(missing_ok=True) diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 7a29ea3ee2..75bd9f1656 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -2,7 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : unittest of repair_llm_raw_output -from metagpt.config2 import config +from metagpt.config2 import Config + +config = Config.default() """ CONFIG.repair_llm_output should be True before retry_parse_json_text imported. diff --git a/tests/metagpt/utils/test_repo_to_markdown.py b/tests/metagpt/utils/test_repo_to_markdown.py index 914c50dd7c..28bdf87b77 100644 --- a/tests/metagpt/utils/test_repo_to_markdown.py +++ b/tests/metagpt/utils/test_repo_to_markdown.py @@ -10,7 +10,12 @@ @pytest.mark.parametrize( ["repo_path", "output"], - [(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")], + [ + ( + Path(__file__).parent.parent.parent, + Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}.md", + ), + ], ) @pytest.mark.asyncio async def test_repo_to_markdown(repo_path: Path, output: Path): diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index c4262e0806..fdbf86825a 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -1,14 +1,17 @@ import json from typing import Optional, Union -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.configs.llm_config import LLMType +from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import logger from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message +config = Config.default() + OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM @@ -22,7 +25,7 @@ def __init__(self, allow_open_api_call): self.rsp_cache: dict = {} self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list - async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str: """Overwrite original acompletion_text to cancel retry""" if stream: resp = await self._achat_completion_stream(messages, timeout=timeout) @@ -37,7 +40,7 @@ async def original_aask( system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, - timeout=3, + timeout=LLM_API_TIMEOUT, stream=True, ) -> str: if system_msgs: @@ -56,7 +59,7 @@ async def original_aask( rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp - async def original_aask_batch(self, msgs: list, timeout=3) -> str: + async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str: """A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked""" context = [] for msg in msgs: @@ -83,8 +86,8 @@ async def aask( system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, - timeout=3, - stream=True, + timeout=LLM_API_TIMEOUT, + stream=False, ) -> str: # used to identify it a message has been called before if isinstance(msg, list): @@ -98,7 +101,7 @@ async def aask( rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream) return rsp - async def aask_batch(self, msgs: list, timeout=3) -> str: + async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str: msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs]) rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout) return rsp