Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Min tokens in token limiter #2400

Merged
merged 20 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3b8586f
Add minimum token threshold in MessageHistoryLimiter
giorgossideris Apr 16, 2024
b55a68d
Update transforms tests for the threshold
giorgossideris Apr 16, 2024
64737d6
Move min_threshold_tokens from Message to Token Limiter
giorgossideris Apr 17, 2024
509f4bc
Optimize _check_tokens_threshold
giorgossideris Apr 18, 2024
1997ad9
Apply requested changes (renaming, phrasing, validations)
giorgossideris Apr 18, 2024
4d0759d
Merge branch 'min_tokens_in_history_limiter' of https://github.com/gi…
giorgossideris Apr 18, 2024
93b10d0
Fix format
giorgossideris Apr 18, 2024
f3e1284
Fix _check_tokens_threshold logic
giorgossideris Apr 18, 2024
73ff56a
Update docs and notebook
giorgossideris Apr 18, 2024
763e888
Improve phrasing
giorgossideris Apr 19, 2024
5437731
Add min_tokens example in notebook
giorgossideris Apr 19, 2024
fd38a5c
Add min_tokens example in website docs
giorgossideris Apr 19, 2024
d14e438
Add min_tokens example in notebook
giorgossideris Apr 19, 2024
7b44d9f
Update website docs to be in sync with get_logs change
giorgossideris Apr 20, 2024
899c472
Merge branch 'main' into min_tokens_in_history_limiter
WaelKarkoub Apr 22, 2024
4cd0b01
Merge branch 'main' into min_tokens_in_history_limiter
giorgossideris Apr 23, 2024
6b05ab3
Merge branch 'main' into min_tokens_in_history_limiter
WaelKarkoub Apr 24, 2024
3d232de
Merge branch 'main' into min_tokens_in_history_limiter
giorgossideris Apr 26, 2024
df20e0b
Merge branch 'main' into min_tokens_in_history_limiter
giorgossideris Apr 28, 2024
3b949d4
Merge branch 'main' into min_tokens_in_history_limiter
sonichi Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class MessageHistoryLimiter:
def __init__(self, max_messages: Optional[int] = None):
"""
Args:
max_messages (None or int): Maximum number of messages to keep in the context.
Must be greater than 0 if not None.
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
Expand All @@ -70,6 +69,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
Returns:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""

if self._max_messages is None:
return messages

Expand Down Expand Up @@ -108,20 +108,23 @@ class MessageTokenLimiter:

The truncation process follows these steps in order:

1. Messages are processed in reverse order (newest to oldest).
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
are less than this threshold, then the messages are returned as is. In other case, the following process is applied.
2. Messages are processed in reverse order (newest to oldest).
3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed get truncated to meet the total token count and any
remaining messages get discarded.
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""

def __init__(
self,
max_tokens_per_message: Optional[int] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
):
"""
Expand All @@ -130,11 +133,14 @@ def __init__(
Must be greater than or equal to 0 if not None.
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
Must be greater than or equal to 0 if not None.
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
Expand All @@ -147,6 +153,11 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not self._are_min_tokens_reached(messages):
return messages

temp_messages = copy.deepcopy(messages)
processed_messages = []
Expand Down Expand Up @@ -194,6 +205,19 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
return logs_str, True
return "No tokens were truncated.", False

def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
"""
Returns True if no minimum tokens restrictions are applied.

Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
or no minimum tokens threshold is set.
"""
if not self._min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= self._min_tokens

def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
Expand Down Expand Up @@ -244,6 +268,15 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int

return max_tokens if max_tokens is not None else sys.maxsize

def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
raise ValueError("min_tokens must be None or greater than or equal to 0.")
if max_tokens is not None and min_tokens > max_tokens:
raise ValueError("min_tokens must not be more than max_tokens.")
return min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
Expand Down
Loading
Loading