Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resyncing to base #3

Merged
merged 11 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ jobs:
pip install pytest-cov>=5
- name: Install packages and dependencies for Transform Messages
run: |
pip install -e .
pip install -e '.[long-context]'
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lfs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ jobs:
uses: actions/checkout@v4
with:
lfs: true
- name: Check Git LFS files for consistency
- name: "Check Git LFS files for consistency, if you see error like 'pointer: unexpectedGitObject ... should have been a pointer but was not', please install Git LFS locally, delete the problematic file, and then add it back again. This ensures it's properly tracked."
run: |
git lfs fsck
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,17 @@ In addition, you can find:
}
```

[AgentOptimizer](https://arxiv.org/pdf/2402.11359)

```
@article{zhang2024training,
title={Training Language Model Agents without Modifying Language Models},
author={Zhang, Shaokun and Zhang, Jieyu and Liu, Jiale and Song, Linxin and Wang, Chi and Krishna, Ranjay and Wu, Qingyun},
journal={ICML'24},
year={2024}
}
```

<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: blue; font-weight: bold;">
↑ Back to Top ↑
Expand Down
68 changes: 68 additions & 0 deletions autogen/agentchat/contrib/capabilities/text_compressors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Optional, Protocol

IMPORT_ERROR: Optional[Exception] = None
try:
import llmlingua
except ImportError:
IMPORT_ERROR = ImportError(
"LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
)
PromptCompressor = object
else:
from llmlingua import PromptCompressor


class TextCompressor(Protocol):
"""Defines a protocol for text compression to optimize agent interactions."""

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
"""This method takes a string as input and returns a dictionary containing the compressed text and other
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
"""
...


class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.

NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
and the specific configurations used for the PromptCompressor.
"""

def __init__(
self,
prompt_compressor_kwargs: Dict = dict(
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2=True,
device_map="cpu",
),
structured_compression: bool = False,
) -> None:
"""
Args:
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2 set to True, and device_map set to "cpu".
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
is used. Defaults to False.
dictionary.

Raises:
ImportError: If the llmlingua library is not installed.
"""
if IMPORT_ERROR:
raise IMPORT_ERROR

self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)

assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
self._compression_method = (
self._prompt_compressor.structured_compress_prompt
if structured_compression
else self._prompt_compressor.compress_prompt
)

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
return self._compression_method([text], **compression_params)
178 changes: 163 additions & 15 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import copy
import json
import sys
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import tiktoken
from termcolor import colored

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache

from .text_compressors import LLMLingua, TextCompressor


class MessageTransform(Protocol):
Expand Down Expand Up @@ -156,7 +160,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not self._are_min_tokens_reached(messages):
if not _min_tokens_reached(messages, self._min_tokens):
return messages

temp_messages = copy.deepcopy(messages)
Expand Down Expand Up @@ -205,19 +209,6 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
return logs_str, True
return "No tokens were truncated.", False

def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
"""
Returns True if no minimum tokens restrictions are applied.

Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
or no minimum tokens threshold is set.
"""
if not self._min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= self._min_tokens

def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
Expand Down Expand Up @@ -268,7 +259,7 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int

return max_tokens if max_tokens is not None else sys.maxsize

def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
Expand All @@ -278,6 +269,154 @@ def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
return min_tokens


class TextMessageCompressor:
"""A transform for compressing text messages in a conversation history.

It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
processing and response generation by downstream models.
"""

def __init__(
self,
text_compressor: Optional[TextCompressor] = None,
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
):
"""
Args:
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
protocol. If None, it defaults to LLMLingua.
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
than or equal to 0 if not None. If None, no threshold-based compression is applied.
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
"""

if text_compressor is None:
text_compressor = LLMLingua()

self._validate_min_tokens(min_tokens)

self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._cache = cache

# Optimizing savings calculations to optimize log generation
self._recent_tokens_savings = 0

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies compression to messages in a conversation history based on the specified configuration.

The function processes each message according to the `compression_args` and `min_tokens` settings, applying
the specified compression configuration and returning a new list of messages with reduced token counts
where possible.

Args:
messages (List[Dict]): A list of message dictionaries to be compressed.

Returns:
List[Dict]: A list of dictionaries with the message content compressed according to the configured
method and scope.
"""
# Make sure there is at least one message
if not messages:
return messages

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
return messages

total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
continue

if _is_content_text_empty(message["content"]):
continue

cached_content = self._cache_get(message["content"])
if cached_content is not None:
savings, compressed_content = cached_content
else:
savings, compressed_content = self._compress(message["content"])

self._cache_set(message["content"], compressed_content, savings)

message["content"] = compressed_content
total_savings += savings

self._recent_tokens_savings = total_savings
return processed_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
if self._recent_tokens_savings > 0:
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
else:
return "No tokens saved with text compression.", False

def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content

def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
tokens_saved += savings
return tokens_saved, content

def _compress_text(self, text: str) -> Tuple[int, str]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)

savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]

return savings, compressed_text["compressed_prompt"]

def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value

def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"

def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")


def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
Expand All @@ -286,3 +425,12 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False
8 changes: 7 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ def __init__(
)
# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
llm_config = copy.deepcopy(llm_config)
try:
llm_config = copy.deepcopy(llm_config)
except TypeError as e:
raise TypeError(
"Please implement __deepcopy__ method for each value class in llm_config to support deepcopy."
" Refer to the docs for more details: https://microsoft.github.io/autogen/docs/topics/llm_configuration#adding-http-client-in-llm_config-for-proxy"
) from e

self._validate_llm_config(llm_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public async Task ChatWithAnAgent(IStreamingAgent agent)

#region ChatWithAnAgent_GenerateStreamingReplyAsync
var textMessage = new TextMessage(Role.User, "Hello");
await foreach (var streamingReply in await agent.GenerateStreamingReplyAsync([message]))
await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message]))
{
if (streamingReply is TextMessageUpdate update)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public async Task StreamingCallCodeSnippetAsync()
IStreamingAgent agent = default;
#region StreamingCallCodeSnippet
var helloTextMessage = new TextMessage(Role.User, "Hello");
var reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name);
await foreach (var message in reply)
{
Expand All @@ -24,7 +24,7 @@ public async Task StreamingCallCodeSnippetAsync()
#endregion StreamingCallCodeSnippet

#region StreamingCallWithFinalMessage
reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
TextMessage finalMessage = null;
await foreach (var message in reply)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public async Task CreateMistralAIClientAsync()
#endregion create_mistral_agent

#region streaming_chat
var reply = await agent.GenerateStreamingReplyAsync(
var reply = agent.GenerateStreamingReplyAsync(
messages: [new TextMessage(Role.User, "Hello, how are you?")]
);

Expand Down Expand Up @@ -75,7 +75,7 @@ public async Task MistralAIChatAgentGetWeatherToolUsageAsync()
#endregion create_get_weather_function_call_middleware

#region register_function_call_middleware
agent = agent.RegisterMiddleware(functionCallMiddleware);
agent = agent.RegisterStreamingMiddleware(functionCallMiddleware);
#endregion register_function_call_middleware

#region send_message_with_function_call
Expand Down
Loading
Loading