Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Improves Token Limiter #2350

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class MessageTokenLimiter:
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed as well as any remaining messages are discarded.
exceeds this limit, the current message being processed get truncated to meet the total token count and any
remaining messages get discarded.
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""
Expand Down Expand Up @@ -128,13 +129,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages)

for msg in reversed(temp_messages):
msg["content"] = self._truncate_str_to_tokens(msg["content"])
msg_tokens = _count_tokens(msg["content"])
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message

# If adding this message would exceed the token limit, discard it and all remaining messages
if processed_messages_tokens + msg_tokens > self._max_tokens:
# If adding this message would exceed the token limit, truncate the last message to meet the total token
# limit and discard all remaining messages
if expected_tokens_remained < 0:
msg["content"] = self._truncate_str_to_tokens(
msg["content"], self._max_tokens - processed_messages_tokens
)
processed_messages.insert(0, msg)
break

msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
msg_tokens = _count_tokens(msg["content"])

# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)
Expand All @@ -149,30 +157,30 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

return processed_messages

def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]:
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents)
return self._truncate_tokens(contents, n_tokens)
elif isinstance(contents, list):
return self._truncate_multimodal_text(contents)
return self._truncate_multimodal_text(contents, n_tokens)
else:
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")

def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]:
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
tmp_contents = []
for content in contents:
if content["type"] == "text":
truncated_text = self._truncate_tokens(content["text"])
truncated_text = self._truncate_tokens(content["text"], n_tokens)
tmp_contents.append({"type": "text", "text": truncated_text})
else:
tmp_contents.append(content)
return tmp_contents

def _truncate_tokens(self, text: str) -> str:
def _truncate_tokens(self, text: str, tokens: int) -> str:
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer

encoded_tokens = encoding.encode(text)
truncated_tokens = encoded_tokens[: self._max_tokens_per_message]
truncated_tokens = encoded_tokens[:tokens]
truncated_text = encoding.decode(truncated_tokens) # Decode back to text

return truncated_text
Expand Down
Loading