Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function call filter in group chat #294

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,12 @@ def register_function(self, function_map: Dict[str, Callable]):
function_map: a dictionary mapping function names to functions.
"""
self._function_map.update(function_map)

def can_execute_function(self, name: str) -> bool:
"""Whether the agent can execute the function."""
return name in self._function_map

@property
def function_map(self) -> Dict[str, Callable]:
"""Return the function map."""
return self._function_map
73 changes: 54 additions & 19 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@

@dataclass
class GroupChat:
"""A group chat class that contains a list of agents and the maximum number of rounds."""
"""A group chat class that contains the following data fields:
- agents: a list of participating agents.
- messages: a list of messages in the group chat.
- max_round: the maximum number of rounds.
- admin_name: the name of the admin agent if there is one. Default is "Admin".
KeyBoardInterrupt will make the admin agent take over.
- func_call_filter: whether to enforce function call filter. Default is True.
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
"""

agents: List[Agent]
messages: List[Dict]
max_round: int = 10
admin_name: str = "Admin" # the name of the admin agent
admin_name: str = "Admin"
func_call_filter: bool = True

@property
def agent_names(self) -> List[str]:
Expand All @@ -30,45 +41,69 @@ def agent_by_name(self, name: str) -> Agent:
"""Find the next speaker based on the message."""
return self.agents[self.agent_names.index(name)]

def next_agent(self, agent: Agent) -> Agent:
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
"""Return the next agent in the list."""
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]

def select_speaker_msg(self):
if agents == self.agents:
return agents[(self.agent_names.index(agent.name) + 1) % len(agents)]
else:
offset = self.agent_names.index(agent.name) + 1
for i in range(len(self.agents)):
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]

def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}.

Read the following conversation.
Then select the next role from {self.agent_names} to play. Only return the role."""
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""

def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
selector.update_system_message(self.select_speaker_msg())

# Warn if GroupChat is underpopulated, without established changing behavior
n_agents = len(self.agent_names)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
)

if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
# find agents with the right function_map which contains the function name
agents = [
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
]
if len(agents) == 1:
# only one agent can execute the function
return agents[0]
elif not agents:
# find all the agents with function_map
agents = [agent for agent in self.agents if agent.function_map]
if len(agents) == 1:
return agents[0]
elif not agents:
raise ValueError(
f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents."
)
else:
agents = self.agents
# Warn if GroupChat is underpopulated
n_agents = len(agents)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
)
selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply(
self.messages
+ [
{
"role": "system",
"content": f"Read the above conversation. Then select the next role from {self.agent_names} to play. Only return the role.",
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}
]
)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
return self.next_agent(last_speaker)
return self.next_agent(last_speaker, agents)
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker)
return self.next_agent(last_speaker, agents)

def _participant_roles(self):
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
Expand Down
56 changes: 54 additions & 2 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,54 @@
import pytest
import autogen


def test_func_call_groupchat():
agent1 = autogen.ConversableAgent(
"alice",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice sepaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
function_map={"test_func": lambda x: x},
)
groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=3)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})

assert len(groupchat.messages) == 3
assert (
groupchat.messages[-2]["role"] == "function"
and groupchat.messages[-2]["name"] == "test_func"
and groupchat.messages[-2]["content"] == "1"
)
assert groupchat.messages[-1]["name"] == "alice"

agent3 = autogen.ConversableAgent(
"carol",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is carol speaking.",
function_map={"test_func": lambda x: x + 1},
)
groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=3)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent3.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})

assert (
groupchat.messages[-2]["role"] == "function"
and groupchat.messages[-2]["name"] == "test_func"
and groupchat.messages[-2]["content"] == "1"
)
assert groupchat.messages[-1]["name"] == "carol"

agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})


def test_chat_manager():
agent1 = autogen.ConversableAgent(
"alice",
Expand Down Expand Up @@ -30,6 +78,9 @@ def test_chat_manager():
agent2.initiate_chat(group_chat_manager, message="hello")
assert len(groupchat.messages) == 2

with pytest.raises(ValueError):
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})


def test_plugin():
# Give another Agent class ability to manage group chat
Expand Down Expand Up @@ -62,6 +113,7 @@ def test_plugin():


if __name__ == "__main__":
test_func_call_groupchat()
# test_broadcast()
# test_chat_manager()
test_plugin()
test_chat_manager()
# test_plugin()
Loading