Skip to content

Commit

Permalink
support assistant retrieval using name
Browse files Browse the repository at this point in the history
  • Loading branch information
IANTHEREAL committed Nov 19, 2023
1 parent 75cc763 commit 2d8138e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 19 deletions.
40 changes: 26 additions & 14 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging

from autogen import OpenAIWrapper
from autogen.oai.openai_utils import retrieval_assistants_by_name
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.assistant_agent import AssistantAgent
Expand Down Expand Up @@ -52,20 +53,31 @@ def __init__(
self._openai_client = oai_wrapper._clients[0]
openai_assistant_id = llm_config.get("assistant_id", None)
if openai_assistant_id is None:
logger.warning("assistant_id was None, creating a new assistant")
# create a new assistant
if instructions is None:
logger.warning(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
# try to find assistant by name first
candidate_assistants = retrieval_assistants_by_name(self._openai_client, name)

if len(candidate_assistants) == 0:
logger.warning(f"assistant {name} does not exist, creating a new assistant")
# create a new assistant
if instructions is None:
logger.warning(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
self._openai_assistant = self._openai_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
file_ids=llm_config.get("file_ids", []),
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
self._openai_assistant = self._openai_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
file_ids=llm_config.get("file_ids", []),
)
else:
if len(candidate_assistants) > 1:
logger.warning(
f"Multiple assistants with name {name} found. Using the first assistant in the list. "
f"Please specify the assistant ID in llm_config to use a specific assistant."
)
self._openai_assistant = candidate_assistants[0]
else:
# retrieve an existing assistant
self._openai_assistant = self._openai_client.beta.assistants.retrieve(openai_assistant_id)
Expand Down Expand Up @@ -95,7 +107,7 @@ def __init__(
llm_config=llm_config,
)

# lazly create thread
# lazly create threads
self._openai_threads = {}
self._unread_index = defaultdict(int)
self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
Expand Down
12 changes: 12 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,15 @@ def config_list_from_dotenv(

logging.info(f"Models available: {[config['model'] for config in config_list]}")
return config_list


def retrieval_assistants_by_name(client, name) -> str:
"""
Return the assistants with the given name from OAI assistant API
"""
assistants = client.beta.assistants.list()
candidate_assistants = []
for assistant in assistants.data:
if assistant.name == name:
candidate_assistants.append(assistant)
return candidate_assistants
38 changes: 33 additions & 5 deletions test/agentchat/contrib/test_gpt_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

try:
from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
from autogen.oai.openai_utils import retrieval_assistants_by_name

skip_test = False
except ImportError:
Expand Down Expand Up @@ -57,14 +58,16 @@ def test_gpt_assistant_chat():
ok, response = analyst._invoke_assistant(
[{"role": "user", "content": "What is the most popular open source project on GitHub?"}]
)
executable = analyst.can_execute_function("ossinsight_data_api")
analyst.reset()
threads_count = len(analyst._openai_threads)
analyst.delete_assistant()

assert ok is True
assert response.get("role", "") == "assistant"
assert len(response.get("content", "")) > 0

assert analyst.can_execute_function("ossinsight_data_api") is False

analyst.reset()
assert len(analyst._openai_threads) == 0
assert executable is False
assert threads_count == 0


@pytest.mark.skipif(
Expand Down Expand Up @@ -203,6 +206,31 @@ def test_get_assistant_files():
assert expected_file_id in retrived_file_ids


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
def test_assistant_retrieval():
name = "For GPTAssistantAgent retrieval testing"

assistant = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config={"config_list": config_list},
)
candidate_first = retrieval_assistants_by_name(assistant.openai_client, name)

assistant = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config={"config_list": config_list},
)
candidate_second = retrieval_assistants_by_name(assistant.openai_client, name)

assistant.delete_assistant()
assert candidate_first == candidate_second


if __name__ == "__main__":
test_gpt_assistant_chat()
test_get_assistant_instructions()
Expand Down

0 comments on commit 2d8138e

Please sign in to comment.