-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
respond to more run status states and other GPT assistant housekeeping changes #665
Changes from all commits
12e285e
4da0876
9c8cac5
65c6292
ad0216b
b1d11ad
975bf0e
b89bbe2
88a6343
7ae462d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
from collections import defaultdict | ||
import openai | ||
import json | ||
import time | ||
import logging | ||
|
@@ -12,7 +11,6 @@ | |
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GPTAssistantAgent(ConversableAgent): | ||
""" | ||
An experimental AutoGen agent class that leverages the OpenAI Assistant API for conversational capabilities. | ||
|
@@ -25,6 +23,7 @@ def __init__( | |
instructions: Optional[str] = None, | ||
llm_config: Optional[Union[Dict, bool]] = None, | ||
overwrite_instructions: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
Args: | ||
|
@@ -64,6 +63,7 @@ def __init__( | |
instructions=instructions, | ||
tools=llm_config.get("tools", []), | ||
model=llm_config.get("model", "gpt-4-1106-preview"), | ||
file_ids=llm_config.get("file_ids", []), | ||
) | ||
else: | ||
# retrieve an existing assistant | ||
|
@@ -86,18 +86,23 @@ def __init__( | |
logger.warning( | ||
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API." | ||
) | ||
|
||
super().__init__( | ||
name=name, | ||
system_message=instructions, | ||
human_input_mode="NEVER", | ||
llm_config=llm_config, | ||
**kwargs | ||
) | ||
|
||
self.cancellation_requested = False | ||
# lazly create thread | ||
self._openai_threads = {} | ||
self._unread_index = defaultdict(int) | ||
self.register_reply(Agent, GPTAssistantAgent._invoke_assistant) | ||
self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=1) | ||
|
||
def check_for_cancellation(self): | ||
""" | ||
Checks for cancellation used during _get_run_response | ||
""" | ||
return self.cancellation_requested | ||
|
||
def _invoke_assistant( | ||
self, | ||
|
@@ -116,7 +121,6 @@ def _invoke_assistant( | |
Returns: | ||
A tuple containing a boolean indicating success and the assistant's reply. | ||
""" | ||
|
||
if messages is None: | ||
messages = self._oai_messages[sender] | ||
unread_index = self._unread_index[sender] or 0 | ||
|
@@ -143,68 +147,108 @@ def _invoke_assistant( | |
# pass the latest system message as instructions | ||
instructions=self.system_message, | ||
) | ||
|
||
run_response_messages = self._get_run_response(assistant_thread, run) | ||
assert len(run_response_messages) > 0, "No response from the assistant." | ||
|
||
response = { | ||
"role": run_response_messages[-1]["role"], | ||
"content": "", | ||
} | ||
for message in run_response_messages: | ||
# just logging or do something with the intermediate messages? | ||
# if current response is not empty and there is more, append new lines | ||
if len(response["content"]) > 0: | ||
response["content"] += "\n\n" | ||
response["content"] += message["content"] | ||
|
||
self.cancellation_requested = False | ||
response = self._get_run_response(assistant_thread, run) | ||
self._unread_index[sender] = len(self._oai_messages[sender]) + 1 | ||
return True, response | ||
if response["content"]: | ||
return True, response | ||
else: | ||
return False, "No response from the assistant." | ||
|
||
def _get_run_response(self, thread, run): | ||
def _process_messages(self, assistant_thread, run): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Decorate the arguments, what types of object are thread and run |
||
""" | ||
Processes and provides a response based on the run status. | ||
Args: | ||
assistant_thread: The thread object for the assistant. | ||
run: The run object initiated with the OpenAI assistant. | ||
""" | ||
if run.status == "failed": | ||
logger.error(f'Run: {run.id} Thread: {assistant_thread.id}: failed...') | ||
if run.last_error: | ||
response = { | ||
"role": "assistant", | ||
"content": str(run.last_error), | ||
} | ||
else: | ||
response = { | ||
"role": "assistant", | ||
"content": 'Failed', | ||
} | ||
return response | ||
elif run.status == "expired": | ||
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: expired...') | ||
response = { | ||
"role": "assistant", | ||
"content": 'Expired', | ||
} | ||
return new_messages | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new_messages is not defined here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed this in this PR: https://github.com/syscoin/autogen/pull/1 @sidhujag @IANTHEREAL please take a look and add to your PR. I don't have write access to your repo. |
||
elif run.status == "cancelled": | ||
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelled...') | ||
response = { | ||
"role": "assistant", | ||
"content": 'Cancelled', | ||
} | ||
return response | ||
elif run.status == "completed": | ||
logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: completed...') | ||
response_messages = self._openai_client.beta.threads.messages.list(assistant_thread.id, order="asc") | ||
new_messages = [] | ||
for msg in response_messages: | ||
if msg.run_id == run.id: | ||
for content in msg.content: | ||
if content.type == "text": | ||
new_messages.append( | ||
{"role": msg.role, "content": self._format_assistant_message(content.text)} | ||
) | ||
elif content.type == "image_file": | ||
new_messages.append( | ||
{ | ||
"role": msg.role, | ||
"content": f"Recieved file id={content.image_file.file_id}", | ||
} | ||
) | ||
response = { | ||
"role": new_messages[-1]["role"], | ||
"content": "", | ||
} | ||
for message in new_messages: | ||
# just logging or do something with the intermediate messages? | ||
# if current response is not empty and there is more, append new lines | ||
if len(response["content"]) > 0: | ||
response["content"] += "\n\n" | ||
response["content"] += message["content"] | ||
return response | ||
|
||
def _get_run_response(self, assistant_thread, run): | ||
""" | ||
Waits for and processes the response of a run from the OpenAI assistant. | ||
|
||
Args: | ||
assistant_thread: The thread object for the assistant. | ||
run: The run object initiated with the OpenAI assistant. | ||
|
||
Returns: | ||
Updated run object, status of the run, and response messages. | ||
""" | ||
while True: | ||
run = self._wait_for_run(run.id, thread.id) | ||
if run.status == "completed": | ||
response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc") | ||
|
||
new_messages = [] | ||
for msg in response_messages: | ||
if msg.run_id == run.id: | ||
for content in msg.content: | ||
if content.type == "text": | ||
new_messages.append( | ||
{"role": msg.role, "content": self._format_assistant_message(content.text)} | ||
) | ||
elif content.type == "image_file": | ||
new_messages.append( | ||
{ | ||
"role": msg.role, | ||
"content": f"Recieved file id={content.image_file.file_id}", | ||
} | ||
) | ||
return new_messages | ||
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id) | ||
if run.status == "in_progress" or run.status == "queued": | ||
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000) | ||
run = self._openai_client.beta.threads.runs.retrieve(run.id, thread_id=assistant_thread.id) | ||
elif run.status == "completed" or run.status == "cancelled" or run.status == "expired" or run.status == "failed": | ||
return self._process_messages(assistant_thread, run) | ||
elif run.status == "cancelling": | ||
logger.warn(f'Run: {run.id} Thread: {assistant_thread.id}: cancelling...') | ||
elif run.status == "requires_action": | ||
logger.info(f'Run: {run.id} Thread: {assistant_thread.id}: required action...') | ||
actions = [] | ||
for tool_call in run.required_action.submit_tool_outputs.tool_calls: | ||
function = tool_call.function | ||
is_exec_success, tool_response = self.execute_function(function.dict()) | ||
tool_response["metadata"] = { | ||
"tool_call_id": tool_call.id, | ||
"run_id": run.id, | ||
"thread_id": thread.id, | ||
"thread_id": assistant_thread.id, | ||
} | ||
|
||
logger.info( | ||
"Intermediate executing(%s, Sucess: %s) : %s", | ||
"Intermediate executing(%s, Success: %s) : %s", | ||
tool_response["name"], | ||
is_exec_success, | ||
tool_response["content"], | ||
|
@@ -217,38 +261,36 @@ def _get_run_response(self, thread, run): | |
for action in actions | ||
], | ||
"run_id": run.id, | ||
"thread_id": thread.id, | ||
"thread_id": assistant_thread.id, | ||
} | ||
|
||
run = self._openai_client.beta.threads.runs.submit_tool_outputs(**submit_tool_outputs) | ||
if self.check_for_cancellation(): | ||
self._cancel_run() | ||
else: | ||
run_info = json.dumps(run.dict(), indent=2) | ||
raise ValueError(f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})") | ||
|
||
def _wait_for_run(self, run_id: str, thread_id: str) -> Any: | ||
|
||
def _cancel_run(self, run_id: str, thread_id: str): | ||
""" | ||
Waits for a run to complete or reach a final state. | ||
Cancels a run. | ||
|
||
Args: | ||
run_id: The ID of the run. | ||
thread_id: The ID of the thread associated with the run. | ||
|
||
Returns: | ||
The updated run object after completion or reaching a final state. | ||
""" | ||
in_progress = True | ||
while in_progress: | ||
run = self._openai_client.beta.threads.runs.retrieve(run_id, thread_id=thread_id) | ||
in_progress = run.status in ("in_progress", "queued") | ||
if in_progress: | ||
time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000) | ||
return run | ||
try: | ||
self._openai_client.beta.threads.runs.cancel(run_id=run_id, thread_id=thread_id) | ||
logger.info(f'Run: {run_id} Thread: {thread_id}: successfully sent cancellation signal.') | ||
except Exception as e: | ||
logger.error(f'Run: {run_id} Thread: {thread_id}: failed to send cancellation signal: {e}') | ||
|
||
|
||
def _format_assistant_message(self, message_content): | ||
""" | ||
Formats the assistant's message to include annotations and citations. | ||
""" | ||
|
||
annotations = message_content.annotations | ||
citations = [] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an useage example for this argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was assumed an external object may want to cancel because the run is looping in the state machine, this is a way to try to cancel before it finishes, wasting tokens or making other calls that are not intended
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, Implementing a timeout mechanism for the GPT Assistant Agent would be even better.