Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compression agent #131

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
791019c
add compression agent
yiranwu0 Oct 6, 2023
2ff2263
add changes from new updates
yiranwu0 Oct 6, 2023
400943e
fix workflow failure
yiranwu0 Oct 6, 2023
556b80a
update
yiranwu0 Oct 6, 2023
e4c4b91
update
yiranwu0 Oct 6, 2023
a9ce4c7
Merge remote-tracking branch 'origin/main' into compression
yiranwu0 Oct 7, 2023
a81341d
update
yiranwu0 Oct 7, 2023
f872a6f
update
yiranwu0 Oct 7, 2023
1938d75
Merge remote-tracking branch 'origin/main' into compression
yiranwu0 Oct 8, 2023
0d4d7dd
add test
yiranwu0 Oct 8, 2023
e95797f
update
yiranwu0 Oct 8, 2023
cb3df56
update
yiranwu0 Oct 10, 2023
f6ebf11
Merge branch 'main' into compression
yiranwu0 Oct 11, 2023
bf88321
Merge branch 'compression' of github.com:microsoft/autogen into compr…
yiranwu0 Oct 11, 2023
e09464c
Merge remote-tracking branch 'origin/main' into compression
yiranwu0 Oct 11, 2023
0e9168a
update
yiranwu0 Oct 11, 2023
1a63430
Merge remote-tracking branch 'origin/main' into compression
yiranwu0 Oct 12, 2023
4109f58
update to resolve comments
yiranwu0 Oct 12, 2023
a60dd8e
clean up
yiranwu0 Oct 12, 2023
bbc90e5
update
yiranwu0 Oct 12, 2023
19551d5
Merge branch 'main' into compression
thinkall Oct 16, 2023
26c8cf7
fix bug, revise prompt
yiranwu0 Oct 17, 2023
5e826e9
Merge branch 'compression' of github.com:microsoft/autogen into compr…
yiranwu0 Oct 17, 2023
ffb314c
Merge remote-tracking branch 'origin/main' into compression
yiranwu0 Oct 17, 2023
ec9ff7f
fix bug, remove affect to originally functionality
yiranwu0 Oct 17, 2023
f64a074
Merge branch 'main' into compression
yiranwu0 Oct 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions autogen/agentchat/contrib/compression_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Callable, Dict, Optional, Union, Tuple, List, Any
from autogen import oai
from autogen import Agent, ConversableAgent
import copy
from autogen.token_count_utils import count_token

try:
from termcolor import colored
except ImportError:

def colored(x, *args, **kwargs):
return x


class CompressionAgent(ConversableAgent):
"""(Experimental) Compression agent, designed to compress a list of messages.

CompressionAgent is a subclass of ConversableAgent configured with a default system message.
The default system message is designed to compress chat history.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code or function call by default.
"""

DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant that will compress messages.
Rules:
1. Please summarize each of the message and reserve the titles: ##USER##, ##ASSISTANT##, ##FUNCTION_CALL##, ##FUNCTION_RETURN##, ##SYSTEM##, ##<Name>(<Title>)## (e.g. ##Bob(ASSISTANT)##).
2. Context after ##USER##, ##ASSISTANT## (and ##<Name>(<Title>)##): compress the content and reserve important information. If there is big chunk of code, please use ##CODE## to indicate and summarize what the code is doing with as few words as possible and include details like exact numbers and defined variables.
3. Context after ##FUNCTION_CALL##: Keep the exact content if it is short. Otherwise, summarize/compress it and reserve names (func_name, argument names).
4. Context after ##FUNCTION_RETURN## (or code return): Keep the exact content if it is short. Summarize/compress if it is too long, you should note what the function has achieved and what the return value is.
"""

def __init__(
self,
name: str = "compressor",
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
llm_config: Optional[Union[Dict, bool]] = None,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
code_execution_config: Optional[Union[Dict, bool]] = False,
**kwargs,
):
"""
Args:
name (str): agent name.
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
llm_config (dict): llm inference configuration.
Please refer to [Completion.create](/docs/reference/oai/completion#create)
for available options.
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS".
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](../conversable_agent#__init__).
"""
super().__init__(
name,
system_message,
is_termination_msg,
max_consecutive_auto_reply,
human_input_mode,
code_execution_config=code_execution_config,
llm_config=llm_config,
**kwargs,
)

self._reply_func_list.clear()
self.register_reply([Agent, None], CompressionAgent.generate_compressed_reply)

def generate_compressed_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None, List]]:
"""Compress a list of messages into one message.

The first message (the initial prompt) will not be compressed.
The rest of the messages will be compressed into one message, the model is asked to distinuish the role of each message: USER, ASSISTANT, FUNCTION_CALL, FUNCTION_RETURN.
Check out the DEFAULT_SYSTEM_MESSAGE prompt above.

TODO: model used in compression agent is different from assistant agent: For example, if original model used by is gpt-4; we start compressing at 70% of usage, 70% of 8092 = 5664; and we use gpt 3.5 here max_toke = 4096, it will raise error. choosinng model automatically?
"""
# Uncomment the following line to check the content to compress
print(colored("*" * 30 + "Start compressing the following content:" + "*" * 30, "magenta"), flush=True)

# 1. use passed-in config and messages
# in function on_oai_limit of conversable agent, we will pass in llm_config from "config" parameter.
llm_config = copy.deepcopy(self.llm_config) if config is None else copy.deepcopy(config)
# remove functions from llm_config
if "functions" in llm_config:
del llm_config["functions"]

if llm_config is False:
return False, None
if messages is None:
messages = self._oai_messages[sender]

# 2. stop if there is only one message in the list
if len(messages) <= 1:
print(f"Warning: the first message contains {count_token(messages)} tokens, which will not be compressed.")
return False, None

# 3. put all history into one, except the first one
compressed_prompt = "Below is the compressed content from the previous conversation, evaluate the process and continue if necessary:\n"
chat_to_compress = "To be compressed:\n"
start_index = 1
for m in messages[start_index:]:
if m.get("role") == "function":
chat_to_compress += f"##FUNCTION_RETURN## (from function \"{m['name']}\"): \n{m['content']}\n"
else:
if "name" in m:
# {"name" : "Bob", "role" : "assistant"} -> ##Bob(ASSISTANT)##
chat_to_compress += f"##{m['name']}({m['role'].upper()})## {m['content']}\n"
elif m["content"] is not None:
if compressed_prompt in m["content"]:
# remove the compressed_prompt from the content
tmp = m["content"].replace(compressed_prompt, "")
chat_to_compress += f"{tmp}\n"
else:
chat_to_compress += f"##{m['role'].upper()}## {m['content']}\n"

if "function_call" in m:
if (
m["function_call"].get("name", None) is None
or m["function_call"].get("arguments", None) is None
):
chat_to_compress += f"##FUNCTION_CALL## {m['function_call']}\n"
else:
chat_to_compress += f"##FUNCTION_CALL## \nName: {m['function_call']['name']}\nArgs: {m['function_call']['arguments']}\n"

chat_to_compress = [{"role": "user", "content": chat_to_compress}]
# Uncomment the following line to check the content to compress
print(chat_to_compress[0]["content"])

# 4. ask LLM to compress
try:
response = oai.ChatCompletion.create(
context=None, messages=self._oai_system_message + chat_to_compress, **llm_config
)
except Exception as e:
print(f"Warning: Failed to compress the content due to {e}.")
return False, None
compressed_message = oai.ChatCompletion.extract_text_or_function_call(response)[0]
print(
colored(
"*" * 30 + "Content after compressing: (type=" + str(type(compressed_message)) + ")" + "*" * 30,
"magenta",
),
flush=True,
)
print(compressed_message, colored("\n" + "*" * 80, "magenta"))

assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
# 5. add compressed message to the first message and return
return True, [
messages[0],
{
"content": compressed_prompt + compressed_message,
"role": "system",
},
]
126 changes: 120 additions & 6 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
infer_lang,
)

from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions

try:
from termcolor import colored
except ImportError:
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
code_execution_config: Optional[Union[Dict, bool]] = None,
llm_config: Optional[Union[Dict, bool]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
compress_config: Optional[Dict] = False,
):
"""
Args:
Expand Down Expand Up @@ -96,6 +99,15 @@ def __init__(
Please refer to [Completion.create](/docs/reference/oai/completion#create)
for available options.
To disable llm-based auto reply, set to False.
compress_config (dict or False): config for compression before oai_reply. Default to None, meaning no compression will be used and
the conversation will terminate when the token count exceeds the limit. You should contain the following keys:
- "mode" (Optional, str, default to "COMPRESS"): Choose from ["COMPRESS", "TERMINATE"]. "COMPRESS": enable the compression agent.
"TERMINATE": terminate the conversation when the token count exceeds the limit.
- "agent" (Optional, "Agent", default CompressionAgent): the agent to call before oai_reply. the `generate_reply` method from this Agent will be called.
- "trigger_count" (Optional, float, int, default to 0.7): the threshold to trigger compression.
If a float between (0, 1], it is the percentage of token used. if a int, it is the number of tokens used.
- "async" (Optional, bool, default to False): whether to compress asynchronously.
- "broadcast" (Optional, bool, default to True): whether to update the compressed message history to sender.
default_auto_reply (str or dict or None): default auto reply when no code execution or llm-based reply is generated.
"""
super().__init__(name)
Expand Down Expand Up @@ -123,7 +135,43 @@ def __init__(
self._default_auto_reply = default_auto_reply
self._reply_func_list = []
self.reply_at_receive = defaultdict(bool)

if compress_config and self.llm_config:
if compress_config is True:
self.compress_config = {}
if not isinstance(compress_config, dict):
raise ValueError("compress_config must be a dict or 'False'.")

# convert trigger_count to int, default to 0.7
trigger_count = compress_config.get("trigger_count", 0.7)
if isinstance(trigger_count, float) and 0 < trigger_count < 1:
trigger_count = int(trigger_count * get_max_token_limit(self.llm_config["model"]))
else:
trigger_count = int(trigger_count)

assert compress_config.get("mode", "COMPRESS") in [
"COMPRESS",
"TERMINATE",
], "compress_config['mode'] must be 'COMPRESS' or 'TERMINATE'"
if compress_config.get("mode", "COMPRESS") == "TERMINATE":
self.compress_config = compress_config
else:
from .contrib.compression_agent import CompressionAgent

self.compress_config = {
"mode": "COMPRESS",
"agent": compress_config.get(
"agent", CompressionAgent(llm_config=llm_config)
), # TODO: llm_config to pass in here?
"trigger_count": trigger_count,
"async": compress_config.get("async", False), # TODO: support async compression
"broadcast": compress_config.get("broadcast", True),
}
else:
self.compress_config = False

self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.on_oai_token_limit) # check token limit
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
Expand Down Expand Up @@ -608,6 +656,78 @@ def generate_oai_reply(
)
return True, oai.ChatCompletion.extract_text_or_function_call(response)[0]

def compute_init_token_count(self):
"""Check if the agent is LLM-based and compute the initial token count."""
if self.llm_config is False:
return 0

func_count = 0
if "functions" in self.llm_config:
func_count = num_tokens_from_functions(self.llm_config["functions"], self.llm_config["model"])

return func_count + count_token(self._oai_system_message, self.llm_config["model"])

def on_oai_token_limit(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""(Experimental) Compress previous messages when a threshold of tokens is reached."""
llm_config = self.llm_config if config is None else config
if llm_config is False or self.compress_config is False:
# Only apply when this is a LLM-based agent (has llm_config), and compression is enabled.
return False, None
if messages is None:
messages = self._oai_messages[sender]

# if mode is TERMINATE, terminate the agent if no token left.
token_used = self.compute_init_token_count() + count_token(messages, llm_config["model"])
max_token = max(get_max_token_limit(llm_config["model"]), llm_config.get("max_token", 0))
if self.compress_config["mode"] == "TERMINATE":
if max_token - token_used <= 0:
# Teminate if no token left.
print(
colored(
f"Warning: Terminate Agent \"{self.name}\" due to no token left for oai reply. max token for {llm_config['model']}: {max_token}, existed token count: {token_used}",
"yellow",
),
flush=True,
)
return True, None
return False, None

# on_oai_token_limit requires a sender. Otherwise, the compressed messages cannot be saved
if sender is None:
return False, None

# if token_used is less than trigger_count, no compression will be used.
if token_used < self.compress_config["trigger_count"]:
return False, None

if self.compress_config["async"]:
# TODO: async compress
pass

yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
compressed_messages = self.compress_config["agent"].generate_reply(messages, None)
if compressed_messages is not None:
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
# TODO: maintain a list for old oai messages (messages before compression)
to_print = "Token Count (of msgs after first prompt): Before compression: " + str(
count_token(self._oai_messages[sender][1:], llm_config["model"])
) + " After: " + str(
count_token(compressed_messages[1:], llm_config["model"])
) + " | " "Total prompt token count after compression: " + str(
count_token(compressed_messages, llm_config["model"]) + self.compute_init_token_count()
)
print(colored(to_print, "magenta"), flush=True)
print("-" * 80)

self._oai_messages[sender] = compressed_messages
if self.compress_config["broadcast"]:
sender._oai_messages[self] = copy.deepcopy(compressed_messages)

return False, None

def generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down Expand Up @@ -768,9 +888,6 @@ def generate_reply(
logger.error(error_msg)
raise AssertionError(error_msg)

if messages is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I removed the two lines here (need to confirm):

        if messages is None:
            messages = self._oai_messages[sender]

Why we can delete this:
The two lines deleted is in every generate_<>_reply function. So when both messages and sender are passed to a subsequent generate_<>_reply, it will perform the same logic.

final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])

Why needed for compression: Compression will modify self._oai_messages, and it is expected that generate_oai_reply will use the updated messages from self._oai_messages. With the two lines, the messages will not be None and the updated self._oai_messages will not be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both justifications are not clear to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to only remove this when compress_config is provided. This ensures the original logic is unchanged in the default seting

messages = self._oai_messages[sender]

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
Expand Down Expand Up @@ -819,9 +936,6 @@ async def a_generate_reply(
logger.error(error_msg)
raise AssertionError(error_msg)

if messages is None:
messages = self._oai_messages[sender]

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
Expand Down
Loading
Loading