diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py index 286b09eb3e2..5c2d5d3ab4b 100644 --- a/autogen/agentchat/contrib/capabilities/context_handling.py +++ b/autogen/agentchat/contrib/capabilities/context_handling.py @@ -53,39 +53,48 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: messages: List of messages to process. Returns: - List of messages with the first system message and the last max_messages messages. + List of messages with the first system message and the last max_messages messages, + ensuring each message does not exceed max_tokens_per_message. """ + temp_messages = messages.copy() processed_messages = [] - messages = messages.copy() - rest_messages = messages + system_message = None + processed_messages_tokens = 0 - # check if the first message is a system message and append it to the processed messages - if len(messages) > 0: - if messages[0]["role"] == "system": - msg = messages[0] - processed_messages.append(msg) - rest_messages = messages[1:] + if messages[0]["role"] == "system": + system_message = messages[0].copy() + temp_messages.pop(0) - processed_messages_tokens = 0 - for msg in messages: - msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message) + total_tokens = sum( + token_count_utils.count_token(msg["content"]) for msg in temp_messages + ) # Calculate tokens for all messages + + # Truncate each message's content to a maximum token limit of each message - # iterate through rest of the messages and append them to the processed messages - for msg in rest_messages[-self.max_messages :]: + for msg in temp_messages[-self.max_messages :]: + msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message) msg_tokens = token_count_utils.count_token(msg["content"]) if processed_messages_tokens + msg_tokens > self.max_tokens: break processed_messages.append(msg) processed_messages_tokens += msg_tokens - - total_tokens = 0 - for msg in messages: - total_tokens += token_count_utils.count_token(msg["content"]) - + if system_message: + processed_messages.insert(0, system_message) + # Optionally, log the number of truncated messages and tokens if needed num_truncated = len(messages) - len(processed_messages) + if num_truncated > 0 or total_tokens > processed_messages_tokens: - print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow")) - print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow")) + print( + colored( + f"Truncated {num_truncated} messages from {len(messages)} to {len(processed_messages)}.", "yellow" + ) + ) + print( + colored( + f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens truncated to {processed_messages_tokens}", + "yellow", + ) + ) return processed_messages