Skip to content

Commit

Permalink
clean up run loop
Browse files Browse the repository at this point in the history
1) remove is_termination_msg
2) add external run cancellation
3) remove _wait_for_run and internalize through _get_run_response
4) process responses through _process_messages
  • Loading branch information
jagdeep sidhu committed Nov 17, 2023
1 parent b89bbe2 commit 88a6343
Showing 1 changed file with 103 additions and 111 deletions.
214 changes: 103 additions & 111 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import json
import time
import logging
import threading

from autogen import OpenAIWrapper
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.assistant_agent import AssistantAgent
from typing import Dict, Optional, Union, List, Tuple, Any, Callable
from typing import Dict, Optional, Union, List, Tuple, Any

logger = logging.getLogger(__name__)


class GPTAssistantAgent(ConversableAgent):
"""
An experimental AutoGen agent class that leverages the OpenAI Assistant API for conversational capabilities.
Expand All @@ -22,7 +22,6 @@ class GPTAssistantAgent(ConversableAgent):
def __init__(
self,
name="GPT Assistant",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
overwrite_instructions: bool = False,
Expand Down Expand Up @@ -66,6 +65,7 @@ def __init__(
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
file_ids=llm_config.get("file_ids", []),
)
else:
# retrieve an existing assistant
Expand All @@ -88,24 +88,23 @@ def __init__(
logger.warning(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)
_is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: "TERMINATE" in x.get("content", ""))
)
super().__init__(
name=name,
system_message=instructions,
llm_config=llm_config,
is_termination_msg=_is_termination_msg,
**kwargs
)

self.cancellation_requested = False
# lazly create thread
self._openai_threads = {}
self._unread_index = defaultdict(int)
self._reply_func_list = []
self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=1)

def check_for_cancellation(self):
"""
Checks for cancellation used during _get_run_response
"""
return self.cancellation_requested

def _invoke_assistant(
self,
Expand All @@ -124,7 +123,6 @@ def _invoke_assistant(
Returns:
A tuple containing a boolean indicating success and the assistant's reply.
"""

if messages is None:
messages = self._oai_messages[sender]
unread_index = self._unread_index[sender] or 0
Expand All @@ -151,108 +149,104 @@ def _invoke_assistant(
# pass the latest system message as instructions
instructions=self.system_message,
)

run_response_messages = self._get_run_response(assistant_thread, run)
assert len(run_response_messages) > 0, "No response from the assistant."

response = {
"role": run_response_messages[-1]["role"],
"content": "",
}
for message in run_response_messages:
# just logging or do something with the intermediate messages?
# if current response is not empty and there is more, append new lines
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]
self.cancellation_requested = False
response = self._get_run_response(assistant_thread, run)
self._unread_index[sender] = len(self._oai_messages[sender]) + 1
return True, response
if response["content"]:
return True, response
else:
return False, "No response from the assistant."

def _get_run_response(self, thread, run):
def _process_messages(self, assistant_thread, run):
"""
Processes and provides a response based on the run status.
Args:
assistant_thread: The thread object for the assistant.
run: The run object initiated with the OpenAI assistant.
"""
if run.status == "failed":
logger.error(f'Run: {run.id} Thread: {assistant_thread.id}: failed...')
if run.last_error:
response = {
"role": "assistant",
"content": str(run.last_error),
}
else:
response = {
"role": "assistant",
"content": 'Failed',
}
return response
elif run.status == "expired":
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: expired...')
response = {
"role": "assistant",
"content": 'Expired',
}
return new_messages
elif run.status == "cancelled":
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelled...')
response = {
"role": "assistant",
"content": 'Cancelled',
}
return response
elif run.status == "completed":
logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: completed...')
response_messages = self._openai_client.beta.threads.messages.list(assistant_thread.id, order="asc")
new_messages = []
for msg in response_messages:
if msg.run_id == run.id:
for content in msg.content:
if content.type == "text":
new_messages.append(
{"role": msg.role, "content": self._format_assistant_message(content.text)}
)
elif content.type == "image_file":
new_messages.append(
{
"role": msg.role,
"content": f"Recieved file id={content.image_file.file_id}",
}
)
response = {
"role": new_messages[-1]["role"],
"content": "",
}
for message in new_messages:
# just logging or do something with the intermediate messages?
# if current response is not empty and there is more, append new lines
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]
return response

def _get_run_response(self, assistant_thread, run):
"""
Waits for and processes the response of a run from the OpenAI assistant.
Args:
assistant_thread: The thread object for the assistant.
run: The run object initiated with the OpenAI assistant.
Returns:
Updated run object, status of the run, and response messages.
"""
while True:
run = self._wait_for_run(run.id, thread.id)
if run.status == "failed":
new_messages = []
logger.error(f'Run: {run.id} Thread: {thread.id}: failed...')
if run.last_error:
new_messages.append(
{
"role": "assistant",
"content": str(run.last_error),
}
)
else:
new_messages.append(
{
"role": "assistant",
"content": 'Failed',
}
)
return new_messages
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id)
if run.status == "in_progress" or run.status == "queued":
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000)
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id)
elif run.status == "completed" or run.status == "cancelled" or run.status == "expired" or run.status == "failed":
return self._process_messages(assistant_thread, run)
elif run.status == "cancelling":
logger.warn(f'Run: {run.id} Thread: {thread.id}: cancelling...')
elif run.status == "expired":
logger.warn(f'Run: {run.id} Thread: {thread.id}: expired...')
new_messages = []
new_messages.append(
{
"role": "assistant",
"content": 'Expired',
}
)
return new_messages
elif run.status == "cancelled":
logger.warn(f'Run: {run.id} Thread: {thread.id}: cancelled...')
new_messages = []
new_messages.append(
{
"role": "assistant",
"content": 'Cancelled',
}
)
return new_messages
elif run.status == "in_progress":
logger.info(f'Run: {run.id} Thread: {thread.id}: in progress...')
elif run.status == "queued":
logger.info(f'Run: {run.id} Thread: {thread.id}: queued...')
elif run.status == "completed":
logger.info(f'Run: {run.id} Thread: {thread.id}: completed...')
response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc")
new_messages = []
for msg in response_messages:
if msg.run_id == run.id:
for content in msg.content:
if content.type == "text":
new_messages.append(
{"role": msg.role, "content": self._format_assistant_message(content.text)}
)
elif content.type == "image_file":
new_messages.append(
{
"role": msg.role,
"content": f"Recieved file id={content.image_file.file_id}",
}
)
return new_messages
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelling...')
elif run.status == "requires_action":
logger.info(f'Run: {run.id} Thread: {thread.id}: required action...')
logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: required action...')
actions = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
function = tool_call.function
is_exec_success, tool_response = self.execute_function(function.dict())
tool_response["metadata"] = {
"tool_call_id": tool_call.id,
"run_id": run.id,
"thread_id": thread.id,
"thread_id": assistant_thread.id,
}

logger.info(
Expand All @@ -269,38 +263,36 @@ def _get_run_response(self, thread, run):
for action in actions
],
"run_id": run.id,
"thread_id": thread.id,
"thread_id": assistant_thread.id,
}

run = self._openai_client.beta.threads.runs.submit_tool_outputs(**submit_tool_outputs)
if self.check_for_cancellation():
self._cancel_run()
else:
run_info = json.dumps(run.dict(), indent=2)
raise ValueError(f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})")

def _wait_for_run(self, run_id: str, thread_id: str) -> Any:

def _cancel_run(self, run_id: str, thread_id: str):
"""
Waits for a run to complete or reach a final state.
Cancels a run.
Args:
run_id: The ID of the run.
thread_id: The ID of the thread associated with the run.
Returns:
The updated run object after completion or reaching a final state.
"""
in_progress = True
while in_progress:
run = self._openai_client.beta.threads.runs.retrieve(run_id, thread_id=thread_id)
in_progress = run.status in ("in_progress", "queued")
if in_progress:
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000)
return run
try:
self._openai_client.beta.threads.runs.cancel(run_id=run_id, thread_id=thread_id)
logger.info(f'Run: {run_id} Thread: {thread_id}: successfully sent cancellation signal.')
except Exception as e:
logger.error(f'Run: {run_id} Thread: {thread_id}: failed to send cancellation signal: {e}')


def _format_assistant_message(self, message_content):
"""
Formats the assistant's message to include annotations and citations.
"""

annotations = message_content.annotations
citations = []

Expand Down

0 comments on commit 88a6343

Please sign in to comment.