Skip to content

Commit

Permalink
added raising an exception on an async reply function in sync chat
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Jan 11, 2024
1 parent e1ba382 commit c00958e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
28 changes: 28 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def __init__(
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)

# we will ignore async reply functions in sync chats for the initial reply functions and output warnings for user defined ones
self._initial_reply_functions = {f["reply_func"] for f in self._reply_func_list}

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
Expand Down Expand Up @@ -601,6 +604,26 @@ def _prepare_chat(self, recipient, clear_history):
self.clear_history(recipient)
recipient.clear_history(self)

def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
Raises:
RuntimeError: if any async reply functions are registered.
"""
reply_functions = {f["reply_func"] for f in self._reply_func_list}.difference(self._initial_reply_functions)
async_reply_functions = [f for f in reply_functions if asyncio.coroutines.iscoroutinefunction(f)]
if async_reply_functions != []:
msg = (
"Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: "
+ ", ".join([f.__name__ for f in async_reply_functions])
)

raise RuntimeError(msg)

if hasattr(self, "_groupchat"):
for agent in self._groupchat.agents:
agent._raise_exception_on_async_reply_functions()

def initiate_chat(
self,
recipient: "ConversableAgent",
Expand All @@ -620,7 +643,12 @@ def initiate_chat(
silent (bool or None): (Experimental) whether to print the messages for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Raises:
RuntimeError: if any async reply functions are registered.
"""
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
self._prepare_chat(recipient, clear_history)
self.send(self.generate_init_message(**context), recipient, silent=silent)

Expand Down
3 changes: 3 additions & 0 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ def __init__(
system_message=system_message,
**kwargs,
)
# Store groupchat
self._groupchat = groupchat

# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
Expand Down
13 changes: 6 additions & 7 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,19 @@ def test_async_trigger_in_sync_chat():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

a_reply_mock = unittest.mock.MagicMock()

async def a_reply(recipient, messages, sender, config):
a_reply_mock()
print("hello from a_reply")
return (True, "hello from reply function")

agent.register_reply(agent1, a_reply)

agent1.initiate_chat(agent, message="hi")

a_reply_mock.assert_not_called()
with pytest.raises(RuntimeError) as e:
agent1.initiate_chat(agent, message="hi")

assert agent1.last_message(agent)["content"] == "hi"
assert (
e.value.args[0] == "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). "
"The following async reply functions are found: a_reply"
)


@pytest.mark.asyncio
Expand Down

0 comments on commit c00958e

Please sign in to comment.