Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

respond to more run status states and other GPT assistant housekeeping changes #665

Closed
wants to merge 10 commits into from
170 changes: 106 additions & 64 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import defaultdict
import openai
import json
import time
import logging
Expand All @@ -12,7 +11,6 @@

logger = logging.getLogger(__name__)


class GPTAssistantAgent(ConversableAgent):
"""
An experimental AutoGen agent class that leverages the OpenAI Assistant API for conversational capabilities.
Expand All @@ -25,6 +23,7 @@ def __init__(
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
overwrite_instructions: bool = False,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -64,6 +63,7 @@ def __init__(
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
file_ids=llm_config.get("file_ids", []),
)
else:
# retrieve an existing assistant
Expand All @@ -86,18 +86,23 @@ def __init__(
logger.warning(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)

super().__init__(
name=name,
system_message=instructions,
human_input_mode="NEVER",
llm_config=llm_config,
**kwargs
)

self.cancellation_requested = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an useage example for this argument?

Copy link
Author

@sidhujag sidhujag Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was assumed an external object may want to cancel because the run is looping in the state machine, this is a way to try to cancel before it finishes, wasting tokens or making other calls that are not intended

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, Implementing a timeout mechanism for the GPT Assistant Agent would be even better.

# lazly create thread
self._openai_threads = {}
self._unread_index = defaultdict(int)
self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=1)

def check_for_cancellation(self):
"""
Checks for cancellation used during _get_run_response
"""
return self.cancellation_requested

def _invoke_assistant(
self,
Expand All @@ -116,7 +121,6 @@ def _invoke_assistant(
Returns:
A tuple containing a boolean indicating success and the assistant's reply.
"""

if messages is None:
messages = self._oai_messages[sender]
unread_index = self._unread_index[sender] or 0
Expand All @@ -143,68 +147,108 @@ def _invoke_assistant(
# pass the latest system message as instructions
instructions=self.system_message,
)

run_response_messages = self._get_run_response(assistant_thread, run)
assert len(run_response_messages) > 0, "No response from the assistant."

response = {
"role": run_response_messages[-1]["role"],
"content": "",
}
for message in run_response_messages:
# just logging or do something with the intermediate messages?
# if current response is not empty and there is more, append new lines
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]

self.cancellation_requested = False
response = self._get_run_response(assistant_thread, run)
self._unread_index[sender] = len(self._oai_messages[sender]) + 1
return True, response
if response["content"]:
return True, response
else:
return False, "No response from the assistant."

def _get_run_response(self, thread, run):
def _process_messages(self, assistant_thread, run):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decorate the arguments, what types of object are thread and run

"""
Processes and provides a response based on the run status.
Args:
assistant_thread: The thread object for the assistant.
run: The run object initiated with the OpenAI assistant.
"""
if run.status == "failed":
logger.error(f'Run: {run.id} Thread: {assistant_thread.id}: failed...')
if run.last_error:
response = {
"role": "assistant",
"content": str(run.last_error),
}
else:
response = {
"role": "assistant",
"content": 'Failed',
}
return response
elif run.status == "expired":
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: expired...')
response = {
"role": "assistant",
"content": 'Expired',
}
return new_messages
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_messages is not defined here.

Copy link
Collaborator

@gagb gagb Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this in this PR: https://github.com/syscoin/autogen/pull/1 @sidhujag @IANTHEREAL please take a look and add to your PR. I don't have write access to your repo.

elif run.status == "cancelled":
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelled...')
response = {
"role": "assistant",
"content": 'Cancelled',
}
return response
elif run.status == "completed":
logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: completed...')
response_messages = self._openai_client.beta.threads.messages.list(assistant_thread.id, order="asc")
new_messages = []
for msg in response_messages:
if msg.run_id == run.id:
for content in msg.content:
if content.type == "text":
new_messages.append(
{"role": msg.role, "content": self._format_assistant_message(content.text)}
)
elif content.type == "image_file":
new_messages.append(
{
"role": msg.role,
"content": f"Recieved file id={content.image_file.file_id}",
}
)
response = {
"role": new_messages[-1]["role"],
"content": "",
}
for message in new_messages:
# just logging or do something with the intermediate messages?
# if current response is not empty and there is more, append new lines
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]
return response

def _get_run_response(self, assistant_thread, run):
"""
Waits for and processes the response of a run from the OpenAI assistant.

Args:
assistant_thread: The thread object for the assistant.
run: The run object initiated with the OpenAI assistant.

Returns:
Updated run object, status of the run, and response messages.
"""
while True:
run = self._wait_for_run(run.id, thread.id)
if run.status == "completed":
response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc")

new_messages = []
for msg in response_messages:
if msg.run_id == run.id:
for content in msg.content:
if content.type == "text":
new_messages.append(
{"role": msg.role, "content": self._format_assistant_message(content.text)}
)
elif content.type == "image_file":
new_messages.append(
{
"role": msg.role,
"content": f"Recieved file id={content.image_file.file_id}",
}
)
return new_messages
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id)
if run.status == "in_progress" or run.status == "queued":
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000)
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id)
elif run.status == "completed" or run.status == "cancelled" or run.status == "expired" or run.status == "failed":
return self._process_messages(assistant_thread, run)
elif run.status == "cancelling":
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelling...')
elif run.status == "requires_action":
logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: required action...')
actions = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
function = tool_call.function
is_exec_success, tool_response = self.execute_function(function.dict())
tool_response["metadata"] = {
"tool_call_id": tool_call.id,
"run_id": run.id,
"thread_id": thread.id,
"thread_id": assistant_thread.id,
}

logger.info(
"Intermediate executing(%s, Sucess: %s) : %s",
"Intermediate executing(%s, Success: %s) : %s",
tool_response["name"],
is_exec_success,
tool_response["content"],
Expand All @@ -217,38 +261,36 @@ def _get_run_response(self, thread, run):
for action in actions
],
"run_id": run.id,
"thread_id": thread.id,
"thread_id": assistant_thread.id,
}

run = self._openai_client.beta.threads.runs.submit_tool_outputs(**submit_tool_outputs)
if self.check_for_cancellation():
self._cancel_run()
else:
run_info = json.dumps(run.dict(), indent=2)
raise ValueError(f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})")

def _wait_for_run(self, run_id: str, thread_id: str) -> Any:

def _cancel_run(self, run_id: str, thread_id: str):
"""
Waits for a run to complete or reach a final state.
Cancels a run.

Args:
run_id: The ID of the run.
thread_id: The ID of the thread associated with the run.

Returns:
The updated run object after completion or reaching a final state.
"""
in_progress = True
while in_progress:
run = self._openai_client.beta.threads.runs.retrieve(run_id, thread_id=thread_id)
in_progress = run.status in ("in_progress", "queued")
if in_progress:
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000)
return run
try:
self._openai_client.beta.threads.runs.cancel(run_id=run_id, thread_id=thread_id)
logger.info(f'Run: {run_id} Thread: {thread_id}: successfully sent cancellation signal.')
except Exception as e:
logger.error(f'Run: {run_id} Thread: {thread_id}: failed to send cancellation signal: {e}')


def _format_assistant_message(self, message_content):
"""
Formats the assistant's message to include annotations and citations.
"""

annotations = message_content.annotations
citations = []

Expand Down
Loading