Skip to content

Commit

Permalink
process message before send (#1783)
Browse files Browse the repository at this point in the history
* process message before send

* rename
  • Loading branch information
sonichi authored Feb 25, 2024
1 parent 085bf6c commit 8ec1c3e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 13 deletions.
2 changes: 1 addition & 1 deletion autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def add_to_agent(self, agent: ConversableAgent):
"""
Adds TransformChatHistory capability to the given agent.
"""
agent.register_hook(hookable_method="process_all_messages", hook=self._transform_messages)
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)

def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Expand Down
4 changes: 2 additions & 2 deletions autogen/agentchat/contrib/capabilities/teachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def add_to_agent(self, agent: ConversableAgent):
self.teachable_agent = agent

# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_message", hook=self.process_last_message)
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)

# Was an llm_config passed to the constructor?
if self.llm_config is None:
Expand All @@ -82,7 +82,7 @@ def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()

def process_last_message(self, text):
def process_last_received_message(self, text):
"""
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
Expand Down
33 changes: 24 additions & 9 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def __init__(

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {"process_last_message": [], "process_all_messages": []}
self.hook_lists = {
"process_last_received_message": [],
"process_all_messages_before_reply": [],
"process_message_before_send": [],
}

@property
def name(self) -> str:
Expand Down Expand Up @@ -467,6 +471,15 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
self._oai_messages[conversation_id].append(oai_message)
return True

def _process_message_before_send(
self, message: Union[Dict, str], recipient: Agent, silent: bool
) -> Union[Dict, str]:
"""Process the message before sending it to the recipient."""
hook_list = self.hook_lists["process_message_before_send"]
for hook in hook_list:
message = hook(message, recipient, silent)
return message

def send(
self,
message: Union[Dict, str],
Expand Down Expand Up @@ -509,6 +522,7 @@ def send(
Returns:
ChatResult: a ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
Expand Down Expand Up @@ -561,6 +575,7 @@ async def a_send(
Returns:
ChatResult: an ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
Expand Down Expand Up @@ -1634,11 +1649,11 @@ def generate_reply(

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
messages = self.process_all_messages_before_reply(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
messages = self.process_last_received_message(messages)

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
Expand Down Expand Up @@ -1695,11 +1710,11 @@ async def a_generate_reply(

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
messages = self.process_all_messages_before_reply(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
messages = self.process_last_received_message(messages)

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
Expand Down Expand Up @@ -2333,11 +2348,11 @@ def register_hook(self, hookable_method: str, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
"""
hook_list = self.hook_lists["process_all_messages"]
hook_list = self.hook_lists["process_all_messages_before_reply"]
# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages
Expand All @@ -2348,14 +2363,14 @@ def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
processed_messages = hook(processed_messages)
return processed_messages

def process_last_message(self, messages):
def process_last_received_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
"""

# If any required condition is not met, return the original message list.
hook_list = self.hook_lists["process_last_message"]
hook_list = self.hook_lists["process_last_received_message"]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:
Expand Down
21 changes: 20 additions & 1 deletion test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,11 +1074,30 @@ def test_max_turn():
assert len(res.chat_history) <= 6


def test_process_before_send():
print_mock = unittest.mock.MagicMock()

def send_to_frontend(message, recipient, silent):
if not silent:
print(f"Message sent to {recipient.name}: {message}")
print_mock(message=message)
return message

dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_2 = ConversableAgent(name="dummy_agent_2", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("process_message_before_send", send_to_frontend)
dummy_agent_1.send("hello", dummy_agent_2)
print_mock.assert_called_once_with(message="hello")
dummy_agent_1.send("silent hello", dummy_agent_2, silent=True)
print_mock.assert_called_once_with(message="hello")


if __name__ == "__main__":
# test_trigger()
# test_context()
# test_max_consecutive_auto_reply()
# test_generate_code_execution_reply()
# test_conversable_agent()
# test_no_llm_config()
test_max_turn()
# test_max_turn()
test_process_before_send()

0 comments on commit 8ec1c3e

Please sign in to comment.