Skip to content

Commit

Permalink
Improve test for function call in groupchat (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
sonichi authored Jan 14, 2024
1 parent 3b2955d commit 1565795
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions test/agentchat/test_function_call_groupchat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import autogen
import pytest
import asyncio
import sys
import os
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
Expand Down Expand Up @@ -34,11 +35,16 @@
("tools", [{"type": "function", "function": func_def}], True),
],
)
def test_function_call_groupchat(key, value, sync):
@pytest.mark.asyncio
async def test_function_call_groupchat(key, value, sync):
import random

def get_random_number():
return random.randint(0, 100)
class Function:
call_count = 0

def get_random_number(self):
self.call_count += 1
return random.randint(0, 100)

config_list_gpt4 = autogen.config_list_from_json(
OAI_CONFIG_LIST,
Expand All @@ -56,25 +62,28 @@ def get_random_number():
llm_config_no_function = llm_config.copy()
del llm_config_no_function[key]

func = Function()
user_proxy = autogen.UserProxyAgent(
name="Executor",
description="An executor that will execute function_calls.",
function_map={"get_random_number": get_random_number},
description="An executor that executes function_calls.",
function_map={"get_random_number": func.get_random_number},
human_input_mode="NEVER",
)
player = autogen.AssistantAgent(
name="Player",
system_message="You will use function `get_random_number` to get a random number. Stop only when you get at least 1 even number and 1 odd number. Reply TERMINATE to stop.",
description="A player that will make function_calls.",
description="A player that makes function_calls.",
llm_config=llm_config,
)
observer = autogen.AssistantAgent(
name="Observer",
system_message="You observe the conversation between the executor and the player. Summarize the conversation in 1 sentence.",
description="An observer that will observe the conversation.",
system_message="You observe the the player's actions and results. Summarize in 1 sentence.",
description="An observer.",
llm_config=llm_config_no_function,
)
groupchat = autogen.GroupChat(agents=[user_proxy, player, observer], messages=[], max_round=7)
groupchat = autogen.GroupChat(
agents=[player, user_proxy, observer], messages=[], max_round=7, speaker_selection_method="round_robin"
)

# pass in llm_config with functions
with pytest.raises(
Expand All @@ -86,9 +95,10 @@ def get_random_number():
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config_no_function)

if sync:
user_proxy.initiate_chat(manager, message="Let's start the game!")
observer.initiate_chat(manager, message="Let's start the game!")
else:
user_proxy.a_initiate_chat(manager, message="Let's start the game!")
await observer.a_initiate_chat(manager, message="Let's start the game!")
assert func.call_count >= 1, "The function get_random_number should be called at least once."


def test_no_function_map():
Expand Down Expand Up @@ -119,5 +129,5 @@ def test_no_function_map():


if __name__ == "__main__":
test_function_call_groupchat()
asyncio.run(test_function_call_groupchat("functions", [func_def], True))
# test_no_function_map()

0 comments on commit 1565795

Please sign in to comment.