Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation and raise exception when registering async reply function in sync chat #1208

Merged
58 changes: 54 additions & 4 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,21 @@ def __init__(
)
self._default_auto_reply = default_auto_reply
self._reply_func_list = []
self._ignore_async_func_in_sync_chat_list = []
self.reply_at_receive = defaultdict(bool)
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_function_call_reply)
self.register_reply(
[Agent, None], ConversableAgent.a_generate_function_call_reply, ignore_async_in_sync_chat=True
)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)
self.register_reply(
[Agent, None], ConversableAgent.a_check_termination_and_human_reply, ignore_async_in_sync_chat=True
)

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
Expand All @@ -163,13 +168,22 @@ def register_reply(
position: int = 0,
config: Optional[Any] = None,
reset_config: Optional[Callable] = None,
*,
ignore_async_in_sync_chat: bool = False,
):
"""Register a reply function.

The reply function will be called when the trigger matches the sender.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.

Both sync and async reply functions can be registered. The sync reply function will be triggered
from both sync and async chats. However, an async reply function will only be triggered from async
chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered
and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows:
- if `ignore_in_sync_chat` is set to `False` (default value), an exception will be raised, and
- if `ignore_in_sync_chat` is set to `True`, the reply function will be ignored.
davorrunje marked this conversation as resolved.
Show resolved Hide resolved

Args:
trigger (Agent class, str, Agent instance, callable, or list): the trigger.
- If a class is provided, the reply function will be called when the sender is an instance of the class.
Expand All @@ -181,6 +195,12 @@ def register_reply(
Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`.
reply_func (Callable): the reply function.
The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
position: the position of the reply function in the reply function list.
config: the config to be passed to the reply function, see below.
reset_config: the function to reset the config, see below.
ignore_async_in_sync_chat: whether to ignore the reply function in sync chats. If `False`, an exception
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
will be raised if an async reply function is registered and a chat is initialized with a sync
function.
```python
def reply_func(
recipient: ConversableAgent,
Expand Down Expand Up @@ -209,6 +229,8 @@ def reply_func(
"reset_config": reset_config,
},
)
if ignore_async_in_sync_chat and asyncio.coroutines.iscoroutinefunction(reply_func):
self._ignore_async_func_in_sync_chat_list.append(reply_func)

@property
def system_message(self) -> Union[str, List]:
Expand Down Expand Up @@ -597,6 +619,29 @@ def _prepare_chat(self, recipient, clear_history):
self.clear_history(recipient)
recipient.clear_history(self)

def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.

Raises:
RuntimeError: if any async reply functions are registered.
"""
reply_functions = {f["reply_func"] for f in self._reply_func_list}.difference(
self._ignore_async_func_in_sync_chat_list
)

async_reply_functions = [f for f in reply_functions if asyncio.coroutines.iscoroutinefunction(f)]
if async_reply_functions != []:
msg = (
"Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: "
+ ", ".join([f.__name__ for f in async_reply_functions])
)

raise RuntimeError(msg)

if hasattr(self, "_groupchat"):
sonichi marked this conversation as resolved.
Show resolved Hide resolved
for agent in self._groupchat.agents:
agent._raise_exception_on_async_reply_functions()

def initiate_chat(
self,
recipient: "ConversableAgent",
Expand All @@ -616,7 +661,12 @@ def initiate_chat(
silent (bool or None): (Experimental) whether to print the messages for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.

Raises:
RuntimeError: if any async reply functions are registered and not ignored in sync chat.
"""
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
self._prepare_chat(recipient, clear_history)
self.send(self.generate_init_message(**context), recipient, silent=silent)

Expand Down
11 changes: 10 additions & 1 deletion autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,20 @@ def __init__(
system_message=system_message,
**kwargs,
)
# Store groupchat
self._groupchat = groupchat

# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
# Allow async chat if initiated using a_initiate_chat
self.register_reply(Agent, GroupChatManager.a_run_chat, config=groupchat, reset_config=GroupChat.reset)
self.register_reply(
Agent,
GroupChatManager.a_run_chat,
config=groupchat,
reset_config=GroupChat.reset,
ignore_async_in_sync_chat=True,
)

def run_chat(
self,
Expand Down
110 changes: 109 additions & 1 deletion test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def conversable_agent():
)


def test_trigger():
def test_sync_trigger():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent.register_reply(agent1, lambda recipient, messages, sender, config: (True, "hello"))
Expand Down Expand Up @@ -72,6 +72,114 @@ def test_trigger():
pytest.raises(ValueError, agent._match_trigger, 1, agent1)


@pytest.mark.asyncio
async def test_async_trigger():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

async def a_reply(recipient, messages, sender, config):
print("hello from a_reply")
return (True, "hello")

agent.register_reply(agent1, a_reply)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello"

async def a_reply_a1(recipient, messages, sender, config):
print("hello from a_reply_a1")
return (True, "hello a1")

agent.register_reply("a1", a_reply_a1)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello a1"

async def a_reply_conversable_agent(recipient, messages, sender, config):
print("hello from a_reply_conversable_agent")
return (True, "hello conversable agent")

agent.register_reply(ConversableAgent, a_reply_conversable_agent)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello conversable agent"

async def a_reply_a(recipient, messages, sender, config):
print("hello from a_reply_a")
return (True, "hello a")

agent.register_reply(lambda sender: sender.name.startswith("a"), a_reply_a)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello a"

async def a_reply_b(recipient, messages, sender, config):
print("hello from a_reply_b")
return (True, "hello b")

agent.register_reply(lambda sender: sender.name.startswith("b"), a_reply_b)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello a"

async def a_reply_agent2_or_agent1(recipient, messages, sender, config):
print("hello from a_reply_agent2_or_agent1")
return (True, "hello agent2 or agent1")

agent.register_reply(["agent2", agent1], a_reply_agent2_or_agent1)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"

async def a_reply_agent2_or_agent3(recipient, messages, sender, config):
print("hello from a_reply_agent2_or_agent3")
return (True, "hello agent2 or agent3")

agent.register_reply(["agent2", "agent3"], a_reply_agent2_or_agent3)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"

with pytest.raises(ValueError):
agent.register_reply(1, a_reply)

with pytest.raises(ValueError):
agent._match_trigger(1, agent1)


def test_async_trigger_in_sync_chat():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent2 = ConversableAgent("a2", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

reply_mock = unittest.mock.MagicMock()

async def a_reply(recipient, messages, sender, config):
reply_mock()
print("hello from a_reply")
return (True, "hello from reply function")

agent.register_reply(agent1, a_reply)

with pytest.raises(RuntimeError) as e:
agent1.initiate_chat(agent, message="hi")

assert (
e.value.args[0] == "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). "
"The following async reply functions are found: a_reply"
)

agent2.register_reply(agent1, a_reply, ignore_async_in_sync_chat=True)
reply_mock.assert_not_called()


@pytest.mark.asyncio
async def test_sync_trigger_in_async_chat():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

def a_reply(recipient, messages, sender, config):
print("hello from a_reply")
return (True, "hello from reply function")

agent.register_reply(agent1, a_reply)
await agent1.a_initiate_chat(agent, message="hi")
assert agent1.last_message(agent)["content"] == "hello from reply function"


def test_context():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
Expand Down
Loading