diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index deaca3036f9..295954ac802 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -66,10 +66,15 @@ def agent_by_name(self, name: str) -> Agent: def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent: """Return the next agent in the list.""" + + # What index is the agent? (-1 if not present) + idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1 + + # Return the next agent if agents == self.agents: - return agents[(self.agent_names.index(agent.name) + 1) % len(agents)] + return agents[(idx + 1) % len(agents)] else: - offset = self.agent_names.index(agent.name) + 1 + offset = idx + 1 for i in range(len(self.agents)): if self.agents[(offset + i) % len(self.agents)] in agents: return self.agents[(offset + i) % len(self.agents)] diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 9cbbe814cef..27fda1fd524 100644 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -382,6 +382,53 @@ def test_termination(): assert len(groupchat.messages) == 3 +def test_next_agent(): + agent1 = autogen.ConversableAgent( + "alice", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is alice speaking.", + ) + agent2 = autogen.ConversableAgent( + "bob", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is bob speaking.", + ) + agent3 = autogen.ConversableAgent( + "sam", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is sam speaking.", + ) + agent4 = autogen.ConversableAgent( + "sally", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is sally speaking.", + ) + + # Test empty is_termination_msg function + groupchat = autogen.GroupChat( + agents=[agent1, agent2, agent3], messages=[], speaker_selection_method="round_robin", max_round=10 + ) + + assert groupchat.next_agent(agent1, [agent1, agent2, agent3]) == agent2 + assert groupchat.next_agent(agent2, [agent1, agent2, agent3]) == agent3 + assert groupchat.next_agent(agent3, [agent1, agent2, agent3]) == agent1 + + assert groupchat.next_agent(agent1, [agent1, agent3]) == agent3 + assert groupchat.next_agent(agent3, [agent1, agent3]) == agent1 + + assert groupchat.next_agent(agent2, [agent1, agent3]) == agent3 + assert groupchat.next_agent(agent4, [agent1, agent3]) == agent1 + assert groupchat.next_agent(agent4, [agent1, agent2, agent3]) == agent1 + + if __name__ == "__main__": # test_func_call_groupchat() # test_broadcast() @@ -390,4 +437,5 @@ def test_termination(): # test_speaker_selection_method() # test_n_agents_less_than_3() # test_agent_mentions() - test_termination() + # test_termination() + test_next_agent()