From c00958ee28fea261af6d65baea0ee72fa79c380b Mon Sep 17 00:00:00 2001 From: Davor Runje Date: Thu, 11 Jan 2024 11:39:00 +0000 Subject: [PATCH] added raising an exception on an async reply function in sync chat --- autogen/agentchat/conversable_agent.py | 28 ++++++++++++++++++++++++ autogen/agentchat/groupchat.py | 3 +++ test/agentchat/test_conversable_agent.py | 13 +++++------ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 3d78c512549..657aea3e138 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -152,6 +152,9 @@ def __init__( self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply) + # we will ignore async reply functions in sync chats for the initial reply functions and output warnings for user defined ones + self._initial_reply_functions = {f["reply_func"] for f in self._reply_func_list} + # Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration. # New hookable methods should be added to this list as required to support new agent capabilities. self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method. @@ -601,6 +604,26 @@ def _prepare_chat(self, recipient, clear_history): self.clear_history(recipient) recipient.clear_history(self) + def _raise_exception_on_async_reply_functions(self) -> None: + """Raise an exception if any async reply functions are registered. + + Raises: + RuntimeError: if any async reply functions are registered. + """ + reply_functions = {f["reply_func"] for f in self._reply_func_list}.difference(self._initial_reply_functions) + async_reply_functions = [f for f in reply_functions if asyncio.coroutines.iscoroutinefunction(f)] + if async_reply_functions != []: + msg = ( + "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: " + + ", ".join([f.__name__ for f in async_reply_functions]) + ) + + raise RuntimeError(msg) + + if hasattr(self, "_groupchat"): + for agent in self._groupchat.agents: + agent._raise_exception_on_async_reply_functions() + def initiate_chat( self, recipient: "ConversableAgent", @@ -620,7 +643,12 @@ def initiate_chat( silent (bool or None): (Experimental) whether to print the messages for this conversation. **context: any context information. "message" needs to be provided if the `generate_init_message` method is not overridden. + + Raises: + RuntimeError: if any async reply functions are registered. """ + for agent in [self, recipient]: + agent._raise_exception_on_async_reply_functions() self._prepare_chat(recipient, clear_history) self.send(self.generate_init_message(**context), recipient, silent=silent) diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index cc87881c544..238bbd5ed97 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -332,6 +332,9 @@ def __init__( system_message=system_message, **kwargs, ) + # Store groupchat + self._groupchat = groupchat + # Order of register_reply is important. # Allow sync chat if initiated using initiate_chat self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset) diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 99c62b96921..bc4e861183d 100644 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -144,20 +144,19 @@ def test_async_trigger_in_sync_chat(): agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER") agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER") - a_reply_mock = unittest.mock.MagicMock() - async def a_reply(recipient, messages, sender, config): - a_reply_mock() print("hello from a_reply") return (True, "hello from reply function") agent.register_reply(agent1, a_reply) - agent1.initiate_chat(agent, message="hi") - - a_reply_mock.assert_not_called() + with pytest.raises(RuntimeError) as e: + agent1.initiate_chat(agent, message="hi") - assert agent1.last_message(agent)["content"] == "hi" + assert ( + e.value.args[0] == "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). " + "The following async reply functions are found: a_reply" + ) @pytest.mark.asyncio