Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix message history limiter for tool call #3178

Merged
merged 10 commits into from
Aug 9, 2024
17 changes: 15 additions & 2 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ class MessageHistoryLimiter:
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""

def __init__(self, max_messages: Optional[int] = None):
def __init__(self, max_messages: Optional[int] = None, keep_first_message: Optional[bool] = False):
marklysze marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
keep_first_message Optional[bool]: Whether to keep the original first message in the conversation history.
Defaults to False. Does not count towards truncation.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
self._keep_first_message = keep_first_message

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Truncates the conversation history to the specified maximum number of messages.
Expand All @@ -78,7 +81,17 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
if self._max_messages is None:
return messages

return messages[-self._max_messages :]
truncated_messages = messages[-self._max_messages :]
# If the last message is a tool message, include its preceding message that must be a tool_calls message
if truncated_messages[0].get("role") == "tool":
start_index = max(-self._max_messages - 1, -len(messages))
truncated_messages = messages[start_index:]

# Keep the first message if required
if self._keep_first_message and messages[0] != truncated_messages[0]:
truncated_messages = [messages[0]] + truncated_messages

return truncated_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_len = len(pre_transform_messages)
Expand Down
27 changes: 22 additions & 5 deletions test/agentchat/contrib/capabilities/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ def get_no_content_messages() -> List[Dict]:
return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]


def get_tool_messages() -> List[Dict]:
return [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "user", "content": "how are you"},
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
]


def get_text_compressors() -> List[TextCompressor]:
compressors: List[TextCompressor] = [_MockTextCompressor()]
try:
Expand Down Expand Up @@ -96,19 +106,24 @@ def _filter_dict_test(

@pytest.mark.parametrize(
"messages, expected_messages_len",
[(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)],
[(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2), (get_tool_messages(), 4)],
)
def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len):
transformed_messages = message_history_limiter.apply_transform(messages)
assert len(transformed_messages) == expected_messages_len

if messages == get_tool_messages():
assert transformed_messages[0]["role"] == "tool_calls"
assert transformed_messages[1]["role"] == "tool"


@pytest.mark.parametrize(
"messages, expected_logs, expected_effect",
[
(get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True),
(get_short_messages(), "No messages were removed.", False),
(get_no_content_messages(), "No messages were removed.", False),
(get_tool_messages(), "Removed 1 messages. Number of messages reduced from 5 to 4.", True),
],
)
def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect):
Expand Down Expand Up @@ -272,24 +287,26 @@ def test_text_compression_cache(text_compressor):
long_messages = get_long_messages()
short_messages = get_short_messages()
no_content_messages = get_no_content_messages()
tool_messages = get_tool_messages()
msg_history_limiter = MessageHistoryLimiter(max_messages=3)
msg_token_limiter = MessageTokenLimiter(max_tokens_per_message=3)
msg_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)

# Test Parameters
message_history_limiter_apply_transform_parameters = {
"messages": [long_messages, short_messages, no_content_messages],
"expected_messages_len": [3, 3, 2],
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
"expected_messages_len": [3, 3, 2, 4],
}

message_history_limiter_get_logs_parameters = {
"messages": [long_messages, short_messages, no_content_messages],
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
"expected_logs": [
"Removed 2 messages. Number of messages reduced from 5 to 3.",
"No messages were removed.",
"No messages were removed.",
"Removed 1 messages. Number of messages reduced from 5 to 4.",
],
"expected_effect": [True, False, False],
"expected_effect": [True, False, False, True],
}

message_token_limiter_apply_transform_parameters = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,28 @@ pprint.pprint(processed_messages)
{'content': 'very very very very very very long string', 'role': 'user'}]
```

By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages.
By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages. However, if the splitting point is between a "tool_calls" and "tool" pair, the complete pair will be included to obey the OpenAI API call constraints.

```python
max_msg_transfrom = transforms.MessageHistoryLimiter(max_messages=3)

messages = [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "user", "content": "how are you"},
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
]

processed_messages = max_msg_transfrom.apply_transform(copy.deepcopy(messages))
pprint.pprint(processed_messages)
```
```console
[{'content': 'calling_tool', 'role': 'tool_calls'},
{'content': 'tool_response', 'role': 'tool'},
{'content': 'how are you', 'role': 'user'},
{'content': [{'text': 'are you doing?', 'type': 'text'}], 'role': 'assistant'}]
```

#### Example 2: Limiting the Number of Tokens

Expand Down
Loading