diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index f6e5d30264b..3e01139a1cb 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -1160,7 +1160,7 @@ async def a_run_chat( def resume( self, messages: Union[List[Dict], str], - remove_termination_string: str = None, + remove_termination_string: Union[str, Callable[[str], str]] = None, silent: Optional[bool] = False, ) -> Tuple[ConversableAgent, Dict]: """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established @@ -1168,7 +1168,9 @@ def resume( Args: - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. - - remove_termination_string str: Remove the provided string from the last message to prevent immediate termination + - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function. - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. Returns: @@ -1263,7 +1265,7 @@ def resume( async def a_resume( self, messages: Union[List[Dict], str], - remove_termination_string: str = None, + remove_termination_string: Union[str, Callable[[str], str]], silent: Optional[bool] = False, ) -> Tuple[ConversableAgent, Dict]: """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established @@ -1271,7 +1273,9 @@ async def a_resume( Args: - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. - - remove_termination_string str: Remove the provided string from the last message to prevent immediate termination + - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function, and the function returns the string after processing. - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. Returns: @@ -1390,11 +1394,15 @@ def _valid_resume_messages(self, messages: List[Dict]): ): raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}") - def _process_resume_termination(self, remove_termination_string: str, messages: List[Dict]): + def _process_resume_termination( + self, remove_termination_string: Union[str, Callable[[str], str]], messages: List[Dict] + ): """Removes termination string, if required, and checks if termination may occur. args: - remove_termination_string (str): termination string to remove from the last message + remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function, and the function returns the string after processing. returns: None @@ -1403,9 +1411,17 @@ def _process_resume_termination(self, remove_termination_string: str, messages: last_message = messages[-1] # Replace any given termination string in the last message - if remove_termination_string: - if messages[-1].get("content") and remove_termination_string in messages[-1]["content"]: - messages[-1]["content"] = messages[-1]["content"].replace(remove_termination_string, "") + if isinstance(remove_termination_string, str): + + def _remove_termination_string(content: str) -> str: + return content.replace(remove_termination_string, "") + + else: + _remove_termination_string = remove_termination_string + + if _remove_termination_string: + if messages[-1].get("content"): + messages[-1]["content"] = _remove_termination_string(messages[-1]["content"]) # Check if the last message meets termination (if it has one) if self._is_termination_msg: diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 933b4ce7df2..20a83685178 100755 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -1916,6 +1916,51 @@ def test_manager_resume_functions(): # TERMINATE should be removed assert messages[-1]["content"] == final_msg.replace("TERMINATE", "") + # Tests termination message replacement with function + def termination_func(x: str) -> str: + if "APPROVED" in x: + x = x.replace("APPROVED", "") + else: + x = x.replace("TERMINATE", "") + return x + + final_msg1 = "Product_Manager has created 3 new product ideas. APPROVED" + messages1 = [ + { + "content": "You are an expert at finding the next speaker.", + "role": "system", + }, + { + "content": final_msg1, + "name": "Coder", + "role": "assistant", + }, + ] + + manager._process_resume_termination(remove_termination_string=termination_func, messages=messages1) + + # APPROVED should be removed + assert messages1[-1]["content"] == final_msg1.replace("APPROVED", "") + + final_msg2 = "Idea has been approved. TERMINATE" + messages2 = [ + { + "content": "You are an expert at finding the next speaker.", + "role": "system", + }, + { + "content": final_msg2, + "name": "Coder", + "role": "assistant", + }, + ] + + manager._process_resume_termination(remove_termination_string=termination_func, messages=messages2) + + # TERMINATE should be removed, "approved" should still be present as the termination_func only replaces upper-cased "APPROVED". + assert messages2[-1]["content"] == final_msg2.replace("TERMINATE", "") + assert "approved" in messages2[-1]["content"] + # Check if the termination string doesn't exist there's no replacing of content final_msg = ( "Let's get this meeting started. First the Product_Manager will create 3 new product ideas. TERMINATE this." @@ -2027,7 +2072,7 @@ def test_manager_resume_messages(): # test_clear_agents_history() # test_custom_speaker_selection_overrides_transition_graph() # test_role_for_select_speaker_messages() - test_select_speaker_message_and_prompt_templates() + # test_select_speaker_message_and_prompt_templates() # test_speaker_selection_agent_name_match() # test_role_for_reflection_summary() # test_speaker_selection_auto_process_result() @@ -2036,7 +2081,7 @@ def test_manager_resume_messages(): # test_select_speaker_auto_messages() # test_manager_messages_to_string() # test_manager_messages_from_string() - # test_manager_resume_functions() + test_manager_resume_functions() # test_manager_resume_returns() # test_manager_resume_messages() pass