Skip to content

Commit

Permalink
more tests + cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
WaelKarkoub committed Mar 23, 2024
1 parent 5c1dc1c commit e7e2b0a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
18 changes: 6 additions & 12 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,24 +2246,18 @@ def _process_carryover(self, message: str, context: dict) -> str:
return message

def _process_multimodal_carryover(self, message: List, context: dict) -> List[Dict]:
carryover = context.get("carryover")
reconstructed_messages = [{"type": "text", "text": ""}]
for msg in message:
text_msg = ""
if msg.get("type") == "text":
text_msg += msg["text"]
reconstructed_messages[0]["text"] += "\n" + msg["text"]
else:
reconstructed_messages.append(msg)

if carryover:
if isinstance(carryover, str):
reconstructed_messages[0]["text"] += "\nContext: \n" + carryover
elif isinstance(carryover, list):
reconstructed_messages[0]["text"] += "\nContext: \n" + ("\n").join([t for t in carryover])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
reconstructed_messages[0]["text"] = self._process_carryover(reconstructed_messages[0]["text"], context)

# Delete the text message if it is empty
if reconstructed_messages[0]["text"] == "":
del reconstructed_messages[0]

return reconstructed_messages

Expand Down
5 changes: 5 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,19 +1278,23 @@ def test_messages_with_carryover():
)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 2

context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"])
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 2

context = dict(message=mm_message, carryover=3)
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)

# Test without carryover
print(mm_message)
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 2

# Test without text in multimodal message
mm_message = [
Expand All @@ -1299,6 +1303,7 @@ def test_messages_with_carryover():
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 1


if __name__ == "__main__":
Expand Down

0 comments on commit e7e2b0a

Please sign in to comment.