Skip to content

Commit

Permalink
Merge branch 'main' into types_1
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Mar 18, 2024
2 parents 77b408f + cc836de commit 8649ac4
Show file tree
Hide file tree
Showing 52 changed files with 5,030 additions and 2,048 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,4 @@ test/agentchat/test_agent_scripts/*


notebook/result.png
samples/apps/autogen-studio/autogenstudio/models/test/
111 changes: 75 additions & 36 deletions autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from functools import partial
import logging
from collections import defaultdict
from collections import defaultdict, abc
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass
from .utils import consolidate_chat_info
import datetime
import warnings
from termcolor import colored
from .utils import consolidate_chat_info


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -135,22 +137,22 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""Initiate a list of chats.
Args:
chat_queue (List[Dict]): a list of dictionaries containing the information of the chats.
Each dictionary should contain the input arguments for `ConversableAgent.initiate_chat`.
More specifically, each dictionary could include the following fields:
- "sender": the sender agent.
- "recipient": the recipient agent.
- "clear_history" (bool): whether to clear the chat history with the agent. Default is True.
- "silent" (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
- "cache" (Cache or None): the cache client to be used for this conversation. Default is None.
- "max_turns" (int or None): the maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
- "summary_method" (str or callable): a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- "summary_args" (dict): a dictionary of arguments to be passed to the summary_method. Default is {}.
- "message" (str, callable or None): if None, input() will be called to get the initial message.
- **context: additional context information to be passed to the chat.
- "carryover": It can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
chat_queue (List[Dict]): a list of dictionaries containing the information about the chats.
Each dictionary should contain the input arguments for [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat). For example:
- "sender": the sender agent.
- "recipient": the recipient agent.
- "clear_history" (bool): whether to clear the chat history with the agent. Default is True.
- "silent" (bool or None): (Experimental) whether to print the messages in this conversation. Default is False.
- "cache" (Cache or None): the cache client to use for this conversation. Default is None.
- "max_turns" (int or None): maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
- "summary_method" (str or callable): a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- "summary_args" (dict): a dictionary of arguments to be passed to the summary_method. Default is {}.
- "message" (str, callable or None): if None, input() will be called to get the initial message.
- **context: additional context information to be passed to the chat.
- "carryover": It can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
Returns:
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
Expand All @@ -173,6 +175,49 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
return finished_chats


def __system_now_str():
ct = datetime.datetime.now()
return f" System time at {ct}. "


def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
"""
Update ChatResult when async Task for Chat is completed.
"""
logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
chat_result = chat_future.result()
chat_result.chat_id = chat_id


async def _dependent_chat_future(
chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future]
) -> asyncio.Task:
"""
Create an async Task for each chat.
"""
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
_chat_carryover = chat_info.get("carryover", [])
finished_chats = dict()
for chat in prerequisite_chat_futures:
chat_future = prerequisite_chat_futures[chat]
if chat_future.cancelled():
raise RuntimeError(f"Chat {chat} is cancelled.")

# wait for prerequisite chat results for the new chat carryover
finished_chats[chat] = await chat_future

if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
chat_res_future.add_done_callback(call_back_with_args)
logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
return chat_res_future


async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
"""(async) Initiate a list of chats.
Expand All @@ -183,31 +228,25 @@ async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatRe
returns:
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""

consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chats = dict()
finished_chat_futures = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
condition = asyncio.Condition()
prerequisite_chat_ids = chat_info.get("prerequisites", [])
async with condition:
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
# Do the actual work here.
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = await sender.a_initiate_chat(**chat_info)
chat_res.chat_id = chat_id
finished_chats[chat_id] = chat_res

pre_chat_futures = dict()
for pre_chat_id in prerequisite_chat_ids:
pre_chat_future = finished_chat_futures[pre_chat_id]
pre_chat_futures[pre_chat_id] = pre_chat_future
current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
finished_chat_futures[chat_id] = current_chat_future
await asyncio.gather(*list(finished_chat_futures.values()))
finished_chats = dict()
for chat in finished_chat_futures:
chat_result = finished_chat_futures[chat].result()
finished_chats[chat] = chat_result
return finished_chats
69 changes: 51 additions & 18 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
name="GPT Assistant",
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
assistant_config: Optional[Dict] = None,
overwrite_instructions: bool = False,
overwrite_tools: bool = False,
**kwargs,
Expand All @@ -43,8 +44,9 @@ def __init__(
AssistantAgent.DEFAULT_SYSTEM_MESSAGE. If the assistant exists, the
system message will be set to the existing assistant instructions.
llm_config (dict or False): llm inference configuration.
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106).
assistant_config
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- check_every_ms: check thread run status interval
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
Expand All @@ -57,23 +59,19 @@ def __init__(
"""

self._verbose = kwargs.pop("verbose", False)
openai_client_cfg, openai_assistant_cfg = self._process_assistant_config(llm_config, assistant_config)

super().__init__(
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=llm_config, **kwargs
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs
)

if llm_config is False:
raise ValueError("llm_config=False is not supported for GPTAssistantAgent.")
# Use AutooGen OpenAIWrapper to create a client
openai_client_cfg = copy.deepcopy(llm_config)
# Use the class variable
model_name = GPTAssistantAgent.DEFAULT_MODEL_NAME

# GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list()
# See: https://github.com/microsoft/autogen/pull/1721
model_name = self.DEFAULT_MODEL_NAME
if openai_client_cfg.get("config_list") is not None and len(openai_client_cfg["config_list"]) > 0:
model_name = openai_client_cfg["config_list"][0].pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME)
model_name = openai_client_cfg["config_list"][0].pop("model", self.DEFAULT_MODEL_NAME)
else:
model_name = openai_client_cfg.pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME)
model_name = openai_client_cfg.pop("model", self.DEFAULT_MODEL_NAME)

logger.warning("OpenAI client config of GPTAssistantAgent(%s) - model: %s", name, model_name)

Expand All @@ -82,14 +80,17 @@ def __init__(
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")

self._openai_client = oai_wrapper._clients[0]._oai_client
openai_assistant_id = llm_config.get("assistant_id", None)
openai_assistant_id = openai_assistant_cfg.get("assistant_id", None)
if openai_assistant_id is None:
# try to find assistant by name first
candidate_assistants = retrieve_assistants_by_name(self._openai_client, name)
if len(candidate_assistants) > 0:
# Filter out candidates with the same name but different instructions, file IDs, and function names.
candidate_assistants = self.find_matching_assistant(
candidate_assistants, instructions, llm_config.get("tools", []), llm_config.get("file_ids", [])
candidate_assistants,
instructions,
openai_assistant_cfg.get("tools", []),
openai_assistant_cfg.get("file_ids", []),
)

if len(candidate_assistants) == 0:
Expand All @@ -103,9 +104,9 @@ def __init__(
self._openai_assistant = self._openai_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
tools=openai_assistant_cfg.get("tools", []),
model=model_name,
file_ids=llm_config.get("file_ids", []),
file_ids=openai_assistant_cfg.get("file_ids", []),
)
else:
logger.warning(
Expand Down Expand Up @@ -135,8 +136,8 @@ def __init__(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)

# Check if tools are specified in llm_config
specified_tools = llm_config.get("tools", None)
# Check if tools are specified in assistant_config
specified_tools = openai_assistant_cfg.get("tools", None)

if specified_tools is None:
# Check if the current assistant has tools defined
Expand All @@ -155,7 +156,7 @@ def __init__(
)
self._openai_assistant = self._openai_client.beta.assistants.update(
assistant_id=openai_assistant_id,
tools=llm_config.get("tools", []),
tools=openai_assistant_cfg.get("tools", []),
)
else:
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
Expand Down Expand Up @@ -414,6 +415,10 @@ def assistant_id(self):
def openai_client(self):
return self._openai_client

@property
def openai_assistant(self):
return self._openai_assistant

def get_assistant_instructions(self):
"""Return the assistant instructions from OAI assistant API"""
return self._openai_assistant.instructions
Expand Down Expand Up @@ -472,3 +477,31 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
matching_assistants.append(assistant)

return matching_assistants

def _process_assistant_config(self, llm_config, assistant_config):
"""
Process the llm_config and assistant_config to extract the model name and assistant related configurations.
"""

if llm_config is False:
raise ValueError("llm_config=False is not supported for GPTAssistantAgent.")

if llm_config is None:
openai_client_cfg = {}
else:
openai_client_cfg = copy.deepcopy(llm_config)

if assistant_config is None:
openai_assistant_cfg = {}
else:
openai_assistant_cfg = copy.deepcopy(assistant_config)

# Move the assistant related configurations to assistant_config
# It's important to keep forward compatibility
assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
for item in assistant_config_items:
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
openai_assistant_cfg[item] = openai_client_cfg[item]
openai_client_cfg.pop(item, None)

return openai_client_cfg, openai_assistant_cfg
Loading

0 comments on commit 8649ac4

Please sign in to comment.