Skip to content

Commit

Permalink
Fix a initiate chats (#1938)
Browse files Browse the repository at this point in the history
* Fix async a_initiate_chats

* Fix type compatibility for python 3.8

* Use partial func to fix context error

* website/docs/tutorial/assets/conversable-agent.jpg: convert to Git LFS

* Update notebook examples

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Davor Runje <davor@airt.ai>
  • Loading branch information
3 people authored Mar 17, 2024
1 parent 2cefff9 commit 96cbaf7
Show file tree
Hide file tree
Showing 2 changed files with 898 additions and 601 deletions.
79 changes: 59 additions & 20 deletions autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from functools import partial
import logging
from collections import defaultdict
from collections import defaultdict, abc
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass
from .utils import consolidate_chat_info
import datetime
import warnings
from termcolor import colored
from .utils import consolidate_chat_info


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,6 +175,49 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
return finished_chats


def __system_now_str():
ct = datetime.datetime.now()
return f" System time at {ct}. "


def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
"""
Update ChatResult when async Task for Chat is completed.
"""
logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
chat_result = chat_future.result()
chat_result.chat_id = chat_id


async def _dependent_chat_future(
chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future]
) -> asyncio.Task:
"""
Create an async Task for each chat.
"""
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
_chat_carryover = chat_info.get("carryover", [])
finished_chats = dict()
for chat in prerequisite_chat_futures:
chat_future = prerequisite_chat_futures[chat]
if chat_future.cancelled():
raise RuntimeError(f"Chat {chat} is cancelled.")

# wait for prerequisite chat results for the new chat carryover
finished_chats[chat] = await chat_future

if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
chat_res_future.add_done_callback(call_back_with_args)
logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
return chat_res_future


async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
"""(async) Initiate a list of chats.
Expand All @@ -183,31 +228,25 @@ async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatRe
returns:
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""

consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chats = dict()
finished_chat_futures = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
condition = asyncio.Condition()
prerequisite_chat_ids = chat_info.get("prerequisites", [])
async with condition:
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
# Do the actual work here.
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = await sender.a_initiate_chat(**chat_info)
chat_res.chat_id = chat_id
finished_chats[chat_id] = chat_res

pre_chat_futures = dict()
for pre_chat_id in prerequisite_chat_ids:
pre_chat_future = finished_chat_futures[pre_chat_id]
pre_chat_futures[pre_chat_id] = pre_chat_future
current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
finished_chat_futures[chat_id] = current_chat_future
await asyncio.gather(*list(finished_chat_futures.values()))
finished_chats = dict()
for chat in finished_chat_futures:
chat_result = finished_chat_futures[chat].result()
finished_chats[chat] = chat_result
return finished_chats
Loading

0 comments on commit 96cbaf7

Please sign in to comment.