diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index b4fc5e8096e..c788d24803d 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -1,5 +1,4 @@ from collections import defaultdict -import openai import json import time import logging @@ -12,7 +11,6 @@ logger = logging.getLogger(__name__) - class GPTAssistantAgent(ConversableAgent): """ An experimental AutoGen agent class that leverages the OpenAI Assistant API for conversational capabilities. @@ -25,6 +23,7 @@ def __init__( instructions: Optional[str] = None, llm_config: Optional[Union[Dict, bool]] = None, overwrite_instructions: bool = False, + **kwargs, ): """ Args: @@ -64,6 +63,7 @@ def __init__( instructions=instructions, tools=llm_config.get("tools", []), model=llm_config.get("model", "gpt-4-1106-preview"), + file_ids=llm_config.get("file_ids", []), ) else: # retrieve an existing assistant @@ -86,18 +86,23 @@ def __init__( logger.warning( "overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API." ) - super().__init__( name=name, system_message=instructions, - human_input_mode="NEVER", llm_config=llm_config, + **kwargs ) - + self.cancellation_requested = False # lazly create thread self._openai_threads = {} self._unread_index = defaultdict(int) - self.register_reply(Agent, GPTAssistantAgent._invoke_assistant) + self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=1) + + def check_for_cancellation(self): + """ + Checks for cancellation used during _get_run_response + """ + return self.cancellation_requested def _invoke_assistant( self, @@ -116,7 +121,6 @@ def _invoke_assistant( Returns: A tuple containing a boolean indicating success and the assistant's reply. """ - if messages is None: messages = self._oai_messages[sender] unread_index = self._unread_index[sender] or 0 @@ -143,56 +147,96 @@ def _invoke_assistant( # pass the latest system message as instructions instructions=self.system_message, ) - - run_response_messages = self._get_run_response(assistant_thread, run) - assert len(run_response_messages) > 0, "No response from the assistant." - - response = { - "role": run_response_messages[-1]["role"], - "content": "", - } - for message in run_response_messages: - # just logging or do something with the intermediate messages? - # if current response is not empty and there is more, append new lines - if len(response["content"]) > 0: - response["content"] += "\n\n" - response["content"] += message["content"] - + self.cancellation_requested = False + response = self._get_run_response(assistant_thread, run) self._unread_index[sender] = len(self._oai_messages[sender]) + 1 - return True, response + if response["content"]: + return True, response + else: + return False, "No response from the assistant." - def _get_run_response(self, thread, run): + def _process_messages(self, assistant_thread, run): + """ + Processes and provides a response based on the run status. + Args: + assistant_thread: The thread object for the assistant. + run: The run object initiated with the OpenAI assistant. + """ + if run.status == "failed": + logger.error(f'Run: {run.id} Thread: {assistant_thread.id}: failed...') + if run.last_error: + response = { + "role": "assistant", + "content": str(run.last_error), + } + else: + response = { + "role": "assistant", + "content": 'Failed', + } + return response + elif run.status == "expired": + logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: expired...') + response = { + "role": "assistant", + "content": 'Expired', + } + return new_messages + elif run.status == "cancelled": + logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelled...') + response = { + "role": "assistant", + "content": 'Cancelled', + } + return response + elif run.status == "completed": + logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: completed...') + response_messages = self._openai_client.beta.threads.messages.list(assistant_thread.id, order="asc") + new_messages = [] + for msg in response_messages: + if msg.run_id == run.id: + for content in msg.content: + if content.type == "text": + new_messages.append( + {"role": msg.role, "content": self._format_assistant_message(content.text)} + ) + elif content.type == "image_file": + new_messages.append( + { + "role": msg.role, + "content": f"Recieved file id={content.image_file.file_id}", + } + ) + response = { + "role": new_messages[-1]["role"], + "content": "", + } + for message in new_messages: + # just logging or do something with the intermediate messages? + # if current response is not empty and there is more, append new lines + if len(response["content"]) > 0: + response["content"] += "\n\n" + response["content"] += message["content"] + return response + + def _get_run_response(self, assistant_thread, run): """ Waits for and processes the response of a run from the OpenAI assistant. - Args: + assistant_thread: The thread object for the assistant. run: The run object initiated with the OpenAI assistant. - - Returns: - Updated run object, status of the run, and response messages. """ while True: - run = self._wait_for_run(run.id, thread.id) - if run.status == "completed": - response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc") - - new_messages = [] - for msg in response_messages: - if msg.run_id == run.id: - for content in msg.content: - if content.type == "text": - new_messages.append( - {"role": msg.role, "content": self._format_assistant_message(content.text)} - ) - elif content.type == "image_file": - new_messages.append( - { - "role": msg.role, - "content": f"Recieved file id={content.image_file.file_id}", - } - ) - return new_messages + run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id) + if run.status == "in_progress" or run.status == "queued": + time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000) + run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id) + elif run.status == "completed" or run.status == "cancelled" or run.status == "expired" or run.status == "failed": + return self._process_messages(assistant_thread, run) + elif run.status == "cancelling": + logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelling...') elif run.status == "requires_action": + logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: required action...') actions = [] for tool_call in run.required_action.submit_tool_outputs.tool_calls: function = tool_call.function @@ -200,11 +244,11 @@ def _get_run_response(self, thread, run): tool_response["metadata"] = { "tool_call_id": tool_call.id, "run_id": run.id, - "thread_id": thread.id, + "thread_id": assistant_thread.id, } logger.info( - "Intermediate executing(%s, Sucess: %s) : %s", + "Intermediate executing(%s, Success: %s) : %s", tool_response["name"], is_exec_success, tool_response["content"], @@ -217,38 +261,36 @@ def _get_run_response(self, thread, run): for action in actions ], "run_id": run.id, - "thread_id": thread.id, + "thread_id": assistant_thread.id, } run = self._openai_client.beta.threads.runs.submit_tool_outputs(**submit_tool_outputs) + if self.check_for_cancellation(): + self._cancel_run() else: run_info = json.dumps(run.dict(), indent=2) raise ValueError(f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})") - def _wait_for_run(self, run_id: str, thread_id: str) -> Any: + + def _cancel_run(self, run_id: str, thread_id: str): """ - Waits for a run to complete or reach a final state. + Cancels a run. Args: run_id: The ID of the run. thread_id: The ID of the thread associated with the run. - - Returns: - The updated run object after completion or reaching a final state. """ - in_progress = True - while in_progress: - run = self._openai_client.beta.threads.runs.retrieve(run_id, thread_id=thread_id) - in_progress = run.status in ("in_progress", "queued") - if in_progress: - time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000) - return run + try: + self._openai_client.beta.threads.runs.cancel(run_id=run_id, thread_id=thread_id) + logger.info(f'Run: {run_id} Thread: {thread_id}: successfully sent cancellation signal.') + except Exception as e: + logger.error(f'Run: {run_id} Thread: {thread_id}: failed to send cancellation signal: {e}') + def _format_assistant_message(self, message_content): """ Formats the assistant's message to include annotations and citations. """ - annotations = message_content.annotations citations = []