Skip to content

Commit

Permalink
support retrievaling assistant by name (#718)
Browse files Browse the repository at this point in the history
* support assistant retrieval using name

* address comment

* Update autogen/agentchat/contrib/gpt_assistant_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* make code more reading friendly

* ignore test error

* format code

* Update autogen/agentchat/contrib/gpt_assistant_agent.py

typo fix

* fix test case

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
  • Loading branch information
3 people authored Nov 24, 2023
1 parent 6087b5a commit c705c6a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 28 deletions.
42 changes: 27 additions & 15 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging

from autogen import OpenAIWrapper
from autogen.oai.openai_utils import retrieve_assistants_by_name
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.assistant_agent import AssistantAgent
Expand All @@ -28,7 +29,7 @@ def __init__(
):
"""
Args:
name (str): name of the agent.
name (str): name of the agent. It will be used to find the existing assistant by name. Please remember to delete an old assistant with the same name if you intend to create a new assistant with the same name.
instructions (str): instructions for the OpenAI assistant configuration.
When instructions is not None, the system message of the agent will be
set to the provided instructions and used in the assistant run, irrespective
Expand All @@ -52,20 +53,31 @@ def __init__(
self._openai_client = oai_wrapper._clients[0]
openai_assistant_id = llm_config.get("assistant_id", None)
if openai_assistant_id is None:
logger.warning("assistant_id was None, creating a new assistant")
# create a new assistant
if instructions is None:
logger.warning(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
# try to find assistant by name first
candidate_assistants = retrieve_assistants_by_name(self._openai_client, name)

if len(candidate_assistants) == 0:
logger.warning(f"assistant {name} does not exist, creating a new assistant")
# create a new assistant
if instructions is None:
logger.warning(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
self._openai_assistant = self._openai_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
file_ids=llm_config.get("file_ids", []),
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
self._openai_assistant = self._openai_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
file_ids=llm_config.get("file_ids", []),
)
else:
if len(candidate_assistants) > 1:
logger.warning(
f"Multiple assistants with name {name} found. Using the first assistant in the list. "
f"Please specify the assistant ID in llm_config to use a specific assistant."
)
self._openai_assistant = candidate_assistants[0]
else:
# retrieve an existing assistant
self._openai_assistant = self._openai_client.beta.assistants.retrieve(openai_assistant_id)
Expand Down Expand Up @@ -95,7 +107,7 @@ def __init__(
llm_config=llm_config,
)

# lazly create thread
# lazily create threads
self._openai_threads = {}
self._unread_index = defaultdict(int)
self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
Expand Down
12 changes: 12 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,15 @@ def config_list_from_dotenv(

logging.info(f"Models available: {[config['model'] for config in config_list]}")
return config_list


def retrieve_assistants_by_name(client, name) -> str:
"""
Return the assistants with the given name from OAI assistant API
"""
assistants = client.beta.assistants.list()
candidate_assistants = []
for assistant in assistants.data:
if assistant.name == name:
candidate_assistants.append(assistant)
return candidate_assistants
69 changes: 56 additions & 13 deletions test/agentchat/contrib/test_gpt_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402

try:
import openai
from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
from autogen.oai.openai_utils import retrieve_assistants_by_name

skip_test = False
except ImportError:
Expand Down Expand Up @@ -43,8 +45,9 @@ def test_gpt_assistant_chat():
"description": "This is an API endpoint allowing users (analysts) to input question about GitHub in text format to retrieve the realted and structured data.",
}

name = "For test_gpt_assistant_chat"
analyst = GPTAssistantAgent(
name="Open_Source_Project_Analyst",
name=name,
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": config_list},
instructions="Hello, Open Source Project Analyst. You'll conduct comprehensive evaluations of open source projects or organizations on the GitHub platform",
)
Expand All @@ -57,14 +60,16 @@ def test_gpt_assistant_chat():
ok, response = analyst._invoke_assistant(
[{"role": "user", "content": "What is the most popular open source project on GitHub?"}]
)
executable = analyst.can_execute_function("ossinsight_data_api")
analyst.reset()
threads_count = len(analyst._openai_threads)
analyst.delete_assistant()

assert ok is True
assert response.get("role", "") == "assistant"
assert len(response.get("content", "")) > 0

assert analyst.can_execute_function("ossinsight_data_api") is False

analyst.reset()
assert len(analyst._openai_threads) == 0
assert executable is False
assert threads_count == 0


@pytest.mark.skipif(
Expand All @@ -76,9 +81,9 @@ def test_get_assistant_instructions():
Test function to create a new GPTAssistantAgent, set its instructions, retrieve the instructions,
and assert that the retrieved instructions match the set instructions.
"""

name = "For test_get_assistant_instructions"
assistant = GPTAssistantAgent(
"assistant",
name,
instructions="This is a test",
llm_config={
"config_list": config_list,
Expand Down Expand Up @@ -107,11 +112,12 @@ def test_gpt_assistant_instructions_overwrite():
4. Check that the instructions of the assistant have been overwritten with the new ones.
"""

name = "For test_gpt_assistant_instructions_overwrite"
instructions1 = "This is a test #1"
instructions2 = "This is a test #2"

assistant = GPTAssistantAgent(
"assistant",
name,
instructions=instructions1,
llm_config={
"config_list": config_list,
Expand All @@ -120,7 +126,7 @@ def test_gpt_assistant_instructions_overwrite():

assistant_id = assistant.assistant_id
assistant = GPTAssistantAgent(
"assistant",
name,
instructions=instructions2,
llm_config={
"config_list": config_list,
Expand All @@ -144,10 +150,11 @@ def test_gpt_assistant_existing_no_instructions():
Test function to check if the GPTAssistantAgent can retrieve instructions for an existing assistant
even if the assistant was created with no instructions initially.
"""
name = "For test_gpt_assistant_existing_no_instructions"
instructions = "This is a test #1"

assistant = GPTAssistantAgent(
"assistant",
name,
instructions=instructions,
llm_config={
"config_list": config_list,
Expand All @@ -158,7 +165,7 @@ def test_gpt_assistant_existing_no_instructions():

# create a new assistant with the same ID but no instructions
assistant = GPTAssistantAgent(
"assistant",
name,
llm_config={
"config_list": config_list,
"assistant_id": assistant_id,
Expand All @@ -182,9 +189,10 @@ def test_get_assistant_files():
current_file_path = os.path.abspath(__file__)
openai_client = OpenAIWrapper(config_list=config_list)._clients[0]
file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
name = "For test_get_assistant_files"

assistant = GPTAssistantAgent(
"assistant",
name,
instructions="This is a test",
llm_config={
"config_list": config_list,
Expand All @@ -203,6 +211,41 @@ def test_get_assistant_files():
assert expected_file_id in retrived_file_ids


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
def test_assistant_retrieval():
"""
Test function to check if the GPTAssistantAgent can retrieve the same assistant
"""

name = "For test_assistant_retrieval"

assistant_first = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config={"config_list": config_list},
)
candidate_first = retrieve_assistants_by_name(assistant_first.openai_client, name)

assistant_second = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config={"config_list": config_list},
)
candidate_second = retrieve_assistants_by_name(assistant_second.openai_client, name)

try:
assistant_first.delete_assistant()
assistant_second.delete_assistant()
except openai.NotFoundError:
# Not found error is expected because the same assistant can not be deleted twice
pass

assert candidate_first == candidate_second


if __name__ == "__main__":
test_gpt_assistant_chat()
test_get_assistant_instructions()
Expand Down

0 comments on commit c705c6a

Please sign in to comment.