Skip to content

Commit

Permalink
respond to more run status states
Browse files Browse the repository at this point in the history
1) according to the spec there are other states we can handle in wait_for_run function, so I added those.
2) added termination msg param.
3) register_reply using invoke_assistant and check_termination_and_human_reply in order, so we can check for exit/human reply for human_input_mode != "NEVER". Remove the hardcoded human_input_mode.
4) return empty array if while loop terminates for some reason without returning messages from the state machine (while loop)
  • Loading branch information
jagdeep sidhu committed Nov 13, 2023
1 parent 841b533 commit 12e285e
Showing 1 changed file with 67 additions and 9 deletions.
76 changes: 67 additions & 9 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.assistant_agent import AssistantAgent
from typing import Dict, Optional, Union, List, Tuple, Any
from typing import Dict, Optional, Union, List, Tuple, Any, Callable

logger = logging.getLogger(__name__)

Expand All @@ -22,9 +22,11 @@ class GPTAssistantAgent(ConversableAgent):
def __init__(
self,
name="GPT Assistant",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
overwrite_instructions: bool = False,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -86,18 +88,24 @@ def __init__(
logger.warning(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)

_is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: "TERMINATE" in x.get("content", ""))
)
super().__init__(
name=name,
system_message=instructions,
human_input_mode="NEVER",
llm_config=llm_config,
is_termination_msg=_is_termination_msg,
**kwargs
)

# lazly create thread
self._openai_threads = {}
self._unread_index = defaultdict(int)
self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
self._reply_func_list = []
self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)


def _invoke_assistant(
self,
Expand Down Expand Up @@ -146,7 +154,7 @@ def _invoke_assistant(

run_response_messages = self._get_run_response(assistant_thread, run)
assert len(run_response_messages) > 0, "No response from the assistant."

response = {
"role": run_response_messages[-1]["role"],
"content": "",
Expand All @@ -157,7 +165,6 @@ def _invoke_assistant(
if len(response["content"]) > 0:
response["content"] += "\n\n"
response["content"] += message["content"]

self._unread_index[sender] = len(self._oai_messages[sender]) + 1
return True, response

Expand All @@ -173,10 +180,60 @@ def _get_run_response(self, thread, run):
"""
while True:
run = self._wait_for_run(run.id, thread.id)
if run.status == "completed":
if run.status == "failed":
new_messages = []
print(f'Run: {run.id} Thread: {thread.id}: failed...')
if run.last_error:
new_messages.append(
{
"role": msg.role,
"content": f'Last error: {run.last_error}',
}
)
new_messages.append(
{
"role": msg.role,
"content": 'Failed',
}
)
return new_messages
elif run.status == "cancelling":
print(f'Run: {run.id} Thread: {thread.id}: cancelling...')
elif run.status == "expired":
print(f'Run: {run.id} Thread: {thread.id}: expired...')
new_messages = []
new_messages.append(
{
"role": msg.role,
"content": 'Expired',
}
)
return new_messages
elif run.status == "cancelled":
print(f'Run: {run.id} Thread: {thread.id}: cancelled...')
new_messages = []
new_messages.append(
{
"role": msg.role,
"content": 'Cancelled',
}
)
return new_messages
elif run.status == "in_progress":
print(f'Run: {run.id} Thread: {thread.id}: in progress...')
elif run.status == "queued":
print(f'Run: {run.id} Thread: {thread.id}: queued...')
elif run.status == "completed":
print(f'Run: {run.id} Thread: {thread.id}: completed...')
response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc")

new_messages = []
if run.last_error:
new_messages.append(
{
"role": msg.role,
"content": f'Last error: {run.last_error}',
}
)
for msg in response_messages:
if msg.run_id == run.id:
for content in msg.content:
Expand All @@ -193,6 +250,7 @@ def _get_run_response(self, thread, run):
)
return new_messages
elif run.status == "requires_action":
print(f'Run: {run.id} Thread: {thread.id}: required action...')
actions = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
function = tool_call.function
Expand Down Expand Up @@ -224,7 +282,7 @@ def _get_run_response(self, thread, run):
else:
run_info = json.dumps(run.dict(), indent=2)
raise ValueError(f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})")

return []
def _wait_for_run(self, run_id: str, thread_id: str) -> Any:
"""
Waits for a run to complete or reach a final state.
Expand Down

0 comments on commit 12e285e

Please sign in to comment.