Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Mar 18, 2024
2 parents fe528d0 + cc836de commit 75414d7
Show file tree
Hide file tree
Showing 10 changed files with 1,039 additions and 733 deletions.
111 changes: 75 additions & 36 deletions autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from functools import partial
import logging
from collections import defaultdict
from collections import defaultdict, abc
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass
from .utils import consolidate_chat_info
import datetime
import warnings
from termcolor import colored
from .utils import consolidate_chat_info


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -135,22 +137,22 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""Initiate a list of chats.
Args:
chat_queue (List[Dict]): a list of dictionaries containing the information of the chats.
Each dictionary should contain the input arguments for `ConversableAgent.initiate_chat`.
More specifically, each dictionary could include the following fields:
- "sender": the sender agent.
- "recipient": the recipient agent.
- "clear_history" (bool): whether to clear the chat history with the agent. Default is True.
- "silent" (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
- "cache" (Cache or None): the cache client to be used for this conversation. Default is None.
- "max_turns" (int or None): the maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
- "summary_method" (str or callable): a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- "summary_args" (dict): a dictionary of arguments to be passed to the summary_method. Default is {}.
- "message" (str, callable or None): if None, input() will be called to get the initial message.
- **context: additional context information to be passed to the chat.
- "carryover": It can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
chat_queue (List[Dict]): a list of dictionaries containing the information about the chats.
Each dictionary should contain the input arguments for [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat). For example:
- "sender": the sender agent.
- "recipient": the recipient agent.
- "clear_history" (bool): whether to clear the chat history with the agent. Default is True.
- "silent" (bool or None): (Experimental) whether to print the messages in this conversation. Default is False.
- "cache" (Cache or None): the cache client to use for this conversation. Default is None.
- "max_turns" (int or None): maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
- "summary_method" (str or callable): a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- "summary_args" (dict): a dictionary of arguments to be passed to the summary_method. Default is {}.
- "message" (str, callable or None): if None, input() will be called to get the initial message.
- **context: additional context information to be passed to the chat.
- "carryover": It can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
Returns:
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
Expand All @@ -173,6 +175,49 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
return finished_chats


def __system_now_str():
ct = datetime.datetime.now()
return f" System time at {ct}. "


def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
"""
Update ChatResult when async Task for Chat is completed.
"""
logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
chat_result = chat_future.result()
chat_result.chat_id = chat_id


async def _dependent_chat_future(
chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future]
) -> asyncio.Task:
"""
Create an async Task for each chat.
"""
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
_chat_carryover = chat_info.get("carryover", [])
finished_chats = dict()
for chat in prerequisite_chat_futures:
chat_future = prerequisite_chat_futures[chat]
if chat_future.cancelled():
raise RuntimeError(f"Chat {chat} is cancelled.")

# wait for prerequisite chat results for the new chat carryover
finished_chats[chat] = await chat_future

if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
chat_res_future.add_done_callback(call_back_with_args)
logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
return chat_res_future


async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
"""(async) Initiate a list of chats.
Expand All @@ -183,31 +228,25 @@ async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatRe
returns:
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""

consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chats = dict()
finished_chat_futures = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
condition = asyncio.Condition()
prerequisite_chat_ids = chat_info.get("prerequisites", [])
async with condition:
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
# Do the actual work here.
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = await sender.a_initiate_chat(**chat_info)
chat_res.chat_id = chat_id
finished_chats[chat_id] = chat_res

pre_chat_futures = dict()
for pre_chat_id in prerequisite_chat_ids:
pre_chat_future = finished_chat_futures[pre_chat_id]
pre_chat_futures[pre_chat_id] = pre_chat_future
current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
finished_chat_futures[chat_id] = current_chat_future
await asyncio.gather(*list(finished_chat_futures.values()))
finished_chats = dict()
for chat in finished_chat_futures:
chat_result = finished_chat_futures[chat].result()
finished_chats[chat] = chat_result
return finished_chats
Loading

0 comments on commit 75414d7

Please sign in to comment.