Skip to content

Commit

Permalink
Add documentation and raise exception when registering async reply fu…
Browse files Browse the repository at this point in the history
…nction in sync chat (microsoft#1208)

* documentation update and added tests for register_reply function

* added raising an exception on an async reply function in sync chat

* big fixing

* test expanded

* Update autogen/agentchat/conversable_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update autogen/agentchat/conversable_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* refactorization

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
  • Loading branch information
davorrunje and sonichi authored Jan 13, 2024
1 parent 4ae8bcf commit eef3239
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 6 deletions.
54 changes: 50 additions & 4 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,21 @@ def __init__(
)
self._default_auto_reply = default_auto_reply
self._reply_func_list = []
self._ignore_async_func_in_sync_chat_list = []
self.reply_at_receive = defaultdict(bool)
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_function_call_reply)
self.register_reply(
[Agent, None], ConversableAgent.a_generate_function_call_reply, ignore_async_in_sync_chat=True
)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)
self.register_reply(
[Agent, None], ConversableAgent.a_check_termination_and_human_reply, ignore_async_in_sync_chat=True
)

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
Expand All @@ -163,13 +168,22 @@ def register_reply(
position: int = 0,
config: Optional[Any] = None,
reset_config: Optional[Callable] = None,
*,
ignore_async_in_sync_chat: bool = False,
):
"""Register a reply function.
The reply function will be called when the trigger matches the sender.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.
Both sync and async reply functions can be registered. The sync reply function will be triggered
from both sync and async chats. However, an async reply function will only be triggered from async
chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered
and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows:
- if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and
- if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored.
Args:
trigger (Agent class, str, Agent instance, callable, or list): the trigger.
- If a class is provided, the reply function will be called when the sender is an instance of the class.
Expand All @@ -181,6 +195,12 @@ def register_reply(
Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`.
reply_func (Callable): the reply function.
The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
position: the position of the reply function in the reply function list.
config: the config to be passed to the reply function, see below.
reset_config: the function to reset the config, see below.
ignore_async_in_sync_chat: whether to ignore the async reply function in sync chats. If `False`, an exception
will be raised if an async reply function is registered and a chat is initialized with a sync
function.
```python
def reply_func(
recipient: ConversableAgent,
Expand Down Expand Up @@ -209,6 +229,8 @@ def reply_func(
"reset_config": reset_config,
},
)
if ignore_async_in_sync_chat and asyncio.coroutines.iscoroutinefunction(reply_func):
self._ignore_async_func_in_sync_chat_list.append(reply_func)

@property
def system_message(self) -> Union[str, List]:
Expand Down Expand Up @@ -597,6 +619,25 @@ def _prepare_chat(self, recipient, clear_history):
self.clear_history(recipient)
recipient.clear_history(self)

def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
Raises:
RuntimeError: if any async reply functions are registered.
"""
reply_functions = {f["reply_func"] for f in self._reply_func_list}.difference(
self._ignore_async_func_in_sync_chat_list
)

async_reply_functions = [f for f in reply_functions if asyncio.coroutines.iscoroutinefunction(f)]
if async_reply_functions != []:
msg = (
"Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: "
+ ", ".join([f.__name__ for f in async_reply_functions])
)

raise RuntimeError(msg)

def initiate_chat(
self,
recipient: "ConversableAgent",
Expand All @@ -616,7 +657,12 @@ def initiate_chat(
silent (bool or None): (Experimental) whether to print the messages for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Raises:
RuntimeError: if any async reply functions are registered and not ignored in sync chat.
"""
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
self._prepare_chat(recipient, clear_history)
self.send(self.generate_init_message(**context), recipient, silent=silent)

Expand Down
22 changes: 21 additions & 1 deletion autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,20 @@ def __init__(
system_message=system_message,
**kwargs,
)
# Store groupchat
self._groupchat = groupchat

# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
# Allow async chat if initiated using a_initiate_chat
self.register_reply(Agent, GroupChatManager.a_run_chat, config=groupchat, reset_config=GroupChat.reset)
self.register_reply(
Agent,
GroupChatManager.a_run_chat,
config=groupchat,
reset_config=GroupChat.reset,
ignore_async_in_sync_chat=True,
)

def run_chat(
self,
Expand Down Expand Up @@ -438,3 +447,14 @@ async def a_run_chat(
await speaker.a_send(reply, self, request_reply=False)
message = self.last_message(speaker)
return True, None

def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
Raises:
RuntimeError: if any async reply functions are registered.
"""
super()._raise_exception_on_async_reply_functions()

for agent in self._groupchat.agents:
agent._raise_exception_on_async_reply_functions()
110 changes: 109 additions & 1 deletion test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def conversable_agent():
)


def test_trigger():
def test_sync_trigger():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent.register_reply(agent1, lambda recipient, messages, sender, config: (True, "hello"))
Expand Down Expand Up @@ -72,6 +72,114 @@ def test_trigger():
pytest.raises(ValueError, agent._match_trigger, 1, agent1)


@pytest.mark.asyncio
async def test_async_trigger():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

async def a_reply(recipient, messages, sender, config):
print("hello from a_reply")
return (True, "hello")

agent.register_reply(agent1, a_reply)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello"

async def a_reply_a1(recipient, messages, sender, config):
print("hello from a_reply_a1")
return (True, "hello a1")

agent.register_reply("a1", a_reply_a1)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello a1"

async def a_reply_conversable_agent(recipient, messages, sender, config):
print("hello from a_reply_conversable_agent")
return (True, "hello conversable agent")

agent.register_reply(ConversableAgent, a_reply_conversable_agent)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello conversable agent"

async def a_reply_a(recipient, messages, sender, config):
print("hello from a_reply_a")
return (True, "hello a")

agent.register_reply(lambda sender: sender.name.startswith("a"), a_reply_a)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello a"

async def a_reply_b(recipient, messages, sender, config):
print("hello from a_reply_b")
return (True, "hello b")

agent.register_reply(lambda sender: sender.name.startswith("b"), a_reply_b)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello a"

async def a_reply_agent2_or_agent1(recipient, messages, sender, config):
print("hello from a_reply_agent2_or_agent1")
return (True, "hello agent2 or agent1")

agent.register_reply(["agent2", agent1], a_reply_agent2_or_agent1)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"

async def a_reply_agent2_or_agent3(recipient, messages, sender, config):
print("hello from a_reply_agent2_or_agent3")
return (True, "hello agent2 or agent3")

agent.register_reply(["agent2", "agent3"], a_reply_agent2_or_agent3)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"

with pytest.raises(ValueError):
agent.register_reply(1, a_reply)

with pytest.raises(ValueError):
agent._match_trigger(1, agent1)


def test_async_trigger_in_sync_chat():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent2 = ConversableAgent("a2", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

reply_mock = unittest.mock.MagicMock()

async def a_reply(recipient, messages, sender, config):
reply_mock()
print("hello from a_reply")
return (True, "hello from reply function")

agent.register_reply(agent1, a_reply)

with pytest.raises(RuntimeError) as e:
agent1.initiate_chat(agent, message="hi")

assert (
e.value.args[0] == "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). "
"The following async reply functions are found: a_reply"
)

agent2.register_reply(agent1, a_reply, ignore_async_in_sync_chat=True)
reply_mock.assert_not_called()


@pytest.mark.asyncio
async def test_sync_trigger_in_async_chat():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

def a_reply(recipient, messages, sender, config):
print("hello from a_reply")
return (True, "hello from reply function")

agent.register_reply(agent1, a_reply)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello from reply function"


def test_context():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
Expand Down

0 comments on commit eef3239

Please sign in to comment.