From 156579565d30c1e21a32a654ca486c7f8b5cde9b Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Sun, 14 Jan 2024 14:55:53 -0800 Subject: [PATCH] Improve test for function call in groupchat (#1252) --- .../agentchat/test_function_call_groupchat.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/test/agentchat/test_function_call_groupchat.py b/test/agentchat/test_function_call_groupchat.py index f6e2f06898e..43ce508a4cc 100644 --- a/test/agentchat/test_function_call_groupchat.py +++ b/test/agentchat/test_function_call_groupchat.py @@ -1,5 +1,6 @@ import autogen import pytest +import asyncio import sys import os from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST @@ -34,11 +35,16 @@ ("tools", [{"type": "function", "function": func_def}], True), ], ) -def test_function_call_groupchat(key, value, sync): +@pytest.mark.asyncio +async def test_function_call_groupchat(key, value, sync): import random - def get_random_number(): - return random.randint(0, 100) + class Function: + call_count = 0 + + def get_random_number(self): + self.call_count += 1 + return random.randint(0, 100) config_list_gpt4 = autogen.config_list_from_json( OAI_CONFIG_LIST, @@ -56,25 +62,28 @@ def get_random_number(): llm_config_no_function = llm_config.copy() del llm_config_no_function[key] + func = Function() user_proxy = autogen.UserProxyAgent( name="Executor", - description="An executor that will execute function_calls.", - function_map={"get_random_number": get_random_number}, + description="An executor that executes function_calls.", + function_map={"get_random_number": func.get_random_number}, human_input_mode="NEVER", ) player = autogen.AssistantAgent( name="Player", system_message="You will use function `get_random_number` to get a random number. Stop only when you get at least 1 even number and 1 odd number. Reply TERMINATE to stop.", - description="A player that will make function_calls.", + description="A player that makes function_calls.", llm_config=llm_config, ) observer = autogen.AssistantAgent( name="Observer", - system_message="You observe the conversation between the executor and the player. Summarize the conversation in 1 sentence.", - description="An observer that will observe the conversation.", + system_message="You observe the the player's actions and results. Summarize in 1 sentence.", + description="An observer.", llm_config=llm_config_no_function, ) - groupchat = autogen.GroupChat(agents=[user_proxy, player, observer], messages=[], max_round=7) + groupchat = autogen.GroupChat( + agents=[player, user_proxy, observer], messages=[], max_round=7, speaker_selection_method="round_robin" + ) # pass in llm_config with functions with pytest.raises( @@ -86,9 +95,10 @@ def get_random_number(): manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config_no_function) if sync: - user_proxy.initiate_chat(manager, message="Let's start the game!") + observer.initiate_chat(manager, message="Let's start the game!") else: - user_proxy.a_initiate_chat(manager, message="Let's start the game!") + await observer.a_initiate_chat(manager, message="Let's start the game!") + assert func.call_count >= 1, "The function get_random_number should be called at least once." def test_no_function_map(): @@ -119,5 +129,5 @@ def test_no_function_map(): if __name__ == "__main__": - test_function_call_groupchat() + asyncio.run(test_function_call_groupchat("functions", [func_def], True)) # test_no_function_map()