-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
莘权 马
committed
Dec 8, 2023
1 parent
769dc82
commit e677de8
Showing
14 changed files
with
251 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import asyncio | ||
from typing import AsyncGenerator, Awaitable, Callable | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
from metagpt.logs import logger | ||
from metagpt.roles import Role | ||
from metagpt.schema import Message | ||
|
||
|
||
class SubscriptionRunner(BaseModel): | ||
"""A simple wrapper to manage subscription tasks for different roles using asyncio. | ||
Example: | ||
>>> import asyncio | ||
>>> from metagpt.subscription import SubscriptionRunner | ||
>>> from metagpt.roles import Searcher | ||
>>> from metagpt.schema import Message | ||
>>> async def trigger(): | ||
... while True: | ||
... yield Message("the latest news about OpenAI") | ||
... await asyncio.sleep(3600 * 24) | ||
>>> async def callback(msg: Message): | ||
... print(msg.content) | ||
>>> async def main(): | ||
... pb = SubscriptionRunner() | ||
... await pb.subscribe(Searcher(), trigger(), callback) | ||
... await pb.run() | ||
>>> asyncio.run(main()) | ||
""" | ||
|
||
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict) | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
async def subscribe( | ||
self, | ||
role: Role, | ||
trigger: AsyncGenerator[Message, None], | ||
callback: Callable[ | ||
[ | ||
Message, | ||
], | ||
Awaitable[None], | ||
], | ||
): | ||
"""Subscribes a role to a trigger and sets up a callback to be called with the role's response. | ||
Args: | ||
role: The role to subscribe. | ||
trigger: An asynchronous generator that yields Messages to be processed by the role. | ||
callback: An asynchronous function to be called with the response from the role. | ||
""" | ||
loop = asyncio.get_running_loop() | ||
|
||
async def _start_role(): | ||
async for msg in trigger: | ||
resp = await role.run(msg) | ||
await callback(resp) | ||
|
||
self.tasks[role] = loop.create_task(_start_role(), name=f"Subscription-{role}") | ||
|
||
async def unsubscribe(self, role: Role): | ||
"""Unsubscribes a role from its trigger and cancels the associated task. | ||
Args: | ||
role: The role to unsubscribe. | ||
""" | ||
task = self.tasks.pop(role) | ||
task.cancel() | ||
|
||
async def run(self, raise_exception: bool = True): | ||
"""Runs all subscribed tasks and handles their completion or exception. | ||
Args: | ||
raise_exception: _description_. Defaults to True. | ||
Raises: | ||
task.exception: _description_ | ||
""" | ||
while True: | ||
for role, task in self.tasks.items(): | ||
if task.done(): | ||
if task.exception(): | ||
if raise_exception: | ||
raise task.exception() | ||
logger.opt(exception=task.exception()).error(f"Task {task.get_name()} run error") | ||
else: | ||
logger.warning( | ||
f"Task {task.get_name()} has completed. " | ||
"If this is unexpected behavior, please check the trigger function." | ||
) | ||
self.tasks.pop(role) | ||
break | ||
else: | ||
await asyncio.sleep(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.