Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

respond to more run status states and other GPT assistant housekeeping changes #665

Closed
wants to merge 10 commits into from
70 changes: 61 additions & 9 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.assistant_agent import AssistantAgent
from typing import Dict, Optional, Union, List, Tuple, Any
from typing import Dict, Optional, Union, List, Tuple, Any, Callable

logger = logging.getLogger(__name__)

Expand All @@ -22,9 +22,11 @@ class GPTAssistantAgent(ConversableAgent):
def __init__(
self,
name="GPT Assistant",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
sidhujag marked this conversation as resolved.
Show resolved Hide resolved
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
overwrite_instructions: bool = False,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -86,18 +88,24 @@ def __init__(
logger.warning(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)

_is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: "TERMINATE" in x.get("content", ""))
)
super().__init__(
name=name,
system_message=instructions,
human_input_mode="NEVER",
llm_config=llm_config,
is_termination_msg=_is_termination_msg,
**kwargs
)

# lazly create thread
self._openai_threads = {}
self._unread_index = defaultdict(int)
self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
self._reply_func_list = []
self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
sonichi marked this conversation as resolved.
Show resolved Hide resolved


def _invoke_assistant(
self,
Expand Down Expand Up @@ -146,7 +154,7 @@ def _invoke_assistant(

run_response_messages = self._get_run_response(assistant_thread, run)
assert len(run_response_messages) > 0, "No response from the assistant."

response = {
"role": run_response_messages[-1]["role"],
"content": "",
Expand All @@ -157,7 +165,6 @@ def _invoke_assistant(
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]

self._unread_index[sender] = len(self._oai_messages[sender]) + 1
return True, response

Expand All @@ -173,9 +180,53 @@ def _get_run_response(self, thread, run):
"""
while True:
run = self._wait_for_run(run.id, thread.id)
if run.status == "completed":
if run.status == "failed":
new_messages = []
logger.error(f'Run: {run.id} Thread: {thread.id}: failed...')
if run.last_error:
sidhujag marked this conversation as resolved.
Show resolved Hide resolved
new_messages.append(
{
"role": "assistant",
"content": str(run.last_error),
}
)
else:
new_messages.append(
{
"role": "assistant",
"content": 'Failed',
}
)
return new_messages
elif run.status == "cancelling":
logger.warn(f'Run: {run.id} Thread: {thread.id}: cancelling...')
elif run.status == "expired":
logger.warn(f'Run: {run.id} Thread: {thread.id}: expired...')
new_messages = []
new_messages.append(
{
"role": "assistant",
"content": 'Expired',
}
)
return new_messages
elif run.status == "cancelled":
logger.warn(f'Run: {run.id} Thread: {thread.id}: cancelled...')
new_messages = []
new_messages.append(
{
"role": "assistant",
"content": 'Cancelled',
}
)
return new_messages
elif run.status == "in_progress":
logger.info(f'Run: {run.id} Thread: {thread.id}: in progress...')
elif run.status == "queued":
logger.info(f'Run: {run.id} Thread: {thread.id}: queued...')
elif run.status == "completed":
logger.info(f'Run: {run.id} Thread: {thread.id}: completed...')
response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc")

new_messages = []
for msg in response_messages:
if msg.run_id == run.id:
Expand All @@ -193,6 +244,7 @@ def _get_run_response(self, thread, run):
)
return new_messages
elif run.status == "requires_action":
logger.info(f'Run: {run.id} Thread: {thread.id}: required action...')
actions = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
function = tool_call.function
Expand All @@ -204,7 +256,7 @@ def _get_run_response(self, thread, run):
}

logger.info(
"Intermediate executing(%s, Sucess: %s) : %s",
"Intermediate executing(%s, Success: %s) : %s",
tool_response["name"],
is_exec_success,
tool_response["content"],
Expand Down
Loading