Skip to content

Commit

Permalink
add autogen.initiate_chats (microsoft#1638)
Browse files Browse the repository at this point in the history
* add initiate_chats

* update notebook

* different user

* add notebook

* add link to website

* update notebook title

* remove redundancy

* notebook

* return list

* tag

* update notebook

* update notebook

* return finished tasks

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
  • Loading branch information
2 people authored and gunnarku committed Feb 13, 2024
1 parent 265e8ba commit 5ce3133
Show file tree
Hide file tree
Showing 10 changed files with 1,939 additions and 663 deletions.
5 changes: 5 additions & 0 deletions autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .conversable_agent import ConversableAgent, register_function
from .groupchat import GroupChat, GroupChatManager
from .user_proxy_agent import UserProxyAgent
from .chat import initiate_chats, ChatResult
from .utils import gather_usage_summary

__all__ = (
"Agent",
Expand All @@ -12,4 +14,7 @@
"GroupChat",
"GroupChatManager",
"register_function",
"initiate_chats",
"gather_usage_summary",
"ChatResult",
)
94 changes: 93 additions & 1 deletion autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import logging
from typing import Dict, List
from typing import Dict, List, Any
from dataclasses import dataclass
from .utils import consolidate_chat_info
import warnings

try:
from termcolor import colored
except ImportError:

def colored(x, *args, **kwargs):
return x


logger = logging.getLogger(__name__)

Expand All @@ -17,3 +27,85 @@ class ChatResult:
"""The cost of the chat. a tuple of (total_cost, total_actual_cost), where total_cost is a dictionary of cost information, and total_actual_cost is a dictionary of information on the actual incurred cost with cache."""
human_input: List[str] = None
"""A list of human input solicited during the chat."""


def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""Initiate a list of chats.
args:
chat_queue (List[Dict]): a list of dictionaries containing the information of the chats.
Each dictionary should contain the following fields:
- "recipient": the recipient agent.
- "context": any context information, e.g., the request message. The following fields are reserved:
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- Supported string are "last_msg" and "reflection_with_llm":
when set "last_msg", it returns the last message of the dialog as the summary.
when set "reflection_with_llm", it returns a summary extracted using an llm client.
`llm_config` must be set in either the recipient or sender.
"reflection_with_llm" requires the llm_config to be set in either the sender or the recipient.
- A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g,
```python
def my_summary_method(
sender: ConversableAgent,
recipient: ConversableAgent,
):
return recipient.last_message(sender)["content"]
```
"summary_prompt" can be used to specify the prompt used to extract a summary when summary_method is "reflection_with_llm".
Default is None and the following default prompt will be used when "summary_method" is set to "reflection_with_llm":
"Identify and extract the final solution to the originally asked question based on the conversation."
"carryover" can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
returns:
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
consolidate_chat_info(chat_queue)
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)
current_chat_queue = chat_queue.copy()
finished_chats = []
while current_chat_queue:
chat_info = current_chat_queue.pop(0)
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats]
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
chat_info["recipient"]
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
finished_chats.append(chat_res)
return finished_chats
81 changes: 13 additions & 68 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
extract_code,
infer_lang,
)
from ..agent_utils import gather_usage_summary
from .chat import ChatResult
from .utils import gather_usage_summary, consolidate_chat_info
from .chat import ChatResult, initiate_chats


from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
Expand Down Expand Up @@ -801,7 +801,7 @@ def my_summary_method(
"""
_chat_info = context.copy()
_chat_info["recipient"] = recipient
self._consolidate_chat_info(_chat_info)
consolidate_chat_info(_chat_info, uniform_sender=self)
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
agent.previous_cache = agent.client_cache
Expand Down Expand Up @@ -846,7 +846,7 @@ async def a_initiate_chat(
"""
_chat_info = context.copy()
_chat_info["recipient"] = recipient
self._consolidate_chat_info(_chat_info)
consolidate_chat_info(_chat_info, uniform_sender=self)
self._prepare_chat(recipient, clear_history)
for agent in [self, recipient]:
agent.previous_cache = agent.client_cache
Expand Down Expand Up @@ -944,23 +944,7 @@ def _reflection_with_llm(
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
return response

def _consolidate_chat_info(self, chat_info: Union[Dict, List[Dict]]):
if isinstance(chat_info, dict):
chat_info = [chat_info]
for c in chat_info:
assert "recipient" in c, "recipient must be provided."
summary_method = c.get("summary_method")
assert (
summary_method is None
or isinstance(summary_method, Callable)
or summary_method in ("last_msg", "reflection_with_llm")
), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None."
if summary_method == "reflection_with_llm":
assert (
self.client is not None or c["recipient"].client is not None
), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."

def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[Agent, ChatResult]:
def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""(Experimental) Initiate chats with multiple agents.
TODO: add async version of this method.
Expand Down Expand Up @@ -992,57 +976,18 @@ def my_summary_method(
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
Returns: a dictionary of ChatResult object from the finished chats of particular agents.
Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
self._consolidate_chat_info(chat_queue)
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)
self._chat_queue = chat_queue.copy()
self._finished_chats = {}
while self._chat_queue:
chat_info = self._chat_queue.pop(0)
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [r.summary for r in self._finished_chats.values()]
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
current_agent = chat_info["recipient"]
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
chat_res = self.initiate_chat(**chat_info)
self._finished_chats[current_agent] = chat_res
_chat_queue = chat_queue.copy()
for chat_info in _chat_queue:
chat_info["sender"] = self
self._finished_chats = initiate_chats(_chat_queue)
return self._finished_chats

def get_chat_results(self, agent: Optional[Agent] = None) -> Union[Dict[Agent, ChatResult], ChatResult]:
def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]:
"""A summary from the finished chats of particular agents."""
if agent is not None:
return self._finished_chats.get(agent)
if chat_index is not None:
return self._finished_chats[chat_index]
else:
return self._finished_chats

Expand Down
27 changes: 25 additions & 2 deletions autogen/agent_utils.py → autogen/agentchat/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Callable
from .agent import Agent


def gather_usage_summary(agents: List) -> Tuple[Dict[str, any], Dict[str, any]]:
def consolidate_chat_info(chat_info, uniform_sender=None) -> None:
if isinstance(chat_info, dict):
chat_info = [chat_info]
for c in chat_info:
if uniform_sender is None:
assert "sender" in c, "sender must be provided."
sender = c["sender"]
else:
sender = uniform_sender
assert "recipient" in c, "recipient must be provided."
summary_method = c.get("summary_method")
assert (
summary_method is None
or isinstance(summary_method, Callable)
or summary_method in ("last_msg", "reflection_with_llm")
), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None."
if summary_method == "reflection_with_llm":
assert (
sender.client is not None or c["recipient"].client is not None
), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."


def gather_usage_summary(agents: List[Agent]) -> Tuple[Dict[str, any], Dict[str, any]]:
"""Gather usage summary from all agents.
Args:
Expand Down
Loading

0 comments on commit 5ce3133

Please sign in to comment.