From fb0e28a3abd03231e1b570ac80eb53f02b205591 Mon Sep 17 00:00:00 2001 From: dkirsche Date: Sun, 11 Feb 2024 20:29:27 +0000 Subject: [PATCH] optimize function. Instead of iterating over each character, guess at size and then iterate by token. --- .../contrib/capabilities/context_handling.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py index 5c2d5d3ab4b..4bf60046301 100644 --- a/autogen/agentchat/contrib/capabilities/context_handling.py +++ b/autogen/agentchat/contrib/capabilities/context_handling.py @@ -100,7 +100,7 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: def truncate_str_to_tokens(text: str, max_tokens: int) -> str: """ - Truncate a string so that number of tokens in less than max_tokens. + Truncate a string so that number of tokens is less than max_tokens. Args: content: String to process. @@ -109,9 +109,11 @@ def truncate_str_to_tokens(text: str, max_tokens: int) -> str: Returns: Truncated string. """ - truncated_string = "" - for char in text: - truncated_string += char - if token_count_utils.count_token(truncated_string) == max_tokens: - break - return truncated_string + + tokens = text.split() + for token_count in range(max_tokens, 0, -1): + truncated_text_tokens = tokens[:token_count] + actual_token_count = token_count_utils.count_token(" ".join(truncated_text_tokens)) + if actual_token_count <= max_tokens: + return " ".join(truncated_text_tokens) + return "" # Return empty string if no tokens are found