Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow up of #665: respond to more run states #899

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 127 additions & 57 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections import defaultdict
import openai
import json
import time
import logging

from openai.types.beta.thread import Thread
from openai.types.beta.threads.run import Run

from autogen import OpenAIWrapper
from autogen.oai.openai_utils import retrieve_assistants_by_name
from autogen.agentchat.agent import Agent
Expand Down Expand Up @@ -112,7 +114,13 @@ def __init__(
# lazily create threads
self._openai_threads = {}
self._unread_index = defaultdict(int)
self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=2)
gagb marked this conversation as resolved.
Show resolved Hide resolved

def _check_for_cancellation(self):
"""
Checks for cancellation used during _get_run_response
"""
return self.cancellation_requested

def _invoke_assistant(
self,
Expand All @@ -131,7 +139,6 @@ def _invoke_assistant(
Returns:
A tuple containing a boolean indicating success and the assistant's reply.
"""

if messages is None:
messages = self._oai_messages[sender]
unread_index = self._unread_index[sender] or 0
Expand All @@ -158,68 +165,136 @@ def _invoke_assistant(
# pass the latest system message as instructions
instructions=self.system_message,
)
self.cancellation_requested = False
response = self._get_run_response(assistant_thread, run)
self._unread_index[sender] = len(self._oai_messages[sender]) + 1
if response["content"]:
return True, response
else:
return False, "No response from the assistant."

run_response_messages = self._get_run_response(assistant_thread, run)
assert len(run_response_messages) > 0, "No response from the assistant."
def _process_messages(self, assistant_thread: Thread, run: Run) -> dict:
"""
Processes the status of a run and generates an appropriate response.

response = {
"role": run_response_messages[-1]["role"],
"content": "",
}
for message in run_response_messages:
# just logging or do something with the intermediate messages?
# if current response is not empty and there is more, append new lines
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]
This method checks the status of a given run and generates a response
based on that status. The response is a dictionary with a 'role' and 'content' field.
The 'role' field is always 'assistant', and the 'content' field contains a message
related to the status of the run.

self._unread_index[sender] = len(self._oai_messages[sender]) + 1
return True, response
Args:
assistant_thread (Thread): The Thread object associated with the current task.
run (Run): The Run object associated with the current task.

def _get_run_response(self, thread, run):
Returns:
dict: A dictionary containing the 'role' and 'content' of the response.
"""
Waits for and processes the response of a run from the OpenAI assistant.
if run.status == "failed":
logger.error(f"Run: {run.id} Thread: {assistant_thread.id}: failed...")
if run.last_error:
response = {
"role": "assistant",
"content": str(run.last_error),
}
else:
response = {
"role": "assistant",
"content": "Failed",
}
return response
elif run.status == "expired":
logger.warn(f"Run: {run.id} Thread: {assistant_thread.id}: expired...")
response = {
"role": "assistant",
"content": "Expired",
}
return response
elif run.status == "cancelled":
logger.warn(f"Run: {run.id} Thread: {assistant_thread.id}: cancelled...")
response = {
"role": "assistant",
"content": "Cancelled",
}
return response
elif run.status == "completed":
logger.info(f"Run: {run.id} Thread: {assistant_thread.id}: completed...")
response_messages = self._openai_client.beta.threads.messages.list(assistant_thread.id, order="asc")
new_messages = []
for msg in response_messages:
if msg.run_id == run.id:
for content in msg.content:
if content.type == "text":
new_messages.append(
{"role": msg.role, "content": self._format_assistant_message(content.text)}
)
elif content.type == "image_file":
new_messages.append(
{
"role": msg.role,
"content": f"Recieved file id={content.image_file.file_id}",
}
)
response = {
"role": new_messages[-1]["role"],
"content": "",
}
for message in new_messages:
# just logging or do something with the intermediate messages?
# if current response is not empty and there is more, append new lines
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]
return response

def _get_run_response(self, assistant_thread: Thread, run: Run) -> dict:
"""
This method waits for a run initiated with the OpenAI assistant to complete,
and then processes its response.
It continuously checks for the status of the run. If the run is still in progress
or queued, it waits for a specified amount of time before checking again. If
the run is completed, cancelled, or expired, it stops waiting and processes the
response.

Args:
run: The run object initiated with the OpenAI assistant.
assistant_thread (Thread): The thread object associated with the task.
run (Run): The run object that was initiated with the task.

Returns:
Updated run object, status of the run, and response messages.
dict: The processed response from the run.

Raises:
Exception: If the run is cancelled due to a cancellation request.
"""
while True:
Copy link
Collaborator

@IANTHEREAL IANTHEREAL Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to check _cancel_run() here (after while True)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused. _cancel_run() issues a cancellation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confusing that why do we only check for cancellation after submit_tool_outputs? Wouldn't it be possible to check for cancellation at the beginning of each new loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the suggestion to remove current line 288-289. And instead put them right after 244. If so, that makes sense to me. Though I think we don't yet have a method to allow the user to set the value of self.cancellation_requested. If I understood the issue correctly, shall we also add that?

Copy link
Collaborator

@IANTHEREAL IANTHEREAL Dec 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and I am with you on adding a cancel function.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can move after submission but I figure to let it finish submitting an action before trying to cancel it. Looks good otherwise.

run = self._wait_for_run(run.id, thread.id)
if run.status == "completed":
response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc")

new_messages = []
for msg in response_messages:
if msg.run_id == run.id:
for content in msg.content:
if content.type == "text":
new_messages.append(
{"role": msg.role, "content": self._format_assistant_message(content.text)}
)
elif content.type == "image_file":
new_messages.append(
{
"role": msg.role,
"content": f"Recieved file id={content.image_file.file_id}",
}
)
return new_messages
if self._check_for_cancellation():
self._cancel_run()
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id)
if run.status == "in_progress" or run.status == "queued":
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000)
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id)
elif (
run.status == "completed"
or run.status == "cancelled"
or run.status == "expired"
or run.status == "failed"
):
return self._process_messages(assistant_thread, run)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line returns an object which contradicts the signature and docstr.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry more info?

elif run.status == "cancelling":
logger.warn(f"Run: {run.id} Thread: {assistant_thread.id}: cancelling...")
elif run.status == "requires_action":
logger.info(f"Run: {run.id} Thread: {assistant_thread.id}: required action...")
actions = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
function = tool_call.function
is_exec_success, tool_response = self.execute_function(function.dict(), self._verbose)
tool_response["metadata"] = {
"tool_call_id": tool_call.id,
"run_id": run.id,
"thread_id": thread.id,
"thread_id": assistant_thread.id,
}

logger.info(
"Intermediate executing(%s, Sucess: %s) : %s",
"Intermediate executing(%s, Success: %s) : %s",
tool_response["name"],
is_exec_success,
tool_response["content"],
Expand All @@ -232,38 +307,33 @@ def _get_run_response(self, thread, run):
for action in actions
],
"run_id": run.id,
"thread_id": thread.id,
"thread_id": assistant_thread.id,
}

run = self._openai_client.beta.threads.runs.submit_tool_outputs(**submit_tool_outputs)

else:
run_info = json.dumps(run.dict(), indent=2)
raise ValueError(f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})")

def _wait_for_run(self, run_id: str, thread_id: str) -> Any:
def _cancel_run(self, run_id: str, thread_id: str):
"""
Waits for a run to complete or reach a final state.
Cancels a run.

Args:
run_id: The ID of the run.
thread_id: The ID of the thread associated with the run.

Returns:
The updated run object after completion or reaching a final state.
"""
in_progress = True
while in_progress:
run = self._openai_client.beta.threads.runs.retrieve(run_id, thread_id=thread_id)
in_progress = run.status in ("in_progress", "queued")
if in_progress:
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000)
return run
try:
self._openai_client.beta.threads.runs.cancel(run_id=run_id, thread_id=thread_id)
logger.info(f"Run: {run_id} Thread: {thread_id}: successfully sent cancellation signal.")
except Exception as e:
logger.error(f"Run: {run_id} Thread: {thread_id}: failed to send cancellation signal: {e}")

def _format_assistant_message(self, message_content):
"""
Formats the assistant's message to include annotations and citations.
"""

annotations = message_content.annotations
citations = []

Expand Down
36 changes: 36 additions & 0 deletions test/agentchat/contrib/test_gpt_assistant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import os
import sys
from unittest.mock import Mock
import autogen
from autogen import OpenAIWrapper

Expand Down Expand Up @@ -246,9 +247,44 @@ def test_assistant_retrieval():
assert candidate_first == candidate_second


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
def test_process_messages_not_complete():
# Create a mock assistant thread and run
mock_thread = Mock()
mock_thread.id = "thread1"

mock_run = Mock()
gt_map = {
("failed", "Here is the last error message"): "Here is the last error message",
("failed", None): "Failed",
("expired", None): "Expired",
("cancelled", None): "Cancelled",
}

for (status, last_error), expected_content in gt_map.items():
mock_run.status = status
mock_run.last_error = last_error

instance = GPTAssistantAgent(
"assistant",
llm_config={
"config_list": config_list,
},
)

# Call _process_messages and assert it returns the correct response
response = instance._process_messages(mock_thread, mock_run)
instance.delete_assistant()
assert response == {"role": "assistant", "content": expected_content}


if __name__ == "__main__":
test_gpt_assistant_chat()
test_get_assistant_instructions()
test_gpt_assistant_instructions_overwrite()
test_gpt_assistant_existing_no_instructions()
test_get_assistant_files()
test_process_messages_not_complete()
Loading