From e677de8e280758f82c7f887b8170466595569acc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Fri, 8 Dec 2023 19:55:47 +0800 Subject: [PATCH] feat: merge geekan:main --- examples/agent_creator.py | 1 + examples/search_kb.py | 22 ++++- metagpt/actions/action.py | 4 +- metagpt/actions/summarize_code.py | 4 +- metagpt/actions/write_code.py | 4 +- metagpt/actions/write_code_review.py | 4 +- metagpt/provider/openai_api.py | 6 +- metagpt/provider/zhipuai_api.py | 4 +- metagpt/roles/__init__.py | 2 +- metagpt/roles/sales.py | 2 +- metagpt/roles/{seacher.py => searcher.py} | 2 +- metagpt/subscription.py | 101 +++++++++++++++++++++ tests/conftest.py | 11 +++ tests/metagpt/test_subscription.py | 102 ++++++++++++++++++++++ 14 files changed, 251 insertions(+), 18 deletions(-) rename metagpt/roles/{seacher.py => searcher.py} (99%) create mode 100644 metagpt/subscription.py create mode 100644 tests/metagpt/test_subscription.py diff --git a/examples/agent_creator.py b/examples/agent_creator.py index e724105a37..05417d24ae 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -49,6 +49,7 @@ def parse_code(rsp): pattern = r"```python(.*)```" match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else "" + CONFIG.workspace_path.mkdir(parents=True, exist_ok=True) with open(CONFIG.workspace_path / "agent_created_agent.py", "w") as f: f.write(code_text) return code_text diff --git a/examples/search_kb.py b/examples/search_kb.py index 0b5d593857..7a9911ca22 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -5,17 +5,35 @@ """ import asyncio +from metagpt.actions import Action from metagpt.const import DATA_PATH from metagpt.document_store import FaissStore from metagpt.logs import logger from metagpt.roles import Sales +from metagpt.schema import Message + +""" example.json, e.g. +[ + { + "source": "Which facial cleanser is good for oily skin?", + "output": "ABC cleanser is preferred by many with oily skin." + }, + { + "source": "Is L'Oreal good to use?", + "output": "L'Oreal is a popular brand with many positive reviews." + } +] +""" async def search(): store = FaissStore(DATA_PATH / "example.json") role = Sales(profile="Sales", store=store) - - queries = ["Which facial cleanser is good for oily skin?", "Is L'Oreal good to use?"] + role._watch({Action}) + queries = [ + Message("Which facial cleanser is good for oily skin?", cause_by=Action), + Message("Is L'Oreal good to use?", cause_by=Action), + ] for query in queries: logger.info(f"User: {query}") result = await role.run(query) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index f8016b8a2a..dc96699a9f 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -9,7 +9,7 @@ from abc import ABC from typing import Optional -from tenacity import retry, stop_after_attempt, wait_fixed +from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM @@ -53,7 +53,7 @@ async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> s system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def _aask_v1( self, prompt: str, diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index d10cd6c553..413ac2a219 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -7,7 +7,7 @@ """ from pathlib import Path -from tenacity import retry, stop_after_attempt, wait_fixed +from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG @@ -92,7 +92,7 @@ class SummarizeCode(Action): def __init__(self, name="SummarizeCode", context=None, llm=None): super().__init__(name, context, llm) - @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) + @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): code_rsp = await self._aask(prompt) return code_rsp diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 9b20843c7b..4c138a1246 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -15,7 +15,7 @@ RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. """ -from tenacity import retry, stop_after_attempt, wait_fixed +from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG @@ -81,7 +81,7 @@ class WriteCode(Action): def __init__(self, name="WriteCode", context=None, llm=None): super().__init__(name, context, llm) - @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: code_rsp = await self._aask(prompt) code = CodeParser.parse_code(block="", text=code_rsp) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index f7c6845d20..e1acc714bb 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -8,7 +8,7 @@ WriteCode object, rather than passing them in when calling the run function. """ -from tenacity import retry, stop_after_attempt, wait_fixed +from tenacity import retry, stop_after_attempt from metagpt.actions.action import Action from metagpt.config import CONFIG @@ -94,7 +94,7 @@ class WriteCodeReview(Action): def __init__(self, name="WriteCodeReview", context=None, llm=None): super().__init__(name, context, llm) - @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) + @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def write_code_review_and_rewrite(self, prompt): code_rsp = await self._aask(prompt) result = CodeParser.parse_block("Code Review Result", code_rsp) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 8ac0c4b21d..a73bb0aa0c 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -15,7 +15,7 @@ retry, retry_if_exception_type, stop_after_attempt, - wait_fixed, + wait_random_exponential, ) from metagpt.config import CONFIG @@ -231,8 +231,8 @@ async def acompletion(self, messages: list[dict]) -> dict: return await self._achat_completion(messages) @retry( - stop=stop_after_attempt(3), - wait=wait_fixed(1), + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index edd9084e36..92119b764d 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -13,7 +13,7 @@ retry, retry_if_exception_type, stop_after_attempt, - wait_fixed, + wait_random_exponential, ) from metagpt.config import CONFIG @@ -122,7 +122,7 @@ async def _achat_completion_stream(self, messages: list[dict]) -> str: @retry( stop=stop_after_attempt(3), - wait=wait_fixed(1), + wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, diff --git a/metagpt/roles/__init__.py b/metagpt/roles/__init__.py index 1768b786c0..f033a5dfa2 100644 --- a/metagpt/roles/__init__.py +++ b/metagpt/roles/__init__.py @@ -12,7 +12,7 @@ from metagpt.roles.product_manager import ProductManager from metagpt.roles.engineer import Engineer from metagpt.roles.qa_engineer import QaEngineer -from metagpt.roles.seacher import Searcher +from metagpt.roles.searcher import Searcher from metagpt.roles.sales import Sales from metagpt.roles.customer_service import CustomerService diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index 18282a494d..d5aac18246 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -28,7 +28,7 @@ def __init__( def _set_store(self, store): if store: - action = SearchAndSummarize("", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.search) + action = SearchAndSummarize("", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) else: action = SearchAndSummarize() self._init_actions([action]) diff --git a/metagpt/roles/seacher.py b/metagpt/roles/searcher.py similarity index 99% rename from metagpt/roles/seacher.py rename to metagpt/roles/searcher.py index 587698d1d8..bee8d39869 100644 --- a/metagpt/roles/seacher.py +++ b/metagpt/roles/searcher.py @@ -3,7 +3,7 @@ """ @Time : 2023/5/23 17:25 @Author : alexanderwu -@File : seacher.py +@File : searcher.py @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ diff --git a/metagpt/subscription.py b/metagpt/subscription.py new file mode 100644 index 0000000000..0d2b308216 --- /dev/null +++ b/metagpt/subscription.py @@ -0,0 +1,101 @@ +import asyncio +from typing import AsyncGenerator, Awaitable, Callable + +from pydantic import BaseModel, Field + +from metagpt.logs import logger +from metagpt.roles import Role +from metagpt.schema import Message + + +class SubscriptionRunner(BaseModel): + """A simple wrapper to manage subscription tasks for different roles using asyncio. + + Example: + >>> import asyncio + >>> from metagpt.subscription import SubscriptionRunner + >>> from metagpt.roles import Searcher + >>> from metagpt.schema import Message + + >>> async def trigger(): + ... while True: + ... yield Message("the latest news about OpenAI") + ... await asyncio.sleep(3600 * 24) + + >>> async def callback(msg: Message): + ... print(msg.content) + + >>> async def main(): + ... pb = SubscriptionRunner() + ... await pb.subscribe(Searcher(), trigger(), callback) + ... await pb.run() + + >>> asyncio.run(main()) + """ + + tasks: dict[Role, asyncio.Task] = Field(default_factory=dict) + + class Config: + arbitrary_types_allowed = True + + async def subscribe( + self, + role: Role, + trigger: AsyncGenerator[Message, None], + callback: Callable[ + [ + Message, + ], + Awaitable[None], + ], + ): + """Subscribes a role to a trigger and sets up a callback to be called with the role's response. + + Args: + role: The role to subscribe. + trigger: An asynchronous generator that yields Messages to be processed by the role. + callback: An asynchronous function to be called with the response from the role. + """ + loop = asyncio.get_running_loop() + + async def _start_role(): + async for msg in trigger: + resp = await role.run(msg) + await callback(resp) + + self.tasks[role] = loop.create_task(_start_role(), name=f"Subscription-{role}") + + async def unsubscribe(self, role: Role): + """Unsubscribes a role from its trigger and cancels the associated task. + + Args: + role: The role to unsubscribe. + """ + task = self.tasks.pop(role) + task.cancel() + + async def run(self, raise_exception: bool = True): + """Runs all subscribed tasks and handles their completion or exception. + + Args: + raise_exception: _description_. Defaults to True. + + Raises: + task.exception: _description_ + """ + while True: + for role, task in self.tasks.items(): + if task.done(): + if task.exception(): + if raise_exception: + raise task.exception() + logger.opt(exception=task.exception()).error(f"Task {task.get_name()} run error") + else: + logger.warning( + f"Task {task.get_name()} has completed. " + "If this is unexpected behavior, please check the trigger function." + ) + self.tasks.pop(role) + break + else: + await asyncio.sleep(1) diff --git a/tests/conftest.py b/tests/conftest.py index 8e4422700c..0cef6a4c93 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,17 @@ async def handle_client(reader, writer): return "http://{}:{}".format(*server.sockets[0].getsockname()) +# see https://github.com/Delgan/loguru/issues/59#issuecomment-466591978 +@pytest.fixture +def loguru_caplog(caplog): + class PropogateHandler(logging.Handler): + def emit(self, record): + logging.getLogger(record.name).handle(record) + + logger.add(PropogateHandler(), format="{message}") + yield caplog + + # init & dispose git repo @pytest.fixture(scope="session", autouse=True) def setup_and_teardown_git_repo(request): diff --git a/tests/metagpt/test_subscription.py b/tests/metagpt/test_subscription.py new file mode 100644 index 0000000000..2e898424dd --- /dev/null +++ b/tests/metagpt/test_subscription.py @@ -0,0 +1,102 @@ +import asyncio + +import pytest + +from metagpt.roles import Role +from metagpt.schema import Message +from metagpt.subscription import SubscriptionRunner + + +@pytest.mark.asyncio +async def test_subscription_run(): + callback_done = 0 + + async def trigger(): + while True: + yield Message("the latest news about OpenAI") + await asyncio.sleep(3600 * 24) + + class MockRole(Role): + async def run(self, message=None): + return Message("") + + async def callback(message): + nonlocal callback_done + callback_done += 1 + + runner = SubscriptionRunner() + + roles = [] + for _ in range(2): + role = MockRole() + roles.append(role) + await runner.subscribe(role, trigger(), callback) + + task = asyncio.get_running_loop().create_task(runner.run()) + + for _ in range(10): + if callback_done == 2: + break + await asyncio.sleep(0) + else: + raise TimeoutError("callback not call") + + role = roles[0] + assert role in runner.tasks + await runner.unsubscribe(roles[0]) + + for _ in range(10): + if role not in runner.tasks: + break + await asyncio.sleep(0) + else: + raise TimeoutError("callback not call") + + task.cancel() + for i in runner.tasks.values(): + i.cancel() + + +@pytest.mark.asyncio +async def test_subscription_run_error(loguru_caplog): + async def trigger1(): + while True: + yield Message("the latest news about OpenAI") + await asyncio.sleep(3600 * 24) + + async def trigger2(): + yield Message("the latest news about OpenAI") + + class MockRole1(Role): + async def run(self, message=None): + raise RuntimeError + + class MockRole2(Role): + async def run(self, message=None): + return Message("") + + async def callback(msg: Message): + print(msg) + + runner = SubscriptionRunner() + await runner.subscribe(MockRole1(), trigger1(), callback) + with pytest.raises(RuntimeError): + await runner.run() + + await runner.subscribe(MockRole2(), trigger2(), callback) + task = asyncio.get_running_loop().create_task(runner.run(False)) + + for _ in range(10): + if not runner.tasks: + break + await asyncio.sleep(0) + else: + raise TimeoutError("wait runner tasks empty timeout") + + task.cancel() + for i in runner.tasks.values(): + i.cancel() + assert len(loguru_caplog.records) >= 2 + logs = "".join(loguru_caplog.messages) + assert "run error" in logs + assert "has completed" in logs