diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index f9e77468007..f319f2b489c 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -5,6 +5,7 @@ import logging from autogen import OpenAIWrapper +from autogen.oai.openai_utils import retrieval_assistants_by_name from autogen.agentchat.agent import Agent from autogen.agentchat.assistant_agent import ConversableAgent from autogen.agentchat.assistant_agent import AssistantAgent @@ -52,20 +53,31 @@ def __init__( self._openai_client = oai_wrapper._clients[0] openai_assistant_id = llm_config.get("assistant_id", None) if openai_assistant_id is None: - logger.warning("assistant_id was None, creating a new assistant") - # create a new assistant - if instructions is None: - logger.warning( - "No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE." + # try to find assistant by name first + candidate_assistants = retrieval_assistants_by_name(self._openai_client, name) + + if len(candidate_assistants) == 0: + logger.warning(f"assistant {name} does not exist, creating a new assistant") + # create a new assistant + if instructions is None: + logger.warning( + "No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE." + ) + instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE + self._openai_assistant = self._openai_client.beta.assistants.create( + name=name, + instructions=instructions, + tools=llm_config.get("tools", []), + model=llm_config.get("model", "gpt-4-1106-preview"), + file_ids=llm_config.get("file_ids", []), ) - instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE - self._openai_assistant = self._openai_client.beta.assistants.create( - name=name, - instructions=instructions, - tools=llm_config.get("tools", []), - model=llm_config.get("model", "gpt-4-1106-preview"), - file_ids=llm_config.get("file_ids", []), - ) + else: + if len(candidate_assistants) > 1: + logger.warning( + f"Multiple assistants with name {name} found. Using the first assistant in the list. " + f"Please specify the assistant ID in llm_config to use a specific assistant." + ) + self._openai_assistant = candidate_assistants[0] else: # retrieve an existing assistant self._openai_assistant = self._openai_client.beta.assistants.retrieve(openai_assistant_id) @@ -95,7 +107,7 @@ def __init__( llm_config=llm_config, ) - # lazly create thread + # lazly create threads self._openai_threads = {} self._unread_index = defaultdict(int) self.register_reply(Agent, GPTAssistantAgent._invoke_assistant) diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index 1c42c5c3503..b1325183b47 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -381,3 +381,15 @@ def config_list_from_dotenv( logging.info(f"Models available: {[config['model'] for config in config_list]}") return config_list + + +def retrieval_assistants_by_name(client, name) -> str: + """ + Return the assistants with the given name from OAI assistant API + """ + assistants = client.beta.assistants.list() + candidate_assistants = [] + for assistant in assistants.data: + if assistant.name == name: + candidate_assistants.append(assistant) + return candidate_assistants diff --git a/test/agentchat/contrib/test_gpt_assistant.py b/test/agentchat/contrib/test_gpt_assistant.py index ae1e0bdf88b..c344bc645bd 100644 --- a/test/agentchat/contrib/test_gpt_assistant.py +++ b/test/agentchat/contrib/test_gpt_assistant.py @@ -9,6 +9,7 @@ try: from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent + from autogen.oai.openai_utils import retrieval_assistants_by_name skip_test = False except ImportError: @@ -57,14 +58,16 @@ def test_gpt_assistant_chat(): ok, response = analyst._invoke_assistant( [{"role": "user", "content": "What is the most popular open source project on GitHub?"}] ) + executable = analyst.can_execute_function("ossinsight_data_api") + analyst.reset() + threads_count = len(analyst._openai_threads) + analyst.delete_assistant() + assert ok is True assert response.get("role", "") == "assistant" assert len(response.get("content", "")) > 0 - - assert analyst.can_execute_function("ossinsight_data_api") is False - - analyst.reset() - assert len(analyst._openai_threads) == 0 + assert executable is False + assert threads_count == 0 @pytest.mark.skipif( @@ -203,6 +206,31 @@ def test_get_assistant_files(): assert expected_file_id in retrived_file_ids +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"] or skip_test, + reason="do not run on MacOS or windows or dependency is not installed", +) +def test_assistant_retrieval(): + name = "For GPTAssistantAgent retrieval testing" + + assistant = GPTAssistantAgent( + name, + instructions="This is a test", + llm_config={"config_list": config_list}, + ) + candidate_first = retrieval_assistants_by_name(assistant.openai_client, name) + + assistant = GPTAssistantAgent( + name, + instructions="This is a test", + llm_config={"config_list": config_list}, + ) + candidate_second = retrieval_assistants_by_name(assistant.openai_client, name) + + assistant.delete_assistant() + assert candidate_first == candidate_second + + if __name__ == "__main__": test_gpt_assistant_chat() test_get_assistant_instructions()