From e7e2b0af763089b6096d24db9ec61a66a6e44159 Mon Sep 17 00:00:00 2001 From: Wael Karkoub Date: Sat, 23 Mar 2024 02:05:15 +0100 Subject: [PATCH] more tests + cleanup code --- autogen/agentchat/conversable_agent.py | 18 ++++++------------ test/agentchat/test_conversable_agent.py | 5 +++++ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index edf042706a8..1c40634ecf9 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -2246,24 +2246,18 @@ def _process_carryover(self, message: str, context: dict) -> str: return message def _process_multimodal_carryover(self, message: List, context: dict) -> List[Dict]: - carryover = context.get("carryover") reconstructed_messages = [{"type": "text", "text": ""}] for msg in message: - text_msg = "" if msg.get("type") == "text": - text_msg += msg["text"] + reconstructed_messages[0]["text"] += "\n" + msg["text"] else: reconstructed_messages.append(msg) - if carryover: - if isinstance(carryover, str): - reconstructed_messages[0]["text"] += "\nContext: \n" + carryover - elif isinstance(carryover, list): - reconstructed_messages[0]["text"] += "\nContext: \n" + ("\n").join([t for t in carryover]) - else: - raise InvalidCarryOverType( - "Carryover should be a string or a list of strings. Not adding carryover to the message." - ) + reconstructed_messages[0]["text"] = self._process_carryover(reconstructed_messages[0]["text"], context) + + # Delete the text message if it is empty + if reconstructed_messages[0]["text"] == "": + del reconstructed_messages[0] return reconstructed_messages diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 89be2d23c88..9d7c696f36b 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1278,19 +1278,23 @@ def test_messages_with_carryover(): ) generated_message = agent1.generate_init_message(**context) assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 2 context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"]) generated_message = agent1.generate_init_message(**context) assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 2 context = dict(message=mm_message, carryover=3) with pytest.raises(InvalidCarryOverType): agent1.generate_init_message(**context) # Test without carryover + print(mm_message) context = dict(message=mm_message) generated_message = agent1.generate_init_message(**context) assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 2 # Test without text in multimodal message mm_message = [ @@ -1299,6 +1303,7 @@ def test_messages_with_carryover(): context = dict(message=mm_message) generated_message = agent1.generate_init_message(**context) assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 1 if __name__ == "__main__":