Skip to content

Commit

Permalink
refactored code to simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
dkirsche committed Feb 11, 2024
1 parent b4a2c6e commit 2b9b486
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,39 +53,48 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
messages: List of messages to process.
Returns:
List of messages with the first system message and the last max_messages messages.
List of messages with the first system message and the last max_messages messages,
ensuring each message does not exceed max_tokens_per_message.
"""
temp_messages = messages.copy()
processed_messages = []
messages = messages.copy()
rest_messages = messages
system_message = None
processed_messages_tokens = 0

# check if the first message is a system message and append it to the processed messages
if len(messages) > 0:
if messages[0]["role"] == "system":
msg = messages[0]
processed_messages.append(msg)
rest_messages = messages[1:]
if messages[0]["role"] == "system":
system_message = messages[0].copy()
temp_messages.pop(0)

processed_messages_tokens = 0
for msg in messages:
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
total_tokens = sum(
token_count_utils.count_token(msg["content"]) for msg in temp_messages
) # Calculate tokens for all messages

# Truncate each message's content to a maximum token limit of each message

# iterate through rest of the messages and append them to the processed messages
for msg in rest_messages[-self.max_messages :]:
for msg in temp_messages[-self.max_messages :]:
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
msg_tokens = token_count_utils.count_token(msg["content"])
if processed_messages_tokens + msg_tokens > self.max_tokens:
break
processed_messages.append(msg)
processed_messages_tokens += msg_tokens

total_tokens = 0
for msg in messages:
total_tokens += token_count_utils.count_token(msg["content"])

if system_message:
processed_messages.insert(0, system_message)
# Optionally, log the number of truncated messages and tokens if needed
num_truncated = len(messages) - len(processed_messages)

if num_truncated > 0 or total_tokens > processed_messages_tokens:
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
print(
colored(
f"Truncated {num_truncated} messages from {len(messages)} to {len(processed_messages)}.", "yellow"
)
)
print(
colored(
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens truncated to {processed_messages_tokens}",
"yellow",
)
)
return processed_messages


Expand Down

0 comments on commit 2b9b486

Please sign in to comment.