Skip to content

Commit

Permalink
function call filter in group chat (#294)
Browse files Browse the repository at this point in the history
* function call filter in group chat

* find agents with function_map
  • Loading branch information
sonichi committed Oct 19, 2023
1 parent bed85a3 commit 8d4afe4
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 21 deletions.
9 changes: 9 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,12 @@ def register_function(self, function_map: Dict[str, Callable]):
function_map: a dictionary mapping function names to functions.
"""
self._function_map.update(function_map)

def can_execute_function(self, name: str) -> bool:
"""Whether the agent can execute the function."""
return name in self._function_map

@property
def function_map(self) -> Dict[str, Callable]:
"""Return the function map."""
return self._function_map
73 changes: 54 additions & 19 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@

@dataclass
class GroupChat:
"""A group chat class that contains a list of agents and the maximum number of rounds."""
"""A group chat class that contains the following data fields:
- agents: a list of participating agents.
- messages: a list of messages in the group chat.
- max_round: the maximum number of rounds.
- admin_name: the name of the admin agent if there is one. Default is "Admin".
KeyBoardInterrupt will make the admin agent take over.
- func_call_filter: whether to enforce function call filter. Default is True.
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
"""

agents: List[Agent]
messages: List[Dict]
max_round: int = 10
admin_name: str = "Admin" # the name of the admin agent
admin_name: str = "Admin"
func_call_filter: bool = True

@property
def agent_names(self) -> List[str]:
Expand All @@ -30,45 +41,69 @@ def agent_by_name(self, name: str) -> Agent:
"""Find the next speaker based on the message."""
return self.agents[self.agent_names.index(name)]

def next_agent(self, agent: Agent) -> Agent:
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
"""Return the next agent in the list."""
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]

def select_speaker_msg(self):
if agents == self.agents:
return agents[(self.agent_names.index(agent.name) + 1) % len(agents)]
else:
offset = self.agent_names.index(agent.name) + 1
for i in range(len(self.agents)):
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]

def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}.
Read the following conversation.
Then select the next role from {self.agent_names} to play. Only return the role."""
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""

def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
selector.update_system_message(self.select_speaker_msg())

# Warn if GroupChat is underpopulated, without established changing behavior
n_agents = len(self.agent_names)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
)

if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
# find agents with the right function_map which contains the function name
agents = [
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
]
if len(agents) == 1:
# only one agent can execute the function
return agents[0]
elif not agents:
# find all the agents with function_map
agents = [agent for agent in self.agents if agent.function_map]
if len(agents) == 1:
return agents[0]
elif not agents:
raise ValueError(
f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents."
)
else:
agents = self.agents
# Warn if GroupChat is underpopulated
n_agents = len(agents)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
)
selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply(
self.messages
+ [
{
"role": "system",
"content": f"Read the above conversation. Then select the next role from {self.agent_names} to play. Only return the role.",
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}
]
)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
return self.next_agent(last_speaker)
return self.next_agent(last_speaker, agents)
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker)
return self.next_agent(last_speaker, agents)

def _participant_roles(self):
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
Expand Down
56 changes: 54 additions & 2 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,54 @@
import pytest
import autogen


def test_func_call_groupchat():
agent1 = autogen.ConversableAgent(
"alice",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice sepaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
function_map={"test_func": lambda x: x},
)
groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=3)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})

assert len(groupchat.messages) == 3
assert (
groupchat.messages[-2]["role"] == "function"
and groupchat.messages[-2]["name"] == "test_func"
and groupchat.messages[-2]["content"] == "1"
)
assert groupchat.messages[-1]["name"] == "alice"

agent3 = autogen.ConversableAgent(
"carol",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is carol speaking.",
function_map={"test_func": lambda x: x + 1},
)
groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=3)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent3.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})

assert (
groupchat.messages[-2]["role"] == "function"
and groupchat.messages[-2]["name"] == "test_func"
and groupchat.messages[-2]["content"] == "1"
)
assert groupchat.messages[-1]["name"] == "carol"

agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})


def test_chat_manager():
agent1 = autogen.ConversableAgent(
"alice",
Expand Down Expand Up @@ -30,6 +78,9 @@ def test_chat_manager():
agent2.initiate_chat(group_chat_manager, message="hello")
assert len(groupchat.messages) == 2

with pytest.raises(ValueError):
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})


def test_plugin():
# Give another Agent class ability to manage group chat
Expand Down Expand Up @@ -62,6 +113,7 @@ def test_plugin():


if __name__ == "__main__":
test_func_call_groupchat()
# test_broadcast()
# test_chat_manager()
test_plugin()
test_chat_manager()
# test_plugin()

0 comments on commit 8d4afe4

Please sign in to comment.