From 8d4afe42635bf040ace461e0279d50e0ef768f30 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Thu, 19 Oct 2023 07:43:36 -0700 Subject: [PATCH] function call filter in group chat (#294) * function call filter in group chat * find agents with function_map --- autogen/agentchat/conversable_agent.py | 9 ++++ autogen/agentchat/groupchat.py | 73 +++++++++++++++++++------- test/agentchat/test_groupchat.py | 56 +++++++++++++++++++- 3 files changed, 117 insertions(+), 21 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 87845f910d1..3a0e1959881 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1017,3 +1017,12 @@ def register_function(self, function_map: Dict[str, Callable]): function_map: a dictionary mapping function names to functions. """ self._function_map.update(function_map) + + def can_execute_function(self, name: str) -> bool: + """Whether the agent can execute the function.""" + return name in self._function_map + + @property + def function_map(self) -> Dict[str, Callable]: + """Return the function map.""" + return self._function_map diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index d2f53002b41..9ed2ff77464 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -10,12 +10,23 @@ @dataclass class GroupChat: - """A group chat class that contains a list of agents and the maximum number of rounds.""" + """A group chat class that contains the following data fields: + - agents: a list of participating agents. + - messages: a list of messages in the group chat. + - max_round: the maximum number of rounds. + - admin_name: the name of the admin agent if there is one. Default is "Admin". + KeyBoardInterrupt will make the admin agent take over. + - func_call_filter: whether to enforce function call filter. Default is True. + When set to True and when a message is a function call suggestion, + the next speaker will be chosen from an agent which contains the corresponding function name + in its `function_map`. + """ agents: List[Agent] messages: List[Dict] max_round: int = 10 - admin_name: str = "Admin" # the name of the admin agent + admin_name: str = "Admin" + func_call_filter: bool = True @property def agent_names(self) -> List[str]: @@ -30,45 +41,69 @@ def agent_by_name(self, name: str) -> Agent: """Find the next speaker based on the message.""" return self.agents[self.agent_names.index(name)] - def next_agent(self, agent: Agent) -> Agent: + def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent: """Return the next agent in the list.""" - return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)] - - def select_speaker_msg(self): + if agents == self.agents: + return agents[(self.agent_names.index(agent.name) + 1) % len(agents)] + else: + offset = self.agent_names.index(agent.name) + 1 + for i in range(len(self.agents)): + if self.agents[(offset + i) % len(self.agents)] in agents: + return self.agents[(offset + i) % len(self.agents)] + + def select_speaker_msg(self, agents: List[Agent]): """Return the message for selecting the next speaker.""" return f"""You are in a role play game. The following roles are available: {self._participant_roles()}. Read the following conversation. -Then select the next role from {self.agent_names} to play. Only return the role.""" +Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" def select_speaker(self, last_speaker: Agent, selector: ConversableAgent): """Select the next speaker.""" - selector.update_system_message(self.select_speaker_msg()) - - # Warn if GroupChat is underpopulated, without established changing behavior - n_agents = len(self.agent_names) - if n_agents < 3: - logger.warning( - f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient." - ) - + if self.func_call_filter and self.messages and "function_call" in self.messages[-1]: + # find agents with the right function_map which contains the function name + agents = [ + agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"]) + ] + if len(agents) == 1: + # only one agent can execute the function + return agents[0] + elif not agents: + # find all the agents with function_map + agents = [agent for agent in self.agents if agent.function_map] + if len(agents) == 1: + return agents[0] + elif not agents: + raise ValueError( + f"No agent can execute the function {self.messages[-1]['name']}. " + "Please check the function_map of the agents." + ) + else: + agents = self.agents + # Warn if GroupChat is underpopulated + n_agents = len(agents) + if n_agents < 3: + logger.warning( + f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient." + ) + selector.update_system_message(self.select_speaker_msg(agents)) final, name = selector.generate_oai_reply( self.messages + [ { "role": "system", - "content": f"Read the above conversation. Then select the next role from {self.agent_names} to play. Only return the role.", + "content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.", } ] ) if not final: # i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id - return self.next_agent(last_speaker) + return self.next_agent(last_speaker, agents) try: return self.agent_by_name(name) except ValueError: - return self.next_agent(last_speaker) + return self.next_agent(last_speaker, agents) def _participant_roles(self): return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents]) diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 5c5d3fb8257..c50ef45cdcc 100644 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -1,6 +1,54 @@ +import pytest import autogen +def test_func_call_groupchat(): + agent1 = autogen.ConversableAgent( + "alice", + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is alice sepaking.", + ) + agent2 = autogen.ConversableAgent( + "bob", + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is bob speaking.", + function_map={"test_func": lambda x: x}, + ) + groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=3) + group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) + agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}}) + + assert len(groupchat.messages) == 3 + assert ( + groupchat.messages[-2]["role"] == "function" + and groupchat.messages[-2]["name"] == "test_func" + and groupchat.messages[-2]["content"] == "1" + ) + assert groupchat.messages[-1]["name"] == "alice" + + agent3 = autogen.ConversableAgent( + "carol", + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is carol speaking.", + function_map={"test_func": lambda x: x + 1}, + ) + groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=3) + group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) + agent3.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}}) + + assert ( + groupchat.messages[-2]["role"] == "function" + and groupchat.messages[-2]["name"] == "test_func" + and groupchat.messages[-2]["content"] == "1" + ) + assert groupchat.messages[-1]["name"] == "carol" + + agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}}) + + def test_chat_manager(): agent1 = autogen.ConversableAgent( "alice", @@ -30,6 +78,9 @@ def test_chat_manager(): agent2.initiate_chat(group_chat_manager, message="hello") assert len(groupchat.messages) == 2 + with pytest.raises(ValueError): + agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}}) + def test_plugin(): # Give another Agent class ability to manage group chat @@ -62,6 +113,7 @@ def test_plugin(): if __name__ == "__main__": + test_func_call_groupchat() # test_broadcast() - # test_chat_manager() - test_plugin() + test_chat_manager() + # test_plugin()