Skip to content

Commit

Permalink
[Fix] Improves Token Limiter (#2350)
Browse files Browse the repository at this point in the history
* improves token limiter

* improve docstr

* rename arg
  • Loading branch information
WaelKarkoub authored Apr 11, 2024
1 parent 72bd0bd commit 97b5433
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class MessageTokenLimiter:
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed as well as any remaining messages are discarded.
exceeds this limit, the current message being processed get truncated to meet the total token count and any
remaining messages get discarded.
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""
Expand Down Expand Up @@ -128,13 +129,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages)

for msg in reversed(temp_messages):
msg["content"] = self._truncate_str_to_tokens(msg["content"])
msg_tokens = _count_tokens(msg["content"])
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message

# If adding this message would exceed the token limit, discard it and all remaining messages
if processed_messages_tokens + msg_tokens > self._max_tokens:
# If adding this message would exceed the token limit, truncate the last message to meet the total token
# limit and discard all remaining messages
if expected_tokens_remained < 0:
msg["content"] = self._truncate_str_to_tokens(
msg["content"], self._max_tokens - processed_messages_tokens
)
processed_messages.insert(0, msg)
break

msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
msg_tokens = _count_tokens(msg["content"])

# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)
Expand All @@ -149,30 +157,30 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

return processed_messages

def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]:
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents)
return self._truncate_tokens(contents, n_tokens)
elif isinstance(contents, list):
return self._truncate_multimodal_text(contents)
return self._truncate_multimodal_text(contents, n_tokens)
else:
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")

def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]:
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
tmp_contents = []
for content in contents:
if content["type"] == "text":
truncated_text = self._truncate_tokens(content["text"])
truncated_text = self._truncate_tokens(content["text"], n_tokens)
tmp_contents.append({"type": "text", "text": truncated_text})
else:
tmp_contents.append(content)
return tmp_contents

def _truncate_tokens(self, text: str) -> str:
def _truncate_tokens(self, text: str, n_tokens: int) -> str:
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer

encoded_tokens = encoding.encode(text)
truncated_tokens = encoded_tokens[: self._max_tokens_per_message]
truncated_tokens = encoded_tokens[:n_tokens]
truncated_text = encoding.decode(truncated_tokens) # Decode back to text

return truncated_text
Expand Down

0 comments on commit 97b5433

Please sign in to comment.