-->
+:fire: May 29, 2024: DeepLearning.ai launched a new short course [AI Agentic Design Patterns with AutoGen](https://www.deeplearning.ai/short-courses/ai-agentic-design-patterns-with-autogen), made in collaboration with Microsoft and Penn State University, and taught by AutoGen creators [Chi Wang](https://github.com/sonichi) and [Qingyun Wu](https://github.com/qingyun-wu).
+
+:fire: May 24, 2024: Foundation Capital published an article on [Forbes: The Promise of Multi-Agent AI](https://www.forbes.com/sites/joannechen/2024/05/24/the-promise-of-multi-agent-ai/?sh=2c1e4f454d97) and a video [AI in the Real World Episode 2: Exploring Multi-Agent AI and AutoGen with Chi Wang](https://www.youtube.com/watch?v=RLwyXRVvlNk).
+
+:fire: May 13, 2024: [The Economist](https://www.economist.com/science-and-technology/2024/05/13/todays-ai-models-are-impressive-teams-of-them-will-be-formidable) published an article about multi-agent systems (MAS) following a January 2024 interview with [Chi Wang](https://github.com/sonichi).
+
+:fire: May 11, 2024: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://openreview.net/pdf?id=uAjxFFing2) received the best paper award at the [ICLR 2024 LLM Agents Workshop](https://llmagents.github.io/).
+
+:fire: Apr 26, 2024: [AutoGen.NET](https://microsoft.github.io/autogen-for-net/) is available for .NET developers!
+
+:fire: Apr 17, 2024: Andrew Ng cited AutoGen in [The Batch newsletter](https://www.deeplearning.ai/the-batch/issue-245/) and [What's next for AI agentic workflows](https://youtu.be/sal78ACtGTc?si=JduUzN_1kDnMq0vF) at Sequoia Capital's AI Ascent (Mar 26).
+
+:fire: Mar 3, 2024: What's new in AutoGen? 📰[Blog](https://microsoft.github.io/autogen/blog/2024/03/03/AutoGen-Update); 📺[Youtube](https://www.youtube.com/watch?v=j_mtwQiaLGU).
-:warning: Jan 23: **Breaking Change in Latest Release v0.2.8** `use_docker` defaults to `True` for code-execution. See [blog post](https://microsoft.github.io/autogen/blog/2024/01/23/Code-execution-in-docker) for details and [FAQ](https://microsoft.github.io/autogen/docs/FAQ#agents-are-throwing-due-to-docker-not-running-how-can-i-resolve-this) for troubleshooting any issues.
+:fire: Mar 1, 2024: the first AutoGen multi-agent experiment on the challenging [GAIA](https://huggingface.co/spaces/gaia-benchmark/leaderboard) benchmark achieved the No. 1 accuracy in all the three levels.
-:fire: Dec 31: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155) is selected by [TheSequence: My Five Favorite AI Papers of 2023](https://thesequence.substack.com/p/my-five-favorite-ai-papers-of-2023).
+
+
+:tada: Dec 31, 2023: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155) is selected by [TheSequence: My Five Favorite AI Papers of 2023](https://thesequence.substack.com/p/my-five-favorite-ai-papers-of-2023).
-:fire: Nov 8: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff.
+:tada: Nov 8, 2023: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff from [FLAML](https://github.com/microsoft/FLAML).
-:fire: Nov 6: AutoGen is mentioned by Satya Nadella in a [fireside chat](https://youtu.be/0pLBvgYtv6U) around 13:20.
+
-:fire: Nov 1: AutoGen is the top trending repo on GitHub in October 2023.
+
-:tada: Oct 03: AutoGen spins off from FLAML on GitHub and has a major paper update (first version on Aug 16).
+
-:tada: Mar 29: AutoGen is first created in [FLAML](https://github.com/microsoft/FLAML).
+:tada: Mar 29, 2023: AutoGen is first created in [FLAML](https://github.com/microsoft/FLAML).
+
+
## What is AutoGen
AutoGen is a framework that enables the development of LLM applications using multiple agents that can converse with each other to solve tasks. AutoGen agents are customizable, conversable, and seamlessly allow human participation. They can operate in various modes that employ combinations of LLMs, human inputs, and tools.
@@ -52,7 +76,24 @@ AutoGen is a framework that enables the development of LLM applications using mu
- It provides a collection of working systems with different complexities. These systems span a [wide range of applications](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat#diverse-applications-implemented-with-autogen) from various domains and complexities. This demonstrates how AutoGen can easily support diverse conversation patterns.
- AutoGen provides [enhanced LLM inference](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification). It offers utilities like API unification and caching, and advanced usage patterns, such as error handling, multi-config inference, context programming, etc.
-AutoGen is powered by collaborative [research studies](https://microsoft.github.io/autogen/docs/Research) from Microsoft, Penn State University, and the University of Washington.
+AutoGen is created out of collaborative [research](https://microsoft.github.io/autogen/docs/Research) from Microsoft, Penn State University, and the University of Washington.
+
+
## Quickstart
The easiest way to start playing is
@@ -64,10 +105,17 @@ The easiest way to start playing is
3. Start playing with the notebooks!
*NOTE*: OAI_CONFIG_LIST_sample lists GPT-4 as the default model, as this represents our current recommendation, and is known to work well with AutoGen. If you use a model other than GPT-4, you may need to revise various system prompts (especially if using weaker models like GPT-3.5-turbo). Moreover, if you use models other than those hosted by OpenAI or Azure, you may incur additional risks related to alignment and safety. Proceed with caution if updating this default.
+
+
+
## [Installation](https://microsoft.github.io/autogen/docs/Installation)
### Option 1. Install and Run AutoGen in Docker
-Find detailed instructions for users [here](https://microsoft.github.io/autogen/docs/Installation#option-1-install-and-run-autogen-in-docker), and for developers [here](https://microsoft.github.io/autogen/docs/Contribute#docker-for-development).
+Find detailed instructions for users [here](https://microsoft.github.io/autogen/docs/installation/Docker#step-1-install-docker), and for developers [here](https://microsoft.github.io/autogen/docs/Contribute#docker-for-development).
### Option 2. Install AutoGen Locally
@@ -92,6 +140,12 @@ Even if you are installing and running AutoGen locally outside of docker, the re
For LLM inference configurations, check the [FAQs](https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints).
+
+
## Multi-Agent Conversation Framework
Autogen enables the next-gen LLM applications with a generic [multi-agent conversation](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat) framework. It offers customizable and conversable agents that integrate LLMs, tools, and humans.
@@ -131,6 +185,12 @@ The figure below shows an example conversation flow with AutoGen.
Alternatively, the [sample code](https://github.com/microsoft/autogen/blob/main/samples/simple_chat.py) here allows a user to chat with an AutoGen agent in ChatGPT style.
Please find more [code examples](https://microsoft.github.io/autogen/docs/Examples#automated-multi-agent-chat) for this feature.
+
+
## Enhanced LLM Inferences
Autogen also helps maximize the utility out of the expensive LLMs such as ChatGPT and GPT-4. It offers [enhanced LLM inference](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification) with powerful functionalities like caching, error handling, multi-config inference and templating.
@@ -154,6 +214,12 @@ response = autogen.Completion.create(context=test_instance, **config)
Please find more [code examples](https://microsoft.github.io/autogen/docs/Examples#tune-gpt-models) for this feature. -->
+
+
## Documentation
You can find detailed documentation about AutoGen [here](https://microsoft.github.io/autogen/).
@@ -162,12 +228,18 @@ In addition, you can find:
- [Research](https://microsoft.github.io/autogen/docs/Research), [blogposts](https://microsoft.github.io/autogen/blog) around AutoGen, and [Transparency FAQs](https://github.com/microsoft/autogen/blob/main/TRANSPARENCY_FAQS.md)
-- [Discord](https://discord.gg/pAbnFJrkgZ)
+- [Discord](https://aka.ms/autogen-dc)
- [Contributing guide](https://microsoft.github.io/autogen/docs/Contribute)
- [Roadmap](https://github.com/orgs/microsoft/projects/989/views/3)
+
+
## Related Papers
[AutoGen](https://arxiv.org/abs/2308.08155)
@@ -205,13 +277,30 @@ In addition, you can find:
}
```
+[AgentOptimizer](https://arxiv.org/pdf/2402.11359)
+
+```
+@article{zhang2024training,
+ title={Training Language Model Agents without Modifying Language Models},
+ author={Zhang, Shaokun and Zhang, Jieyu and Liu, Jiale and Song, Linxin and Wang, Chi and Krishna, Ranjay and Wu, Qingyun},
+ journal={ICML'24},
+ year={2024}
+}
+```
+
+
+
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit .
-If you are new to GitHub [here](https://help.github.com/categories/collaborating-with-issues-and-pull-requests/) is a detailed help source on getting involved with development on GitHub.
+If you are new to GitHub, [here](https://opensource.guide/how-to-contribute/#how-to-submit-a-contribution) is a detailed help source on getting involved with development on GitHub.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
@@ -221,11 +310,23 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope
For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
+
+
# Legal Notices
Microsoft and any contributors grant you a license to the Microsoft documentation and other content
@@ -242,3 +343,9 @@ Privacy information can be found at https://privacy.microsoft.com/en-us/
Microsoft and any contributors reserve all other rights, whether under their respective copyrights, patents,
or trademarks, whether by implication, estoppel, or otherwise.
+
+
diff --git a/TRANSPARENCY_FAQS.md b/TRANSPARENCY_FAQS.md
index fee046c619a..206af084748 100644
--- a/TRANSPARENCY_FAQS.md
+++ b/TRANSPARENCY_FAQS.md
@@ -30,6 +30,7 @@ While AutoGen automates LLM workflows, decisions about how to use specific LLM o
## How was AutoGen evaluated? What metrics are used to measure performance?
- Current version of AutoGen was evaluated on six applications to illustrate its potential in simplifying the development of high-performance multi-agent applications. These applications are selected based on their real-world relevance, problem difficulty and problem solving capabilities enabled by AutoGen, and innovative potential.
- These applications involve using AutoGen to solve math problems, question answering, decision making in text world environments, supply chain optimization, etc. For each of these domains AutoGen was evaluated on various success based metrics (i.e., how often the AutoGen based implementation solved the task). And, in some cases, AutoGen based approach was also evaluated on implementation efficiency (e.g., to track reductions in developer effort to build). More details can be found at: https://aka.ms/AutoGen/TechReport
+- The team has conducted tests where a “red” agent attempts to get the default AutoGen assistant to break from its alignment and guardrails. The team has observed that out of 70 attempts to break guardrails, only 1 was successful in producing text that would have been flagged as problematic by Azure OpenAI filters. The team has not observed any evidence that AutoGen (or GPT models as hosted by OpenAI or Azure) can produce novel code exploits or jailbreak prompts, since direct prompts to “be a hacker”, “write exploits”, or “produce a phishing email” are refused by existing filters.
## What are the limitations of AutoGen? How can users minimize the impact of AutoGen’s limitations when using the system?
AutoGen relies on existing LLMs. Experimenting with AutoGen would retain common limitations of large language models; including:
diff --git a/autogen/__init__.py b/autogen/__init__.py
index 3002ad5df8e..02f956c4bcf 100644
--- a/autogen/__init__.py
+++ b/autogen/__init__.py
@@ -1,9 +1,10 @@
import logging
-from .version import __version__
-from .oai import *
+
from .agentchat import *
from .code_utils import DEFAULT_MODEL, FAST_MODEL
-
+from .exception_utils import *
+from .oai import *
+from .version import __version__
# Set the root logger.
logger = logging.getLogger(__name__)
diff --git a/autogen/_pydantic.py b/autogen/_pydantic.py
index 89dbc4fd291..c463dbb3875 100644
--- a/autogen/_pydantic.py
+++ b/autogen/_pydantic.py
@@ -13,7 +13,7 @@
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
from pydantic.json_schema import JsonSchemaValue
- def type2schema(t: Optional[Type]) -> JsonSchemaValue:
+ def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
@@ -51,11 +51,11 @@ def model_dump_json(model: BaseModel) -> str:
# Remove this once we drop support for pydantic 1.x
else: # pragma: no cover
from pydantic import schema_of
- from pydantic.typing import evaluate_forwardref as evaluate_forwardref
+ from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
- JsonSchemaValue = Dict[str, Any]
+ JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
- def type2schema(t: Optional[Type]) -> JsonSchemaValue:
+ def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
@@ -64,27 +64,27 @@ def type2schema(t: Optional[Type]) -> JsonSchemaValue:
Returns:
JsonSchemaValue: The JSON schema
"""
- if PYDANTIC_V1:
- if t is None:
- return {"type": "null"}
- elif get_origin(t) is Union:
- return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
- elif get_origin(t) in [Tuple, tuple]:
- prefixItems = [type2schema(tt) for tt in get_args(t)]
- return {
- "maxItems": len(prefixItems),
- "minItems": len(prefixItems),
- "prefixItems": prefixItems,
- "type": "array",
- }
-
- d = schema_of(t)
- if "title" in d:
- d.pop("title")
- if "description" in d:
- d.pop("description")
-
- return d
+
+ if t is None:
+ return {"type": "null"}
+ elif get_origin(t) is Union:
+ return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
+ elif get_origin(t) in [Tuple, tuple]:
+ prefixItems = [type2schema(tt) for tt in get_args(t)]
+ return {
+ "maxItems": len(prefixItems),
+ "minItems": len(prefixItems),
+ "prefixItems": prefixItems,
+ "type": "array",
+ }
+ else:
+ d = schema_of(t)
+ if "title" in d:
+ d.pop("title")
+ if "description" in d:
+ d.pop("description")
+
+ return d
def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict
diff --git a/autogen/agent_utils.py b/autogen/agent_utils.py
deleted file mode 100644
index 431d03c78d0..00000000000
--- a/autogen/agent_utils.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from typing import List, Dict, Tuple
-from autogen import Agent
-
-
-def gather_usage_summary(agents: List[Agent]) -> Tuple[Dict[str, any], Dict[str, any]]:
- """Gather usage summary from all agents.
-
- Args:
- agents: (list): List of agents.
-
- Returns:
- tuple: (total_usage_summary, actual_usage_summary)
-
- Example return:
- total_usage_summary = {
- 'total_cost': 0.0006090000000000001,
- 'gpt-35-turbo':
- {
- 'cost': 0.0006090000000000001,
- 'prompt_tokens': 242,
- 'completion_tokens': 123,
- 'total_tokens': 365
- }
- }
- `actual_usage_summary` follows the same format.
- If none of the agents incurred any cost (not having a client), then the total_usage_summary and actual_usage_summary will be {'total_cost': 0}.
- """
-
- def aggregate_summary(usage_summary: Dict[str, any], agent_summary: Dict[str, any]) -> None:
- if agent_summary is None:
- return
- usage_summary["total_cost"] += agent_summary.get("total_cost", 0)
- for model, data in agent_summary.items():
- if model != "total_cost":
- if model not in usage_summary:
- usage_summary[model] = data.copy()
- else:
- usage_summary[model]["cost"] += data.get("cost", 0)
- usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0)
- usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
- usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
-
- total_usage_summary = {"total_cost": 0}
- actual_usage_summary = {"total_cost": 0}
-
- for agent in agents:
- if agent.client:
- aggregate_summary(total_usage_summary, agent.client.total_usage_summary)
- aggregate_summary(actual_usage_summary, agent.client.actual_usage_summary)
-
- return total_usage_summary, actual_usage_summary
diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py
index 52cf15b050c..d31a59d98fb 100644
--- a/autogen/agentchat/__init__.py
+++ b/autogen/agentchat/__init__.py
@@ -1,8 +1,10 @@
from .agent import Agent
from .assistant_agent import AssistantAgent
+from .chat import ChatResult, initiate_chats
from .conversable_agent import ConversableAgent, register_function
from .groupchat import GroupChat, GroupChatManager
from .user_proxy_agent import UserProxyAgent
+from .utils import gather_usage_summary
__all__ = (
"Agent",
@@ -12,4 +14,7 @@
"GroupChat",
"GroupChatManager",
"register_function",
+ "initiate_chats",
+ "gather_usage_summary",
+ "ChatResult",
)
diff --git a/autogen/agentchat/agent.py b/autogen/agentchat/agent.py
index b83709dc30b..410635bce6e 100644
--- a/autogen/agentchat/agent.py
+++ b/autogen/agentchat/agent.py
@@ -1,70 +1,136 @@
-from typing import Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Protocol, Union, runtime_checkable
-class Agent:
- """(In preview) An abstract class for AI agent.
+@runtime_checkable
+class Agent(Protocol):
+ """(In preview) A protocol for Agent.
An agent can communicate with other agents and perform actions.
Different agents can differ in what actions they perform in the `receive` method.
"""
- def __init__(
+ @property
+ def name(self) -> str:
+ """The name of the agent."""
+ ...
+
+ @property
+ def description(self) -> str:
+ """The description of the agent. Used for the agent's introduction in
+ a group chat setting."""
+ ...
+
+ def send(
self,
- name: str,
- ):
- """
+ message: Union[Dict[str, Any], str],
+ recipient: "Agent",
+ request_reply: Optional[bool] = None,
+ ) -> None:
+ """Send a message to another agent.
+
Args:
- name (str): name of the agent.
+ message (dict or str): the message to send. If a dict, it should be
+ a JSON-serializable and follows the OpenAI's ChatCompletion schema.
+ recipient (Agent): the recipient of the message.
+ request_reply (bool): whether to request a reply from the recipient.
"""
- # a dictionary of conversations, default value is list
- self._name = name
+ ...
- @property
- def name(self):
- """Get the name of the agent."""
- return self._name
+ async def a_send(
+ self,
+ message: Union[Dict[str, Any], str],
+ recipient: "Agent",
+ request_reply: Optional[bool] = None,
+ ) -> None:
+ """(Async) Send a message to another agent.
- def send(self, message: Union[Dict, str], recipient: "Agent", request_reply: Optional[bool] = None):
- """(Abstract method) Send a message to another agent."""
+ Args:
+ message (dict or str): the message to send. If a dict, it should be
+ a JSON-serializable and follows the OpenAI's ChatCompletion schema.
+ recipient (Agent): the recipient of the message.
+ request_reply (bool): whether to request a reply from the recipient.
+ """
+ ...
- async def a_send(self, message: Union[Dict, str], recipient: "Agent", request_reply: Optional[bool] = None):
- """(Abstract async method) Send a message to another agent."""
+ def receive(
+ self,
+ message: Union[Dict[str, Any], str],
+ sender: "Agent",
+ request_reply: Optional[bool] = None,
+ ) -> None:
+ """Receive a message from another agent.
- def receive(self, message: Union[Dict, str], sender: "Agent", request_reply: Optional[bool] = None):
- """(Abstract method) Receive a message from another agent."""
+ Args:
+ message (dict or str): the message received. If a dict, it should be
+ a JSON-serializable and follows the OpenAI's ChatCompletion schema.
+ sender (Agent): the sender of the message.
+ request_reply (bool): whether the sender requests a reply.
+ """
- async def a_receive(self, message: Union[Dict, str], sender: "Agent", request_reply: Optional[bool] = None):
- """(Abstract async method) Receive a message from another agent."""
+ async def a_receive(
+ self,
+ message: Union[Dict[str, Any], str],
+ sender: "Agent",
+ request_reply: Optional[bool] = None,
+ ) -> None:
+ """(Async) Receive a message from another agent.
- def reset(self):
- """(Abstract method) Reset the agent."""
+ Args:
+ message (dict or str): the message received. If a dict, it should be
+ a JSON-serializable and follows the OpenAI's ChatCompletion schema.
+ sender (Agent): the sender of the message.
+ request_reply (bool): whether the sender requests a reply.
+ """
+ ...
def generate_reply(
self,
- messages: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
- **kwargs,
- ) -> Union[str, Dict, None]:
- """(Abstract method) Generate a reply based on the received messages.
+ **kwargs: Any,
+ ) -> Union[str, Dict[str, Any], None]:
+ """Generate a reply based on the received messages.
Args:
- messages (list[dict]): a list of messages received.
+ messages (list[dict]): a list of messages received from other agents.
+ The messages are dictionaries that are JSON-serializable and
+ follows the OpenAI's ChatCompletion schema.
sender: sender of an Agent instance.
+
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
async def a_generate_reply(
self,
- messages: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
- **kwargs,
- ) -> Union[str, Dict, None]:
- """(Abstract async method) Generate a reply based on the received messages.
+ **kwargs: Any,
+ ) -> Union[str, Dict[str, Any], None]:
+ """(Async) Generate a reply based on the received messages.
Args:
- messages (list[dict]): a list of messages received.
+ messages (list[dict]): a list of messages received from other agents.
+ The messages are dictionaries that are JSON-serializable and
+ follows the OpenAI's ChatCompletion schema.
sender: sender of an Agent instance.
+
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
+
+
+@runtime_checkable
+class LLMAgent(Agent, Protocol):
+ """(In preview) A protocol for an LLM agent."""
+
+ @property
+ def system_message(self) -> str:
+ """The system message of this agent."""
+
+ def update_system_message(self, system_message: str) -> None:
+ """Update this agent's system message.
+
+ Args:
+ system_message (str): system message for inference.
+ """
diff --git a/autogen/agentchat/assistant_agent.py b/autogen/agentchat/assistant_agent.py
index bdec0fef665..c1601ea9ba8 100644
--- a/autogen/agentchat/assistant_agent.py
+++ b/autogen/agentchat/assistant_agent.py
@@ -1,5 +1,7 @@
from typing import Callable, Dict, Literal, Optional, Union
+from autogen.runtime_logging import log_new_agent, logging_enabled
+
from .conversable_agent import ConversableAgent
@@ -36,7 +38,7 @@ def __init__(
llm_config: Optional[Union[Dict, Literal[False]]] = None,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "NEVER",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
**kwargs,
):
@@ -45,7 +47,7 @@ def __init__(
name (str): agent name.
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
- llm_config (dict): llm inference configuration.
+ llm_config (dict or False or None): llm inference configuration.
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
is_termination_msg (function): a function that takes a message in the form of a dictionary
@@ -67,6 +69,8 @@ def __init__(
description=description,
**kwargs,
)
+ if logging_enabled():
+ log_new_agent(self, locals())
# Update the provided description if None, and we are using the default system_message,
# then use the default description.
diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py
new file mode 100644
index 00000000000..dd489c03625
--- /dev/null
+++ b/autogen/agentchat/chat.py
@@ -0,0 +1,284 @@
+import asyncio
+import datetime
+import logging
+import warnings
+from collections import abc, defaultdict
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, List, Set, Tuple
+
+from ..formatting_utils import colored
+from ..io.base import IOStream
+from .utils import consolidate_chat_info
+
+logger = logging.getLogger(__name__)
+Prerequisite = Tuple[int, int]
+
+
+@dataclass
+class ChatResult:
+ """(Experimental) The result of a chat. Almost certain to be changed."""
+
+ chat_id: int = None
+ """chat id"""
+ chat_history: List[Dict[str, any]] = None
+ """The chat history."""
+ summary: str = None
+ """A summary obtained from the chat."""
+ cost: Dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference"
+ """The cost of the chat.
+ The value for each usage type is a dictionary containing cost information for that specific type.
+ - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
+ - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
+ """
+ human_input: List[str] = None
+ """A list of human input solicited during the chat."""
+
+
+def _validate_recipients(chat_queue: List[Dict[str, Any]]) -> None:
+ """
+ Validate recipients exits and warn repetitive recipients.
+ """
+ receipts_set = set()
+ for chat_info in chat_queue:
+ assert "recipient" in chat_info, "recipient must be provided."
+ receipts_set.add(chat_info["recipient"])
+ if len(receipts_set) < len(chat_queue):
+ warnings.warn(
+ "Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
+ UserWarning,
+ )
+
+
+def __create_async_prerequisites(chat_queue: List[Dict[str, Any]]) -> List[Prerequisite]:
+ """
+ Create list of Prerequisite (prerequisite_chat_id, chat_id)
+ """
+ prerequisites = []
+ for chat_info in chat_queue:
+ if "chat_id" not in chat_info:
+ raise ValueError("Each chat must have a unique id for async multi-chat execution.")
+ chat_id = chat_info["chat_id"]
+ pre_chats = chat_info.get("prerequisites", [])
+ for pre_chat_id in pre_chats:
+ if not isinstance(pre_chat_id, int):
+ raise ValueError("Prerequisite chat id is not int.")
+ prerequisites.append((chat_id, pre_chat_id))
+ return prerequisites
+
+
+def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite]) -> List[int]:
+ """Find chat order for async execution based on the prerequisite chats
+
+ args:
+ num_chats: number of chats
+ prerequisites: List of Prerequisite (prerequisite_chat_id, chat_id)
+
+ returns:
+ list: a list of chat_id in order.
+ """
+ edges = defaultdict(set)
+ indegree = defaultdict(int)
+ for pair in prerequisites:
+ chat, pre = pair[0], pair[1]
+ if chat not in edges[pre]:
+ indegree[chat] += 1
+ edges[pre].add(chat)
+ bfs = [i for i in chat_ids if i not in indegree]
+ chat_order = []
+ steps = len(indegree)
+ for _ in range(steps + 1):
+ if not bfs:
+ break
+ chat_order.extend(bfs)
+ nxt = []
+ for node in bfs:
+ if node in edges:
+ for course in edges[node]:
+ indegree[course] -= 1
+ if indegree[course] == 0:
+ nxt.append(course)
+ indegree.pop(course)
+ edges.pop(node)
+ bfs = nxt
+
+ if indegree:
+ return []
+ return chat_order
+
+
+def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
+ iostream = IOStream.get_default()
+
+ if "message" not in chat_info:
+ warnings.warn(
+ "message is not provided in a chat_queue entry. input() will be called to get the initial message.",
+ UserWarning,
+ )
+ print_carryover = (
+ ("\n").join([t for t in chat_info["carryover"]])
+ if isinstance(chat_info["carryover"], list)
+ else chat_info["carryover"]
+ )
+ message = chat_info.get("message")
+ if isinstance(message, str):
+ print_message = message
+ elif callable(message):
+ print_message = "Callable: " + message.__name__
+ elif isinstance(message, dict):
+ print_message = "Dict: " + str(message)
+ elif message is None:
+ print_message = "None"
+ iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
+ iostream.print(
+ colored(
+ "Starting a new chat....",
+ "blue",
+ ),
+ flush=True,
+ )
+ if chat_info.get("verbose", False):
+ iostream.print(colored("Message:\n" + print_message, "blue"), flush=True)
+ iostream.print(colored("Carryover:\n" + print_carryover, "blue"), flush=True)
+ iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
+
+
+def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
+ """Initiate a list of chats.
+ Args:
+ chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.
+
+ Each dictionary should contain the input arguments for
+ [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat).
+ For example:
+ - `"sender"` - the sender agent.
+ - `"recipient"` - the recipient agent.
+ - `"clear_history" (bool) - whether to clear the chat history with the agent.
+ Default is True.
+ - `"silent"` (bool or None) - (Experimental) whether to print the messages in this
+ conversation. Default is False.
+ - `"cache"` (Cache or None) - the cache client to use for this conversation.
+ Default is None.
+ - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
+ will continue until a termination condition is met. Default is None.
+ - `"summary_method"` (str or callable) - a string or callable specifying the method to get
+ a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
+ - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
+ Default is {}.
+ - `"message"` (str, callable or None) - if None, input() will be called to get the
+ initial message.
+ - `**context` - additional context information to be passed to the chat.
+ - `"carryover"` - It can be used to specify the carryover information to be passed
+ to this chat. If provided, we will combine this carryover with the "message" content when
+ generating the initial chat message in `generate_init_message`.
+ - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
+ from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
+ then summary from all the finished chats will be taken.
+ Returns:
+ (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
+ """
+
+ consolidate_chat_info(chat_queue)
+ _validate_recipients(chat_queue)
+ current_chat_queue = chat_queue.copy()
+ finished_chats = []
+ while current_chat_queue:
+ chat_info = current_chat_queue.pop(0)
+ _chat_carryover = chat_info.get("carryover", [])
+ finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
+ "finished_chat_indexes_to_exclude_from_carryover", []
+ )
+
+ if isinstance(_chat_carryover, str):
+ _chat_carryover = [_chat_carryover]
+ chat_info["carryover"] = _chat_carryover + [
+ r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
+ ]
+
+ if not chat_info.get("silent", False):
+ __post_carryover_processing(chat_info)
+
+ sender = chat_info["sender"]
+ chat_res = sender.initiate_chat(**chat_info)
+ finished_chats.append(chat_res)
+ return finished_chats
+
+
+def __system_now_str():
+ ct = datetime.datetime.now()
+ return f" System time at {ct}. "
+
+
+def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
+ """
+ Update ChatResult when async Task for Chat is completed.
+ """
+ logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
+ chat_result = chat_future.result()
+ chat_result.chat_id = chat_id
+
+
+async def _dependent_chat_future(
+ chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future]
+) -> asyncio.Task:
+ """
+ Create an async Task for each chat.
+ """
+ logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
+ _chat_carryover = chat_info.get("carryover", [])
+ finished_chats = dict()
+ for chat in prerequisite_chat_futures:
+ chat_future = prerequisite_chat_futures[chat]
+ if chat_future.cancelled():
+ raise RuntimeError(f"Chat {chat} is cancelled.")
+
+ # wait for prerequisite chat results for the new chat carryover
+ finished_chats[chat] = await chat_future
+
+ if isinstance(_chat_carryover, str):
+ _chat_carryover = [_chat_carryover]
+ chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
+
+ if not chat_info.get("silent", False):
+ __post_carryover_processing(chat_info)
+
+ sender = chat_info["sender"]
+ chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
+ call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
+ chat_res_future.add_done_callback(call_back_with_args)
+ logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
+ return chat_res_future
+
+
+async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
+ """(async) Initiate a list of chats.
+
+ args:
+ - Please refer to `initiate_chats`.
+
+
+ returns:
+ - (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
+ """
+ consolidate_chat_info(chat_queue)
+ _validate_recipients(chat_queue)
+ chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
+ num_chats = chat_book.keys()
+ prerequisites = __create_async_prerequisites(chat_queue)
+ chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
+ finished_chat_futures = dict()
+ for chat_id in chat_order_by_id:
+ chat_info = chat_book[chat_id]
+ prerequisite_chat_ids = chat_info.get("prerequisites", [])
+ pre_chat_futures = dict()
+ for pre_chat_id in prerequisite_chat_ids:
+ pre_chat_future = finished_chat_futures[pre_chat_id]
+ pre_chat_futures[pre_chat_id] = pre_chat_future
+ current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
+ finished_chat_futures[chat_id] = current_chat_future
+ await asyncio.gather(*list(finished_chat_futures.values()))
+ finished_chats = dict()
+ for chat in finished_chat_futures:
+ chat_result = finished_chat_futures[chat].result()
+ finished_chats[chat] = chat_result
+ return finished_chats
diff --git a/autogen/agentchat/contrib/agent_builder.py b/autogen/agentchat/contrib/agent_builder.py
index 7a3850d79ae..c9a2d79607d 100644
--- a/autogen/agentchat/contrib/agent_builder.py
+++ b/autogen/agentchat/contrib/agent_builder.py
@@ -1,10 +1,19 @@
-import autogen
-import time
-import subprocess as sp
-import socket
-import json
import hashlib
-from typing import Optional, List, Dict, Tuple
+import importlib
+import json
+import logging
+import re
+import socket
+import subprocess as sp
+import time
+from typing import Dict, List, Optional, Tuple, Union
+
+import requests
+from termcolor import colored
+
+import autogen
+
+logger = logging.getLogger(__name__)
def _config_check(config: Dict):
@@ -15,113 +24,162 @@ def _config_check(config: Dict):
for agent_config in config["agent_configs"]:
assert agent_config.get("name", None) is not None, 'Missing agent "name" in your agent_configs.'
- assert agent_config.get("model", None) is not None, 'Missing agent "model" in your agent_configs.'
assert (
agent_config.get("system_message", None) is not None
), 'Missing agent "system_message" in your agent_configs.'
assert agent_config.get("description", None) is not None, 'Missing agent "description" in your agent_configs.'
+def _retrieve_json(text):
+ match = re.findall(autogen.code_utils.CODE_BLOCK_PATTERN, text, flags=re.DOTALL)
+ if not match:
+ return text
+ code_blocks = []
+ for _, code in match:
+ code_blocks.append(code)
+ return code_blocks[0]
+
+
class AgentBuilder:
"""
AgentBuilder can help user build an automatic task solving process powered by multi-agent system.
Specifically, our building pipeline includes initialize and build.
- In build(), we prompt a LLM to create multiple participant agents, and specify whether this task need programming to solve.
- User can save the built agents' config by calling save(), and load the saved configs by load(), which can skip the
- building process.
"""
online_server_name = "online"
+ DEFAULT_PROXY_AUTO_REPLY = 'There is no code from the last 1 message for me to execute. Group chat manager should let other participants to continue the conversation. If the group chat manager want to end the conversation, you should let other participant reply me only with "TERMINATE"'
+
+ GROUP_CHAT_DESCRIPTION = """ # Group chat instruction
+You are now working in a group chat with different expert and a group chat manager.
+You should refer to the previous message from other participant members or yourself, follow their topic and reply to them.
+
+**Your role is**: {name}
+Group chat members: {members}{user_proxy_desc}
+
+When the task is complete and the result has been carefully verified, after obtaining agreement from the other members, you can end the conversation by replying only with "TERMINATE".
+
+# Your profile
+{sys_msg}
+"""
+
+ DEFAULT_DESCRIPTION = """## Your role
+[Complete this part with expert's name and skill description]
+
+## Task and skill instructions
+- [Complete this part with task description]
+- [Complete this part with skill description]
+- [(Optional) Complete this part with other information]
+"""
+
+ CODING_AND_TASK_SKILL_INSTRUCTION = """## Useful instructions for task-solving
+- Solve the task step by step if you need to.
+- When you find an answer, verify the answer carefully. Include verifiable evidence with possible test case in your response if possible.
+- All your reply should be based on the provided facts.
+
+## How to verify?
+**You have to keep believing that everyone else's answers are wrong until they provide clear enough evidence.**
+- Verifying with step-by-step backward reasoning.
+- Write test cases according to the general task.
+
+## How to use code?
+- Suggest python code (in a python coding block) or shell script (in a sh coding block) for the Computer_terminal to execute.
+- If missing python packages, you can install the package by suggesting a `pip install` code in the ```sh ... ``` block.
+- When using code, you must indicate the script type in the coding block.
+- Do not the coding block which requires users to modify.
+- Do not suggest a coding block if it's not intended to be executed by the Computer_terminal.
+- The Computer_terminal cannot modify your code.
+- **Use 'print' function for the output when relevant**.
+- Check the execution result returned by the Computer_terminal.
+- Do not ask Computer_terminal to copy and paste the result.
+- If the result indicates there is an error, fix the error and output the code again. """
+
CODING_PROMPT = """Does the following task need programming (i.e., access external API or tool by coding) to solve,
- or coding may help the following task become easier?
+or coding may help the following task become easier?
- TASK: {task}
+TASK: {task}
- Hint:
- # Answer only YES or NO.
- """
+Answer only YES or NO.
+"""
- AGENT_NAME_PROMPT = """To complete the following task, what positions/jobs should be set to maximize efficiency?
-
- TASK: {task}
-
- Hint:
- # Considering the effort, the position in this task should be no more than {max_agents}; less is better.
- # These positions' name should include enough information that can help a group chat manager know when to let this position speak.
- # The position name should be as specific as possible. For example, use "python_programmer" instead of "programmer".
- # Do not use ambiguous position name, such as "domain expert" with no specific description of domain or "technical writer" with no description of what it should write.
- # Each position should have a unique function and the position name should reflect this.
- # The positions should relate to the task and significantly different in function.
- # Add ONLY ONE programming related position if the task needs coding.
- # Generated agent's name should follow the format of ^[a-zA-Z0-9_-]{{1,64}}$, use "_" to split words.
- # Answer the names of those positions/jobs, separated names by commas.
- # Only return the list of positions.
- """
+ AGENT_NAME_PROMPT = """# Your task
+Suggest no more then {max_agents} experts with their name according to the following user requirement.
- AGENT_SYS_MSG_PROMPT = """Considering the following position and task:
+## User requirement
+{task}
- TASK: {task}
- POSITION: {position}
+# Task requirement
+- Expert's name should follow the format: [skill]_Expert.
+- Only reply the names of the experts, separated by ",".
+For example: Python_Expert, Math_Expert, ... """
- Modify the following position requirement, making it more suitable for the above task and position:
+ AGENT_SYS_MSG_PROMPT = """# Your goal
+- According to the task and expert name, write a high-quality description for the expert by filling the given template.
+- Ensure that your description are clear and unambiguous, and include all necessary information.
- REQUIREMENT: {default_sys_msg}
+# Task
+{task}
- Hint:
- # Your answer should be natural, starting from "You are now in a group chat. You need to complete a task with other participants. As a ...".
- # [IMPORTANT] You should let them reply "TERMINATE" when they think the task is completed (the user's need has actually been satisfied).
- # The modified requirement should not contain the code interpreter skill.
- # You should remove the related skill description when the position is not a programmer or developer.
- # Coding skill is limited to Python.
- # Your answer should omit the word "REQUIREMENT".
- # People with the above position can doubt previous messages or code in the group chat (for example, if there is no
-output after executing the code) and provide a corrected answer or code.
- # People in the above position should ask for help from the group chat manager when confused and let the manager select another participant.
- """
+# Expert name
+{position}
- AGENT_DESCRIPTION_PROMPT = """Considering the following position:
+# Template
+{default_sys_msg}
+"""
- POSITION: {position}
+ AGENT_DESCRIPTION_PROMPT = """# Your goal
+Summarize the following expert's description in a sentence.
- What requirements should this position be satisfied?
+# Expert name
+{position}
- Hint:
- # This description should include enough information that can help a group chat manager know when to let this position speak.
- # People with the above position can doubt previous messages or code in the group chat (for example, if there is no
-output after executing the code) and provide a corrected answer or code.
- # Your answer should be in at most three sentences.
- # Your answer should be natural, starting from "[POSITION's name] is a ...".
- # Your answer should include the skills that this position should have.
- # Your answer should not contain coding-related skills when the position is not a programmer or developer.
- # Coding skills should be limited to Python.
- """
+# Expert's description
+{sys_msg}
+"""
- AGENT_SEARCHING_PROMPT = """Considering the following task:
+ AGENT_SEARCHING_PROMPT = """# Your goal
+Considering the following task, what experts should be involved to the task?
- TASK: {task}
+# TASK
+{task}
- What following agents should be involved to the task?
+# EXPERT LIST
+{agent_list}
- AGENT LIST:
- {agent_list}
+# Requirement
+- You should consider if the experts' name and profile match the task.
+- Considering the effort, you should select less then {max_agents} experts; less is better.
+- Separate expert names by commas and use "_" instead of space. For example, Product_manager,Programmer
+- Only return the list of expert names.
+"""
- Hint:
- # You should consider if the agent's name and profile match the task.
- # Considering the effort, you should select less then {max_agents} agents; less is better.
- # Separate agent names by commas and use "_" instead of space. For example, Product_manager,Programmer
- # Only return the list of agent names.
- """
+ AGENT_SELECTION_PROMPT = """# Your goal
+Match roles in the role set to each expert in expert set.
+
+# Skill set
+{skills}
+
+# Expert pool (formatting with name: description)
+{expert_pool}
+
+# Answer format
+```json
+{{
+ "skill_1 description": "expert_name: expert_description", // if there exists an expert that suitable for skill_1
+ "skill_2 description": "None", // if there is no experts that suitable for skill_2
+ ...
+}}
+```
+"""
def __init__(
self,
config_file_or_env: Optional[str] = "OAI_CONFIG_LIST",
config_file_location: Optional[str] = "",
- builder_model: Optional[str] = "gpt-4",
- agent_model: Optional[str] = "gpt-4",
- host: Optional[str] = "localhost",
- endpoint_building_timeout: Optional[int] = 600,
- max_tokens: Optional[int] = 945,
+ builder_model: Optional[Union[str, list]] = [],
+ agent_model: Optional[Union[str, list]] = [],
+ builder_model_tags: Optional[list] = [],
+ agent_model_tags: Optional[list] = [],
max_agents: Optional[int] = 5,
):
"""
@@ -130,17 +188,27 @@ def __init__(
config_file_or_env: path or environment of the OpenAI api configs.
builder_model: specify a model as the backbone of build manager.
agent_model: specify a model as the backbone of participant agents.
- host: endpoint host.
endpoint_building_timeout: timeout for building up an endpoint server.
- max_tokens: max tokens for each agent.
max_agents: max agents for each task.
"""
- self.host = host
- self.builder_model = builder_model
- self.agent_model = agent_model
+ builder_model = builder_model if isinstance(builder_model, list) else [builder_model]
+ builder_filter_dict = {}
+ if len(builder_model) != 0:
+ builder_filter_dict.update({"model": builder_model})
+ if len(builder_model_tags) != 0:
+ builder_filter_dict.update({"tags": builder_model_tags})
+ builder_config_list = autogen.config_list_from_json(config_file_or_env, filter_dict=builder_filter_dict)
+ if len(builder_config_list) == 0:
+ raise RuntimeError(
+ f"Fail to initialize build manager: {builder_model}{builder_model_tags} does not exist in {config_file_or_env}. "
+ f'If you want to change this model, please specify the "builder_model" in the constructor.'
+ )
+ self.builder_model = autogen.OpenAIWrapper(config_list=builder_config_list)
+
+ self.agent_model = agent_model if isinstance(agent_model, list) else [agent_model]
+ self.agent_model_tags = agent_model_tags
self.config_file_or_env = config_file_or_env
self.config_file_location = config_file_location
- self.endpoint_building_timeout = endpoint_building_timeout
self.building_task: str = None
self.agent_configs: List[Dict] = []
@@ -149,40 +217,20 @@ def __init__(
self.agent_procs_assign: Dict[str, Tuple[autogen.ConversableAgent, str]] = {}
self.cached_configs: Dict = {}
- self.max_tokens = max_tokens
self.max_agents = max_agents
- for port in range(8000, 65535):
- if self._is_port_open(host, port):
- self.open_ports.append(str(port))
-
def set_builder_model(self, model: str):
self.builder_model = model
def set_agent_model(self, model: str):
self.agent_model = model
- @staticmethod
- def _is_port_open(host, port):
- """Check if a tcp port is open."""
- try:
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- s.settimeout(10)
- s.bind((host, int(port)))
- s.close()
- return True
- except OSError:
- return False
-
def _create_agent(
self,
- agent_name: str,
- model_name_or_hf_repo: str,
+ agent_config: Dict,
+ member_name: List[str],
llm_config: dict,
- system_message: Optional[str] = autogen.AssistantAgent.DEFAULT_SYSTEM_MESSAGE,
- description: Optional[str] = autogen.AssistantAgent.DEFAULT_DESCRIPTION,
use_oai_assistant: Optional[bool] = False,
- world_size: Optional[int] = 1,
) -> autogen.AssistantAgent:
"""
Create a group chat participant agent.
@@ -191,100 +239,46 @@ def _create_agent(
The API address of that endpoint will be "localhost:{free port}".
Args:
- agent_name: the name that identify the function of the agent (e.g., Coder, Product Manager,...)
- model_name_or_hf_repo: the name of the model or the huggingface repo.
+ agent_config: agent's config. It should include the following information:
+ 1. model_name: backbone model of an agent, e.g., gpt-4-1106-preview, meta/Llama-2-70b-chat
+ 2. agent_name: use to identify an agent in the group chat.
+ 3. system_message: including persona, task solving instruction, etc.
+ 4. description: brief description of an agent that help group chat manager to pick the speaker.
llm_config: specific configs for LLM (e.g., config_list, seed, temperature, ...).
- system_message: system prompt use to format an agent's behavior.
- description: a brief description of the agent. This will improve the group chat performance.
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
world_size: the max size of parallel tensors (in most of the cases, this is identical to the amount of GPUs).
Returns:
agent: a set-up agent.
"""
- from huggingface_hub import HfApi
- from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
-
+ model_name_or_hf_repo = agent_config.get("model", [])
+ model_name_or_hf_repo = (
+ model_name_or_hf_repo if isinstance(model_name_or_hf_repo, list) else [model_name_or_hf_repo]
+ )
+ model_tags = agent_config.get("tags", [])
+ agent_name = agent_config["name"]
+ system_message = agent_config["system_message"]
+ description = agent_config["description"]
+
+ # Path to the customize **ConversableAgent** class.
+ model_path = agent_config.get("model_path", None)
+ filter_dict = {}
+ if len(model_name_or_hf_repo) > 0:
+ filter_dict.update({"model": model_name_or_hf_repo})
+ if len(model_tags) > 0:
+ filter_dict.update({"tags": model_tags})
config_list = autogen.config_list_from_json(
- self.config_file_or_env,
- file_location=self.config_file_location,
- filter_dict={"model": [model_name_or_hf_repo]},
+ self.config_file_or_env, file_location=self.config_file_location, filter_dict=filter_dict
)
if len(config_list) == 0:
raise RuntimeError(
- f"Fail to initialize agent {agent_name}: {model_name_or_hf_repo} does not exist in {self.config_file_or_env}.\n"
+ f"Fail to initialize agent {agent_name}: {model_name_or_hf_repo}{model_tags} does not exist in {self.config_file_or_env}.\n"
f'If you would like to change this model, please specify the "agent_model" in the constructor.\n'
f"If you load configs from json, make sure the model in agent_configs is in the {self.config_file_or_env}."
)
- try:
- hf_api = HfApi()
- hf_api.model_info(model_name_or_hf_repo)
- model_name = model_name_or_hf_repo.split("/")[-1]
- server_id = f"{model_name}_{self.host}"
- except GatedRepoError as e:
- raise e
- except RepositoryNotFoundError:
- server_id = self.online_server_name
-
- if server_id != self.online_server_name:
- # The code in this block is uncovered by tests because online environment does not support gpu use.
- if self.agent_procs.get(server_id, None) is None:
- while True:
- port = self.open_ports.pop()
- if self._is_port_open(self.host, port):
- break
-
- # Use vLLM to set up a server with OpenAI API support.
- agent_proc = sp.Popen(
- [
- "python",
- "-m",
- "vllm.entrypoints.openai.api_server",
- "--host",
- f"{self.host}",
- "--port",
- f"{port}",
- "--model",
- f"{model_name_or_hf_repo}",
- "--tensor-parallel-size",
- f"{world_size}",
- ],
- stdout=sp.PIPE,
- stderr=sp.STDOUT,
- )
- timeout_start = time.time()
-
- while True:
- server_stdout = agent_proc.stdout.readline()
- if server_stdout != b"":
- print(server_stdout)
- timeout_end = time.time()
- if b"running" in server_stdout:
- print(
- f"Running {model_name_or_hf_repo} on http://{self.host}:{port} "
- f"with tensor parallel size {world_size}."
- )
- break
- elif b"address already in use" in server_stdout:
- raise RuntimeError(
- f"{self.host}:{port} already in use. Fail to set up the endpoint for "
- f"{model_name_or_hf_repo} on {self.host}:{port}."
- )
- elif timeout_end - timeout_start > self.endpoint_building_timeout:
- raise RuntimeError(
- f"Timeout exceed. Fail to set up the endpoint for "
- f"{model_name_or_hf_repo} on {self.host}:{port}."
- )
- self.agent_procs[server_id] = (agent_proc, port)
- else:
- port = self.agent_procs[server_id][1]
-
- config_list[0]["base_url"] = f"http://{self.host}:{port}/v1"
-
+ server_id = self.online_server_name
current_config = llm_config.copy()
- current_config.update(
- {"config_list": config_list, "model": model_name_or_hf_repo, "max_tokens": self.max_tokens}
- )
+ current_config.update({"config_list": config_list})
if use_oai_assistant:
from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
@@ -295,12 +289,38 @@ def _create_agent(
overwrite_instructions=False,
)
else:
- agent = autogen.AssistantAgent(
- name=agent_name,
- llm_config=current_config.copy(),
- system_message=system_message,
- description=description,
+ user_proxy_desc = ""
+ if self.cached_configs["coding"] is True:
+ user_proxy_desc = (
+ "\nThe group also include a Computer_terminal to help you run the python and shell code."
+ )
+
+ model_class = autogen.AssistantAgent
+ if model_path:
+ module_path, model_class_name = model_path.replace("/", ".").rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ model_class = getattr(module, model_class_name)
+ if not issubclass(model_class, autogen.ConversableAgent):
+ logger.error(f"{model_class} is not a ConversableAgent. Use AssistantAgent as default")
+ model_class = autogen.AssistantAgent
+
+ additional_config = {
+ k: v
+ for k, v in agent_config.items()
+ if k not in ["model", "name", "system_message", "description", "model_path", "tags"]
+ }
+ agent = model_class(
+ name=agent_name, llm_config=current_config.copy(), description=description, **additional_config
)
+ if system_message == "":
+ system_message = agent.system_message
+ else:
+ system_message = f"{system_message}\n\n{self.CODING_AND_TASK_SKILL_INSTRUCTION}"
+
+ enhanced_sys_msg = self.GROUP_CHAT_DESCRIPTION.format(
+ name=agent_name, members=member_name, user_proxy_desc=user_proxy_desc, sys_msg=system_message
+ )
+ agent.update_system_message(enhanced_sys_msg)
self.agent_procs_assign[agent_name] = (agent, server_id)
return agent
@@ -324,7 +344,7 @@ def clear_agent(self, agent_name: str, recycle_endpoint: Optional[bool] = True):
return
self.agent_procs[server_id][0].terminate()
self.open_ports.append(server_id.split("_")[-1])
- print(f"Agent {agent_name} has been cleared.")
+ print(colored(f"Agent {agent_name} has been cleared.", "yellow"), flush=True)
def clear_all_agents(self, recycle_endpoint: Optional[bool] = True):
"""
@@ -332,7 +352,7 @@ def clear_all_agents(self, recycle_endpoint: Optional[bool] = True):
"""
for agent_name in [agent_name for agent_name in self.agent_procs_assign.keys()]:
self.clear_agent(agent_name, recycle_endpoint)
- print("All agents have been cleared.")
+ print(colored("All agents have been cleared.", "yellow"), flush=True)
def build(
self,
@@ -341,6 +361,8 @@ def build(
coding: Optional[bool] = None,
code_execution_config: Optional[Dict] = None,
use_oai_assistant: Optional[bool] = False,
+ user_proxy: Optional[autogen.ConversableAgent] = None,
+ max_agents: Optional[int] = None,
**kwargs,
) -> Tuple[List[autogen.ConversableAgent], Dict]:
"""
@@ -352,6 +374,7 @@ def build(
code_execution_config: specific configs for user proxy (e.g., last_n_messages, work_dir, ...).
default_llm_config: specific configs for LLM (e.g., config_list, seed, temperature, ...).
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
+ user_proxy: user proxy's class that can be used to replace the default user proxy.
Returns:
agent_list: a list of agents.
@@ -359,34 +382,25 @@ def build(
"""
if code_execution_config is None:
code_execution_config = {
- "last_n_messages": 2,
+ "last_n_messages": 1,
"work_dir": "groupchat",
"use_docker": False,
- "timeout": 60,
+ "timeout": 10,
}
+ if max_agents is None:
+ max_agents = self.max_agents
+
agent_configs = []
self.building_task = building_task
- config_list = autogen.config_list_from_json(
- self.config_file_or_env,
- file_location=self.config_file_location,
- filter_dict={"model": [self.builder_model]},
- )
- if len(config_list) == 0:
- raise RuntimeError(
- f"Fail to initialize build manager: {self.builder_model} does not exist in {self.config_file_or_env}. "
- f'If you want to change this model, please specify the "builder_model" in the constructor.'
- )
- build_manager = autogen.OpenAIWrapper(config_list=config_list)
-
- print("==> Generating agents...")
+ print(colored("==> Generating agents...", "green"), flush=True)
resp_agent_name = (
- build_manager.create(
+ self.builder_model.create(
messages=[
{
"role": "user",
- "content": self.AGENT_NAME_PROMPT.format(task=building_task, max_agents=self.max_agents),
+ "content": self.AGENT_NAME_PROMPT.format(task=building_task, max_agents=max_agents),
}
]
)
@@ -394,21 +408,21 @@ def build(
.message.content
)
agent_name_list = [agent_name.strip().replace(" ", "_") for agent_name in resp_agent_name.split(",")]
- print(f"{agent_name_list} are generated.")
+ print(f"{agent_name_list} are generated.", flush=True)
- print("==> Generating system message...")
+ print(colored("==> Generating system message...", "green"), flush=True)
agent_sys_msg_list = []
for name in agent_name_list:
- print(f"Preparing system message for {name}")
+ print(f"Preparing system message for {name}", flush=True)
resp_agent_sys_msg = (
- build_manager.create(
+ self.builder_model.create(
messages=[
{
"role": "user",
"content": self.AGENT_SYS_MSG_PROMPT.format(
task=building_task,
position=name,
- default_sys_msg=autogen.AssistantAgent.DEFAULT_SYSTEM_MESSAGE,
+ default_sys_msg=self.DEFAULT_DESCRIPTION,
),
}
]
@@ -418,16 +432,16 @@ def build(
)
agent_sys_msg_list.append(resp_agent_sys_msg)
- print("==> Generating description...")
+ print(colored("==> Generating description...", "green"), flush=True)
agent_description_list = []
- for name in agent_name_list:
- print(f"Preparing description for {name}")
+ for name, sys_msg in list(zip(agent_name_list, agent_sys_msg_list)):
+ print(f"Preparing description for {name}", flush=True)
resp_agent_description = (
- build_manager.create(
+ self.builder_model.create(
messages=[
{
"role": "user",
- "content": self.AGENT_DESCRIPTION_PROMPT.format(position=name),
+ "content": self.AGENT_DESCRIPTION_PROMPT.format(position=name, sys_msg=sys_msg),
}
]
)
@@ -438,12 +452,18 @@ def build(
for name, sys_msg, description in list(zip(agent_name_list, agent_sys_msg_list, agent_description_list)):
agent_configs.append(
- {"name": name, "model": self.agent_model, "system_message": sys_msg, "description": description}
+ {
+ "name": name,
+ "model": self.agent_model,
+ "tags": self.agent_model_tags,
+ "system_message": sys_msg,
+ "description": description,
+ }
)
if coding is None:
resp = (
- build_manager.create(
+ self.builder_model.create(
messages=[{"role": "user", "content": self.CODING_PROMPT.format(task=building_task)}]
)
.choices[0]
@@ -460,18 +480,20 @@ def build(
"code_execution_config": code_execution_config,
}
)
-
- return self._build_agents(use_oai_assistant, **kwargs)
+ _config_check(self.cached_configs)
+ return self._build_agents(use_oai_assistant, user_proxy=user_proxy, **kwargs)
def build_from_library(
self,
building_task: str,
library_path_or_json: str,
default_llm_config: Dict,
- coding: Optional[bool] = True,
+ top_k: int = 3,
+ coding: Optional[bool] = None,
code_execution_config: Optional[Dict] = None,
use_oai_assistant: Optional[bool] = False,
- embedding_model: Optional[str] = None,
+ embedding_model: Optional[str] = "all-mpnet-base-v2",
+ user_proxy: Optional[autogen.ConversableAgent] = None,
**kwargs,
) -> Tuple[List[autogen.ConversableAgent], Dict]:
"""
@@ -487,81 +509,83 @@ def build_from_library(
code_execution_config: specific configs for user proxy (e.g., last_n_messages, work_dir, ...).
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
embedding_model: a Sentence-Transformers model use for embedding similarity to select agents from library.
- if None, an openai model will be prompted to select agents. As reference, chromadb use "all-mpnet-base-
- v2" as default.
+ As reference, chromadb use "all-mpnet-base-v2" as default.
+ user_proxy: user proxy's class that can be used to replace the default user proxy.
Returns:
agent_list: a list of agents.
cached_configs: cached configs.
"""
+ import sqlite3
+
+ # Some system will have an unexcepted sqlite3 version.
+ # Check if the user has installed pysqlite3.
+ if int(sqlite3.version.split(".")[0]) < 3:
+ try:
+ __import__("pysqlite3")
+ import sys
+
+ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
+ except Exception as e:
+ raise e
import chromadb
from chromadb.utils import embedding_functions
if code_execution_config is None:
code_execution_config = {
- "last_n_messages": 2,
+ "last_n_messages": 1,
"work_dir": "groupchat",
"use_docker": False,
- "timeout": 60,
+ "timeout": 120,
}
- agent_configs = []
-
- config_list = autogen.config_list_from_json(
- self.config_file_or_env,
- file_location=self.config_file_location,
- filter_dict={"model": [self.builder_model]},
- )
- if len(config_list) == 0:
- raise RuntimeError(
- f"Fail to initialize build manager: {self.builder_model} does not exist in {self.config_file_or_env}. "
- f'If you want to change this model, please specify the "builder_model" in the constructor.'
- )
- build_manager = autogen.OpenAIWrapper(config_list=config_list)
-
try:
agent_library = json.loads(library_path_or_json)
except json.decoder.JSONDecodeError:
with open(library_path_or_json, "r") as f:
agent_library = json.load(f)
+ except Exception as e:
+ raise e
- print("==> Looking for suitable agents in library...")
- if embedding_model is not None:
- chroma_client = chromadb.Client()
- collection = chroma_client.create_collection(
- name="agent_list",
- embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model),
- )
- collection.add(
- documents=[agent["profile"] for agent in agent_library],
- metadatas=[{"source": "agent_profile"} for _ in range(len(agent_library))],
- ids=[f"agent_{i}" for i in range(len(agent_library))],
- )
- agent_profile_list = collection.query(query_texts=[building_task], n_results=self.max_agents)["documents"][
- 0
- ]
-
- # search name from library
- agent_name_list = []
- for profile in agent_profile_list:
- for agent in agent_library:
- if agent["profile"] == profile:
- agent_name_list.append(agent["name"])
- break
- chroma_client.delete_collection(collection.name)
- print(f"{agent_name_list} are selected.")
- else:
- agent_profiles = [
- f"No.{i + 1} AGENT's NAME: {agent['name']}\nNo.{i + 1} AGENT's PROFILE: {agent['profile']}\n\n"
- for i, agent in enumerate(agent_library)
- ]
- resp_agent_name = (
- build_manager.create(
+ print(colored("==> Looking for suitable agents in the library...", "green"), flush=True)
+ skills = building_task.replace(":", " ").split("\n")
+ # skills = [line.split("-", 1)[1].strip() if line.startswith("-") else line for line in lines]
+ if len(skills) == 0:
+ skills = [building_task]
+
+ chroma_client = chromadb.Client()
+ collection = chroma_client.create_collection(
+ name="agent_list",
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model),
+ )
+ collection.add(
+ documents=[agent["description"] for agent in agent_library],
+ metadatas=[{"source": "agent_profile"} for _ in range(len(agent_library))],
+ ids=[f"agent_{i}" for i in range(len(agent_library))],
+ )
+ agent_desc_list = set()
+ for skill in skills:
+ recall = set(collection.query(query_texts=[skill], n_results=top_k)["documents"][0])
+ agent_desc_list = agent_desc_list.union(recall)
+
+ agent_config_list = []
+ for description in list(agent_desc_list):
+ for agent in agent_library:
+ if agent["description"] == description:
+ agent_config_list.append(agent.copy())
+ break
+ chroma_client.delete_collection(collection.name)
+
+ # double recall from the searching result
+ expert_pool = [f"{agent['name']}: {agent['description']}" for agent in agent_config_list]
+ while True:
+ skill_agent_pair_json = (
+ self.builder_model.create(
messages=[
{
"role": "user",
- "content": self.AGENT_SEARCHING_PROMPT.format(
- task=building_task, agent_list="".join(agent_profiles), max_agents=self.max_agents
+ "content": self.AGENT_SELECTION_PROMPT.format(
+ skills=building_task, expert_pool=expert_pool, max_agents=self.max_agents
),
}
]
@@ -569,48 +593,45 @@ def build_from_library(
.choices[0]
.message.content
)
- agent_name_list = [agent_name.strip().replace(" ", "_") for agent_name in resp_agent_name.split(",")]
-
- # search profile from library
- agent_profile_list = []
- for name in agent_name_list:
- for agent in agent_library:
- if agent["name"] == name:
- agent_profile_list.append(agent["profile"])
- break
- print(f"{agent_name_list} are selected.")
-
- print("==> Generating system message...")
- # generate system message from profile
- agent_sys_msg_list = []
- for name, profile in list(zip(agent_name_list, agent_profile_list)):
- print(f"Preparing system message for {name}...")
- resp_agent_sys_msg = (
- build_manager.create(
- messages=[
- {
- "role": "user",
- "content": self.AGENT_SYS_MSG_PROMPT.format(
- task=building_task,
- position=f"{name}\nPOSITION PROFILE: {profile}",
- default_sys_msg=autogen.AssistantAgent.DEFAULT_SYSTEM_MESSAGE,
- ),
- }
- ]
+ try:
+ skill_agent_pair_json = _retrieve_json(skill_agent_pair_json)
+ skill_agent_pair = json.loads(skill_agent_pair_json)
+ break
+ except Exception as e:
+ print(e, flush=True)
+ time.sleep(5)
+ continue
+
+ recalled_agent_config_list = []
+ recalled_name_desc = []
+ for skill, agent_profile in skill_agent_pair.items():
+ # If no suitable agent, generate an agent
+ if agent_profile == "None":
+ _, agent_config_temp = self.build(
+ building_task=skill,
+ default_llm_config=default_llm_config.copy(),
+ coding=False,
+ use_oai_assistant=use_oai_assistant,
+ max_agents=1,
)
- .choices[0]
- .message.content
- )
- agent_sys_msg_list.append(resp_agent_sys_msg)
-
- for name, sys_msg, description in list(zip(agent_name_list, agent_sys_msg_list, agent_profile_list)):
- agent_configs.append(
- {"name": name, "model": self.agent_model, "system_message": sys_msg, "description": description}
- )
+ self.clear_agent(agent_config_temp["agent_configs"][0]["name"])
+ recalled_agent_config_list.append(agent_config_temp["agent_configs"][0])
+ else:
+ if agent_profile in recalled_name_desc:
+ # prevent identical agents
+ continue
+ recalled_name_desc.append(agent_profile)
+ name = agent_profile.split(":")[0].strip()
+ desc = agent_profile.split(":")[1].strip()
+ for agent in agent_config_list:
+ if name == agent["name"] and desc == agent["description"]:
+ recalled_agent_config_list.append(agent.copy())
+
+ print(f"{[agent['name'] for agent in recalled_agent_config_list]} are selected.", flush=True)
if coding is None:
resp = (
- build_manager.create(
+ self.builder_model.create(
messages=[{"role": "user", "content": self.CODING_PROMPT.format(task=building_task)}]
)
.choices[0]
@@ -621,23 +642,25 @@ def build_from_library(
self.cached_configs.update(
{
"building_task": building_task,
- "agent_configs": agent_configs,
+ "agent_configs": recalled_agent_config_list,
"coding": coding,
"default_llm_config": default_llm_config,
"code_execution_config": code_execution_config,
}
)
+ _config_check(self.cached_configs)
- return self._build_agents(use_oai_assistant, **kwargs)
+ return self._build_agents(use_oai_assistant, user_proxy=user_proxy, **kwargs)
def _build_agents(
- self, use_oai_assistant: Optional[bool] = False, **kwargs
+ self, use_oai_assistant: Optional[bool] = False, user_proxy: Optional[autogen.ConversableAgent] = None, **kwargs
) -> Tuple[List[autogen.ConversableAgent], Dict]:
"""
Build agents with generated configs.
Args:
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
+ user_proxy: user proxy's class that can be used to replace the default user proxy.
Returns:
agent_list: a list of agents.
@@ -648,37 +671,29 @@ def _build_agents(
coding = self.cached_configs["coding"]
code_execution_config = self.cached_configs["code_execution_config"]
- print("==> Creating agents...")
+ print(colored("==> Creating agents...", "green"), flush=True)
for config in agent_configs:
- print(f"Creating agent {config['name']} with backbone {config['model']}...")
+ print(f"Creating agent {config['name']}...", flush=True)
self._create_agent(
- config["name"],
- config["model"],
- default_llm_config,
- system_message=config["system_message"],
- description=config["description"],
+ agent_config=config.copy(),
+ member_name=[agent["name"] for agent in agent_configs],
+ llm_config=default_llm_config,
use_oai_assistant=use_oai_assistant,
**kwargs,
)
agent_list = [agent_config[0] for agent_config in self.agent_procs_assign.values()]
if coding is True:
- print("Adding user console proxy...")
- agent_list = (
- [
- autogen.UserProxyAgent(
- name="User_console_and_code_interpreter",
- is_termination_msg=lambda x: "TERMINATE" in x.get("content"),
- system_message="User console with a python code interpreter interface.",
- description="""A user console with a code interpreter interface.
-It can provide the code execution results. Select this player when other players provide some code that needs to be executed.
-DO NOT SELECT THIS PLAYER WHEN NO CODE TO EXECUTE; IT WILL NOT ANSWER ANYTHING.""",
- code_execution_config=code_execution_config,
- human_input_mode="NEVER",
- )
- ]
- + agent_list
- )
+ print("Adding user console proxy...", flush=True)
+ if user_proxy is None:
+ user_proxy = autogen.UserProxyAgent(
+ name="Computer_terminal",
+ is_termination_msg=lambda x: x == "TERMINATE" or x == "TERMINATE.",
+ code_execution_config=code_execution_config,
+ human_input_mode="NEVER",
+ default_auto_reply=self.DEFAULT_PROXY_AUTO_REPLY,
+ )
+ agent_list = agent_list + [user_proxy]
return agent_list, self.cached_configs.copy()
@@ -697,7 +712,7 @@ def save(self, filepath: Optional[str] = None) -> str:
filepath = f'./save_config_{hashlib.md5(self.building_task.encode("utf-8")).hexdigest()}.json'
with open(filepath, "w") as save_file:
json.dump(self.cached_configs, save_file, indent=4)
- print(f"Building config saved to {filepath}")
+ print(colored(f"Building config saved to {filepath}", "green"), flush=True)
return filepath
@@ -722,12 +737,12 @@ def load(
"""
# load json string.
if config_json is not None:
- print("Loading config from JSON...")
+ print(colored("Loading config from JSON...", "green"), flush=True)
cached_configs = json.loads(config_json)
# load from path.
if filepath is not None:
- print(f"Loading config from {filepath}")
+ print(colored(f"Loading config from {filepath}", "green"), flush=True)
with open(filepath) as f:
cached_configs = json.load(f)
diff --git a/autogen/agentchat/contrib/agent_eval/README.md b/autogen/agentchat/contrib/agent_eval/README.md
new file mode 100644
index 00000000000..6588a1ec611
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/README.md
@@ -0,0 +1,7 @@
+Agents for running the AgentEval pipeline.
+
+AgentEval is a process for evaluating a LLM-based system's performance on a given task.
+
+When given a task to evaluate and a few example runs, the critic and subcritic agents create evaluation criteria for evaluating a system's solution. Once the criteria has been created, the quantifier agent can evaluate subsequent task solutions based on the generated criteria.
+
+For more information see: [AgentEval Integration Roadmap](https://github.com/microsoft/autogen/issues/2162)
diff --git a/autogen/agentchat/contrib/agent_eval/agent_eval.py b/autogen/agentchat/contrib/agent_eval/agent_eval.py
new file mode 100644
index 00000000000..b48c65a66d2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/agent_eval.py
@@ -0,0 +1,101 @@
+from typing import Dict, List, Literal, Optional, Union
+
+import autogen
+from autogen.agentchat.contrib.agent_eval.criterion import Criterion
+from autogen.agentchat.contrib.agent_eval.critic_agent import CriticAgent
+from autogen.agentchat.contrib.agent_eval.quantifier_agent import QuantifierAgent
+from autogen.agentchat.contrib.agent_eval.subcritic_agent import SubCriticAgent
+from autogen.agentchat.contrib.agent_eval.task import Task
+
+
+def generate_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ task: Task = None,
+ additional_instructions: str = "",
+ max_round=2,
+ use_subcritic: bool = False,
+):
+ """
+ Creates a list of criteria for evaluating the utility of a given task.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ task (Task): The task to evaluate.
+ additional_instructions (str): Additional instructions for the criteria agent.
+ max_round (int): The maximum number of rounds to run the conversation.
+ use_subcritic (bool): Whether to use the subcritic agent to generate subcriteria.
+ Returns:
+ list: A list of Criterion objects for evaluating the utility of the given task.
+ """
+ critic = CriticAgent(
+ system_message=CriticAgent.DEFAULT_SYSTEM_MESSAGE + "\n" + additional_instructions,
+ llm_config=llm_config,
+ )
+
+ critic_user = autogen.UserProxyAgent(
+ name="critic_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ agents = [critic_user, critic]
+
+ if use_subcritic:
+ subcritic = SubCriticAgent(
+ llm_config=llm_config,
+ )
+ agents.append(subcritic)
+
+ groupchat = autogen.GroupChat(
+ agents=agents, messages=[], max_round=max_round, speaker_selection_method="round_robin"
+ )
+ critic_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
+
+ critic_user.initiate_chat(critic_manager, message=task.get_sys_message())
+ criteria = critic_user.last_message()
+ content = criteria["content"]
+ # need to strip out any extra code around the returned json
+ content = content[content.find("[") : content.rfind("]") + 1]
+ criteria = Criterion.parse_json_str(content)
+ return criteria
+
+
+def quantify_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ criteria: List[Criterion] = None,
+ task: Task = None,
+ test_case: str = "",
+ ground_truth: str = "",
+):
+ """
+ Quantifies the performance of a system using the provided criteria.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ criteria ([Criterion]): A list of criteria for evaluating the utility of a given task.
+ task (Task): The task to evaluate.
+ test_case (str): The test case to evaluate.
+ ground_truth (str): The ground truth for the test case.
+ Returns:
+ dict: A dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ """
+ quantifier = QuantifierAgent(
+ llm_config=llm_config,
+ )
+
+ quantifier_user = autogen.UserProxyAgent(
+ name="quantifier_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ quantifier_user.initiate_chat( # noqa: F841
+ quantifier,
+ message=task.get_sys_message()
+ + "Evaluation dictionary: "
+ + Criterion.write_json(criteria)
+ + "actual test case to evaluate: "
+ + test_case,
+ )
+ quantified_results = quantifier_user.last_message()
+ return {"actual_success": ground_truth, "estimated_performance": quantified_results["content"]}
diff --git a/autogen/agentchat/contrib/agent_eval/criterion.py b/autogen/agentchat/contrib/agent_eval/criterion.py
new file mode 100644
index 00000000000..5efd121ec07
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/criterion.py
@@ -0,0 +1,41 @@
+from __future__ import annotations
+
+import json
+from typing import List
+
+import pydantic_core
+from pydantic import BaseModel
+from pydantic.json import pydantic_encoder
+
+
+class Criterion(BaseModel):
+ """
+ A class that represents a criterion for agent evaluation.
+ """
+
+ name: str
+ description: str
+ accepted_values: List[str]
+ sub_criteria: List[Criterion] = list()
+
+ @staticmethod
+ def parse_json_str(criteria: str):
+ """
+ Create a list of Criterion objects from a json string.
+ Args:
+ criteria (str): Json string that represents the criteria
+ returns:
+ [Criterion]: A list of Criterion objects that represents the json criteria information.
+ """
+ return [Criterion(**crit) for crit in json.loads(criteria)]
+
+ @staticmethod
+ def write_json(criteria):
+ """
+ Create a json string from a list of Criterion objects.
+ Args:
+ criteria ([Criterion]): A list of Criterion objects.
+ Returns:
+ str: A json string that represents the list of Criterion objects.
+ """
+ return json.dumps([crit.model_dump() for crit in criteria], indent=2)
diff --git a/autogen/agentchat/contrib/agent_eval/critic_agent.py b/autogen/agentchat/contrib/agent_eval/critic_agent.py
new file mode 100644
index 00000000000..2f5e5598ba6
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/critic_agent.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class CriticAgent(ConversableAgent):
+ """
+ An agent for creating list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant. You suggest criteria for evaluating different tasks. They should be distinguishable, quantifiable and not redundant.
+ Convert the evaluation criteria into a list where each item is a criteria which consists of the following dictionary as follows
+ {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ Make sure "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels and "description" includes the criterion description.
+ Output just the criteria string you have created, no code.
+ """
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating list criteria for evaluating the utility of a given task."
+
+ def __init__(
+ self,
+ name="critic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/quantifier_agent.py b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
new file mode 100644
index 00000000000..02a8f650fab
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class QuantifierAgent(ConversableAgent):
+ """
+ An agent for quantifying the performance of a system using the provided criteria.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """"You are a helpful assistant. You quantify the output of different tasks based on the given criteria.
+ The criterion is given in a json list format where each element is a distinct criteria.
+ The each element is a dictionary as follows {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ You are going to quantify each of the crieria for a given task based on the task description.
+ Return a dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ Return only the dictionary, no code."""
+
+ DEFAULT_DESCRIPTION = "An AI agent for quantifing the performance of a system using the provided criteria."
+
+ def __init__(
+ self,
+ name="quantifier",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(name=name, system_message=system_message, description=description, **kwargs)
diff --git a/autogen/agentchat/contrib/agent_eval/subcritic_agent.py b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
new file mode 100755
index 00000000000..fa994ee7bda
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
@@ -0,0 +1,42 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class SubCriticAgent(ConversableAgent):
+ """
+ An agent for creating subcriteria from a given list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant to the critic agent. You suggest sub criteria for evaluating different tasks based on the criteria provided by the critic agent (if you feel it is needed).
+ They should be distinguishable, quantifiable, and related to the overall theme of the critic's provided criteria.
+ You operate by taking in the description of the criteria. You then create a new key called sub criteria where you provide the sub criteria for the given criteria.
+ The value of the sub_criteria is a dictionary where the keys are the subcriteria and each value is as follows {"description": sub criteria description , "accepted_values": possible accepted inputs for this key}
+ Do this for each criteria provided by the critic (removing the criteria's accepted values). "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels. "description" includes the criterion description.
+ Once you have created the sub criteria for the given criteria, you return the json (make sure to include the contents of the critic's dictionary in the final dictionary as well).
+ Make sure to return a valid json and no code"""
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating subcriteria from a given list of criteria."
+
+ def __init__(
+ self,
+ name="subcritic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/task.py b/autogen/agentchat/contrib/agent_eval/task.py
new file mode 100644
index 00000000000..9f96fbf79e2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/task.py
@@ -0,0 +1,37 @@
+import json
+
+from pydantic import BaseModel
+
+
+class Task(BaseModel):
+ """
+ Class representing a task for agent completion, includes example agent execution for criteria generation.
+ """
+
+ name: str
+ description: str
+ successful_response: str
+ failed_response: str
+
+ def get_sys_message(self):
+ return f"""Task: {self.name}.
+ Task description: {self.description}
+ Task successful example: {self.successful_response}
+ Task failed example: {self.failed_response}
+ """
+
+ @staticmethod
+ def parse_json_str(task: str):
+ """
+ Create a Task object from a json object.
+ Args:
+ json_data (dict): A dictionary that represents the task.
+ Returns:
+ Task: A Task object that represents the json task information.
+ """
+ json_data = json.loads(task)
+ name = json_data.get("name")
+ description = json_data.get("description")
+ successful_response = json_data.get("successful_response")
+ failed_response = json_data.get("failed_response")
+ return Task(name, description, successful_response, failed_response)
diff --git a/autogen/agentchat/contrib/agent_optimizer.py b/autogen/agentchat/contrib/agent_optimizer.py
new file mode 100644
index 00000000000..af264d4b65f
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_optimizer.py
@@ -0,0 +1,444 @@
+import copy
+import json
+from typing import Dict, List, Literal, Optional, Union
+
+import autogen
+from autogen.code_utils import execute_code
+
+ADD_FUNC = {
+ "type": "function",
+ "function": {
+ "name": "add_function",
+ "description": "Add a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string", "description": "The name of the function in the code implementation."},
+ "description": {"type": "string", "description": "A short description of the function."},
+ "arguments": {
+ "type": "string",
+ "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.',
+ },
+ "packages": {
+ "type": "string",
+ "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.",
+ },
+ "code": {
+ "type": "string",
+ "description": "The implementation in Python. Do not include the function declaration.",
+ },
+ },
+ "required": ["name", "description", "arguments", "packages", "code"],
+ },
+ },
+}
+
+REVISE_FUNC = {
+ "type": "function",
+ "function": {
+ "name": "revise_function",
+ "description": "Revise a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string", "description": "The name of the function in the code implementation."},
+ "description": {"type": "string", "description": "A short description of the function."},
+ "arguments": {
+ "type": "string",
+ "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.',
+ },
+ "packages": {
+ "type": "string",
+ "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.",
+ },
+ "code": {
+ "type": "string",
+ "description": "The implementation in Python. Do not include the function declaration.",
+ },
+ },
+ "required": ["name", "description", "arguments", "packages", "code"],
+ },
+ },
+}
+
+REMOVE_FUNC = {
+ "type": "function",
+ "function": {
+ "name": "remove_function",
+ "description": "Remove one function in the context of the conversation. Once remove one function, the assistant will not use this function in future conversation.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string", "description": "The name of the function in the code implementation."}
+ },
+ "required": ["name"],
+ },
+ },
+}
+
+OPT_PROMPT = """You are a function optimizer. Your task is to maintain a list of functions for the assistant according to the existing function list and conversation history that happens between the assistant and the user.
+You can perform one of the following four actions to manipulate the function list using the functions you have:
+1. Revise one existing function (using revise_function).
+2. Remove one existing function (using remove_function).
+3. Add one new function (using add_function).
+4. Directly return "TERMINATE" to me if no more actions are needed for the current function list.
+
+Below are the principles that you need to follow for taking these four actions.
+(1) Revise one existing function:
+1. Pay more attention to the failed tasks and corresponding error information, and optimize the function used in these tasks according to the conversation history if needed.
+2. A failed function call can occur due to incorrect input arguments (missing arguments) or an incorrect function code implementation. You should focus more on the function code implementation and make it easy to get success function call.
+3. Do not revise the function that you think works well and plays a critical role in solving the problems according to the conversation history. Only making revisions if needed.
+4. Sometimes, a NameError may occur. To fix this error, you can either revise the name of the function in the code implementation or revise the name of the function call to make these two names consistent.
+(2) Remove one existing function:
+1. Only remove the function that you think is not needed anymore in future tasks.
+(3) Add one new function:
+1. The added function should be general enough to be used in future tasks. For instance, if you encounter a problem that this function can solve, or one step of it, you can use the generated function directly instead of starting from scratch
+2. The added new function should solve a higher-level question that encompasses the original query and extend the code's functionality to make it more versatile and widely applicable.
+3. Replace specific strings or variable names with general variables to enhance the tool's applicability to various queries. All names used inside the function should be passed in as arguments.
+Below is an example of a function that potentially deserves to be adde in solving MATH problems, which can be used to solve a higher-level question:
+{{
+ \"name\": \"evaluate_expression\",
+ \"description\": \"Evaluate arithmetic or mathematical expressions provided as strings.\",
+ \"arguments\": {{
+ \"expression\": {{
+ \"type\": \"string\",
+ \"description\": \"The mathematical expression to evaluate.\"
+ }}
+ }},
+ \"packages\": \"sympy\",
+ \"code\": \"from sympy import sympify, SympifyError\\n\\ndef evaluate_expression(expression):\\n try:\\n result = sympify(expression)\\n if result.is_number:\\n result = float(result)\\n else:\\n result = str(result)\\n return result\\n except SympifyError as e:\\n return str(e)\"
+}}
+(4) Directly return "TERMINATE":
+If you think there is no need to perform any other actions for the current function list since the current list is optimal more actions will harm the performance in future tasks. Please directly reply to me with "TERMINATE".
+
+One function signature includes the following five elements:
+1. Function name
+2. Function description
+3. JSON schema of arguments encoded as a string
+4. A list of package names imported by the function packages
+5. The code implementation
+
+Below are the signatures of the current functions:
+List A: {best_functions}.
+The following list are the function signatures that you have after taking {actions_num} actions to manipulate List A:
+List B: {incumbent_functions}.
+
+{accumulated_experience}
+
+Here are {best_conversations_num} conversation histories of solving {best_conversations_num} tasks using List A.
+History:
+{best_conversations_history}
+
+{statistic_informations}
+
+According to the information I provide, please take one of four actions to manipulate list B using the functions you know.
+Instead of returning TERMINATE directly or taking no action, you should try your best to optimize the function list. Only take no action if you really think the current list is optimal, as more actions will harm performance in future tasks.
+Even adding a general function that can substitute the assistant’s repeated suggestions of Python code with the same functionality could also be helpful.
+"""
+
+
+def execute_func(name, packages, code, **args):
+ """
+ The wrapper for generated functions.
+ """
+ pip_install = (
+ f"""print("Installing package: {packages}")\nsubprocess.run(["pip", "-qq", "install", "{packages}"])"""
+ if packages
+ else ""
+ )
+ str = f"""
+import subprocess
+{pip_install}
+print("Result of {name} function execution:")
+{code}
+args={args}
+result={name}(**args)
+if result is not None: print(result)
+"""
+ print(f"execute_code:\n{str}")
+ result = execute_code(str, use_docker="shaokun529/evoagent:v1")
+ if result[0] != 0:
+ raise Exception("Error in executing function:" + result[1])
+ print(f"Result: {result[1]}")
+ return result[1]
+
+
+class AgentOptimizer:
+ """
+ Base class for optimizing AutoGen agents. Specifically, it is used to optimize the functions used in the agent.
+ More information could be found in the following paper: https://arxiv.org/abs/2402.11359.
+ """
+
+ def __init__(
+ self,
+ max_actions_per_step: int,
+ llm_config: dict,
+ optimizer_model: Optional[str] = "gpt-4-1106-preview",
+ ):
+ """
+ (These APIs are experimental and may change in the future.)
+ Args:
+ max_actions_per_step (int): the maximum number of actions that the optimizer can take in one step.
+ llm_config (dict): llm inference configuration.
+ Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options.
+ When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`.
+ optimizer_model: the model used for the optimizer.
+ """
+ self.max_actions_per_step = max_actions_per_step
+ self._max_trials = 3
+ self.optimizer_model = optimizer_model
+
+ self._trial_conversations_history = []
+ self._trial_conversations_performance = []
+ self._trial_functions = []
+
+ self._best_conversations_history = []
+ self._best_conversations_performance = []
+ self._best_functions = []
+
+ self._failure_functions_performance = []
+ self._best_performance = -1
+
+ assert isinstance(llm_config, dict), "llm_config must be a dict"
+ llm_config = copy.deepcopy(llm_config)
+ self.llm_config = llm_config
+ if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]:
+ raise ValueError(
+ "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'."
+ )
+ self.llm_config["config_list"] = autogen.filter_config(
+ llm_config["config_list"], {"model": [self.optimizer_model]}
+ )
+ self._client = autogen.OpenAIWrapper(**self.llm_config)
+
+ def record_one_conversation(self, conversation_history: List[Dict], is_satisfied: bool = None):
+ """
+ record one conversation history.
+ Args:
+ conversation_history (List[Dict]): the chat messages of the conversation.
+ is_satisfied (bool): whether the user is satisfied with the solution. If it is none, the user will be asked to input the satisfaction.
+ """
+ if is_satisfied is None:
+ reply = input(
+ "Please provide whether the user is satisfied with the solution. 1 represents satisfied. 0 represents not satisfied. Press enter to submit. \n"
+ )
+ assert reply in [
+ "0",
+ "1",
+ ], "The input is invalid. Please input 1 or 0. 1 represents satisfied. 0 represents not satisfied."
+ is_satisfied = True if reply == "1" else False
+ self._trial_conversations_history.append(
+ {"Conversation {i}".format(i=len(self._trial_conversations_history)): conversation_history}
+ )
+ self._trial_conversations_performance.append(
+ {"Conversation {i}".format(i=len(self._trial_conversations_performance)): 1 if is_satisfied else 0}
+ )
+
+ def step(self):
+ """
+ One step of training. It will return register_for_llm and register_for_executor at each iteration,
+ which are subsequently utilized to update the assistant and executor agents, respectively.
+ See example: https://github.com/microsoft/autogen/blob/main/notebook/agentchat_agentoptimizer.ipynb
+ """
+ performance = sum(sum(d.values()) for d in self._trial_conversations_performance) / len(
+ self._trial_conversations_performance
+ )
+
+ if performance < self._best_performance:
+ self._failure_functions_performance.append({"functions": self._trial_functions, "performance": performance})
+ self._failure_functions_performance = sorted(
+ self._failure_functions_performance, key=lambda x: x["performance"]
+ )
+ else:
+ self._failure_functions_performance = []
+ self._best_performance = performance
+ self._best_functions = copy.deepcopy(self._trial_functions)
+ self._best_conversations_history = copy.deepcopy(self._trial_conversations_history)
+ self._best_conversations_performance = copy.deepcopy(self._trial_conversations_performance)
+ self._trial_conversations_history = []
+ self._trial_conversations_performance = []
+
+ best_functions = copy.deepcopy(self._best_functions)
+ incumbent_functions = copy.deepcopy(self._best_functions)
+ failure_experience_prompt, statistic_prompt = self._construct_intermediate_prompt()
+
+ for action_index in range(self.max_actions_per_step):
+ prompt = OPT_PROMPT.format(
+ best_conversations_history=self._best_conversations_history,
+ best_conversations_num=len(self._best_conversations_history),
+ actions_num=action_index,
+ best_functions=best_functions,
+ incumbent_functions=incumbent_functions,
+ accumulated_experience=failure_experience_prompt,
+ statistic_informations=statistic_prompt,
+ )
+ messages = [{"role": "user", "content": prompt}]
+ for _ in range(self._max_trials):
+ response = self._client.create(
+ messages=messages, tools=[ADD_FUNC, REVISE_FUNC, REMOVE_FUNC], tool_choice="auto"
+ )
+ actions = response.choices[0].message.tool_calls
+ if self._validate_actions(actions, incumbent_functions):
+ break
+ if actions is not None and self._validate_actions(actions, incumbent_functions):
+ incumbent_functions = self._update_function_call(incumbent_functions, actions)
+
+ remove_functions = list(
+ set([key for dictionary in self._trial_functions for key in dictionary.keys()])
+ - set([key for dictionary in incumbent_functions for key in dictionary.keys()])
+ )
+
+ register_for_llm = []
+ register_for_exector = {}
+ for name in remove_functions:
+ register_for_llm.append({"func_sig": {"name": name}, "is_remove": True})
+ register_for_exector.update({name: None})
+ for func in incumbent_functions:
+ register_for_llm.append(
+ {
+ "func_sig": {
+ "name": func.get("name"),
+ "description": func.get("description"),
+ "parameters": {"type": "object", "properties": func.get("arguments")},
+ },
+ "is_remove": False,
+ }
+ )
+ register_for_exector.update(
+ {
+ func.get("name"): lambda **args: execute_func(
+ func.get("name"), func.get("packages"), func.get("code"), **args
+ )
+ }
+ )
+
+ self._trial_functions = incumbent_functions
+ return register_for_llm, register_for_exector
+
+ def reset_optimizer(self):
+ """
+ reset the optimizer.
+ """
+
+ self._trial_conversations_history = []
+ self._trial_conversations_performance = []
+ self._trial_functions = []
+
+ self._best_conversations_history = []
+ self._best_conversations_performance = []
+ self._best_functions = []
+
+ self._best_performance = -1
+ self._failure_functions_performance = []
+
+ def _update_function_call(self, incumbent_functions, actions):
+ """
+ update function call.
+ """
+
+ formated_actions = []
+ for action in actions:
+ func = json.loads(action.function.arguments.strip('"'))
+ func["action_name"] = action.function.name
+
+ if func.get("action_name") == "remove_function":
+ item = {
+ "action_name": func.get("action_name"),
+ "name": func.get("name"),
+ }
+ else:
+ item = {
+ "action_name": func.get("action_name"),
+ "name": func.get("name"),
+ "description": func.get("description"),
+ "arguments": json.loads(func.get("arguments").strip('"')),
+ "packages": func.get("packages"),
+ "code": func.get("code"),
+ }
+ formated_actions.append(item)
+ actions = formated_actions
+
+ for action in actions:
+ name, description, arguments, packages, code, action_name = (
+ action.get("name"),
+ action.get("description"),
+ action.get("arguments"),
+ action.get("packages"),
+ action.get("code"),
+ action.get("action_name"),
+ )
+ if action_name == "remove_function":
+ incumbent_functions = [item for item in incumbent_functions if item["name"] != name]
+ else:
+ incumbent_functions = [item for item in incumbent_functions if item["name"] != name]
+ incumbent_functions.append(
+ {
+ "name": name,
+ "description": description,
+ "arguments": arguments,
+ "packages": packages,
+ "code": code,
+ }
+ )
+
+ return incumbent_functions
+
+ def _construct_intermediate_prompt(self):
+ """
+ construct intermediate prompts.
+ """
+ if len(self._failure_functions_performance) != 0:
+ failure_experience_prompt = "We also provide more examples for different functions and their corresponding performance (0-100).\n The following function signatures are arranged in are arranged in ascending order based on their performance, where higher performance indicate better quality."
+ failure_experience_prompt += "\n"
+ for item in self._failure_functions_performance:
+ failure_experience_prompt += "Function: \n" + str(item["functions"]) + "\n"
+ failure_experience_prompt += "Performance: \n" + str(item["performance"]) + "\n"
+ else:
+ failure_experience_prompt = "\n"
+
+ if len(self._best_conversations_performance) != 0:
+ statistic_prompt = "The following table shows the statistical information for solving each task in each conversation and indicates, whether the result is satisfied by the users. 1 represents satisfied. 0 represents not satisfied."
+ statistic_prompt += "\n"
+ for item in self._best_conversations_performance:
+ statistic_prompt += str(item) + "\n"
+ else:
+ statistic_prompt = "\n"
+
+ return failure_experience_prompt, statistic_prompt
+
+ def _validate_actions(self, actions, incumbent_functions):
+ """
+ validate whether the proposed actions are feasible.
+ """
+ if actions is None:
+ return True
+ else:
+ # val json format
+ for action in actions:
+ function_args = action.function.arguments
+ try:
+ function_args = json.loads(function_args.strip('"'))
+ if "arguments" in function_args.keys():
+ json.loads(function_args.get("arguments").strip('"'))
+ except Exception as e:
+ print("JSON is invalid:", e)
+ return False
+ # val syntax
+ for action in actions:
+ if action.function.name != "remove_function":
+ function_args = json.loads(action.function.arguments.strip('"'))
+ code = function_args.get("code")
+ try:
+ compile(code, "", "exec")
+ print("successfully compiled")
+ except Exception as e:
+ print("Syntax is invalid:", e)
+ return False
+ for action in actions:
+ action_name = action.function.name
+ if action_name == "remove_function":
+ function_args = json.loads(action.function.arguments.strip('"'))
+ if function_args.get("name") not in [item["name"] for item in incumbent_functions]:
+ print("The function you want to remove does not exist.")
+ return False
+ return True
diff --git a/autogen/agentchat/contrib/capabilities/__init__.py b/autogen/agentchat/contrib/capabilities/__init__.py
index 24d6be9de8b..e69de29bb2d 100644
--- a/autogen/agentchat/contrib/capabilities/__init__.py
+++ b/autogen/agentchat/contrib/capabilities/__init__.py
@@ -1,5 +0,0 @@
-from .teachability import Teachability
-from .agent_capability import AgentCapability
-
-
-__all__ = ["Teachability", "AgentCapability"]
diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py
new file mode 100644
index 00000000000..44b10259f1b
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/context_handling.py
@@ -0,0 +1,138 @@
+import sys
+from typing import Dict, List, Optional
+from warnings import warn
+
+import tiktoken
+from termcolor import colored
+
+from autogen import ConversableAgent, token_count_utils
+
+warn(
+ "Context handling with TransformChatHistory is deprecated and will be removed in `0.2.30`. "
+ "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
+
+class TransformChatHistory:
+ """
+ An agent's chat history with other agents is a common context that it uses to generate a reply.
+ This capability allows the agent to transform its chat history prior to using it to generate a reply.
+ It does not permanently modify the chat history, but rather processes it on every invocation.
+
+ This capability class enables various strategies to transform chat history, such as:
+ - Truncate messages: Truncate each message to first maximum number of tokens.
+ - Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
+ - Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
+ maximum number of tokens.
+ Note that the system message, because of its special significance, is always kept as is.
+
+ The three strategies can be combined. For example, when each of these parameters are specified
+ they are used in the following order:
+ 1. First truncate messages to a maximum number of tokens
+ 2. Second, it limits the number of message to keep
+ 3. Third, it limits the total number of tokens in the chat history
+
+ When adding this capability to an agent, the following are modified:
+ - A hook is added to the hookable method `process_all_messages_before_reply` to transform the
+ received messages for possible truncation.
+ Not modifying the stored message history.
+ """
+
+ def __init__(
+ self,
+ *,
+ max_tokens_per_message: Optional[int] = None,
+ max_messages: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ ):
+ """
+ Args:
+ max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
+ max_messages (Optional[int]): Maximum number of messages to keep in the context.
+ max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
+ """
+ self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
+ self.max_messages = max_messages if max_messages else sys.maxsize
+ self.max_tokens = max_tokens if max_tokens else sys.maxsize
+
+ def add_to_agent(self, agent: ConversableAgent):
+ """
+ Adds TransformChatHistory capability to the given agent.
+ """
+ agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
+
+ def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
+ """
+ Args:
+ messages: List of messages to process.
+
+ Returns:
+ List of messages with the first system message and the last max_messages messages,
+ ensuring each message does not exceed max_tokens_per_message.
+ """
+ temp_messages = messages.copy()
+ processed_messages = []
+ system_message = None
+ processed_messages_tokens = 0
+
+ if messages[0]["role"] == "system":
+ system_message = messages[0].copy()
+ temp_messages.pop(0)
+
+ total_tokens = sum(
+ token_count_utils.count_token(msg["content"]) for msg in temp_messages
+ ) # Calculate tokens for all messages
+
+ # Truncate each message's content to a maximum token limit of each message
+
+ # Process recent messages first
+ for msg in reversed(temp_messages[-self.max_messages :]):
+ msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
+ msg_tokens = token_count_utils.count_token(msg["content"])
+ if processed_messages_tokens + msg_tokens > self.max_tokens:
+ break
+ # append the message to the beginning of the list to preserve order
+ processed_messages = [msg] + processed_messages
+ processed_messages_tokens += msg_tokens
+ if system_message:
+ processed_messages.insert(0, system_message)
+ # Optionally, log the number of truncated messages and tokens if needed
+ num_truncated = len(messages) - len(processed_messages)
+
+ if num_truncated > 0 or total_tokens > processed_messages_tokens:
+ print(
+ colored(
+ f"Truncated {num_truncated} messages. Reduced from {len(messages)} to {len(processed_messages)}.",
+ "yellow",
+ )
+ )
+ print(
+ colored(
+ f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
+ "yellow",
+ )
+ )
+ return processed_messages
+
+
+def truncate_str_to_tokens(text: str, max_tokens: int, model: str = "gpt-3.5-turbo-0613") -> str:
+ """Truncate a string so that the number of tokens is less than or equal to max_tokens using tiktoken.
+
+ Args:
+ text: The string to truncate.
+ max_tokens: The maximum number of tokens to keep.
+ model: The target OpenAI model for tokenization alignment.
+
+ Returns:
+ The truncated string.
+ """
+
+ encoding = tiktoken.encoding_for_model(model) # Get the appropriate tokenizer
+
+ encoded_tokens = encoding.encode(text)
+ truncated_tokens = encoded_tokens[:max_tokens]
+ truncated_text = encoding.decode(truncated_tokens) # Decode back to text
+
+ return truncated_text
diff --git a/autogen/agentchat/contrib/capabilities/generate_images.py b/autogen/agentchat/contrib/capabilities/generate_images.py
new file mode 100644
index 00000000000..e4a8f1195c2
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/generate_images.py
@@ -0,0 +1,291 @@
+import re
+from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union
+
+from openai import OpenAI
+from PIL.Image import Image
+
+from autogen import Agent, ConversableAgent, code_utils
+from autogen.agentchat.contrib import img_utils
+from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
+from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
+from autogen.cache import AbstractCache
+
+SYSTEM_MESSAGE = "You've been given the special ability to generate images."
+DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
+
+PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
+DO NOT include any advice. RESPOND like the following example:
+EXAMPLE: Blue background, 3D shapes, ...
+"""
+
+
+class ImageGenerator(Protocol):
+ """This class defines an interface for image generators.
+
+ Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
+ input and returns a PIL Image object.
+
+ NOTE: Current implementation does not allow you to edit a previously existing image.
+ """
+
+ def generate_image(self, prompt: str) -> Image:
+ """Generates an image based on the provided prompt.
+
+ Args:
+ prompt: A string describing the desired image.
+
+ Returns:
+ A PIL Image object representing the generated image.
+
+ Raises:
+ ValueError: If the image generation fails.
+ """
+ ...
+
+ def cache_key(self, prompt: str) -> str:
+ """Generates a unique cache key for the given prompt.
+
+ This key can be used to store and retrieve generated images based on the prompt.
+
+ Args:
+ prompt: A string describing the desired image.
+
+ Returns:
+ A unique string that can be used as a cache key.
+ """
+ ...
+
+
+class DalleImageGenerator:
+ """Generates images using OpenAI's DALL-E models.
+
+ This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
+ models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
+
+ Note: Current implementation does not allow you to edit a previously existing image.
+ """
+
+ def __init__(
+ self,
+ llm_config: Dict,
+ resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
+ quality: Literal["standard", "hd"] = "standard",
+ num_images: int = 1,
+ ):
+ """
+ Args:
+ llm_config (dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
+ resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
+ quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
+ num_images (int): The number of images to generate.
+ """
+ config_list = llm_config["config_list"]
+ _validate_dalle_model(config_list[0]["model"])
+ _validate_resolution_format(resolution)
+
+ self._model = config_list[0]["model"]
+ self._resolution = resolution
+ self._quality = quality
+ self._num_images = num_images
+ self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
+
+ def generate_image(self, prompt: str) -> Image:
+ response = self._dalle_client.images.generate(
+ model=self._model,
+ prompt=prompt,
+ size=self._resolution,
+ quality=self._quality,
+ n=self._num_images,
+ )
+
+ image_url = response.data[0].url
+ if image_url is None:
+ raise ValueError("Failed to generate image.")
+
+ return img_utils.get_pil_image(image_url)
+
+ def cache_key(self, prompt: str) -> str:
+ keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
+ return ",".join([str(k) for k in keys])
+
+
+class ImageGeneration(AgentCapability):
+ """This capability allows a ConversableAgent to generate images based on the message received from other Agents.
+
+ 1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
+ extract relevant details.
+ 2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
+ 3. Optionally caches generated images for faster retrieval in future conversations.
+
+ NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
+ message received by the agent.
+
+ Example:
+ ```python
+ import autogen
+ from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
+
+ # Assuming you have llm configs configured for the LLMs you want to use and Dalle.
+ # Create the agent
+ agent = autogen.ConversableAgent(
+ name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
+ )
+
+ # Create an ImageGenerator with desired settings
+ dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
+
+ # Add the ImageGeneration capability to the agent
+ agent.add_capability(ImageGeneration(image_generator=dalle_gen))
+ ```
+ """
+
+ def __init__(
+ self,
+ image_generator: ImageGenerator,
+ cache: Optional[AbstractCache] = None,
+ text_analyzer_llm_config: Optional[Dict] = None,
+ text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
+ verbosity: int = 0,
+ register_reply_position: int = 2,
+ ):
+ """
+ Args:
+ image_generator (ImageGenerator): The image generator you would like to use to generate images.
+ cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None,
+ no caching will be used.
+ text_analyzer_llm_config (Dict or None): The LLM config for the text analyzer. If None, the LLM config will
+ be retrieved from the agent you're adding the ability to.
+ text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
+ incoming messages and extract the prompt for image generation. The default instructions focus on
+ summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
+ extraction.
+ Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
+ verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
+ analyzer llm calls will be silent if verbosity is less than 2.
+ register_reply_position (int): The position of the reply function in the agent's list of reply functions.
+ This capability registers a new reply function to handle messages with image generation requests.
+ Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
+ """
+ self._image_generator = image_generator
+ self._cache = cache
+ self._text_analyzer_llm_config = text_analyzer_llm_config
+ self._text_analyzer_instructions = text_analyzer_instructions
+ self._verbosity = verbosity
+ self._register_reply_position = register_reply_position
+
+ self._agent: Optional[ConversableAgent] = None
+ self._text_analyzer: Optional[TextAnalyzerAgent] = None
+
+ def add_to_agent(self, agent: ConversableAgent):
+ """Adds the Image Generation capability to the specified ConversableAgent.
+
+ This function performs the following modifications to the agent:
+
+ 1. Registers a reply function: A new reply function is registered with the agent to handle messages that
+ potentially request image generation. This function analyzes the message and triggers image generation if
+ necessary.
+ 2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
+ 3. Updates System Message: The agent's system message is updated to include a message indicating the
+ capability to generate images has been added.
+ 4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
+ capability. This might be helpful in certain use cases, like group chats.
+
+ Args:
+ agent (ConversableAgent): The ConversableAgent to add the capability to.
+ """
+ self._agent = agent
+
+ agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
+
+ self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
+ self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
+
+ agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
+ agent.description += "\n" + DESCRIPTION_MESSAGE
+
+ def _image_gen_reply(
+ self,
+ recipient: ConversableAgent,
+ messages: Optional[List[Dict]],
+ sender: Optional[Agent] = None,
+ config: Optional[Any] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ if messages is None:
+ return False, None
+
+ last_message = code_utils.content_str(messages[-1]["content"])
+
+ if not last_message:
+ return False, None
+
+ if self._should_generate_image(last_message):
+ prompt = self._extract_prompt(last_message)
+
+ image = self._cache_get(prompt)
+ if image is None:
+ image = self._image_generator.generate_image(prompt)
+ self._cache_set(prompt, image)
+
+ return True, self._generate_content_message(prompt, image)
+
+ else:
+ return False, None
+
+ def _should_generate_image(self, message: str) -> bool:
+ assert self._text_analyzer is not None
+
+ instructions = """
+ Does any part of the TEXT ask the agent to generate an image?
+ The TEXT must explicitly mention that the image must be generated.
+ Answer with just one word, yes or no.
+ """
+ analysis = self._text_analyzer.analyze_text(message, instructions)
+
+ return "yes" in self._extract_analysis(analysis).lower()
+
+ def _extract_prompt(self, last_message) -> str:
+ assert self._text_analyzer is not None
+
+ analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
+ return self._extract_analysis(analysis)
+
+ def _cache_get(self, prompt: str) -> Optional[Image]:
+ if self._cache:
+ key = self._image_generator.cache_key(prompt)
+ cached_value = self._cache.get(key)
+
+ if cached_value:
+ return img_utils.get_pil_image(cached_value)
+
+ def _cache_set(self, prompt: str, image: Image):
+ if self._cache:
+ key = self._image_generator.cache_key(prompt)
+ self._cache.set(key, img_utils.pil_to_data_uri(image))
+
+ def _extract_analysis(self, analysis: Union[str, Dict, None]) -> str:
+ if isinstance(analysis, Dict):
+ return code_utils.content_str(analysis["content"])
+ else:
+ return code_utils.content_str(analysis)
+
+ def _generate_content_message(self, prompt: str, image: Image) -> Dict[str, Any]:
+ return {
+ "content": [
+ {"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
+ {"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
+ ]
+ }
+
+
+### Helpers
+def _validate_resolution_format(resolution: str):
+ """Checks if a string is in a valid resolution format (e.g., "1024x768")."""
+ pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
+ matched_resolution = re.match(pattern, resolution)
+ if matched_resolution is None:
+ raise ValueError(f"Invalid resolution format: {resolution}")
+
+
+def _validate_dalle_model(model: str):
+ if model not in ["dall-e-3", "dall-e-2"]:
+ raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
diff --git a/autogen/agentchat/contrib/capabilities/teachability.py b/autogen/agentchat/contrib/capabilities/teachability.py
index f673269d60b..596e449ce34 100644
--- a/autogen/agentchat/contrib/capabilities/teachability.py
+++ b/autogen/agentchat/contrib/capabilities/teachability.py
@@ -1,18 +1,15 @@
import os
-from autogen.agentchat.assistant_agent import ConversableAgent
-from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
-from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
-from typing import Dict, Optional, Union, List, Tuple, Any
+import pickle
+from typing import Dict, Optional, Union
+
import chromadb
from chromadb.config import Settings
-import pickle
-try:
- from termcolor import colored
-except ImportError:
+from autogen.agentchat.assistant_agent import ConversableAgent
+from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
+from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
- def colored(x, *args, **kwargs):
- return x
+from ....formatting_utils import colored
class Teachability(AgentCapability):
@@ -23,6 +20,13 @@ class Teachability(AgentCapability):
To make any conversable agent teachable, instantiate both the agent and the Teachability class,
then pass the agent to teachability.add_to_agent(agent).
Note that teachable agents in a group chat must be given unique path_to_db_dir values.
+
+ When adding Teachability to an agent, the following are modified:
+ - The agent's system message is appended with a note about the agent's new ability.
+ - A hook is added to the agent's `process_last_received_message` hookable method,
+ and the hook potentially modifies the last of the received messages to include earlier teachings related to the message.
+ Added teachings do not propagate into the stored message history.
+ If new user teachings are detected, they are added to new memos in the vector database.
"""
def __init__(
@@ -61,7 +65,7 @@ def add_to_agent(self, agent: ConversableAgent):
self.teachable_agent = agent
# Register a hook for processing the last message.
- agent.register_hook(hookable_method=agent.process_last_message, hook=self.process_last_message)
+ agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
# Was an llm_config passed to the constructor?
if self.llm_config is None:
@@ -82,7 +86,7 @@ def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()
- def process_last_message(self, text):
+ def process_last_received_message(self, text: Union[Dict, str]):
"""
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
@@ -99,7 +103,7 @@ def process_last_message(self, text):
# Return the (possibly) expanded message text.
return expanded_text
- def _consider_memo_storage(self, comment):
+ def _consider_memo_storage(self, comment: Union[Dict, str]):
"""Decides whether to store something from one user comment in the DB."""
memo_added = False
@@ -157,7 +161,7 @@ def _consider_memo_storage(self, comment):
# Yes. Save them to disk.
self.memo_store._save_memos()
- def _consider_memo_retrieval(self, comment):
+ def _consider_memo_retrieval(self, comment: Union[Dict, str]):
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
# First, use the comment directly as the lookup key.
@@ -191,7 +195,7 @@ def _consider_memo_retrieval(self, comment):
# Append the memos to the text of the last message.
return comment + self._concatenate_memo_texts(memo_list)
- def _retrieve_relevant_memos(self, input_text):
+ def _retrieve_relevant_memos(self, input_text: str) -> list:
"""Returns semantically related memos from the DB."""
memo_list = self.memo_store.get_related_memos(
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
@@ -209,7 +213,7 @@ def _retrieve_relevant_memos(self, input_text):
memo_list = [memo[1] for memo in memo_list]
return memo_list
- def _concatenate_memo_texts(self, memo_list):
+ def _concatenate_memo_texts(self, memo_list: list) -> str:
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
memo_texts = ""
if len(memo_list) > 0:
@@ -221,7 +225,7 @@ def _concatenate_memo_texts(self, memo_list):
memo_texts = memo_texts + "\n" + info
return memo_texts
- def _analyze(self, text_to_analyze, analysis_instructions):
+ def _analyze(self, text_to_analyze: Union[Dict, str], analysis_instructions: Union[Dict, str]):
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
self.analyzer.reset() # Clear the analyzer's list of messages.
self.teachable_agent.send(
@@ -242,10 +246,16 @@ class MemoStore:
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
"""
- def __init__(self, verbosity, reset, path_to_db_dir):
+ def __init__(
+ self,
+ verbosity: Optional[int] = 0,
+ reset: Optional[bool] = False,
+ path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
+ ):
"""
Args:
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
+ - reset (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
"""
self.verbosity = verbosity
@@ -300,7 +310,7 @@ def reset_db(self):
self.uid_text_dict = {}
self._save_memos()
- def add_input_output_pair(self, input_text, output_text):
+ def add_input_output_pair(self, input_text: str, output_text: str):
"""Adds an input-output pair to the vector DB."""
self.last_memo_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
@@ -317,7 +327,7 @@ def add_input_output_pair(self, input_text, output_text):
if self.verbosity >= 3:
self.list_memos()
- def get_nearest_memo(self, query_text):
+ def get_nearest_memo(self, query_text: str):
"""Retrieves the nearest memo to the given query text."""
results = self.vec_db.query(query_texts=[query_text], n_results=1)
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
@@ -334,7 +344,7 @@ def get_nearest_memo(self, query_text):
)
return input_text, output_text, distance
- def get_related_memos(self, query_text, n_results, threshold):
+ def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
diff --git a/autogen/agentchat/contrib/capabilities/text_compressors.py b/autogen/agentchat/contrib/capabilities/text_compressors.py
new file mode 100644
index 00000000000..78554bdc935
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/text_compressors.py
@@ -0,0 +1,68 @@
+from typing import Any, Dict, Optional, Protocol
+
+IMPORT_ERROR: Optional[Exception] = None
+try:
+ import llmlingua
+except ImportError:
+ IMPORT_ERROR = ImportError(
+ "LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
+ )
+ PromptCompressor = object
+else:
+ from llmlingua import PromptCompressor
+
+
+class TextCompressor(Protocol):
+ """Defines a protocol for text compression to optimize agent interactions."""
+
+ def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
+ """This method takes a string as input and returns a dictionary containing the compressed text and other
+ relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
+ To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
+ """
+ ...
+
+
+class LLMLingua:
+ """Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
+
+ NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
+ and the specific configurations used for the PromptCompressor.
+ """
+
+ def __init__(
+ self,
+ prompt_compressor_kwargs: Dict = dict(
+ model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2=True,
+ device_map="cpu",
+ ),
+ structured_compression: bool = False,
+ ) -> None:
+ """
+ Args:
+ prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
+ dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2 set to True, and device_map set to "cpu".
+ structured_compression (bool): A flag indicating whether to use structured compression. If True, the
+ structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
+ is used. Defaults to False.
+ dictionary.
+
+ Raises:
+ ImportError: If the llmlingua library is not installed.
+ """
+ if IMPORT_ERROR:
+ raise IMPORT_ERROR
+
+ self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
+
+ assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
+ self._compression_method = (
+ self._prompt_compressor.structured_compress_prompt
+ if structured_compression
+ else self._prompt_compressor.compress_prompt
+ )
+
+ def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
+ return self._compression_method([text], **compression_params)
diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py
new file mode 100644
index 00000000000..e96dc39fa7b
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/transform_messages.py
@@ -0,0 +1,87 @@
+import copy
+from typing import Dict, List
+
+from autogen import ConversableAgent
+
+from ....formatting_utils import colored
+from .transforms import MessageTransform
+
+
+class TransformMessages:
+ """Agent capability for transforming messages before reply generation.
+
+ This capability allows you to apply a series of message transformations to
+ a ConversableAgent's incoming messages before they are processed for response
+ generation. This is useful for tasks such as:
+
+ - Limiting the number of messages considered for context.
+ - Truncating messages to meet token limits.
+ - Filtering sensitive information.
+ - Customizing message formatting.
+
+ To use `TransformMessages`:
+
+ 1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
+ 2. Instantiate `TransformMessages` with a list of these transformations.
+ 3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.
+
+ NOTE: Order of message transformations is important. You could get different results based on
+ the order of transformations.
+
+ Example:
+ ```python
+ from agentchat import ConversableAgent
+ from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter
+
+ max_messages = MessageHistoryLimiter(max_messages=2)
+ truncate_messages = MessageTokenLimiter(max_tokens=500)
+ transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])
+
+ agent = ConversableAgent(...)
+ transform_messages.add_to_agent(agent)
+ ```
+ """
+
+ def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True):
+ """
+ Args:
+ transforms: A list of message transformations to apply.
+ verbose: Whether to print logs of each transformation or not.
+ """
+ self._transforms = transforms
+ self._verbose = verbose
+
+ def add_to_agent(self, agent: ConversableAgent):
+ """Adds the message transformations capability to the specified ConversableAgent.
+
+ This function performs the following modifications to the agent:
+
+ 1. Registers a hook that automatically transforms all messages before they are processed for
+ response generation.
+ """
+ agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
+
+ def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
+ post_transform_messages = copy.deepcopy(messages)
+ system_message = None
+
+ if messages[0]["role"] == "system":
+ system_message = copy.deepcopy(messages[0])
+ post_transform_messages.pop(0)
+
+ for transform in self._transforms:
+ # deepcopy in case pre_transform_messages will later be used for logs printing
+ pre_transform_messages = (
+ copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
+ )
+ post_transform_messages = transform.apply_transform(pre_transform_messages)
+
+ if self._verbose:
+ logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
+ if had_effect:
+ print(colored(logs_str, "yellow"))
+
+ if system_message:
+ post_transform_messages.insert(0, system_message)
+
+ return post_transform_messages
diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py
new file mode 100644
index 00000000000..dad3fc335ed
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/transforms.py
@@ -0,0 +1,423 @@
+import copy
+import sys
+from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
+
+import tiktoken
+from termcolor import colored
+
+from autogen import token_count_utils
+from autogen.cache import AbstractCache, Cache
+from autogen.types import MessageContentType
+
+from . import transforms_util
+from .text_compressors import LLMLingua, TextCompressor
+
+
+class MessageTransform(Protocol):
+ """Defines a contract for message transformation.
+
+ Classes implementing this protocol should provide an `apply_transform` method
+ that takes a list of messages and returns the transformed list.
+ """
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies a transformation to a list of messages.
+
+ Args:
+ messages: A list of dictionaries representing messages.
+
+ Returns:
+ A new list of dictionaries containing the transformed messages.
+ """
+ ...
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ """Creates the string including the logs of the transformation
+
+ Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
+
+ Args:
+ pre_transform_messages: A list of dictionaries representing messages before the transformation.
+ post_transform_messages: A list of dictionaries representig messages after the transformation.
+
+ Returns:
+ A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
+ """
+ ...
+
+
+class MessageHistoryLimiter:
+ """Limits the number of messages considered by an agent for response generation.
+
+ This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
+ It trims the conversation history by removing older messages, retaining only the most recent messages.
+ """
+
+ def __init__(self, max_messages: Optional[int] = None):
+ """
+ Args:
+ max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
+ """
+ self._validate_max_messages(max_messages)
+ self._max_messages = max_messages
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Truncates the conversation history to the specified maximum number of messages.
+
+ This method returns a new list containing the most recent messages up to the specified
+ maximum number of messages (max_messages). If max_messages is None, it returns the
+ original list of messages unmodified.
+
+ Args:
+ messages (List[Dict]): The list of messages representing the conversation history.
+
+ Returns:
+ List[Dict]: A new list containing the most recent messages up to the specified maximum.
+ """
+
+ if self._max_messages is None:
+ return messages
+
+ return messages[-self._max_messages :]
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ pre_transform_messages_len = len(pre_transform_messages)
+ post_transform_messages_len = len(post_transform_messages)
+
+ if post_transform_messages_len < pre_transform_messages_len:
+ logs_str = (
+ f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
+ f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
+ )
+ return logs_str, True
+ return "No messages were removed.", False
+
+ def _validate_max_messages(self, max_messages: Optional[int]):
+ if max_messages is not None and max_messages < 1:
+ raise ValueError("max_messages must be None or greater than 1")
+
+
+class MessageTokenLimiter:
+ """Truncates messages to meet token limits for efficient processing and response generation.
+
+ This transformation applies two levels of truncation to the conversation history:
+
+ 1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
+ 2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.
+
+ NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
+ counts for the same text.
+
+ NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
+ (e.g images).
+
+ The truncation process follows these steps in order:
+
+ 1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
+ are less than this threshold, then the messages are returned as is. In other case, the following process is applied.
+ 2. Messages are processed in reverse order (newest to oldest).
+ 3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
+ and other types of content, only the text content is truncated.
+ 4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
+ exceeds this limit, the current message being processed get truncated to meet the total token count and any
+ remaining messages get discarded.
+ 5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
+ original message order.
+ """
+
+ def __init__(
+ self,
+ max_tokens_per_message: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ min_tokens: Optional[int] = None,
+ model: str = "gpt-3.5-turbo-0613",
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
+ ):
+ """
+ Args:
+ max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
+ Must be greater than or equal to 0 if not None.
+ max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
+ Must be greater than or equal to 0 if not None.
+ min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
+ Must be greater than or equal to 0 if not None.
+ model (str): The target OpenAI model for tokenization alignment.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from token truncation. If False, messages that match the filter will be truncated.
+ """
+ self._model = model
+ self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
+ self._max_tokens = self._validate_max_tokens(max_tokens)
+ self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies token truncation to the conversation history.
+
+ Args:
+ messages (List[Dict]): The list of messages representing the conversation history.
+
+ Returns:
+ List[Dict]: A new list containing the truncated messages up to the specified token limits.
+ """
+ assert self._max_tokens_per_message is not None
+ assert self._max_tokens is not None
+ assert self._min_tokens is not None
+
+ # if the total number of tokens in the messages is less than the min_tokens, return the messages as is
+ if not transforms_util.min_tokens_reached(messages, self._min_tokens):
+ return messages
+
+ temp_messages = copy.deepcopy(messages)
+ processed_messages = []
+ processed_messages_tokens = 0
+
+ for msg in reversed(temp_messages):
+ # Some messages may not have content.
+ if not transforms_util.is_content_right_type(msg.get("content")):
+ processed_messages.insert(0, msg)
+ continue
+
+ if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
+ processed_messages.insert(0, msg)
+ processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
+ continue
+
+ expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
+
+ # If adding this message would exceed the token limit, truncate the last message to meet the total token
+ # limit and discard all remaining messages
+ if expected_tokens_remained < 0:
+ msg["content"] = self._truncate_str_to_tokens(
+ msg["content"], self._max_tokens - processed_messages_tokens
+ )
+ processed_messages.insert(0, msg)
+ break
+
+ msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
+ msg_tokens = transforms_util.count_text_tokens(msg["content"])
+
+ # prepend the message to the list to preserve order
+ processed_messages_tokens += msg_tokens
+ processed_messages.insert(0, msg)
+
+ return processed_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ pre_transform_messages_tokens = sum(
+ transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
+ )
+ post_transform_messages_tokens = sum(
+ transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
+ )
+
+ if post_transform_messages_tokens < pre_transform_messages_tokens:
+ logs_str = (
+ f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
+ f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
+ )
+ return logs_str, True
+ return "No tokens were truncated.", False
+
+ def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
+ if isinstance(contents, str):
+ return self._truncate_tokens(contents, n_tokens)
+ elif isinstance(contents, list):
+ return self._truncate_multimodal_text(contents, n_tokens)
+ else:
+ raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
+
+ def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]:
+ """Truncates text content within a list of multimodal elements, preserving the overall structure."""
+ tmp_contents = []
+ for content in contents:
+ if content["type"] == "text":
+ truncated_text = self._truncate_tokens(content["text"], n_tokens)
+ tmp_contents.append({"type": "text", "text": truncated_text})
+ else:
+ tmp_contents.append(content)
+ return tmp_contents
+
+ def _truncate_tokens(self, text: str, n_tokens: int) -> str:
+ encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
+
+ encoded_tokens = encoding.encode(text)
+ truncated_tokens = encoded_tokens[:n_tokens]
+ truncated_text = encoding.decode(truncated_tokens) # Decode back to text
+
+ return truncated_text
+
+ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
+ if max_tokens is not None and max_tokens < 0:
+ raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")
+
+ try:
+ allowed_tokens = token_count_utils.get_max_token_limit(self._model)
+ except Exception:
+ print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
+ allowed_tokens = None
+
+ if max_tokens is not None and allowed_tokens is not None:
+ if max_tokens > allowed_tokens:
+ print(
+ colored(
+ f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
+ "yellow",
+ )
+ )
+ return allowed_tokens
+
+ return max_tokens if max_tokens is not None else sys.maxsize
+
+ def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
+ if min_tokens is None:
+ return 0
+ if min_tokens < 0:
+ raise ValueError("min_tokens must be None or greater than or equal to 0.")
+ if max_tokens is not None and min_tokens > max_tokens:
+ raise ValueError("min_tokens must not be more than max_tokens.")
+ return min_tokens
+
+
+class TextMessageCompressor:
+ """A transform for compressing text messages in a conversation history.
+
+ It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
+ processing and response generation by downstream models.
+ """
+
+ def __init__(
+ self,
+ text_compressor: Optional[TextCompressor] = None,
+ min_tokens: Optional[int] = None,
+ compression_params: Dict = dict(),
+ cache: Optional[AbstractCache] = Cache.disk(),
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
+ ):
+ """
+ Args:
+ text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
+ protocol. If None, it defaults to LLMLingua.
+ min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
+ than or equal to 0 if not None. If None, no threshold-based compression is applied.
+ compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
+ dictionary.
+ cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
+ If None, no caching will be used.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from compression. If False, messages that match the filter will be compressed.
+ """
+
+ if text_compressor is None:
+ text_compressor = LLMLingua()
+
+ self._validate_min_tokens(min_tokens)
+
+ self._text_compressor = text_compressor
+ self._min_tokens = min_tokens
+ self._compression_args = compression_params
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
+ self._cache = cache
+
+ # Optimizing savings calculations to optimize log generation
+ self._recent_tokens_savings = 0
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies compression to messages in a conversation history based on the specified configuration.
+
+ The function processes each message according to the `compression_args` and `min_tokens` settings, applying
+ the specified compression configuration and returning a new list of messages with reduced token counts
+ where possible.
+
+ Args:
+ messages (List[Dict]): A list of message dictionaries to be compressed.
+
+ Returns:
+ List[Dict]: A list of dictionaries with the message content compressed according to the configured
+ method and scope.
+ """
+ # Make sure there is at least one message
+ if not messages:
+ return messages
+
+ # if the total number of tokens in the messages is less than the min_tokens, return the messages as is
+ if not transforms_util.min_tokens_reached(messages, self._min_tokens):
+ return messages
+
+ total_savings = 0
+ processed_messages = messages.copy()
+ for message in processed_messages:
+ # Some messages may not have content.
+ if not transforms_util.is_content_right_type(message.get("content")):
+ continue
+
+ if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
+ continue
+
+ if transforms_util.is_content_text_empty(message["content"]):
+ continue
+
+ cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
+ cached_content = transforms_util.cache_content_get(self._cache, cache_key)
+ if cached_content is not None:
+ message["content"], savings = cached_content
+ else:
+ message["content"], savings = self._compress(message["content"])
+
+ transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
+
+ assert isinstance(savings, int)
+ total_savings += savings
+
+ self._recent_tokens_savings = total_savings
+ return processed_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ if self._recent_tokens_savings > 0:
+ return f"{self._recent_tokens_savings} tokens saved with text compression.", True
+ else:
+ return "No tokens saved with text compression.", False
+
+ def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
+ """Compresses the given text or multimodal content using the specified compression method."""
+ if isinstance(content, str):
+ return self._compress_text(content)
+ elif isinstance(content, list):
+ return self._compress_multimodal(content)
+ else:
+ return content, 0
+
+ def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
+ tokens_saved = 0
+ for item in content:
+ if isinstance(item, dict) and "text" in item:
+ item["text"], savings = self._compress_text(item["text"])
+ tokens_saved += savings
+
+ elif isinstance(item, str):
+ item, savings = self._compress_text(item)
+ tokens_saved += savings
+
+ return content, tokens_saved
+
+ def _compress_text(self, text: str) -> Tuple[str, int]:
+ """Compresses the given text using the specified compression method."""
+ compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
+
+ savings = 0
+ if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
+ savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
+
+ return compressed_text["compressed_prompt"], savings
+
+ def _validate_min_tokens(self, min_tokens: Optional[int]):
+ if min_tokens is not None and min_tokens <= 0:
+ raise ValueError("min_tokens must be greater than 0 or None")
diff --git a/autogen/agentchat/contrib/capabilities/transforms_util.py b/autogen/agentchat/contrib/capabilities/transforms_util.py
new file mode 100644
index 00000000000..8678dec654c
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/transforms_util.py
@@ -0,0 +1,114 @@
+from typing import Any, Dict, Hashable, List, Optional, Tuple
+
+from autogen import token_count_utils
+from autogen.cache.abstract_cache_base import AbstractCache
+from autogen.oai.openai_utils import filter_config
+from autogen.types import MessageContentType
+
+
+def cache_key(content: MessageContentType, *args: Hashable) -> str:
+ """Calculates the cache key for the given message content and any other hashable args.
+
+ Args:
+ content (MessageContentType): The message content to calculate the cache key for.
+ *args: Any additional hashable args to include in the cache key.
+ """
+ str_keys = [str(key) for key in (content, *args)]
+ return "".join(str_keys)
+
+
+def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]:
+ """Retrieves cachedd content from the cache.
+
+ Args:
+ cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
+ key (str): The key to retrieve the content from.
+ """
+ if cache:
+ cached_value = cache.get(key)
+ if cached_value:
+ return cached_value
+
+
+def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
+ """Sets content into the cache.
+
+ Args:
+ cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
+ key (str): The key to set the content into.
+ content (MessageContentType): The message content to set into the cache.
+ *extra_values: Additional values to be passed to the cache.
+ """
+ if cache:
+ cache_value = (content, *extra_values)
+ cache.set(key, cache_value)
+
+
+def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
+ """Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
+
+ Args:
+ messages (List[Dict]): A list of messages to check.
+ """
+ if not min_tokens:
+ return True
+
+ messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
+ return messages_tokens >= min_tokens
+
+
+def count_text_tokens(content: MessageContentType) -> int:
+ """Calculates the number of text tokens in the given message content.
+
+ Args:
+ content (MessageContentType): The message content to calculate the number of text tokens for.
+ """
+ token_count = 0
+ if isinstance(content, str):
+ token_count = token_count_utils.count_token(content)
+ elif isinstance(content, list):
+ for item in content:
+ if isinstance(item, str):
+ token_count += token_count_utils.count_token(item)
+ else:
+ token_count += count_text_tokens(item.get("text", ""))
+ return token_count
+
+
+def is_content_right_type(content: Any) -> bool:
+ """A helper function to check if the passed in content is of the right type."""
+ return isinstance(content, (str, list))
+
+
+def is_content_text_empty(content: MessageContentType) -> bool:
+ """Checks if the content of the message does not contain any text.
+
+ Args:
+ content (MessageContentType): The message content to check.
+ """
+ if isinstance(content, str):
+ return content == ""
+ elif isinstance(content, list):
+ texts = []
+ for item in content:
+ if isinstance(item, str):
+ texts.append(item)
+ elif isinstance(item, dict):
+ texts.append(item.get("text", ""))
+ return not any(texts)
+ else:
+ return True
+
+
+def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
+ """Validates whether the transform should be applied according to the filter dictionary.
+
+ Args:
+ message (Dict[str, Any]): The message to validate.
+ filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
+ exclude (bool): Whether to exclude messages that match the filter dictionary.
+ """
+ if not filter_dict:
+ return True
+
+ return len(filter_config([message], filter_dict, exclude)) > 0
diff --git a/autogen/agentchat/contrib/capabilities/vision_capability.py b/autogen/agentchat/contrib/capabilities/vision_capability.py
new file mode 100644
index 00000000000..acfb9c8f6d8
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/vision_capability.py
@@ -0,0 +1,211 @@
+import copy
+from typing import Callable, Dict, List, Optional, Union
+
+from autogen.agentchat.assistant_agent import ConversableAgent
+from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
+from autogen.agentchat.contrib.img_utils import (
+ convert_base64_to_data_uri,
+ get_image_data,
+ get_pil_image,
+ gpt4v_formatter,
+ message_formatter_pil_to_b64,
+)
+from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
+from autogen.agentchat.conversable_agent import colored
+from autogen.code_utils import content_str
+from autogen.oai.client import OpenAIWrapper
+
+DEFAULT_DESCRIPTION_PROMPT = (
+ "Write a detailed caption for this image. "
+ "Pay special attention to any details that might be useful or relevant "
+ "to the ongoing conversation."
+)
+
+
+class VisionCapability(AgentCapability):
+ """We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
+ such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
+ the image (captioning) before sending the information to the agent's actual client.
+
+ The vision capability will hook to the ConversableAgent's `process_last_received_message`.
+
+ Some technical details:
+ When the agent (who has the vision capability) received an message, it will:
+ 1. _process_received_message:
+ a. _append_oai_message
+ 2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
+ a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
+ b. hook process_all_messages_before_reply
+ 3. send:
+ a. hook process_message_before_send
+ b. _append_oai_message
+ """
+
+ def __init__(
+ self,
+ lmm_config: Dict,
+ description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
+ custom_caption_func: Callable = None,
+ ) -> None:
+ """
+ Initializes a new instance, setting up the configuration for interacting with
+ a Language Multimodal (LMM) client and specifying optional parameters for image
+ description and captioning.
+
+ Args:
+ lmm_config (Dict): Configuration for the LMM client, which is used to call
+ the LMM service for describing the image. This must be a dictionary containing
+ the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
+ it is considered invalid, and initialization will assert.
+ description_prompt (Optional[str], optional): The prompt to use for generating
+ descriptions of the image. This parameter allows customization of the
+ prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
+ custom_caption_func (Callable, optional): A callable that, if provided, will be used
+ to generate captions for images. This allows for custom captioning logic outside
+ of the standard LMM service interaction.
+ The callable should take three parameters as input:
+ 1. an image URL (or local location)
+ 2. image_data (a PIL image)
+ 3. lmm_client (to call remote LMM)
+ and then return a description (as string).
+ If not provided, captioning will rely on the LMM client configured via `lmm_config`.
+ If provided, we will not run the default self._get_image_caption method.
+
+ Raises:
+ AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
+ an AssertionError is raised to indicate that the Vision Capability requires
+ one of these to be valid for operation.
+ """
+ self._lmm_config = lmm_config
+ self._description_prompt = description_prompt
+ self._parent_agent = None
+
+ if lmm_config:
+ self._lmm_client = OpenAIWrapper(**lmm_config)
+ else:
+ self._lmm_client = None
+
+ self._custom_caption_func = custom_caption_func
+ assert (
+ self._lmm_config or custom_caption_func
+ ), "Vision Capability requires a valid lmm_config or custom_caption_func."
+
+ def add_to_agent(self, agent: ConversableAgent) -> None:
+ self._parent_agent = agent
+
+ # Append extra info to the system message.
+ agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")
+
+ # Register a hook for processing the last message.
+ agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
+
+ def process_last_received_message(self, content: Union[str, List[dict]]) -> str:
+ """
+ Processes the last received message content by normalizing and augmenting it
+ with descriptions of any included images. The function supports input content
+ as either a string or a list of dictionaries, where each dictionary represents
+ a content item (e.g., text, image). If the content contains image URLs, it
+ fetches the image data, generates a caption for each image, and inserts the
+ caption into the augmented content.
+
+ The function aims to transform the content into a format compatible with GPT-4V
+ multimodal inputs, specifically by formatting strings into PIL-compatible
+ images if needed and appending text descriptions for images. This allows for
+ a more accessible presentation of the content, especially in contexts where
+ images cannot be displayed directly.
+
+ Args:
+ content (Union[str, List[dict]]): The last received message content, which
+ can be a plain text string or a list of dictionaries representing
+ different types of content items (e.g., text, image_url).
+
+ Returns:
+ str: The augmented message content
+
+ Raises:
+ AssertionError: If an item in the content list is not a dictionary.
+
+ Examples:
+ Assuming `self._get_image_caption(img_data)` returns
+ "A beautiful sunset over the mountains" for the image.
+
+ - Input as String:
+ content = "Check out this cool photo!"
+ Output: "Check out this cool photo!"
+ (Content is a string without an image, remains unchanged.)
+
+ - Input as String, with image location:
+ content = "What's weather in this cool photo: "
+ Output: "What's weather in this cool photo: in case you can not see, the caption of this image is:
+ A beautiful sunset over the mountains\n"
+ (Caption added after the image)
+
+ - Input as List with Text Only:
+ content = [{"type": "text", "text": "Here's an interesting fact."}]
+ Output: "Here's an interesting fact."
+ (No images in the content, it remains unchanged.)
+
+ - Input as List with Image URL:
+ content = [
+ {"type": "text", "text": "What's weather in this cool photo:"},
+ {"type": "image_url", "image_url": {"url": "http://example.com/photo.jpg"}}
+ ]
+ Output: "What's weather in this cool photo: in case you can not see, the caption of this image is:
+ A beautiful sunset over the mountains\n"
+ (Caption added after the image)
+ """
+ copy.deepcopy(content)
+ # normalize the content into the gpt-4v format for multimodal
+ # we want to keep the URL format to keep it concise.
+ if isinstance(content, str):
+ content = gpt4v_formatter(content, img_format="url")
+
+ aug_content: str = ""
+ for item in content:
+ assert isinstance(item, dict)
+ if item["type"] == "text":
+ aug_content += item["text"]
+ elif item["type"] == "image_url":
+ img_url = item["image_url"]["url"]
+ img_caption = ""
+
+ if self._custom_caption_func:
+ img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
+ elif self._lmm_client:
+ img_data = get_image_data(img_url)
+ img_caption = self._get_image_caption(img_data)
+ else:
+ img_caption = ""
+
+ aug_content += f" in case you can not see, the caption of this image is: {img_caption}\n"
+ else:
+ print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")
+
+ return aug_content
+
+ def _get_image_caption(self, img_data: str) -> str:
+ """
+ Args:
+ img_data (str): base64 encoded image data.
+ Returns:
+ str: caption for the given image.
+ """
+ response = self._lmm_client.create(
+ context=None,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": self._description_prompt},
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": convert_base64_to_data_uri(img_data),
+ },
+ },
+ ],
+ }
+ ],
+ )
+ description = response.choices[0].message.content
+ return content_str(description)
diff --git a/autogen/agentchat/contrib/compressible_agent.py b/autogen/agentchat/contrib/compressible_agent.py
index f9ea6bd268b..bea4058b94a 100644
--- a/autogen/agentchat/contrib/compressible_agent.py
+++ b/autogen/agentchat/contrib/compressible_agent.py
@@ -1,26 +1,27 @@
-from typing import Callable, Dict, Optional, Union, Tuple, List, Any
-from autogen import OpenAIWrapper
-from autogen import Agent, ConversableAgent
import copy
-import asyncio
-import logging
import inspect
-from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions
-
-try:
- from termcolor import colored
-except ImportError:
+import logging
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+from warnings import warn
- def colored(x, *args, **kwargs):
- return x
+from autogen import Agent, ConversableAgent, OpenAIWrapper
+from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions
+from ...formatting_utils import colored
logger = logging.getLogger(__name__)
+warn(
+ "Context handling with CompressibleAgent is deprecated and will be removed in `0.2.30`. "
+ "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
class CompressibleAgent(ConversableAgent):
- """(Experimental) CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`,
- it also provides the added feature of compression when activated through the `compress_config` setting.
+ """CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`,
+ it also provides the added feature of compression when activated through the `compress_config` setting.
`compress_config` is set to False by default, making this agent equivalent to the `AssistantAgent`.
This agent does not work well in a GroupChat: The compressed messages will not be sent to all the agents in the group.
@@ -58,7 +59,7 @@ def __init__(
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "NEVER",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Optional[Union[Dict, bool]] = False,
llm_config: Optional[Union[Dict, bool]] = None,
@@ -73,6 +74,7 @@ def __init__(
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
llm_config (dict): llm inference configuration.
+ Note: you must set `model` in llm_config. It will be used to compute the token count.
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
is_termination_msg (function): a function that takes a message in the form of a dictionary
@@ -84,10 +86,9 @@ def __init__(
compress_config (dict or True/False): config for compression before oai_reply. Default to False.
You should contain the following keys:
- "mode" (Optional, str, default to "TERMINATE"): Choose from ["COMPRESS", "TERMINATE", "CUSTOMIZED"].
- "TERMINATE": terminate the conversation ONLY when token count exceeds the max limit of current model.
- `trigger_count` is NOT used in this mode.
- "COMPRESS": compress the messages when the token count exceeds the limit.
- "CUSTOMIZED": pass in a customized function to compress the messages.
+ 1. `TERMINATE`: terminate the conversation ONLY when token count exceeds the max limit of current model. `trigger_count` is NOT used in this mode.
+ 2. `COMPRESS`: compress the messages when the token count exceeds the limit.
+ 3. `CUSTOMIZED`: pass in a customized function to compress the messages.
- "compress_function" (Optional, callable, default to None): Must be provided when mode is "CUSTOMIZED".
The function should takes a list of messages and returns a tuple of (is_compress_success: bool, compressed_messages: List[Dict]).
- "trigger_count" (Optional, float, int, default to 0.7): the threshold to trigger compression.
@@ -122,6 +123,8 @@ def __init__(
self.llm_compress_config = False
self.compress_client = None
else:
+ if "model" not in llm_config:
+ raise ValueError("llm_config must contain the 'model' field.")
self.llm_compress_config = self.llm_config.copy()
# remove functions
if "functions" in self.llm_compress_config:
diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py
index b588b2b59f5..0dcad27b16d 100644
--- a/autogen/agentchat/contrib/gpt_assistant_agent.py
+++ b/autogen/agentchat/contrib/gpt_assistant_agent.py
@@ -1,15 +1,15 @@
-from collections import defaultdict
-import openai
+import copy
import json
-import time
import logging
+import time
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Tuple, Union
from autogen import OpenAIWrapper
-from autogen.oai.openai_utils import retrieve_assistants_by_name
from autogen.agentchat.agent import Agent
-from autogen.agentchat.assistant_agent import ConversableAgent
-from autogen.agentchat.assistant_agent import AssistantAgent
-from typing import Dict, Optional, Union, List, Tuple, Any
+from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
+from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant
+from autogen.runtime_logging import log_new_agent, logging_enabled
logger = logging.getLogger(__name__)
@@ -20,11 +20,14 @@ class GPTAssistantAgent(ConversableAgent):
This agent is unique in its reliance on the OpenAI Assistant for state management, differing from other agents like ConversableAgent.
"""
+ DEFAULT_MODEL_NAME = "gpt-4-0125-preview"
+
def __init__(
self,
name="GPT Assistant",
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
+ assistant_config: Optional[Dict] = None,
overwrite_instructions: bool = False,
overwrite_tools: bool = False,
**kwargs,
@@ -40,31 +43,55 @@ def __init__(
AssistantAgent.DEFAULT_SYSTEM_MESSAGE. If the assistant exists, the
system message will be set to the existing assistant instructions.
llm_config (dict or False): llm inference configuration.
- - assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106).
+ assistant_config
+ - assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- check_every_ms: check thread run status interval
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
- - file_ids: files used by retrieval in run
+ - file_ids: (Deprecated) files used by retrieval in run. It is Deprecated, use tool_resources instead. https://platform.openai.com/docs/assistants/migration/what-has-changed.
+ - tool_resources: A set of resources that are used by the assistant's tools. The resources are specific to the type of tool.
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
overwrite_tools (bool): whether to overwrite the tools of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
kwargs (dict): Additional configuration options for the agent.
- verbose (bool): If set to True, enables more detailed output from the assistant thread.
- Other kwargs: Except verbose, others are passed directly to ConversableAgent.
"""
- # Use AutoGen OpenAIWrapper to create a client
- oai_wrapper = OpenAIWrapper(**llm_config)
+
+ self._verbose = kwargs.pop("verbose", False)
+ openai_client_cfg, openai_assistant_cfg = self._process_assistant_config(llm_config, assistant_config)
+
+ super().__init__(
+ name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs
+ )
+ if logging_enabled():
+ log_new_agent(self, locals())
+
+ # GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list()
+ # See: https://github.com/microsoft/autogen/pull/1721
+ model_name = self.DEFAULT_MODEL_NAME
+ if openai_client_cfg.get("config_list") is not None and len(openai_client_cfg["config_list"]) > 0:
+ model_name = openai_client_cfg["config_list"][0].pop("model", self.DEFAULT_MODEL_NAME)
+ else:
+ model_name = openai_client_cfg.pop("model", self.DEFAULT_MODEL_NAME)
+
+ logger.warning("OpenAI client config of GPTAssistantAgent(%s) - model: %s", name, model_name)
+
+ oai_wrapper = OpenAIWrapper(**openai_client_cfg)
if len(oai_wrapper._clients) > 1:
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")
+
self._openai_client = oai_wrapper._clients[0]._oai_client
- openai_assistant_id = llm_config.get("assistant_id", None)
+ openai_assistant_id = openai_assistant_cfg.get("assistant_id", None)
if openai_assistant_id is None:
# try to find assistant by name first
candidate_assistants = retrieve_assistants_by_name(self._openai_client, name)
if len(candidate_assistants) > 0:
# Filter out candidates with the same name but different instructions, file IDs, and function names.
candidate_assistants = self.find_matching_assistant(
- candidate_assistants, instructions, llm_config.get("tools", []), llm_config.get("file_ids", [])
+ candidate_assistants,
+ instructions,
+ openai_assistant_cfg.get("tools", []),
)
if len(candidate_assistants) == 0:
@@ -75,12 +102,12 @@ def __init__(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
- self._openai_assistant = self._openai_client.beta.assistants.create(
+ self._openai_assistant = create_gpt_assistant(
+ self._openai_client,
name=name,
instructions=instructions,
- tools=llm_config.get("tools", []),
- model=llm_config.get("model", "gpt-4-1106-preview"),
- file_ids=llm_config.get("file_ids", []),
+ model=model_name,
+ assistant_config=openai_assistant_cfg,
)
else:
logger.warning(
@@ -101,17 +128,20 @@ def __init__(
logger.warning(
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
)
- self._openai_assistant = self._openai_client.beta.assistants.update(
+ self._openai_assistant = update_gpt_assistant(
+ self._openai_client,
assistant_id=openai_assistant_id,
- instructions=instructions,
+ assistant_config={
+ "instructions": instructions,
+ },
)
else:
logger.warning(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)
- # Check if tools are specified in llm_config
- specified_tools = llm_config.get("tools", None)
+ # Check if tools are specified in assistant_config
+ specified_tools = openai_assistant_cfg.get("tools", None)
if specified_tools is None:
# Check if the current assistant has tools defined
@@ -128,23 +158,23 @@ def __init__(
logger.warning(
"overwrite_tools is True. Provided tools will be used and will modify the assistant in the API"
)
- self._openai_assistant = self._openai_client.beta.assistants.update(
+ self._openai_assistant = update_gpt_assistant(
+ self._openai_client,
assistant_id=openai_assistant_id,
- tools=llm_config.get("tools", []),
+ assistant_config={
+ "tools": specified_tools,
+ "tool_resources": openai_assistant_cfg.get("tool_resources", None),
+ },
)
else:
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
logger.warning("overwrite_tools is False. Using existing tools from assistant API.")
- self._verbose = kwargs.pop("verbose", False)
- super().__init__(
- name=name, system_message=instructions, human_input_mode="NEVER", llm_config=llm_config, **kwargs
- )
-
+ self.update_system_message(self._openai_assistant.instructions)
# lazily create threads
self._openai_threads = {}
self._unread_index = defaultdict(int)
- self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
+ self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=2)
def _invoke_assistant(
self,
@@ -177,6 +207,8 @@ def _invoke_assistant(
assistant_thread = self._openai_threads[sender]
# Process each unread message
for message in pending_messages:
+ if message["content"].strip() == "":
+ continue
self._openai_client.beta.threads.messages.create(
thread_id=assistant_thread.id,
content=message["content"],
@@ -392,6 +424,10 @@ def assistant_id(self):
def openai_client(self):
return self._openai_client
+ @property
+ def openai_assistant(self):
+ return self._openai_assistant
+
def get_assistant_instructions(self):
"""Return the assistant instructions from OAI assistant API"""
return self._openai_assistant.instructions
@@ -401,22 +437,23 @@ def delete_assistant(self):
logger.warning("Permanently deleting assistant...")
self._openai_client.beta.assistants.delete(self.assistant_id)
- def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
+ def find_matching_assistant(self, candidate_assistants, instructions, tools):
"""
Find the matching assistant from a list of candidate assistants.
- Filter out candidates with the same name but different instructions, file IDs, and function names.
- TODO: implement accurate match based on assistant metadata fields.
+ Filter out candidates with the same name but different instructions, and function names.
"""
matching_assistants = []
# Preprocess the required tools for faster comparison
- required_tool_types = set(tool.get("type") for tool in tools)
+ required_tool_types = set(
+ "file_search" if tool.get("type") in ["retrieval", "file_search"] else tool.get("type") for tool in tools
+ )
+
required_function_names = set(
tool.get("function", {}).get("name")
for tool in tools
- if tool.get("type") not in ["code_interpreter", "retrieval"]
+ if tool.get("type") not in ["code_interpreter", "retrieval", "file_search"]
)
- required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison
for assistant in candidate_assistants:
# Check if instructions are similar
@@ -429,11 +466,12 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
continue
# Preprocess the assistant's tools
- assistant_tool_types = set(tool.type for tool in assistant.tools)
+ assistant_tool_types = set(
+ "file_search" if tool.type in ["retrieval", "file_search"] else tool.type for tool in assistant.tools
+ )
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
- assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison
- # Check if the tool types, function names, and file IDs match
+ # Check if the tool types, function names match
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
logger.warning(
"tools not match, skip assistant(%s): tools %s, functions %s",
@@ -442,11 +480,36 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
assistant_function_names,
)
continue
- if required_file_ids != assistant_file_ids:
- logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
- continue
# Append assistant to matching list if all conditions are met
matching_assistants.append(assistant)
return matching_assistants
+
+ def _process_assistant_config(self, llm_config, assistant_config):
+ """
+ Process the llm_config and assistant_config to extract the model name and assistant related configurations.
+ """
+
+ if llm_config is False:
+ raise ValueError("llm_config=False is not supported for GPTAssistantAgent.")
+
+ if llm_config is None:
+ openai_client_cfg = {}
+ else:
+ openai_client_cfg = copy.deepcopy(llm_config)
+
+ if assistant_config is None:
+ openai_assistant_cfg = {}
+ else:
+ openai_assistant_cfg = copy.deepcopy(assistant_config)
+
+ # Move the assistant related configurations to assistant_config
+ # It's important to keep forward compatibility
+ assistant_config_items = ["assistant_id", "tools", "file_ids", "tool_resources", "check_every_ms"]
+ for item in assistant_config_items:
+ if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
+ openai_assistant_cfg[item] = openai_client_cfg[item]
+ openai_client_cfg.pop(item, None)
+
+ return openai_client_cfg, openai_assistant_cfg
diff --git a/autogen/agentchat/contrib/img_utils.py b/autogen/agentchat/contrib/img_utils.py
index 4fc08f8f357..a389c74b064 100644
--- a/autogen/agentchat/contrib/img_utils.py
+++ b/autogen/agentchat/contrib/img_utils.py
@@ -1,24 +1,79 @@
import base64
-import mimetypes
+import copy
+import os
import re
from io import BytesIO
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Tuple, Union
import requests
from PIL import Image
+from autogen.agentchat import utils
+
+
+def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
+ """
+ Loads an image from a file and returns a PIL Image object.
+
+ Parameters:
+ image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
+
+ Returns:
+ Image.Image: The PIL Image object.
+ """
+ if isinstance(image_file, Image.Image):
+ # Already a PIL Image object
+ return image_file
+
+ # Remove quotes if existed
+ if image_file.startswith('"') and image_file.endswith('"'):
+ image_file = image_file[1:-1]
+ if image_file.startswith("'") and image_file.endswith("'"):
+ image_file = image_file[1:-1]
-def get_image_data(image_file: str, use_b64=True) -> bytes:
if image_file.startswith("http://") or image_file.startswith("https://"):
+ # A URL file
response = requests.get(image_file)
- content = response.content
+ content = BytesIO(response.content)
+ image = Image.open(content)
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
- return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
+ # A URI. Remove the prefix and decode the base64 string.
+ base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
+ image = _to_pil(base64_data)
+ elif os.path.exists(image_file):
+ # A local file
+ image = Image.open(image_file)
else:
- image = Image.open(image_file).convert("RGB")
- buffered = BytesIO()
- image.save(buffered, format="PNG")
- content = buffered.getvalue()
+ # base64 encoded string
+ image = _to_pil(image_file)
+
+ return image.convert("RGB")
+
+
+def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
+ """
+ Loads an image and returns its data either as raw bytes or in base64-encoded format.
+
+ This function first loads an image from the specified file, URL, or base64 string using
+ the `get_pil_image` function. It then saves this image in memory in PNG format and
+ retrieves its binary content. Depending on the `use_b64` flag, this binary content is
+ either returned directly or as a base64-encoded string.
+
+ Parameters:
+ image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
+ string of the image.
+ use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
+ If False, it returns the raw byte data of the image. Defaults to True.
+
+ Returns:
+ bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
+ if `use_b64` is True.
+ """
+ image = get_pil_image(image_file)
+
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ content = buffered.getvalue()
if use_b64:
return base64.b64encode(content).decode("utf-8")
@@ -72,6 +127,22 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> Tuple[str,
return new_prompt, images
+def pil_to_data_uri(image: Image.Image) -> str:
+ """
+ Converts a PIL Image object to a data URI.
+
+ Parameters:
+ image (Image.Image): The PIL Image object.
+
+ Returns:
+ str: The data URI string.
+ """
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ content = buffered.getvalue()
+ return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
+
+
def convert_base64_to_data_uri(base64_image):
def _get_mime_type_from_data_uri(base64_image):
# Decode the base64 string
@@ -92,41 +163,48 @@ def _get_mime_type_from_data_uri(base64_image):
return data_uri
-def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
+def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dict]]:
"""
Formats the input prompt by replacing image tags and returns a list of text and images.
- Parameters:
+ Args:
- prompt (str): The input string that may contain image tags like .
+ - img_format (str): what image format should be used. One of "uri", "url", "pil".
Returns:
- List[Union[str, dict]]: A list of alternating text and image dictionary items.
"""
+ assert img_format in ["uri", "url", "pil"]
+
output = []
last_index = 0
image_count = 0
- # Regular expression pattern for matching tags
- img_tag_pattern = re.compile(r"]+)>")
-
# Find all image tags
- for match in img_tag_pattern.finditer(prompt):
- image_location = match.group(1)
-
+ for parsed_tag in utils.parse_tags_from_content("img", prompt):
+ image_location = parsed_tag["attr"]["src"]
try:
- img_data = get_image_data(image_location)
+ if img_format == "pil":
+ img_data = get_pil_image(image_location)
+ elif img_format == "uri":
+ img_data = get_image_data(image_location)
+ img_data = convert_base64_to_data_uri(img_data)
+ elif img_format == "url":
+ img_data = image_location
+ else:
+ raise ValueError(f"Unknown image format {img_format}")
except Exception as e:
# Warning and skip this token
print(f"Warning! Unable to load image from {image_location}, because {e}")
continue
# Add text before this image tag to output list
- output.append({"type": "text", "text": prompt[last_index : match.start()]})
+ output.append({"type": "text", "text": prompt[last_index : parsed_tag["match"].start()]})
# Add image data to output list
- output.append({"type": "image_url", "image_url": {"url": convert_base64_to_data_uri(img_data)}})
+ output.append({"type": "image_url", "image_url": {"url": img_data}})
- last_index = match.end()
+ last_index = parsed_tag["match"].end()
image_count += 1
# Add remaining text to output list
@@ -162,9 +240,61 @@ def _to_pil(data: str) -> Image.Image:
and finally creates and returns a PIL Image object from the BytesIO object.
Parameters:
- data (str): The base64 encoded image data string.
+ data (str): The encoded image data string.
Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))
+
+
+def message_formatter_pil_to_b64(messages: List[Dict]) -> List[Dict]:
+ """
+ Converts the PIL image URLs in the messages to base64 encoded data URIs.
+
+ This function iterates over a list of message dictionaries. For each message,
+ if it contains a 'content' key with a list of items, it looks for items
+ with an 'image_url' key. The function then converts the PIL image URL
+ (pointed to by 'image_url') to a base64 encoded data URI.
+
+ Parameters:
+ messages (List[Dict]): A list of message dictionaries. Each dictionary
+ may contain a 'content' key with a list of items,
+ some of which might be image URLs.
+
+ Returns:
+ List[Dict]: A new list of message dictionaries with PIL image URLs in the
+ 'image_url' key converted to base64 encoded data URIs.
+
+ Example Input:
+ [
+ {'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
+ {'content': [
+ {'type': 'text', 'text': "What's the breed of this dog here? \n"},
+ {'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
+ {'type': 'text', 'text': '.'}],
+ 'role': 'user'}
+ ]
+
+ Example Output:
+ [
+ {'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
+ {'content': [
+ {'type': 'text', 'text': "What's the breed of this dog here? \n"},
+ {'type': 'image_url', 'image_url': {'url': a B64 Image}},
+ {'type': 'text', 'text': '.'}],
+ 'role': 'user'}
+ ]
+ """
+ new_messages = []
+ for message in messages:
+ # Handle the new GPT messages format.
+ if isinstance(message, dict) and "content" in message and isinstance(message["content"], list):
+ message = copy.deepcopy(message)
+ for item in message["content"]:
+ if isinstance(item, dict) and "image_url" in item:
+ item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])
+
+ new_messages.append(message)
+
+ return new_messages
diff --git a/autogen/agentchat/contrib/llamaindex_conversable_agent.py b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
new file mode 100644
index 00000000000..f7a9c3e615d
--- /dev/null
+++ b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
@@ -0,0 +1,109 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+from autogen import OpenAIWrapper
+from autogen.agentchat import Agent, ConversableAgent
+from autogen.agentchat.contrib.vectordb.utils import get_logger
+
+logger = get_logger(__name__)
+
+try:
+ from llama_index.core.agent.runner.base import AgentRunner
+ from llama_index.core.chat_engine.types import AgentChatResponse
+ from llama_index_client import ChatMessage
+except ImportError as e:
+ logger.fatal("Failed to import llama-index. Try running 'pip install llama-index'")
+ raise e
+
+
+class LLamaIndexConversableAgent(ConversableAgent):
+
+ def __init__(
+ self,
+ name: str,
+ llama_index_agent: AgentRunner,
+ description: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ llama_index_agent (AgentRunner): llama index agent.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): a short description of the agent. This description is used by other agents
+ (e.g. the GroupChatManager) to decide when to call upon this agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../conversable_agent#__init__).
+ """
+
+ if llama_index_agent is None:
+ raise ValueError("llama_index_agent must be provided")
+
+ if description is None or description.isspace():
+ raise ValueError("description must be provided")
+
+ super().__init__(
+ name,
+ description=description,
+ **kwargs,
+ )
+
+ self._llama_index_agent = llama_index_agent
+
+ # Override the `generate_oai_reply`
+ self.replace_reply_func(ConversableAgent.generate_oai_reply, LLamaIndexConversableAgent._generate_oai_reply)
+
+ self.replace_reply_func(ConversableAgent.a_generate_oai_reply, LLamaIndexConversableAgent._a_generate_oai_reply)
+
+ def _generate_oai_reply(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[OpenAIWrapper] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Generate a reply using autogen.oai."""
+ user_message, history = self._extract_message_and_history(messages=messages, sender=sender)
+
+ chatResponse: AgentChatResponse = self._llama_index_agent.chat(message=user_message, chat_history=history)
+
+ extracted_response = chatResponse.response
+
+ return (True, extracted_response)
+
+ async def _a_generate_oai_reply(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[OpenAIWrapper] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Generate a reply using autogen.oai."""
+ user_message, history = self._extract_message_and_history(messages=messages, sender=sender)
+
+ chatResponse: AgentChatResponse = await self._llama_index_agent.achat(
+ message=user_message, chat_history=history
+ )
+
+ extracted_response = chatResponse.response
+
+ return (True, extracted_response)
+
+ def _extract_message_and_history(
+ self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None
+ ) -> Tuple[str, List[ChatMessage]]:
+ """Extract the message and history from the messages."""
+ if not messages:
+ messages = self._oai_messages[sender]
+
+ if not messages:
+ return "", []
+
+ message = messages[-1].get("content", "")
+
+ history = messages[:-1]
+ history_messages: List[ChatMessage] = []
+ for history_message in history:
+ content = history_message.get("content", "")
+ role = history_message.get("role", "user")
+ if role:
+ if role == "user" or role == "assistant":
+ history_messages.append(ChatMessage(content=content, role=role, additional_kwargs={}))
+ return message, history_messages
diff --git a/autogen/agentchat/contrib/llava_agent.py b/autogen/agentchat/contrib/llava_agent.py
index 65c39fd1e20..063b256d3cd 100644
--- a/autogen/agentchat/contrib/llava_agent.py
+++ b/autogen/agentchat/contrib/llava_agent.py
@@ -1,26 +1,16 @@
import json
import logging
-import os
-import pdb
-import re
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple
import replicate
import requests
-from regex import R
from autogen.agentchat.agent import Agent
from autogen.agentchat.contrib.img_utils import get_image_data, llava_formatter
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.code_utils import content_str
-try:
- from termcolor import colored
-except ImportError:
-
- def colored(x, *args, **kwargs):
- return x
-
+from ...formatting_utils import colored
logger = logging.getLogger(__name__)
@@ -77,7 +67,9 @@ def _image_reply(self, messages=None, sender=None, config=None):
content_prompt = content_str(msg["content"])
prompt += f"{SEP}{role}: {content_prompt}\n"
prompt += "\n" + SEP + "Assistant: "
- images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]
+
+ # TODO: PIL to base64
+ images = [get_image_data(im) for im in images]
print(colored(prompt, "blue"))
out = ""
diff --git a/autogen/agentchat/contrib/math_user_proxy_agent.py b/autogen/agentchat/contrib/math_user_proxy_agent.py
index 67c86daf05d..699caeb85b3 100644
--- a/autogen/agentchat/contrib/math_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/math_user_proxy_agent.py
@@ -1,15 +1,15 @@
-import re
import os
-from pydantic import BaseModel, Extra, root_validator
-from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+import re
from time import sleep
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+from pydantic import BaseModel, Extra, root_validator
from autogen._pydantic import PYDANTIC_V1
from autogen.agentchat import Agent, UserProxyAgent
-from autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
+from autogen.code_utils import UNKNOWN, execute_code, extract_code, infer_lang
from autogen.math_utils import get_answer
-
PROMPTS = {
# default
"default": """Let's use Python to solve a math problem.
@@ -136,7 +136,7 @@ def __init__(
is_termination_msg: Optional[
Callable[[Dict], bool]
] = _is_termination_msg_mathchat, # terminate if \boxed{} in message
- human_input_mode: Optional[str] = "NEVER", # Fully automated
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", # Fully automated
default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_REPLY,
max_invalid_q_per_step=3, # a parameter needed in MathChat
**kwargs,
@@ -177,28 +177,35 @@ def __init__(
self._previous_code = ""
self.last_reply = None
- def generate_init_message(self, problem, prompt_type="default", customized_prompt=None):
+ @staticmethod
+ def message_generator(sender, recipient, context):
"""Generate a prompt for the assistant agent with the given problem and prompt.
Args:
- problem (str): the problem to be solved.
- prompt_type (str): the type of the prompt. Possible values are "default", "python", "wolfram".
- (1) "default": the prompt that allows the agent to choose between 3 ways to solve a problem:
- 1. write a python program to solve it directly.
- 2. solve it directly without python.
- 3. solve it step by step with python.
- (2) "python":
- a simplified prompt from the third way of the "default" prompt, that asks the assistant
- to solve the problem step by step with python.
- (3) "two_tools":
- a simplified prompt similar to the "python" prompt, but allows the model to choose between
- Python and Wolfram Alpha to solve the problem.
- customized_prompt (str): a customized prompt to be used. If it is not None, the prompt_type will be ignored.
+ sender (Agent): the sender of the message.
+ recipient (Agent): the recipient of the message.
+ context (dict): a dictionary with the following fields:
+ problem (str): the problem to be solved.
+ prompt_type (str, Optional): the type of the prompt. Possible values are "default", "python", "wolfram".
+ (1) "default": the prompt that allows the agent to choose between 3 ways to solve a problem:
+ 1. write a python program to solve it directly.
+ 2. solve it directly without python.
+ 3. solve it step by step with python.
+ (2) "python":
+ a simplified prompt from the third way of the "default" prompt, that asks the assistant
+ to solve the problem step by step with python.
+ (3) "two_tools":
+ a simplified prompt similar to the "python" prompt, but allows the model to choose between
+ Python and Wolfram Alpha to solve the problem.
+ customized_prompt (str, Optional): a customized prompt to be used. If it is not None, the prompt_type will be ignored.
Returns:
str: the generated prompt ready to be sent to the assistant agent.
"""
- self._reset()
+ sender._reset()
+ problem = context.get("problem")
+ prompt_type = context.get("prompt_type", "default")
+ customized_prompt = context.get("customized_prompt", None)
if customized_prompt is not None:
return customized_prompt + problem
return PROMPTS[prompt_type] + problem
diff --git a/autogen/agentchat/contrib/multimodal_conversable_agent.py b/autogen/agentchat/contrib/multimodal_conversable_agent.py
index e6f3720186c..edeb88cd531 100644
--- a/autogen/agentchat/contrib/multimodal_conversable_agent.py
+++ b/autogen/agentchat/contrib/multimodal_conversable_agent.py
@@ -1,20 +1,16 @@
import copy
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
from autogen import OpenAIWrapper
from autogen.agentchat import Agent, ConversableAgent
-from autogen.agentchat.contrib.img_utils import gpt4v_formatter
-
-try:
- from termcolor import colored
-except ImportError:
-
- def colored(x, *args, **kwargs):
- return x
-
-
+from autogen.agentchat.contrib.img_utils import (
+ gpt4v_formatter,
+ message_formatter_pil_to_b64,
+)
from autogen.code_utils import content_str
+from ..._pydantic import model_dump
+
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
DEFAULT_MODEL = "gpt-4-vision-preview"
@@ -55,6 +51,13 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
+ # Override the `generate_oai_reply`
+ self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply)
+ self.replace_reply_func(
+ ConversableAgent.a_generate_oai_reply,
+ MultimodalConversableAgent.a_generate_oai_reply,
+ )
+
def update_system_message(self, system_message: Union[Dict, List, str]):
"""Update the system message.
@@ -76,14 +79,14 @@ def _message_to_dict(message: Union[Dict, List, str]) -> Dict:
will be processed using the gpt4v_formatter.
"""
if isinstance(message, str):
- return {"content": gpt4v_formatter(message)}
+ return {"content": gpt4v_formatter(message, img_format="pil")}
if isinstance(message, list):
return {"content": message}
if isinstance(message, dict):
assert "content" in message, "The message dict must have a `content` field"
if isinstance(message["content"], str):
message = copy.deepcopy(message)
- message["content"] = gpt4v_formatter(message["content"])
+ message["content"] = gpt4v_formatter(message["content"], img_format="pil")
try:
content_str(message["content"])
except (TypeError, ValueError) as e:
@@ -91,3 +94,27 @@ def _message_to_dict(message: Union[Dict, List, str]) -> Dict:
raise e
return message
raise ValueError(f"Unsupported message type: {type(message)}")
+
+ def generate_oai_reply(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[OpenAIWrapper] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Generate a reply using autogen.oai."""
+ client = self.client if config is None else config
+ if client is None:
+ return False, None
+ if messages is None:
+ messages = self._oai_messages[sender]
+
+ messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)
+
+ # TODO: #1143 handle token limit exceeded error
+ response = client.create(context=messages[-1].pop("context", None), messages=messages_with_b64_img)
+
+ # TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
+ extracted_response = client.extract_text_or_completion_object(response)[0]
+ if not isinstance(extracted_response, str):
+ extracted_response = model_dump(extracted_response)
+ return True, extracted_response
diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
index 1efeb3c1926..ea81de6dff1 100644
--- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
@@ -1,17 +1,21 @@
-from typing import Callable, Dict, List, Optional
+from typing import Callable, Dict, List, Literal, Optional
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
-from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS
-import logging
+from autogen.agentchat.contrib.vectordb.utils import (
+ chroma_results_to_query_results,
+ filter_results_by_distance,
+ get_logger,
+)
+from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
try:
+ import fastembed
from qdrant_client import QdrantClient, models
from qdrant_client.fastembed_common import QueryResponse
- import fastembed
except ImportError as e:
- logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
+ logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
raise e
@@ -19,7 +23,7 @@ class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
- human_input_mode: Optional[str] = "ALWAYS",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "ALWAYS",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
@@ -29,12 +33,12 @@ def __init__(
name (str): name of the agent.
human_input_mode (str): whether to ask for human inputs every time a message is received.
Possible values are "ALWAYS", "TERMINATE", "NEVER".
- (1) When "ALWAYS", the agent prompts for human input every time a message is received.
+ 1. When "ALWAYS", the agent prompts for human input every time a message is received.
Under this mode, the conversation stops when the human input is "exit",
or when is_termination_msg is True and there is no human input.
- (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or
+ 2. When "TERMINATE", the agent only prompts for human input only when a termination message is received or
the number of auto reply reaches the max_consecutive_auto_reply.
- (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
+ 3. When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
@@ -136,6 +140,11 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
collection_name=self._collection_name,
embedding_model=self._embedding_model,
)
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results, "distances")
+ results = filter_results_by_distance(results, self._distance_threshold)
+
+ self._search_string = search_string
self._results = results
@@ -190,12 +199,12 @@ def create_qdrant_from_dir(
client.set_model(embedding_model)
if custom_text_split_function is not None:
- chunks = split_files_to_chunks(
+ chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
- chunks = split_files_to_chunks(
+ chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
)
logger.info(f"Found {len(chunks)} chunks.")
@@ -281,20 +290,24 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore
collection_name,
query_texts,
limit=n_results,
- query_filter=models.Filter(
- must=[
- models.FieldCondition(
- key="document",
- match=models.MatchText(text=search_string),
- )
- ]
- )
- if search_string
- else None,
+ query_filter=(
+ models.Filter(
+ must=[
+ models.FieldCondition(
+ key="document",
+ match=models.MatchText(text=search_string),
+ )
+ ]
+ )
+ if search_string
+ else None
+ ),
)
data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
+ "distances": [[result.score for result in sublist] for sublist in results],
+ "metadatas": [[result.metadata for result in sublist] for sublist in results],
}
return data
diff --git a/autogen/agentchat/contrib/retrieve_assistant_agent.py b/autogen/agentchat/contrib/retrieve_assistant_agent.py
index a09677710aa..9b5ace200dc 100644
--- a/autogen/agentchat/contrib/retrieve_assistant_agent.py
+++ b/autogen/agentchat/contrib/retrieve_assistant_agent.py
@@ -1,6 +1,7 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import AssistantAgent
-from typing import Dict, Optional, Union, List, Tuple, Any
class RetrieveAssistantAgent(AssistantAgent):
diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
index 949afcc9678..59a4abccb1d 100644
--- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
@@ -1,26 +1,35 @@
+import hashlib
+import os
import re
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+from IPython import get_ipython
try:
import chromadb
-except ImportError:
- raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
-from autogen.agentchat.agent import Agent
+except ImportError as e:
+ raise ImportError(f"{e}. You can try `pip install pyautogen[retrievechat]`, or install `chromadb` manually.")
from autogen.agentchat import UserProxyAgent
-from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
-from autogen.token_count_utils import count_token
+from autogen.agentchat.agent import Agent
+from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory
+from autogen.agentchat.contrib.vectordb.utils import (
+ chroma_results_to_query_results,
+ filter_results_by_distance,
+ get_logger,
+)
from autogen.code_utils import extract_code
-from autogen import logger
-
-from typing import Callable, Dict, Optional, Union, List, Tuple, Any
-from IPython import get_ipython
-
-try:
- from termcolor import colored
-except ImportError:
+from autogen.retrieve_utils import (
+ TEXT_FORMATS,
+ create_vector_db_from_dir,
+ get_files_from_dir,
+ query_vector_db,
+ split_files_to_chunks,
+)
+from autogen.token_count_utils import count_token
- def colored(x, *args, **kwargs):
- return x
+from ...formatting_utils import colored
+logger = get_logger(__name__)
PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the
context provided by the user. You should follow the following steps to answer a question:
@@ -40,6 +49,10 @@ def colored(x, *args, **kwargs):
User's question is: {input_question}
Context is: {input_context}
+
+The source of the context is: {input_sources}
+
+If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`.
"""
PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the
@@ -67,79 +80,146 @@ def colored(x, *args, **kwargs):
Context is: {input_context}
"""
+HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))
+UPDATE_CONTEXT_IN_PROMPT = "you should reply exactly `UPDATE CONTEXT`"
+
class RetrieveUserProxyAgent(UserProxyAgent):
+ """(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding
+ similarity, and sends them along with the question to the Retrieval-Augmented Assistant
+ """
+
def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
- human_input_mode: Optional[str] = "ALWAYS",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "ALWAYS",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
):
- """
+ r"""
Args:
name (str): name of the agent.
+
human_input_mode (str): whether to ask for human inputs every time a message is received.
Possible values are "ALWAYS", "TERMINATE", "NEVER".
- (1) When "ALWAYS", the agent prompts for human input every time a message is received.
+ 1. When "ALWAYS", the agent prompts for human input every time a message is received.
Under this mode, the conversation stops when the human input is "exit",
or when is_termination_msg is True and there is no human input.
- (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or
- the number of auto reply reaches the max_consecutive_auto_reply.
- (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
- when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
+ 2. When "TERMINATE", the agent only prompts for human input only when a termination
+ message is received or the number of auto reply reaches
+ the max_consecutive_auto_reply.
+ 3. When "NEVER", the agent will never prompt for human input. Under this mode, the
+ conversation stops when the number of auto reply reaches the
+ max_consecutive_auto_reply or when is_termination_msg is True.
+
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
+
retrieve_config (dict or None): config for the retrieve agent.
- To use default config, set to None. Otherwise, set to a dictionary with the following keys:
- - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
- prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
- - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
- will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
- - docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file,
- the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created.
- - extra_docs (Optional, bool): when true, allows adding documents with unique IDs without overwriting existing ones; when false, it replaces existing documents using default IDs, risking collection overwrite.,
- when set to true it enables the system to assign unique IDs starting from "length+i" for new document chunks, preventing the replacement of existing documents and facilitating the addition of more content to the collection..
- By default, "extra_docs" is set to false, starting document IDs from zero. This poses a risk as new documents might overwrite existing ones, potentially causing unintended loss or alteration of data in the collection.
- - collection_name (Optional, str): the name of the collection.
- If key not provided, a default name `autogen-docs` will be used.
- - model (Optional, str): the model to use for the retrieve chat.
+
+ To use default config, set to None. Otherwise, set to a dictionary with the
+ following keys:
+ - `task` (Optional, str) - the task of the retrieve chat. Possible values are
+ "code", "qa" and "default". System prompt will be different for different tasks.
+ The default value is `default`, which supports both code and qa, and provides
+ source information in the end of the response.
+ - `vector_db` (Optional, Union[str, VectorDB]) - the vector db for the retrieve chat.
+ If it's a string, it should be the type of the vector db, such as "chroma"; otherwise,
+ it should be an instance of the VectorDB protocol. Default is "chroma".
+ Set `None` to use the deprecated `client`.
+ - `db_config` (Optional, Dict) - the config for the vector db. Default is `{}`. Please make
+ sure you understand the config for the vector db you are using, otherwise, leave it as `{}`.
+ Only valid when `vector_db` is a string.
+ - `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
+ default client `chromadb.Client()` will be used. If you want to use other
+ vector db, extend this class and override the `retrieve_docs` function.
+ **Deprecated**: use `vector_db` instead.
+ - `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It
+ can also be the path to a single file, the url to a single file or a list
+ of directories, files and urls. Default is None, which works only if the
+ collection is already created.
+ - `extra_docs` (Optional, bool) - when true, allows adding documents with unique IDs
+ without overwriting existing ones; when false, it replaces existing documents
+ using default IDs, risking collection overwrite., when set to true it enables
+ the system to assign unique IDs starting from "length+i" for new document
+ chunks, preventing the replacement of existing documents and facilitating the
+ addition of more content to the collection..
+ By default, "extra_docs" is set to false, starting document IDs from zero.
+ This poses a risk as new documents might overwrite existing ones, potentially
+ causing unintended loss or alteration of data in the collection.
+ **Deprecated**: use `new_docs` when use `vector_db` instead of `client`.
+ - `new_docs` (Optional, bool) - when True, only adds new documents to the collection;
+ when False, updates existing documents and adds new ones. Default is True.
+ Document id is used to determine if a document is new or existing. By default, the
+ id is the hash value of the content.
+ - `model` (Optional, str) - the model to use for the retrieve chat.
If key not provided, a default model `gpt-4` will be used.
- - chunk_token_size (Optional, int): the chunk token size for the retrieve chat.
+ - `chunk_token_size` (Optional, int) - the chunk token size for the retrieve chat.
If key not provided, a default size `max_tokens * 0.4` will be used.
- - context_max_tokens (Optional, int): the context max token size for the retrieve chat.
+ - `context_max_tokens` (Optional, int) - the context max token size for the
+ retrieve chat.
If key not provided, a default size `max_tokens * 0.8` will be used.
- - chunk_mode (Optional, str): the chunk mode for the retrieve chat. Possible values are
- "multi_lines" and "one_line". If key not provided, a default mode `multi_lines` will be used.
- - must_break_at_empty_line (Optional, bool): chunk will only break at empty line if True. Default is True.
+ - `chunk_mode` (Optional, str) - the chunk mode for the retrieve chat. Possible values
+ are "multi_lines" and "one_line". If key not provided, a default mode
+ `multi_lines` will be used.
+ - `must_break_at_empty_line` (Optional, bool) - chunk will only break at empty line
+ if True. Default is True.
If chunk_mode is "one_line", this parameter will be ignored.
- - embedding_model (Optional, str): the embedding model to use for the retrieve chat.
- If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
- can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
- fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
- - embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
- SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
- other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
- - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
- - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
- If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- - get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb.
- Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None.
- - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string.
- The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
- Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
- - custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
- Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
- - custom_text_types (Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.
- This only applies to files under the directories in `docs_path`. Explicitly included files and urls will be chunked regardless of their types.
- - recursive (Optional, bool): whether to search documents recursively in the docs_path. Default is True.
- **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
-
- Example of overriding retrieve_docs:
- If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
+ - `embedding_model` (Optional, str) - the embedding model to use for the retrieve chat.
+ If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available
+ models can be found at `https://www.sbert.net/docs/pretrained_models.html`.
+ The default model is a fast model. If you want to use a high performance model,
+ `all-mpnet-base-v2` is recommended.
+ **Deprecated**: no need when use `vector_db` instead of `client`.
+ - `embedding_function` (Optional, Callable) - the embedding function for creating the
+ vector db. Default is None, SentenceTransformer with the given `embedding_model`
+ will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
+ functions, you can pass it here,
+ follow the examples in `https://docs.trychroma.com/embeddings`.
+ - `customized_prompt` (Optional, str) - the customized prompt for the retrieve chat.
+ Default is None.
+ - `customized_answer_prefix` (Optional, str) - the customized answer prefix for the
+ retrieve chat. Default is "".
+ If not "" and the customized_answer_prefix is not in the answer,
+ `Update Context` will be triggered.
+ - `update_context` (Optional, bool) - if False, will not apply `Update Context` for
+ interactive retrieval. Default is True.
+ - `collection_name` (Optional, str) - the name of the collection.
+ If key not provided, a default name `autogen-docs` will be used.
+ - `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is True.
+ - `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+ - `custom_token_count_function` (Optional, Callable) - a custom function to count the
+ number of tokens in a string.
+ The function should take (text:str, model:str) as input and return the
+ token_count(int). the retrieve_config["model"] will be passed in the function.
+ Default is autogen.token_count_utils.count_token that uses tiktoken, which may
+ not be accurate for non-OpenAI models.
+ - `custom_text_split_function` (Optional, Callable) - a custom function to split a
+ string into a list of strings.
+ Default is None, will use the default function in
+ `autogen.retrieve_utils.split_text_to_chunks`.
+ - `custom_text_types` (Optional, List[str]) - a list of file types to be processed.
+ Default is `autogen.retrieve_utils.TEXT_FORMATS`.
+ This only applies to files under the directories in `docs_path`. Explicitly
+ included files and urls will be chunked regardless of their types.
+ - `recursive` (Optional, bool) - whether to search documents recursively in the
+ docs_path. Default is True.
+ - `distance_threshold` (Optional, float) - the threshold for the distance score, only
+ distance smaller than it will be returned. Will be ignored if < 0. Default is -1.
+
+ `**kwargs` (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
+
+ Example:
+
+ Example of overriding retrieve_docs - If you have set up a customized vector db, and it's
+ not compatible with chromadb, you can easily plug in it with below code.
+ **Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
@@ -172,9 +252,14 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._task = self._retrieve_config.get("task", "default")
- self._client = self._retrieve_config.get("client", chromadb.Client())
+ self._vector_db = self._retrieve_config.get("vector_db", "chroma")
+ self._db_config = self._retrieve_config.get("db_config", {})
+ self._client = self._retrieve_config.get("client", None)
+ if self._client is None:
+ self._client = chromadb.Client()
self._docs_path = self._retrieve_config.get("docs_path", None)
self._extra_docs = self._retrieve_config.get("extra_docs", False)
+ self._new_docs = self._retrieve_config.get("new_docs", True)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
if "docs_path" not in self._retrieve_config:
logger.warning(
@@ -193,25 +278,104 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
+ self._overwrite = self._retrieve_config.get("overwrite", False)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
self._recursive = self._retrieve_config.get("recursive", True)
- self._context_max_tokens = self._max_tokens * 0.8
+ self._context_max_tokens = self._retrieve_config.get("context_max_tokens", self._max_tokens * 0.8)
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
- self._results = {} # the results of the current query
+ self._results = [] # the results of the current query
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
+ self._current_docs_in_context = [] # the ids of the current context sources
self._search_string = "" # the search string used in the current query
+ self._distance_threshold = self._retrieve_config.get("distance_threshold", -1)
# update the termination message function
self._is_termination_msg = (
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
)
+ if isinstance(self._vector_db, str):
+ if not isinstance(self._db_config, dict):
+ raise ValueError("`db_config` should be a dictionary.")
+ if "embedding_function" in self._retrieve_config:
+ self._db_config["embedding_function"] = self._embedding_function
+ self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config)
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2)
+ def _init_db(self):
+ if not self._vector_db:
+ return
+
+ IS_TO_CHUNK = False # whether to chunk the raw files
+ if self._new_docs:
+ IS_TO_CHUNK = True
+ if not self._docs_path:
+ try:
+ self._vector_db.get_collection(self._collection_name)
+ logger.warning(f"`docs_path` is not provided. Use the existing collection `{self._collection_name}`.")
+ self._overwrite = False
+ self._get_or_create = True
+ IS_TO_CHUNK = False
+ except ValueError:
+ raise ValueError(
+ "`docs_path` is not provided. "
+ f"The collection `{self._collection_name}` doesn't exist either. "
+ "Please provide `docs_path` or create the collection first."
+ )
+ elif self._get_or_create and not self._overwrite:
+ try:
+ self._vector_db.get_collection(self._collection_name)
+ logger.info(f"Use the existing collection `{self._collection_name}`.", color="green")
+ except ValueError:
+ IS_TO_CHUNK = True
+ else:
+ IS_TO_CHUNK = True
+
+ self._vector_db.active_collection = self._vector_db.create_collection(
+ self._collection_name, overwrite=self._overwrite, get_or_create=self._get_or_create
+ )
+
+ docs = None
+ if IS_TO_CHUNK:
+ if self.custom_text_split_function is not None:
+ chunks, sources = split_files_to_chunks(
+ get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
+ custom_text_split_function=self.custom_text_split_function,
+ )
+ else:
+ chunks, sources = split_files_to_chunks(
+ get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
+ self._chunk_token_size,
+ self._chunk_mode,
+ self._must_break_at_empty_line,
+ )
+ logger.info(f"Found {len(chunks)} chunks.")
+
+ if self._new_docs:
+ all_docs_ids = set(
+ [
+ doc["id"]
+ for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name)
+ ]
+ )
+ else:
+ all_docs_ids = set()
+
+ chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
+ chunk_ids_set = set(chunk_ids)
+ chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
+ docs = [
+ Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx])
+ for idx in chunk_ids_set_idx
+ if chunk_ids[idx] not in all_docs_ids
+ ]
+
+ self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True)
+
def _is_termination_msg_retrievechat(self, message):
"""Check if a message is a termination message.
For code generation, terminate when no code block is detected. Currently only detect python code blocks.
@@ -244,37 +408,42 @@ def get_max_tokens(model="gpt-3.5-turbo"):
def _reset(self, intermediate=False):
self._doc_idx = -1 # the index of the current used doc
- self._results = {} # the results of the current query
+ self._results = [] # the results of the current query
if not intermediate:
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
- def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
+ def _get_context(self, results: QueryResults):
doc_contents = ""
+ self._current_docs_in_context = []
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
- for idx, doc in enumerate(results["documents"][0]):
+ for idx, doc in enumerate(results[0]):
+ doc = doc[0]
if idx <= _doc_idx:
continue
- if results["ids"][0][idx] in self._doc_ids:
+ if doc["id"] in self._doc_ids:
continue
- _doc_tokens = self.custom_token_count_function(doc, self._model)
+ _doc_tokens = self.custom_token_count_function(doc["content"], self._model)
if _doc_tokens > self._context_max_tokens:
- func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
+ func_print = f"Skip doc_id {doc['id']} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
self._doc_idx = idx
continue
if current_tokens + _doc_tokens > self._context_max_tokens:
break
- func_print = f"Adding doc_id {results['ids'][0][idx]} to context."
+ func_print = f"Adding content of doc {doc['id']} to context."
print(colored(func_print, "green"), flush=True)
current_tokens += _doc_tokens
- doc_contents += doc + "\n"
+ doc_contents += doc["content"] + "\n"
+ _metadata = doc.get("metadata")
+ if isinstance(_metadata, dict):
+ self._current_docs_in_context.append(_metadata.get("source", ""))
self._doc_idx = idx
- self._doc_ids.append(results["ids"][0][idx])
- self._doc_contents.append(doc)
+ self._doc_ids.append(doc["id"])
+ self._doc_contents.append(doc["content"])
_tmp_retrieve_count += 1
if _tmp_retrieve_count >= self.n_results:
break
@@ -291,7 +460,9 @@ def _generate_message(self, doc_contents, task="default"):
elif task.upper() == "QA":
message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "DEFAULT":
- message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents)
+ message = PROMPT_DEFAULT.format(
+ input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context
+ )
else:
raise NotImplementedError(f"task {task} is not implemented.")
return message
@@ -301,7 +472,7 @@ def _check_update_context(self, message):
message = message.get("content", "")
elif not isinstance(message, str):
message = ""
- update_context_case1 = "UPDATE CONTEXT" in message[-20:].upper() or "UPDATE CONTEXT" in message[:20].upper()
+ update_context_case1 = "UPDATE CONTEXT" in message.upper() and UPDATE_CONTEXT_IN_PROMPT not in message
update_context_case2 = self.customized_answer_prefix and self.customized_answer_prefix not in message.upper()
return update_context_case1, update_context_case2
@@ -366,21 +537,40 @@ def _generate_retrieve_user_reply(
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
- In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
- compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
- parameters and add your own parameters with default values, and keep the results in below type.
-
- Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
- the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
- to `chromadb.api.types.QueryResult` as an example.
- ids: List[string]
- documents: List[List[string]]
+ The retrieved docs should be type of `QueryResults` which is a list of tuples containing the document and
+ the distance.
Args:
problem (str): the problem to be solved.
n_results (int): the number of results to be retrieved. Default is 20.
search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "".
+ Not used if the vector_db doesn't support it.
+
+ Returns:
+ None.
"""
+ if isinstance(self._vector_db, VectorDB):
+ if not self._collection or not self._get_or_create:
+ print("Trying to create collection.")
+ self._init_db()
+ self._collection = True
+ self._get_or_create = True
+
+ kwargs = {}
+ if hasattr(self._vector_db, "type") and self._vector_db.type == "chroma":
+ kwargs["where_document"] = {"$contains": search_string} if search_string else None
+ results = self._vector_db.retrieve_docs(
+ queries=[problem],
+ n_results=n_results,
+ collection_name=self._collection_name,
+ distance_threshold=self._distance_threshold,
+ **kwargs,
+ )
+ self._search_string = search_string
+ self._results = results
+ print("VectorDB returns doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
+ return
+
if not self._collection or not self._get_or_create:
print("Trying to create collection.")
self._client = create_vector_db_from_dir(
@@ -410,27 +600,39 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
)
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results, "distances")
+ results = filter_results_by_distance(results, self._distance_threshold)
+
self._search_string = search_string
self._results = results
- print("doc_ids: ", results["ids"])
-
- def generate_init_message(self, problem: str, n_results: int = 20, search_string: str = ""):
- """Generate an initial message with the given problem and prompt.
+ print("doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
+ @staticmethod
+ def message_generator(sender, recipient, context):
+ """
+ Generate an initial message with the given context for the RetrieveUserProxyAgent.
Args:
- problem (str): the problem to be solved.
- n_results (int): the number of results to be retrieved.
- search_string (str): only docs containing this string will be retrieved.
-
+ sender (Agent): the sender agent. It should be the instance of RetrieveUserProxyAgent.
+ recipient (Agent): the recipient agent. Usually it's the assistant agent.
+ context (dict): the context for the message generation. It should contain the following keys:
+ - `problem` (str) - the problem to be solved.
+ - `n_results` (int) - the number of results to be retrieved. Default is 20.
+ - `search_string` (str) - only docs that contain an exact match of this string will be retrieved. Default is "".
Returns:
- str: the generated prompt ready to be sent to the assistant agent.
+ str: the generated message ready to be sent to the recipient agent.
"""
- self._reset()
- self.retrieve_docs(problem, n_results, search_string)
- self.problem = problem
- self.n_results = n_results
- doc_contents = self._get_context(self._results)
- message = self._generate_message(doc_contents, self._task)
+ sender._reset()
+
+ problem = context.get("problem", "")
+ n_results = context.get("n_results", 20)
+ search_string = context.get("search_string", "")
+
+ sender.retrieve_docs(problem, n_results, search_string)
+ sender.problem = problem
+ sender.n_results = n_results
+ doc_contents = sender._get_context(sender._results)
+ message = sender._generate_message(doc_contents, sender._task)
return message
def run_code(self, code, **kwargs):
diff --git a/autogen/agentchat/contrib/society_of_mind_agent.py b/autogen/agentchat/contrib/society_of_mind_agent.py
index a12b228f0e6..2f6be5088a4 100644
--- a/autogen/agentchat/contrib/society_of_mind_agent.py
+++ b/autogen/agentchat/contrib/society_of_mind_agent.py
@@ -1,10 +1,9 @@
# ruff: noqa: E722
-import json
-import traceback
import copy
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Union, Callable, Literal, Tuple
-from autogen import Agent, ConversableAgent, GroupChatManager, GroupChat, OpenAIWrapper
+import traceback
+from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
+
+from autogen import Agent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper
class SocietyOfMindAgent(ConversableAgent):
@@ -35,7 +34,7 @@ def __init__(
response_preparer: Optional[Union[str, Callable]] = None,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "TERMINATE",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = False,
@@ -95,6 +94,26 @@ def _llm_response_preparer(self, prompt, messages):
for message in messages:
message = copy.deepcopy(message)
message["role"] = "user"
+
+ # Convert tool and function calls to basic messages to avoid an error on the LLM call
+ if "content" not in message:
+ message["content"] = ""
+
+ if "tool_calls" in message:
+ del message["tool_calls"]
+ if "tool_responses" in message:
+ del message["tool_responses"]
+ if "function_call" in message:
+ if message["content"] == "":
+ try:
+ message["content"] = (
+ message["function_call"]["name"] + "(" + message["function_call"]["arguments"] + ")"
+ )
+ except KeyError:
+ pass
+ del message["function_call"]
+
+ # Add the modified message to the transcript
_messages.append(message)
_messages.append(
diff --git a/autogen/agentchat/contrib/text_analyzer_agent.py b/autogen/agentchat/contrib/text_analyzer_agent.py
index 10100c9e57f..62345156a53 100644
--- a/autogen/agentchat/contrib/text_analyzer_agent.py
+++ b/autogen/agentchat/contrib/text_analyzer_agent.py
@@ -1,7 +1,7 @@
-from autogen import oai
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
-from typing import Callable, Dict, Optional, Union, List, Tuple, Any
system_message = """You are an expert in text analysis.
The user will give you TEXT to analyze.
@@ -16,7 +16,7 @@ def __init__(
self,
name="analyzer",
system_message: Optional[str] = system_message,
- human_input_mode: Optional[str] = "NEVER",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
llm_config: Optional[Union[Dict, bool]] = None,
**kwargs,
):
diff --git a/autogen/agentchat/contrib/vectordb/__init__.py b/autogen/agentchat/contrib/vectordb/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py
new file mode 100644
index 00000000000..29a08008619
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/base.py
@@ -0,0 +1,213 @@
+from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable
+
+Metadata = Union[Mapping[str, Any], None]
+Vector = Union[Sequence[float], Sequence[int]]
+ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does
+
+
+class Document(TypedDict):
+ """A Document is a record in the vector database.
+
+ id: ItemID | the unique identifier of the document.
+ content: str | the text content of the chunk.
+ metadata: Metadata, Optional | contains additional information about the document such as source, date, etc.
+ embedding: Vector, Optional | the vector representation of the content.
+ """
+
+ id: ItemID
+ content: str
+ metadata: Optional[Metadata]
+ embedding: Optional[Vector]
+
+
+"""QueryResults is the response from the vector database for a query/queries.
+A query is a list containing one string while queries is a list containing multiple strings.
+The response is a list of query results, each query result is a list of tuples containing the document and the distance.
+"""
+QueryResults = List[List[Tuple[Document, float]]]
+
+
+@runtime_checkable
+class VectorDB(Protocol):
+ """
+ Abstract class for vector database. A vector database is responsible for storing and retrieving documents.
+
+ Attributes:
+ active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
+ type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
+
+ Methods:
+ create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database.
+ get_collection: Callable[[str], Any] | Get the collection from the vector database.
+ delete_collection: Callable[[str], Any] | Delete the collection from the vector database.
+ insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database.
+ update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database.
+ delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database.
+ retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries.
+ get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids.
+ """
+
+ active_collection: Any = None
+ type: str = ""
+
+ def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Any | The collection object.
+ """
+ ...
+
+ def get_collection(self, collection_name: str = None) -> Any:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Any | The collection object.
+ """
+ ...
+
+ def delete_collection(self, collection_name: str) -> Any:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any
+ """
+ ...
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ ...
+
+ def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None:
+ """
+ Update documents in the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ ...
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ ...
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ ...
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is None.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ ...
+
+
+class VectorDBFactory:
+ """
+ Factory class for creating vector databases.
+ """
+
+ PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
+
+ @staticmethod
+ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
+ """
+ Create a vector database.
+
+ Args:
+ db_type: str | The type of the vector database.
+ kwargs: Dict | The keyword arguments for initializing the vector database.
+
+ Returns:
+ VectorDB | The vector database.
+ """
+ if db_type.lower() in ["chroma", "chromadb"]:
+ from .chromadb import ChromaVectorDB
+
+ return ChromaVectorDB(**kwargs)
+ if db_type.lower() in ["pgvector", "pgvectordb"]:
+ from .pgvectordb import PGVectorDB
+
+ return PGVectorDB(**kwargs)
+ else:
+ raise ValueError(
+ f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
+ )
diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py
new file mode 100644
index 00000000000..1ed8708409d
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/chromadb.py
@@ -0,0 +1,320 @@
+import os
+from typing import Callable, List
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import chroma_results_to_query_results, filter_results_by_distance, get_logger
+
+try:
+ import chromadb
+
+ if chromadb.__version__ < "0.4.15":
+ raise ImportError("Please upgrade chromadb to version 0.4.15 or later.")
+ import chromadb.utils.embedding_functions as ef
+ from chromadb.api.models.Collection import Collection
+except ImportError:
+ raise ImportError("Please install chromadb: `pip install chromadb`")
+
+CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000)
+logger = get_logger(__name__)
+
+
+class ChromaVectorDB(VectorDB):
+ """
+ A vector database that uses ChromaDB as the backend.
+ """
+
+ def __init__(
+ self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Args:
+ client: chromadb.Client | The client object of the vector database. Default is None.
+ If provided, it will use the client object directly and ignore other arguments.
+ path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24.
+ embedding_function: Callable | The embedding function used to generate the vector representation
+ of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
+ metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
+ setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of
+ the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances),
+ [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184),
+ and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md).
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ self.client = client
+ self.path = path
+ self.embedding_function = (
+ ef.SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2")
+ if embedding_function is None
+ else embedding_function
+ )
+ self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}
+ if not self.client:
+ if self.path is not None:
+ self.client = chromadb.PersistentClient(path=self.path, **kwargs)
+ else:
+ self.client = chromadb.Client(**kwargs)
+ self.active_collection = None
+ self.type = "chroma"
+
+ def create_collection(
+ self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
+ ) -> Collection:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Collection | The collection object.
+ """
+ try:
+ if self.active_collection and self.active_collection.name == collection_name:
+ collection = self.active_collection
+ else:
+ collection = self.client.get_collection(collection_name, embedding_function=self.embedding_function)
+ except ValueError:
+ collection = None
+ if collection is None:
+ return self.client.create_collection(
+ collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ elif overwrite:
+ self.client.delete_collection(collection_name)
+ return self.client.create_collection(
+ collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ elif get_or_create:
+ return collection
+ else:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None) -> Collection:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Collection | The collection object.
+ """
+ if collection_name is None:
+ if self.active_collection is None:
+ raise ValueError("No collection is specified.")
+ else:
+ logger.info(
+ f"No collection is specified. Using current active collection {self.active_collection.name}."
+ )
+ else:
+ if not (self.active_collection and self.active_collection.name == collection_name):
+ self.active_collection = self.client.get_collection(
+ collection_name, embedding_function=self.embedding_function
+ )
+ return self.active_collection
+
+ def delete_collection(self, collection_name: str) -> None:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ None
+ """
+ self.client.delete_collection(collection_name)
+ if self.active_collection and self.active_collection.name == collection_name:
+ self.active_collection = None
+
+ def _batch_insert(
+ self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
+ ) -> None:
+ batch_size = int(CHROMADB_MAX_BATCH_SIZE)
+ for i in range(0, len(documents), min(batch_size, len(documents))):
+ end_idx = i + min(batch_size, len(documents) - i)
+ collection_kwargs = {
+ "documents": documents[i:end_idx],
+ "ids": ids[i:end_idx],
+ "metadatas": metadatas[i:end_idx] if metadatas else None,
+ "embeddings": embeddings[i:end_idx] if embeddings else None,
+ }
+ if upsert:
+ collection.upsert(**collection_kwargs)
+ else:
+ collection.add(**collection_kwargs)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if docs[0].get("content") is None:
+ raise ValueError("The document content is required.")
+ if docs[0].get("id") is None:
+ raise ValueError("The document id is required.")
+ documents = [doc.get("content") for doc in docs]
+ ids = [doc.get("id") for doc in docs]
+ collection = self.get_collection(collection_name)
+ if docs[0].get("embedding") is None:
+ logger.info(
+ "No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding."
+ )
+ embeddings = None
+ else:
+ embeddings = [doc.get("embedding") for doc in docs]
+ if docs[0].get("metadata") is None:
+ metadatas = None
+ else:
+ metadatas = [doc.get("metadata") for doc in docs]
+ self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ """
+ Update documents in the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+
+ Returns:
+ None
+ """
+ self.insert_docs(docs, collection_name, upsert=True)
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ collection = self.get_collection(collection_name)
+ collection.delete(ids, **kwargs)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ collection = self.get_collection(collection_name)
+ if isinstance(queries, str):
+ queries = [queries]
+ results = collection.query(
+ query_texts=queries,
+ n_results=n_results,
+ **kwargs,
+ )
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results)
+ results = filter_results_by_distance(results, distance_threshold)
+ return results
+
+ @staticmethod
+ def _chroma_get_results_to_list_documents(data_dict) -> List[Document]:
+ """Converts a dictionary with list values to a list of Document.
+
+ Args:
+ data_dict: A dictionary where keys map to lists or None.
+
+ Returns:
+ List[Document] | The list of Document.
+
+ Example:
+ data_dict = {
+ "key1s": [1, 2, 3],
+ "key2s": ["a", "b", "c"],
+ "key3s": None,
+ "key4s": ["x", "y", "z"],
+ }
+
+ results = [
+ {"key1": 1, "key2": "a", "key4": "x"},
+ {"key1": 2, "key2": "b", "key4": "y"},
+ {"key1": 3, "key2": "c", "key4": "z"},
+ ]
+ """
+
+ results = []
+ keys = [key for key in data_dict if data_dict[key] is not None]
+
+ for i in range(len(data_dict[keys[0]])):
+ sub_dict = {}
+ for key in data_dict.keys():
+ if data_dict[key] is not None and len(data_dict[key]) > i:
+ sub_dict[key[:-1]] = data_dict[key][i]
+ results.append(sub_dict)
+ return results
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is None.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ collection = self.get_collection(collection_name)
+ include = include if include else ["metadatas", "documents"]
+ results = collection.get(ids, include=include, **kwargs)
+ results = self._chroma_get_results_to_list_documents(results)
+ return results
diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py
new file mode 100644
index 00000000000..ac86802b672
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py
@@ -0,0 +1,952 @@
+import os
+import re
+import urllib.parse
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+from sentence_transformers import SentenceTransformer
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+try:
+ import pgvector
+ from pgvector.psycopg import register_vector
+except ImportError:
+ raise ImportError("Please install pgvector: `pip install pgvector`")
+
+try:
+ import psycopg
+except ImportError:
+ raise ImportError("Please install pgvector: `pip install psycopg`")
+
+PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000)
+logger = get_logger(__name__)
+
+
+class Collection:
+ """
+ A Collection object for PGVector.
+
+ Attributes:
+ client: The PGVector client.
+ collection_name (str): The name of the collection. Default is "documents".
+ embedding_function (Callable): The embedding function used to generate the vector representation.
+ Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
+ Models can be chosen from:
+ https://huggingface.co/models?library=sentence-transformers
+ metadata (Optional[dict]): The metadata of the collection.
+ get_or_create (Optional): The flag indicating whether to get or create the collection.
+ """
+
+ def __init__(
+ self,
+ client=None,
+ collection_name: str = "autogen-docs",
+ embedding_function: Callable = None,
+ metadata=None,
+ get_or_create=None,
+ ):
+ """
+ Initialize the Collection object.
+
+ Args:
+ client: The PostgreSQL client.
+ collection_name: The name of the collection. Default is "documents".
+ embedding_function: The embedding function used to generate the vector representation.
+ metadata: The metadata of the collection.
+ get_or_create: The flag indicating whether to get or create the collection.
+ Returns:
+ None
+ """
+ self.client = client
+ self.name = self.set_collection_name(collection_name)
+ self.require_embeddings_or_documents = False
+ self.ids = []
+ if embedding_function:
+ self.embedding_function = embedding_function
+ else:
+ self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
+ self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
+ self.documents = ""
+ self.get_or_create = get_or_create
+ # This will get the model dimension size by computing the embeddings dimensions
+ sentences = [
+ "The weather is lovely today in paradise.",
+ ]
+ embeddings = self.embedding_function(sentences)
+ self.dimension = len(embeddings[0])
+
+ def set_collection_name(self, collection_name) -> str:
+ name = re.sub("-", "_", collection_name)
+ self.name = name
+ return self.name
+
+ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
+ """
+ Add documents to the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs.
+ embeddings (List): A list of document embeddings. Optional
+ metadatas (List): A list of document metadatas. Optional
+ documents (List): A list of documents.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ if embeddings is not None and metadatas is not None:
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ sql_values.append((doc_id, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
+ )
+ elif embeddings is not None:
+ for doc_id, embedding, document in zip(ids, embeddings, documents):
+ sql_values.append((doc_id, embedding, document))
+ sql_string = f"INSERT INTO {self.name} (id, embedding, documents) " f"VALUES (%s, %s, %s);\n"
+ elif metadatas is not None:
+ for doc_id, metadata, document in zip(ids, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, metadata, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
+ )
+ else:
+ for doc_id, document in zip(ids, documents):
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, document, embedding))
+ sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
+ logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
+ """
+ Upsert documents into the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs.
+ documents (List): A list of documents.
+ embeddings (List): A list of document embeddings.
+ metadatas (List): A list of document metadatas.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ if embeddings is not None and metadatas is not None:
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ sql_values.append((doc_id, embedding, metadata, document, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n"
+ f"VALUES (%s, %s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET embedding = %s,\n"
+ f"metadatas = %s, documents = %s;\n"
+ )
+ elif embeddings is not None:
+ for doc_id, embedding, document in zip(ids, embeddings, documents):
+ sql_values.append((doc_id, embedding, document, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, documents) "
+ f"VALUES (%s, %s, %s) ON CONFLICT (id)\n"
+ f"DO UPDATE SET embedding = %s, documents = %s;\n"
+ )
+ elif metadatas is not None:
+ for doc_id, metadata, document in zip(ids, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
+ f"VALUES (%s, %s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET metadatas = %s, documents = %s, embedding = %s;\n"
+ )
+ else:
+ for doc_id, document in zip(ids, documents):
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, document, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, documents, embedding)\n"
+ f"VALUES (%s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET documents = %s;\n"
+ )
+ logger.debug(f"Upsert SQL String:\n{sql_string}\n{sql_values}")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ def count(self) -> int:
+ """
+ Get the total number of documents in the collection.
+
+ Returns:
+ int: The total number of documents.
+ """
+ cursor = self.client.cursor()
+ query = f"SELECT COUNT(*) FROM {self.name}"
+ cursor.execute(query)
+ total = cursor.fetchone()[0]
+ cursor.close()
+ try:
+ total = int(total)
+ except (TypeError, ValueError):
+ total = None
+ return total
+
+ def table_exists(self, table_name: str) -> bool:
+ """
+ Check if a table exists in the PostgreSQL database.
+
+ Args:
+ table_name (str): The name of the table to check.
+
+ Returns:
+ bool: True if the table exists, False otherwise.
+ """
+
+ cursor = self.client.cursor()
+ cursor.execute(
+ """
+ SELECT EXISTS (
+ SELECT 1
+ FROM information_schema.tables
+ WHERE table_name = %s
+ )
+ """,
+ (table_name,),
+ )
+ exists = cursor.fetchone()[0]
+ return exists
+
+ def get(
+ self,
+ ids: Optional[str] = None,
+ include: Optional[str] = None,
+ where: Optional[str] = None,
+ limit: Optional[Union[int, str]] = None,
+ offset: Optional[Union[int, str]] = None,
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection.
+
+ Args:
+ ids (Optional[List]): A list of document IDs.
+ include (Optional): The fields to include.
+ where (Optional): Additional filtering criteria.
+ limit (Optional): The maximum number of documents to retrieve.
+ offset (Optional): The offset for pagination.
+
+ Returns:
+ List: The retrieved documents.
+ """
+ cursor = self.client.cursor()
+
+ # Initialize variables for query components
+ select_clause = "SELECT id, metadatas, documents, embedding"
+ from_clause = f"FROM {self.name}"
+ where_clause = ""
+ limit_clause = ""
+ offset_clause = ""
+
+ # Handle include clause
+ if include:
+ select_clause = f"SELECT id, {', '.join(include)}, embedding"
+
+ # Handle where clause
+ if ids:
+ where_clause = f"WHERE id IN ({', '.join(['%s' for _ in ids])})"
+ elif where:
+ where_clause = f"WHERE {where}"
+
+ # Handle limit and offset clauses
+ if limit:
+ limit_clause = "LIMIT %s"
+ if offset:
+ offset_clause = "OFFSET %s"
+
+ # Construct the full query
+ query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
+ retrieved_documents = []
+ try:
+ # Execute the query with the appropriate values
+ if ids is not None:
+ cursor.execute(query, ids)
+ else:
+ query_params = []
+ if limit:
+ query_params.append(limit)
+ if offset:
+ query_params.append(offset)
+ cursor.execute(query, query_params)
+
+ retrieval = cursor.fetchall()
+ for retrieved_document in retrieval:
+ retrieved_documents.append(
+ Document(
+ id=retrieved_document[0].strip(),
+ metadata=retrieved_document[1],
+ content=retrieved_document[2],
+ embedding=retrieved_document[3],
+ )
+ )
+ except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
+ logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
+ self.create_collection(collection_name=self.name, dimension=self.dimension)
+ logger.info(f"Created table {self.name}")
+
+ cursor.close()
+ return retrieved_documents
+
+ def update(self, ids: List, embeddings: List, metadatas: List, documents: List) -> None:
+ """
+ Update documents in the collection.
+
+ Args:
+ ids (List): A list of document IDs.
+ embeddings (List): A list of document embeddings.
+ metadatas (List): A list of document metadatas.
+ documents (List): A list of documents.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ sql_values.append((doc_id, embedding, metadata, document, doc_id, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadata, document) "
+ f"VALUES (%s, %s, %s, %s) "
+ f"ON CONFLICT (id) "
+ f"DO UPDATE SET id = %s, embedding = %s, "
+ f"metadata = %s, document = %s;\n"
+ )
+ logger.debug(f"Upsert SQL String:\n{sql_string}\n")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ @staticmethod
+ def euclidean_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the Euclidean distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The Euclidean distance between arr1 and arr2.
+ """
+ dist = np.linalg.norm(arr1 - arr2)
+ return dist
+
+ @staticmethod
+ def cosine_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the cosine distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The cosine distance between arr1 and arr2.
+ """
+ dist = np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))
+ return dist
+
+ @staticmethod
+ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the Euclidean distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The Euclidean distance between arr1 and arr2.
+ """
+ dist = np.linalg.norm(arr1 - arr2)
+ return dist
+
+ def query(
+ self,
+ query_texts: List[str],
+ collection_name: Optional[str] = None,
+ n_results: Optional[int] = 10,
+ distance_type: Optional[str] = "euclidean",
+ distance_threshold: Optional[float] = -1,
+ include_embedding: Optional[bool] = False,
+ ) -> QueryResults:
+ """
+ Query documents in the collection.
+
+ Args:
+ query_texts (List[str]): A list of query texts.
+ collection_name (Optional[str]): The name of the collection.
+ n_results (int): The maximum number of results to return.
+ distance_type (Optional[str]): Distance search type - euclidean or cosine
+ distance_threshold (Optional[float]): Distance threshold to limit searches
+ include_embedding (Optional[bool]): Include embedding values in QueryResults
+ Returns:
+ QueryResults: The query results.
+ """
+ if collection_name:
+ self.name = collection_name
+
+ clause = "ORDER BY"
+ if distance_threshold == -1:
+ distance_threshold = ""
+ clause = "ORDER BY"
+ elif distance_threshold > 0:
+ distance_threshold = f"< {distance_threshold}"
+ clause = "WHERE"
+
+ cursor = self.client.cursor()
+ results = []
+ for query_text in query_texts:
+ vector = self.embedding_function(query_text, convert_to_tensor=False).tolist()
+ if distance_type.lower() == "cosine":
+ index_function = "<=>"
+ elif distance_type.lower() == "euclidean":
+ index_function = "<->"
+ elif distance_type.lower() == "inner-product":
+ index_function = "<#>"
+ else:
+ index_function = "<->"
+ query = (
+ f"SELECT id, documents, embedding, metadatas "
+ f"FROM {self.name} "
+ f"{clause} embedding {index_function} '{str(vector)}' {distance_threshold} "
+ f"LIMIT {n_results}"
+ )
+ cursor.execute(query)
+ result = []
+ for row in cursor.fetchall():
+ fetched_document = Document(id=row[0].strip(), content=row[1], embedding=row[2], metadata=row[3])
+ fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
+ if distance_type.lower() == "cosine":
+ distance = self.cosine_distance(fetched_document_array, vector)
+ elif distance_type.lower() == "euclidean":
+ distance = self.euclidean_distance(fetched_document_array, vector)
+ elif distance_type.lower() == "inner-product":
+ distance = self.inner_product_distance(fetched_document_array, vector)
+ else:
+ distance = self.euclidean_distance(fetched_document_array, vector)
+ if not include_embedding:
+ fetched_document = Document(id=row[0].strip(), content=row[1], metadata=row[3])
+ result.append((fetched_document, distance))
+ results.append(result)
+ cursor.close()
+ logger.debug(f"Query Results: {results}")
+ return results
+
+ @staticmethod
+ def convert_string_to_array(array_string: str) -> List[float]:
+ """
+ Convert a string representation of an array to a list of floats.
+
+ Parameters:
+ - array_string (str): The string representation of the array.
+
+ Returns:
+ - list: A list of floats parsed from the input string. If the input is
+ not a string, it returns the input itself.
+ """
+ if not isinstance(array_string, str):
+ return array_string
+ array_string = array_string.strip("[]")
+ array = [float(num) for num in array_string.split()]
+ return array
+
+ def modify(self, metadata, collection_name: Optional[str] = None) -> None:
+ """
+ Modify metadata for the collection.
+
+ Args:
+ collection_name: The name of the collection.
+ metadata: The new metadata.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ cursor.execute(
+ "UPDATE collections" "SET metadata = '%s'" "WHERE collection_name = '%s';", (metadata, self.name)
+ )
+ cursor.close()
+
+ def delete(self, ids: List[ItemID], collection_name: Optional[str] = None) -> None:
+ """
+ Delete documents from the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs to delete.
+ collection_name (str): The name of the collection to delete.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ id_placeholders = ", ".join(["%s" for _ in ids])
+ cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
+ cursor.close()
+
+ def delete_collection(self, collection_name: Optional[str] = None) -> None:
+ """
+ Delete the entire collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the collection to delete.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
+ cursor.close()
+
+ def create_collection(
+ self, collection_name: Optional[str] = None, dimension: Optional[Union[str, int]] = None
+ ) -> None:
+ """
+ Create a new collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the new collection.
+ dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+
+ if dimension:
+ self.dimension = dimension
+ elif self.dimension is None:
+ self.dimension = 384
+
+ cursor = self.client.cursor()
+ cursor.execute(
+ f"CREATE TABLE {self.name} ("
+ f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({self.dimension}));"
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_cosine_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_ip_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ )
+ cursor.close()
+
+
+class PGVectorDB(VectorDB):
+ """
+ A vector database that uses PGVector as the backend.
+ """
+
+ def __init__(
+ self,
+ *,
+ conn: Optional[psycopg.Connection] = None,
+ connection_string: Optional[str] = None,
+ host: Optional[str] = None,
+ port: Optional[Union[int, str]] = None,
+ dbname: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ connect_timeout: Optional[int] = 10,
+ embedding_function: Callable = None,
+ metadata: Optional[dict] = None,
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Note: connection_string or host + port + dbname must be specified
+
+ Args:
+ conn: psycopg.Connection | A customer connection object to connect to the database.
+ A connection object may include additional key/values:
+ https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
+ connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
+ host: str | The host to connect to. Default is None.
+ port: int | The port to connect to. Default is None.
+ dbname: str | The database name to connect to. Default is None.
+ username: str | The database username to use. Default is None.
+ password: str | The database user password to use. Default is None.
+ connect_timeout: int | The timeout to set for the connection. Default is 10.
+ embedding_function: Callable | The embedding function used to generate the vector representation.
+ Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
+ Models can be chosen from:
+ https://huggingface.co/models?library=sentence-transformers
+ metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
+ setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
+ using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
+ For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
+ Returns:
+ None
+ """
+ self.client = self.establish_connection(
+ conn=conn,
+ connection_string=connection_string,
+ host=host,
+ port=port,
+ dbname=dbname,
+ username=username,
+ password=password,
+ connect_timeout=connect_timeout,
+ )
+ if embedding_function:
+ self.embedding_function = embedding_function
+ else:
+ self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
+ self.metadata = metadata
+ register_vector(self.client)
+ self.active_collection = None
+
+ def establish_connection(
+ self,
+ conn: Optional[psycopg.Connection] = None,
+ connection_string: Optional[str] = None,
+ host: Optional[str] = None,
+ port: Optional[Union[int, str]] = None,
+ dbname: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ connect_timeout: Optional[int] = 10,
+ ) -> psycopg.Connection:
+ """
+ Establishes a connection to a PostgreSQL database using psycopg.
+
+ Args:
+ conn: An existing psycopg connection object. If provided, this connection will be used.
+ connection_string: A string containing the connection information. If provided, a new connection will be established using this string.
+ host: The hostname of the PostgreSQL server. Used if connection_string is not provided.
+ port: The port number to connect to at the server host. Used if connection_string is not provided.
+ dbname: The database name. Used if connection_string is not provided.
+ username: The username to connect as. Used if connection_string is not provided.
+ password: The user's password. Used if connection_string is not provided.
+ connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds.
+
+ Returns:
+ A psycopg.Connection object representing the established connection.
+
+ Raises:
+ PermissionError if no credentials are supplied
+ psycopg.Error: If an error occurs while trying to connect to the database.
+ """
+ try:
+ if conn:
+ self.client = conn
+ elif connection_string:
+ parsed_connection = urllib.parse.urlparse(connection_string)
+ encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
+ encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
+ encoded_password = f":{encoded_password}@"
+ encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
+ encoded_port = f":{parsed_connection.port}"
+ encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
+ connection_string_encoded = (
+ f"{parsed_connection.scheme}://{encoded_username}{encoded_password}"
+ f"{encoded_host}{encoded_port}/{encoded_database}"
+ )
+ self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
+ elif host:
+ connection_string = ""
+ if host:
+ encoded_host = urllib.parse.quote(host, safe="")
+ connection_string += f"host={encoded_host} "
+ if port:
+ connection_string += f"port={port} "
+ if dbname:
+ encoded_database = urllib.parse.quote(dbname, safe="")
+ connection_string += f"dbname={encoded_database} "
+ if username:
+ encoded_username = urllib.parse.quote(username, safe="")
+ connection_string += f"user={encoded_username} "
+ if password:
+ encoded_password = urllib.parse.quote(password, safe="")
+ connection_string += f"password={encoded_password} "
+
+ self.client = psycopg.connect(
+ conninfo=connection_string,
+ connect_timeout=connect_timeout,
+ autocommit=True,
+ )
+ else:
+ logger.error("Credentials were not supplied...")
+ raise PermissionError
+ self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
+ except psycopg.Error as e:
+ logger.error("Error connecting to the database: ", e)
+ raise e
+ return self.client
+
+ def create_collection(
+ self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
+ ) -> Collection:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Collection | The collection object.
+ """
+ try:
+ if self.active_collection and self.active_collection.name == collection_name:
+ collection = self.active_collection
+ else:
+ collection = self.get_collection(collection_name)
+ except ValueError:
+ collection = None
+ if collection is None:
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ elif overwrite:
+ self.delete_collection(collection_name)
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ elif get_or_create:
+ return collection
+ elif not collection.table_exists(table_name=collection_name):
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ else:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None) -> Collection:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Collection | The collection object.
+ """
+ if collection_name is None:
+ if self.active_collection is None:
+ raise ValueError("No collection is specified.")
+ else:
+ logger.debug(
+ f"No collection is specified. Using current active collection {self.active_collection.name}."
+ )
+ else:
+ if not (self.active_collection and self.active_collection.name == collection_name):
+ self.active_collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ )
+ return self.active_collection
+
+ def delete_collection(self, collection_name: str) -> None:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ None
+ """
+ if self.active_collection:
+ self.active_collection.delete_collection(collection_name)
+ else:
+ collection = self.get_collection(collection_name)
+ collection.delete_collection(collection_name)
+ if self.active_collection and self.active_collection.name == collection_name:
+ self.active_collection = None
+
+ def _batch_insert(
+ self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
+ ) -> None:
+ batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
+ default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
+ default_metadatas = [default_metadata] * min(batch_size, len(documents))
+ for i in range(0, len(documents), min(batch_size, len(documents))):
+ end_idx = i + min(batch_size, len(documents) - i)
+ collection_kwargs = {
+ "documents": documents[i:end_idx],
+ "ids": ids[i:end_idx],
+ "metadatas": metadatas[i:end_idx] if metadatas else default_metadatas,
+ "embeddings": embeddings[i:end_idx] if embeddings else None,
+ }
+ if upsert:
+ collection.upsert(**collection_kwargs)
+ else:
+ collection.add(**collection_kwargs)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if docs[0].get("content") is None:
+ raise ValueError("The document content is required.")
+ if docs[0].get("id") is None:
+ raise ValueError("The document id is required.")
+ documents = [doc.get("content") for doc in docs]
+ ids = [doc.get("id") for doc in docs]
+
+ collection = self.get_collection(collection_name)
+ if docs[0].get("embedding") is None:
+ logger.debug(
+ "No content embedding is provided. "
+ "Will use the VectorDB's embedding function to generate the content embedding."
+ )
+ embeddings = None
+ else:
+ embeddings = [doc.get("embedding") for doc in docs]
+ if docs[0].get("metadata") is None:
+ metadatas = None
+ else:
+ metadatas = [doc.get("metadata") for doc in docs]
+
+ self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ """
+ Update documents in the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+
+ Returns:
+ None
+ """
+ self.insert_docs(docs, collection_name, upsert=True)
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ collection = self.get_collection(collection_name)
+ collection.delete(ids=ids, collection_name=collection_name)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ collection = self.get_collection(collection_name)
+ if isinstance(queries, str):
+ queries = [queries]
+ results = collection.query(
+ query_texts=queries,
+ n_results=n_results,
+ distance_threshold=distance_threshold,
+ )
+ logger.debug(f"Retrieve Docs Results:\n{results}")
+ return results
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is None.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ collection = self.get_collection(collection_name)
+ include = include if include else ["metadatas", "documents"]
+ results = collection.get(ids, include=include, **kwargs)
+ logger.debug(f"Retrieve Documents by ID Results:\n{results}")
+ return results
diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py
new file mode 100644
index 00000000000..3dcf79f1f55
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/utils.py
@@ -0,0 +1,115 @@
+import logging
+from typing import Any, Dict, List
+
+from termcolor import colored
+
+from .base import QueryResults
+
+
+class ColoredLogger(logging.Logger):
+ def __init__(self, name, level=logging.NOTSET):
+ super().__init__(name, level)
+
+ def debug(self, msg, *args, color=None, **kwargs):
+ super().debug(colored(msg, color), *args, **kwargs)
+
+ def info(self, msg, *args, color=None, **kwargs):
+ super().info(colored(msg, color), *args, **kwargs)
+
+ def warning(self, msg, *args, color="yellow", **kwargs):
+ super().warning(colored(msg, color), *args, **kwargs)
+
+ def error(self, msg, *args, color="light_red", **kwargs):
+ super().error(colored(msg, color), *args, **kwargs)
+
+ def critical(self, msg, *args, color="red", **kwargs):
+ super().critical(colored(msg, color), *args, **kwargs)
+
+ def fatal(self, msg, *args, color="red", **kwargs):
+ super().fatal(colored(msg, color), *args, **kwargs)
+
+
+def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger:
+ logger = ColoredLogger(name, level)
+ console_handler = logging.StreamHandler()
+ logger.addHandler(console_handler)
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+ logger.handlers[0].setFormatter(formatter)
+ return logger
+
+
+logger = get_logger(__name__)
+
+
+def filter_results_by_distance(results: QueryResults, distance_threshold: float = -1) -> QueryResults:
+ """Filters results based on a distance threshold.
+
+ Args:
+ results: QueryResults | The query results. List[List[Tuple[Document, float]]]
+ distance_threshold: The maximum distance allowed for results.
+
+ Returns:
+ QueryResults | A filtered results containing only distances smaller than the threshold.
+ """
+
+ if distance_threshold > 0:
+ results = [[(key, value) for key, value in data if value < distance_threshold] for data in results]
+
+ return results
+
+
+def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], special_key="distances") -> QueryResults:
+ """Converts a dictionary with list-of-list values to a list of tuples.
+
+ Args:
+ data_dict: A dictionary where keys map to lists of lists or None.
+ special_key: The key in the dictionary containing the special values
+ for each tuple.
+
+ Returns:
+ A list of tuples, where each tuple contains a sub-dictionary with
+ some keys from the original dictionary and the value from the
+ special_key.
+
+ Example:
+ data_dict = {
+ "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]],
+ "key3s": None,
+ "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]],
+ "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],
+ }
+
+ results = [
+ [
+ ({"key1": 1, "key2": "a", "key4": "x"}, 0.1),
+ ({"key1": 2, "key2": "b", "key4": "y"}, 0.2),
+ ({"key1": 3, "key2": "c", "key4": "z"}, 0.3),
+ ],
+ [
+ ({"key1": 4, "key2": "c", "key4": "1"}, 0.4),
+ ({"key1": 5, "key2": "d", "key4": "2"}, 0.5),
+ ({"key1": 6, "key2": "e", "key4": "3"}, 0.6),
+ ],
+ [
+ ({"key1": 7, "key2": "e", "key4": "4"}, 0.7),
+ ({"key1": 8, "key2": "f", "key4": "5"}, 0.8),
+ ({"key1": 9, "key2": "g", "key4": "6"}, 0.9),
+ ],
+ ]
+ """
+
+ keys = [key for key in data_dict if key != special_key]
+ result = []
+
+ for i in range(len(data_dict[special_key])):
+ sub_result = []
+ for j, distance in enumerate(data_dict[special_key][i]):
+ sub_dict = {}
+ for key in keys:
+ if data_dict[key] is not None and len(data_dict[key]) > i:
+ sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key
+ sub_result.append((sub_dict, distance))
+ result.append(sub_result)
+
+ return result
diff --git a/autogen/agentchat/contrib/web_surfer.py b/autogen/agentchat/contrib/web_surfer.py
index 4877a4d0949..af07be6d343 100644
--- a/autogen/agentchat/contrib/web_surfer.py
+++ b/autogen/agentchat/contrib/web_surfer.py
@@ -1,15 +1,18 @@
-import json
import copy
+import json
import logging
import re
from dataclasses import dataclass
-from typing import Dict, List, Optional, Union, Callable, Literal, Tuple
-from autogen import Agent, ConversableAgent, AssistantAgent, UserProxyAgent, GroupChatManager, GroupChat, OpenAIWrapper
-from autogen.browser_utils import SimpleTextBrowser
-from autogen.code_utils import content_str
from datetime import datetime
-from autogen.token_count_utils import count_token, get_max_token_limit
-from autogen.oai.openai_utils import filter_config
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+from typing_extensions import Annotated
+
+from ... import Agent, AssistantAgent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper, UserProxyAgent
+from ...browser_utils import SimpleTextBrowser
+from ...code_utils import content_str
+from ...oai.openai_utils import filter_config
+from ...token_count_utils import count_token, get_max_token_limit
logger = logging.getLogger(__name__)
@@ -26,12 +29,12 @@ class WebSurferAgent(ConversableAgent):
def __init__(
self,
- name,
- system_message: Optional[Union[str, List]] = DEFAULT_PROMPT,
+ name: str,
+ system_message: Optional[Union[str, List[str]]] = DEFAULT_PROMPT,
description: Optional[str] = DEFAULT_DESCRIPTION,
- is_termination_msg: Optional[Callable[[Dict], bool]] = None,
+ is_termination_msg: Optional[Callable[[Dict[str, Any]], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "TERMINATE",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
@@ -52,6 +55,38 @@ def __init__(
default_auto_reply=default_auto_reply,
)
+ self._create_summarizer_client(summarizer_llm_config, llm_config)
+
+ # Create the browser
+ self.browser = SimpleTextBrowser(**(browser_config if browser_config else {}))
+
+ inner_llm_config = copy.deepcopy(llm_config)
+
+ # Set up the inner monologue
+ self._assistant = AssistantAgent(
+ self.name + "_inner_assistant",
+ system_message=system_message, # type: ignore[arg-type]
+ llm_config=inner_llm_config,
+ is_termination_msg=lambda m: False,
+ )
+
+ self._user_proxy = UserProxyAgent(
+ self.name + "_inner_user_proxy",
+ human_input_mode="NEVER",
+ code_execution_config=False,
+ default_auto_reply="",
+ is_termination_msg=lambda m: False,
+ )
+
+ if inner_llm_config not in [None, False]:
+ self._register_functions()
+
+ self.register_reply([Agent, None], WebSurferAgent.generate_surfer_reply, remove_other_reply_funcs=True)
+ self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
+ self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
+ self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
+
+ def _create_summarizer_client(self, summarizer_llm_config: Dict[str, Any], llm_config: Dict[str, Any]) -> None:
# If the summarizer_llm_config is None, we copy it from the llm_config
if summarizer_llm_config is None:
if llm_config is None: # Nothing to copy
@@ -59,10 +94,10 @@ def __init__(
elif llm_config is False: # LLMs disabled
self.summarizer_llm_config = False
else: # Create a suitable config
- self.summarizer_llm_config = copy.deepcopy(llm_config)
- if "config_list" in self.summarizer_llm_config:
- preferred_models = filter_config(
- self.summarizer_llm_config["config_list"],
+ self.summarizer_llm_config = copy.deepcopy(llm_config) # type: ignore[assignment]
+ if "config_list" in self.summarizer_llm_config: # type: ignore[operator]
+ preferred_models = filter_config( # type: ignore[no-untyped-call]
+ self.summarizer_llm_config["config_list"], # type: ignore[index]
{"model": ["gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"]},
)
if len(preferred_models) == 0:
@@ -71,142 +106,20 @@ def __init__(
"Semantic operations on webpages (summarization or Q&A) might be costly or ineffective."
)
else:
- self.summarizer_llm_config["config_list"] = preferred_models
+ self.summarizer_llm_config["config_list"] = preferred_models # type: ignore[index]
else:
- self.summarizer_llm_config = summarizer_llm_config
+ self.summarizer_llm_config = summarizer_llm_config # type: ignore[assignment]
# Create the summarizer client
- self.summarization_client = None
- if self.summarizer_llm_config is not False:
- self.summarization_client = OpenAIWrapper(**self.summarizer_llm_config)
-
- # Create the browser
- if browser_config is None:
- self.browser = SimpleTextBrowser()
- else:
- self.browser = SimpleTextBrowser(**browser_config)
-
- # Create a copy of the llm_config for the inner monologue agents to use, and set them up with function calling
- if llm_config is None: # Nothing to copy
- inner_llm_config = None
- elif llm_config is False: # LLMs disabled
- inner_llm_config = False
- else:
- inner_llm_config = copy.deepcopy(llm_config)
- inner_llm_config["functions"] = [
- {
- "name": "informational_web_search",
- "description": "Perform an INFORMATIONAL web search query then return the search results.",
- "parameters": {
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "The informational web search query to perform.",
- }
- },
- },
- "required": ["query"],
- },
- {
- "name": "navigational_web_search",
- "description": "Perform a NAVIGATIONAL web search query then immediately navigate to the top result. Useful, for example, to navigate to a particular Wikipedia article or other known destination. Equivalent to Google's \"I'm Feeling Lucky\" button.",
- "parameters": {
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "The navigational web search query to perform.",
- }
- },
- },
- "required": ["query"],
- },
- {
- "name": "visit_page",
- "description": "Visit a webpage at a given URL and return its text.",
- "parameters": {
- "type": "object",
- "properties": {
- "url": {
- "type": "string",
- "description": "The relative or absolute url of the webapge to visit.",
- }
- },
- },
- "required": ["url"],
- },
- {
- "name": "page_up",
- "description": "Scroll the viewport UP one page-length in the current webpage and return the new viewport content.",
- "parameters": {"type": "object", "properties": {}},
- "required": [],
- },
- {
- "name": "page_down",
- "description": "Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content.",
- "parameters": {"type": "object", "properties": {}},
- "required": [],
- },
- ]
-
- # Enable semantic operations
- if self.summarization_client is not None:
- inner_llm_config["functions"].append(
- {
- "name": "answer_from_page",
- "description": "Uses AI to read the page and directly answer a given question based on the content.",
- "parameters": {
- "type": "object",
- "properties": {
- "question": {
- "type": "string",
- "description": "The question to directly answer.",
- },
- "url": {
- "type": "string",
- "description": "[Optional] The url of the page. (Defaults to the current page)",
- },
- },
- },
- "required": ["question"],
- }
- )
- inner_llm_config["functions"].append(
- {
- "name": "summarize_page",
- "description": "Uses AI to summarize the content found at a given url. If the url is not provided, the current page is summarized.",
- "parameters": {
- "type": "object",
- "properties": {
- "url": {
- "type": "string",
- "description": "[Optional] The url of the page to summarize. (Defaults to current page)",
- },
- },
- },
- "required": [],
- }
- )
+ self.summarization_client = (
+ None if self.summarizer_llm_config is False else OpenAIWrapper(**self.summarizer_llm_config)
+ ) # type: ignore[arg-type]
- # Set up the inner monologue
- self._assistant = AssistantAgent(
- self.name + "_inner_assistant",
- system_message=system_message,
- llm_config=inner_llm_config,
- is_termination_msg=lambda m: False,
- )
-
- self._user_proxy = UserProxyAgent(
- self.name + "_inner_user_proxy",
- human_input_mode="NEVER",
- code_execution_config=False,
- default_auto_reply="",
- is_termination_msg=lambda m: False,
- )
+ def _register_functions(self) -> None:
+ """Register the functions for the inner assistant and user proxy."""
# Helper functions
- def _browser_state():
+ def _browser_state() -> Tuple[str, str]:
header = f"Address: {self.browser.address}\n"
if self.browser.page_title is not None:
header += f"Title: {self.browser.page_title}\n"
@@ -217,12 +130,22 @@ def _browser_state():
header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n"
return (header, self.browser.viewport)
- def _informational_search(query):
+ @self._user_proxy.register_for_execution()
+ @self._assistant.register_for_llm(
+ name="informational_web_search",
+ description="Perform an INFORMATIONAL web search query then return the search results.",
+ )
+ def _informational_search(query: Annotated[str, "The informational web search query to perform."]) -> str:
self.browser.visit_page(f"bing: {query}")
header, content = _browser_state()
return header.strip() + "\n=======================\n" + content
- def _navigational_search(query):
+ @self._user_proxy.register_for_execution()
+ @self._assistant.register_for_llm(
+ name="navigational_web_search",
+ description="Perform a NAVIGATIONAL web search query then immediately navigate to the top result. Useful, for example, to navigate to a particular Wikipedia article or other known destination. Equivalent to Google's \"I'm Feeling Lucky\" button.",
+ )
+ def _navigational_search(query: Annotated[str, "The navigational web search query to perform."]) -> str:
self.browser.visit_page(f"bing: {query}")
# Extract the first linl
@@ -234,99 +157,117 @@ def _navigational_search(query):
header, content = _browser_state()
return header.strip() + "\n=======================\n" + content
- def _visit_page(url):
+ @self._user_proxy.register_for_execution()
+ @self._assistant.register_for_llm(
+ name="visit_page", description="Visit a webpage at a given URL and return its text."
+ )
+ def _visit_page(url: Annotated[str, "The relative or absolute url of the webapge to visit."]) -> str:
self.browser.visit_page(url)
header, content = _browser_state()
return header.strip() + "\n=======================\n" + content
- def _page_up():
+ @self._user_proxy.register_for_execution()
+ @self._assistant.register_for_llm(
+ name="page_up",
+ description="Scroll the viewport UP one page-length in the current webpage and return the new viewport content.",
+ )
+ def _page_up() -> str:
self.browser.page_up()
header, content = _browser_state()
return header.strip() + "\n=======================\n" + content
- def _page_down():
+ @self._user_proxy.register_for_execution()
+ @self._assistant.register_for_llm(
+ name="page_down",
+ description="Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content.",
+ )
+ def _page_down() -> str:
self.browser.page_down()
header, content = _browser_state()
return header.strip() + "\n=======================\n" + content
- def _summarize_page(question, url):
- if url is not None and url != self.browser.address:
- self.browser.visit_page(url)
-
- # We are likely going to need to fix this later, but summarize only as many tokens that fit in the buffer
- limit = 4096
- try:
- limit = get_max_token_limit(self.summarizer_llm_config["config_list"][0]["model"])
- except ValueError:
- pass # limit is unknown
- except TypeError:
- pass # limit is unknown
-
- if limit < 16000:
- logger.warning(
- f"The token limit ({limit}) of the WebSurferAgent.summarizer_llm_config, is below the recommended 16k."
- )
+ if self.summarization_client is not None:
- buffer = ""
- for line in re.split(r"([\r\n]+)", self.browser.page_content):
- tokens = count_token(buffer + line)
- if tokens + 1024 > limit: # Leave room for our summary
- break
- buffer += line
-
- buffer = buffer.strip()
- if len(buffer) == 0:
- return "Nothing to summarize."
-
- messages = [
- {
- "role": "system",
- "content": "You are a helpful assistant that can summarize long documents to answer question.",
- }
- ]
-
- prompt = f"Please summarize the following into one or two paragraph:\n\n{buffer}"
- if question is not None:
- prompt = f"Please summarize the following into one or two paragraphs with respect to '{question}':\n\n{buffer}"
-
- messages.append(
- {"role": "user", "content": prompt},
+ @self._user_proxy.register_for_execution()
+ @self._assistant.register_for_llm(
+ name="answer_from_page",
+ description="Uses AI to read the page and directly answer a given question based on the content.",
)
+ def _answer_from_page(
+ question: Annotated[Optional[str], "The question to directly answer."],
+ url: Annotated[Optional[str], "[Optional] The url of the page. (Defaults to the current page)"] = None,
+ ) -> str:
+ if url is not None and url != self.browser.address:
+ self.browser.visit_page(url)
+
+ # We are likely going to need to fix this later, but summarize only as many tokens that fit in the buffer
+ limit = 4096
+ try:
+ limit = get_max_token_limit(self.summarizer_llm_config["config_list"][0]["model"]) # type: ignore[index]
+ except ValueError:
+ pass # limit is unknown
+ except TypeError:
+ pass # limit is unknown
+
+ if limit < 16000:
+ logger.warning(
+ f"The token limit ({limit}) of the WebSurferAgent.summarizer_llm_config, is below the recommended 16k."
+ )
- response = self.summarization_client.create(context=None, messages=messages)
- extracted_response = self.summarization_client.extract_text_or_completion_object(response)[0]
- return str(extracted_response)
-
- self._user_proxy.register_function(
- function_map={
- "informational_web_search": lambda query: _informational_search(query),
- "navigational_web_search": lambda query: _navigational_search(query),
- "visit_page": lambda url: _visit_page(url),
- "page_up": lambda: _page_up(),
- "page_down": lambda: _page_down(),
- "answer_from_page": lambda question=None, url=None: _summarize_page(question, url),
- "summarize_page": lambda question=None, url=None: _summarize_page(None, url),
- }
- )
+ buffer = ""
+ for line in re.split(r"([\r\n]+)", self.browser.page_content):
+ tokens = count_token(buffer + line)
+ if tokens + 1024 > limit: # Leave room for our summary
+ break
+ buffer += line
- self._reply_func_list = []
- self.register_reply([Agent, None], WebSurferAgent.generate_surfer_reply)
- self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
- self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
- self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
+ buffer = buffer.strip()
+ if len(buffer) == 0:
+ return "Nothing to summarize."
+
+ messages = [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant that can summarize long documents to answer question.",
+ }
+ ]
+
+ prompt = f"Please summarize the following into one or two paragraph:\n\n{buffer}"
+ if question is not None:
+ prompt = f"Please summarize the following into one or two paragraphs with respect to '{question}':\n\n{buffer}"
+
+ messages.append(
+ {"role": "user", "content": prompt},
+ )
+
+ response = self.summarization_client.create(context=None, messages=messages) # type: ignore[union-attr]
+ extracted_response = self.summarization_client.extract_text_or_completion_object(response)[0] # type: ignore[union-attr]
+ return str(extracted_response)
+
+ @self._user_proxy.register_for_execution()
+ @self._assistant.register_for_llm(
+ name="summarize_page",
+ description="Uses AI to summarize the content found at a given url. If the url is not provided, the current page is summarized.",
+ )
+ def _summarize_page(
+ url: Annotated[
+ Optional[str], "[Optional] The url of the page to summarize. (Defaults to current page)"
+ ] = None,
+ ) -> str:
+ return _answer_from_page(url=url, question=None)
def generate_surfer_reply(
self,
- messages: Optional[List[Dict]] = None,
+ messages: Optional[List[Dict[str, str]]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
- ) -> Tuple[bool, Union[str, Dict, None]]:
+ ) -> Tuple[bool, Optional[Union[str, Dict[str, str]]]]:
"""Generate a reply using autogen.oai."""
if messages is None:
messages = self._oai_messages[sender]
- self._user_proxy.reset()
- self._assistant.reset()
+ self._user_proxy.reset() # type: ignore[no-untyped-call]
+ self._assistant.reset() # type: ignore[no-untyped-call]
# Clone the messages to give context
self._assistant.chat_messages[self._user_proxy] = list()
@@ -353,4 +294,4 @@ def generate_surfer_reply(
if proxy_reply == "": # Was the default reply
return True, None if agent_reply is None else agent_reply["content"]
else:
- return True, None if proxy_reply is None else proxy_reply["content"]
+ return True, None if proxy_reply is None else proxy_reply["content"] # type: ignore[index]
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index e3fb3d8be4a..b434fc648eb 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -5,35 +5,36 @@
import json
import logging
import re
-from collections import defaultdict
-from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
import warnings
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
+
+from openai import BadRequestError
-from .. import OpenAIWrapper, ModelClient
-from ..cache.cache import Cache
+from autogen.exception_utils import InvalidCarryOverType, SenderRequired
+
+from .._pydantic import model_dump
+from ..cache.cache import AbstractCache
from ..code_utils import (
- DEFAULT_MODEL,
+ PYTHON_VARIANTS,
UNKNOWN,
- content_str,
check_can_use_docker_or_throw,
+ content_str,
decide_use_docker,
execute_code,
extract_code,
infer_lang,
)
-
-
+from ..coding.base import CodeExecutor
+from ..coding.factory import CodeExecutorFactory
+from ..formatting_utils import colored
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
-from .agent import Agent
-from .._pydantic import model_dump
-
-try:
- from termcolor import colored
-except ImportError:
-
- def colored(x, *args, **kwargs):
- return x
-
+from ..io.base import IOStream
+from ..oai.client import ModelClient, OpenAIWrapper
+from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
+from .agent import Agent, LLMAgent
+from .chat import ChatResult, a_initiate_chats, initiate_chats
+from .utils import consolidate_chat_info, gather_usage_summary
__all__ = ("ConversableAgent",)
@@ -42,7 +43,7 @@ def colored(x, *args, **kwargs):
F = TypeVar("F", bound=Callable[..., Any])
-class ConversableAgent(Agent):
+class ConversableAgent(LLMAgent):
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg.
@@ -54,12 +55,13 @@ class ConversableAgent(Agent):
To modify the way to get human input, override `get_human_input` method.
To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`,
`run_code`, and `execute_function` methods respectively.
- To customize the initial message when a conversation starts, override `generate_init_message` method.
"""
- DEFAULT_CONFIG = {} # An empty configuration
+ DEFAULT_CONFIG = False # False or dict, the default config for llm inference
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change)
+ DEFAULT_SUMMARY_PROMPT = "Summarize the takeaway from the conversation. Do not add any introductory phrases."
+ DEFAULT_SUMMARY_METHOD = "last_msg"
llm_config: Union[Dict, Literal[False]]
def __init__(
@@ -68,12 +70,13 @@ def __init__(
system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "TERMINATE",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
- default_auto_reply: Optional[Union[str, Dict, None]] = "",
+ default_auto_reply: Union[str, Dict] = "",
description: Optional[str] = None,
+ chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
):
"""
Args:
@@ -110,54 +113,57 @@ def __init__(
- timeout (Optional, int): The maximum execution time in seconds.
- last_n_messages (Experimental, int or str): The number of messages to look back for code execution.
If set to 'auto', it will scan backwards through all messages arriving since the agent last spoke, which is typically the last time execution was attempted. (Default: auto)
- llm_config (dict or False): llm inference configuration.
+ llm_config (dict or False or None): llm inference configuration.
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
+ When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`.
To disable llm-based auto reply, set to False.
- default_auto_reply (str or dict or None): default auto reply when no code execution or llm-based reply is generated.
+ When set to None, will use self.DEFAULT_CONFIG, which defaults to False.
+ default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated.
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
+ chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents.
+ Can be used to give the agent a memory by providing the chat history. This will allow the agent to
+ resume previous had conversations. Defaults to an empty chat history.
"""
- super().__init__(name)
+ # we change code_execution_config below and we have to make sure we don't change the input
+ # in case of UserProxyAgent, without this we could even change the default value {}
+ code_execution_config = (
+ code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config
+ )
+
+ self._name = name
# a dictionary of conversations, default value is list
- self._oai_messages = defaultdict(list)
+ if chat_messages is None:
+ self._oai_messages = defaultdict(list)
+ else:
+ self._oai_messages = chat_messages
+
self._oai_system_message = [{"content": system_message, "role": "system"}]
- self.description = description if description is not None else system_message
+ self._description = description if description is not None else system_message
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
+ # Take a copy to avoid modifying the given dict
+ if isinstance(llm_config, dict):
+ try:
+ llm_config = copy.deepcopy(llm_config)
+ except TypeError as e:
+ raise TypeError(
+ "Please implement __deepcopy__ method for each value class in llm_config to support deepcopy."
+ " Refer to the docs for more details: https://microsoft.github.io/autogen/docs/topics/llm_configuration#adding-http-client-in-llm_config-for-proxy"
+ ) from e
- if llm_config is False:
- self.llm_config = False
- self.client = None
- else:
- self.llm_config = self.DEFAULT_CONFIG.copy()
- if isinstance(llm_config, dict):
- self.llm_config.update(llm_config)
- self.client = OpenAIWrapper(**self.llm_config)
+ self._validate_llm_config(llm_config)
+
+ if logging_enabled():
+ log_new_agent(self, locals())
# Initialize standalone client cache object.
self.client_cache = None
- if code_execution_config is None:
- warnings.warn(
- "Using None to signal a default code_execution_config is deprecated. "
- "Use {} to use default or False to disable code execution.",
- stacklevel=2,
- )
-
- self._code_execution_config: Union[Dict, Literal[False]] = (
- {} if code_execution_config is None else code_execution_config
- )
-
- if isinstance(self._code_execution_config, dict):
- use_docker = self._code_execution_config.get("use_docker", None)
- use_docker = decide_use_docker(use_docker)
- check_can_use_docker_or_throw(use_docker)
- self._code_execution_config["use_docker"] = use_docker
-
self.human_input_mode = human_input_mode
self._max_consecutive_auto_reply = (
max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY
@@ -171,11 +177,58 @@ def __init__(
)
self._default_auto_reply = default_auto_reply
self._reply_func_list = []
- self._ignore_async_func_in_sync_chat_list = []
+ self._human_input = []
self.reply_at_receive = defaultdict(bool)
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True)
- self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
+
+ # Setting up code execution.
+ # Do not register code execution reply if code execution is disabled.
+ if code_execution_config is not False:
+ # If code_execution_config is None, set it to an empty dict.
+ if code_execution_config is None:
+ warnings.warn(
+ "Using None to signal a default code_execution_config is deprecated. "
+ "Use {} to use default or False to disable code execution.",
+ stacklevel=2,
+ )
+ code_execution_config = {}
+ if not isinstance(code_execution_config, dict):
+ raise ValueError("code_execution_config must be a dict or False.")
+
+ # We have got a valid code_execution_config.
+ self._code_execution_config = code_execution_config
+
+ if self._code_execution_config.get("executor") is not None:
+ if "use_docker" in self._code_execution_config:
+ raise ValueError(
+ "'use_docker' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead."
+ )
+
+ if "work_dir" in self._code_execution_config:
+ raise ValueError(
+ "'work_dir' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead."
+ )
+
+ if "timeout" in self._code_execution_config:
+ raise ValueError(
+ "'timeout' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead."
+ )
+
+ # Use the new code executor.
+ self._code_executor = CodeExecutorFactory.create(self._code_execution_config)
+ self.register_reply([Agent, None], ConversableAgent._generate_code_execution_reply_using_executor)
+ else:
+ # Legacy code execution using code_utils.
+ use_docker = self._code_execution_config.get("use_docker", None)
+ use_docker = decide_use_docker(use_docker)
+ check_can_use_docker_or_throw(use_docker)
+ self._code_execution_config["use_docker"] = use_docker
+ self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
+ else:
+ # Code execution is disabled.
+ self._code_execution_config = False
+
self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
@@ -189,7 +242,47 @@ def __init__(
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
- self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
+ self.hook_lists: Dict[str, List[Callable]] = {
+ "process_last_received_message": [],
+ "process_all_messages_before_reply": [],
+ "process_message_before_send": [],
+ }
+
+ def _validate_llm_config(self, llm_config):
+ assert llm_config in (None, False) or isinstance(
+ llm_config, dict
+ ), "llm_config must be a dict or False or None."
+ if llm_config is None:
+ llm_config = self.DEFAULT_CONFIG
+ self.llm_config = self.DEFAULT_CONFIG if llm_config is None else llm_config
+ # TODO: more complete validity check
+ if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]:
+ raise ValueError(
+ "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'."
+ )
+ self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config)
+
+ @property
+ def name(self) -> str:
+ """Get the name of the agent."""
+ return self._name
+
+ @property
+ def description(self) -> str:
+ """Get the description of the agent."""
+ return self._description
+
+ @description.setter
+ def description(self, description: str):
+ """Set the description of the agent."""
+ self._description = description
+
+ @property
+ def code_executor(self) -> Optional[CodeExecutor]:
+ """The code executor used by this agent. Returns None if code execution is disabled."""
+ if not hasattr(self, "_code_executor"):
+ return None
+ return self._code_executor
def register_reply(
self,
@@ -200,6 +293,7 @@ def register_reply(
reset_config: Optional[Callable] = None,
*,
ignore_async_in_sync_chat: bool = False,
+ remove_other_reply_funcs: bool = False,
):
"""Register a reply function.
@@ -211,34 +305,29 @@ def register_reply(
from both sync and async chats. However, an async reply function will only be triggered from async
chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered
and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows:
- - if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and
- - if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored.
+ if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and
+ if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored.
Args:
trigger (Agent class, str, Agent instance, callable, or list): the trigger.
- - If a class is provided, the reply function will be called when the sender is an instance of the class.
- - If a string is provided, the reply function will be called when the sender's name matches the string.
- - If an agent instance is provided, the reply function will be called when the sender is the agent instance.
- - If a callable is provided, the reply function will be called when the callable returns True.
- - If a list is provided, the reply function will be called when any of the triggers in the list is activated.
- - If None is provided, the reply function will be called only when the sender is None.
- Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`.
+ If a class is provided, the reply function will be called when the sender is an instance of the class.
+ If a string is provided, the reply function will be called when the sender's name matches the string.
+ If an agent instance is provided, the reply function will be called when the sender is the agent instance.
+ If a callable is provided, the reply function will be called when the callable returns True.
+ If a list is provided, the reply function will be called when any of the triggers in the list is activated.
+ If None is provided, the reply function will be called only when the sender is None.
+ Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`.
reply_func (Callable): the reply function.
The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
- position: the position of the reply function in the reply function list.
- config: the config to be passed to the reply function, see below.
- reset_config: the function to reset the config, see below.
- ignore_async_in_sync_chat: whether to ignore the async reply function in sync chats. If `False`, an exception
- will be raised if an async reply function is registered and a chat is initialized with a sync
- function.
- ```python
- def reply_func(
- recipient: ConversableAgent,
- messages: Optional[List[Dict]] = None,
- sender: Optional[Agent] = None,
- config: Optional[Any] = None,
- ) -> Tuple[bool, Union[str, Dict, None]]:
- ```
+
+ ```python
+ def reply_func(
+ recipient: ConversableAgent,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[Any] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ ```
position (int): the position of the reply function in the reply function list.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.
@@ -246,9 +335,15 @@ def reply_func(
When an agent is reset, the config will be reset to the original value.
reset_config (Callable): the function to reset the config.
The function returns None. Signature: ```def reset_config(config: Any)```
+ ignore_async_in_sync_chat (bool): whether to ignore the async reply function in sync chats. If `False`, an exception
+ will be raised if an async reply function is registered and a chat is initialized with a sync
+ function.
+ remove_other_reply_funcs (bool): whether to remove other reply functions when registering this reply function.
"""
if not isinstance(trigger, (type, str, Agent, Callable, list)):
raise ValueError("trigger must be a class, a string, an agent, a callable or a list.")
+ if remove_other_reply_funcs:
+ self._reply_func_list.clear()
self._reply_func_list.insert(
position,
{
@@ -257,21 +352,112 @@ def reply_func(
"config": copy.copy(config),
"init_config": config,
"reset_config": reset_config,
+ "ignore_async_in_sync_chat": ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func),
},
)
- if ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func):
- self._ignore_async_func_in_sync_chat_list.append(reply_func)
+
+ def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable):
+ """Replace a registered reply function with a new one.
+
+ Args:
+ old_reply_func (Callable): the old reply function to be replaced.
+ new_reply_func (Callable): the new reply function to replace the old one.
+ """
+ for f in self._reply_func_list:
+ if f["reply_func"] == old_reply_func:
+ f["reply_func"] = new_reply_func
+
+ @staticmethod
+ def _summary_from_nested_chats(
+ chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
+ ) -> Tuple[bool, str]:
+ """A simple chat reply function.
+ This function initiate one or a sequence of chats between the "recipient" and the agents in the
+ chat_queue.
+
+ It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
+
+ Returns:
+ Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
+ """
+ last_msg = messages[-1].get("content")
+ chat_to_run = []
+ for i, c in enumerate(chat_queue):
+ current_c = c.copy()
+ if current_c.get("sender") is None:
+ current_c["sender"] = recipient
+ message = current_c.get("message")
+ # If message is not provided in chat_queue, we by default use the last message from the original chat history as the first message in this nested chat (for the first chat in the chat queue).
+ # NOTE: This setting is prone to change.
+ if message is None and i == 0:
+ message = last_msg
+ if callable(message):
+ message = message(recipient, messages, sender, config)
+ # We only run chat that has a valid message. NOTE: This is prone to change dependin on applications.
+ if message:
+ current_c["message"] = message
+ chat_to_run.append(current_c)
+ if not chat_to_run:
+ return True, None
+ res = initiate_chats(chat_to_run)
+ return True, res[-1].summary
+
+ def register_nested_chats(
+ self,
+ chat_queue: List[Dict[str, Any]],
+ trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
+ reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
+ position: int = 2,
+ **kwargs,
+ ) -> None:
+ """Register a nested chat reply function.
+ Args:
+ chat_queue (list): a list of chat objects to be initiated.
+ trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
+ reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
+ The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
+ Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue.
+ ```python
+ def reply_func_from_nested_chats(
+ chat_queue: List[Dict],
+ recipient: ConversableAgent,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[Any] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ ```
+ position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
+ kwargs: Ref to `register_reply` for details.
+ """
+ if reply_func_from_nested_chats == "summary_from_nested_chats":
+ reply_func_from_nested_chats = self._summary_from_nested_chats
+ if not callable(reply_func_from_nested_chats):
+ raise ValueError("reply_func_from_nested_chats must be a callable")
+
+ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
+ return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
+
+ functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
+
+ self.register_reply(
+ trigger,
+ wrapped_reply_func,
+ position,
+ kwargs.get("config"),
+ kwargs.get("reset_config"),
+ ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"),
+ )
@property
- def system_message(self) -> Union[str, List]:
+ def system_message(self) -> str:
"""Return the system message."""
return self._oai_system_message[0]["content"]
- def update_system_message(self, system_message: Union[str, List]):
+ def update_system_message(self, system_message: str) -> None:
"""Update the system message.
Args:
- system_message (str or List): system message for the ChatCompletion inference.
+ system_message (str): system message for the ChatCompletion inference.
"""
self._oai_system_message[0]["content"] = system_message
@@ -298,6 +484,10 @@ def chat_messages(self) -> Dict[Agent, List[Dict]]:
"""A dictionary of conversations from agent to list of messages."""
return self._oai_messages
+ def chat_messages_for_summary(self, agent: Agent) -> List[Dict]:
+ """A list of messages as a conversation to summarize."""
+ return self._oai_messages[agent]
+
def last_message(self, agent: Optional[Agent] = None) -> Optional[Dict]:
"""The last message exchanged with the agent.
@@ -396,6 +586,11 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
if message.get("role") in ["function", "tool"]:
oai_message["role"] = message.get("role")
+ elif "override_role" in message:
+ # If we have a direction to override the role then set the
+ # role accordingly. Used to customise the role for the
+ # select speaker prompt.
+ oai_message["role"] = message.get("override_role")
else:
oai_message["role"] = role
@@ -404,6 +599,15 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
self._oai_messages[conversation_id].append(oai_message)
return True
+ def _process_message_before_send(
+ self, message: Union[Dict, str], recipient: Agent, silent: bool
+ ) -> Union[Dict, str]:
+ """Process the message before sending it to the recipient."""
+ hook_list = self.hook_lists["process_message_before_send"]
+ for hook in hook_list:
+ message = hook(sender=self, message=message, recipient=recipient, silent=silent)
+ return message
+
def send(
self,
message: Union[Dict, str],
@@ -443,6 +647,7 @@ def send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
+ message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
@@ -492,6 +697,7 @@ async def a_send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
+ message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
@@ -503,8 +709,9 @@ async def a_send(
)
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
+ iostream = IOStream.get_default()
# print the message received
- print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
+ iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
message = self._message_to_dict(message)
if message.get("tool_responses"): # Handle tool multi-call responses
@@ -518,11 +725,11 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
id_key = "name"
else:
id_key = "tool_call_id"
-
- func_print = f"***** Response from calling {message['role']} \"{message[id_key]}\" *****"
- print(colored(func_print, "green"), flush=True)
- print(message["content"], flush=True)
- print(colored("*" * len(func_print), "green"), flush=True)
+ id = message.get(id_key, "No id found")
+ func_print = f"***** Response from calling {message['role']} ({id}) *****"
+ iostream.print(colored(func_print, "green"), flush=True)
+ iostream.print(message["content"], flush=True)
+ iostream.print(colored("*" * len(func_print), "green"), flush=True)
else:
content = message.get("content")
if content is not None:
@@ -532,39 +739,42 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
message["context"],
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
- print(content_str(content), flush=True)
+ iostream.print(content_str(content), flush=True)
if "function_call" in message and message["function_call"]:
function_call = dict(message["function_call"])
func_print = (
- f"***** Suggested function Call: {function_call.get('name', '(No function name found)')} *****"
+ f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****"
)
- print(colored(func_print, "green"), flush=True)
- print(
+ iostream.print(colored(func_print, "green"), flush=True)
+ iostream.print(
"Arguments: \n",
function_call.get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
- print(colored("*" * len(func_print), "green"), flush=True)
+ iostream.print(colored("*" * len(func_print), "green"), flush=True)
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
- id = tool_call.get("id", "(No id found)")
+ id = tool_call.get("id", "No tool call id found")
function_call = dict(tool_call.get("function", {}))
- func_print = f"***** Suggested tool Call ({id}): {function_call.get('name', '(No function name found)')} *****"
- print(colored(func_print, "green"), flush=True)
- print(
+ func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****"
+ iostream.print(colored(func_print, "green"), flush=True)
+ iostream.print(
"Arguments: \n",
function_call.get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
- print(colored("*" * len(func_print), "green"), flush=True)
+ iostream.print(colored("*" * len(func_print), "green"), flush=True)
- print("\n", "-" * 80, flush=True, sep="")
+ iostream.print("\n", "-" * 80, flush=True, sep="")
def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
valid = self._append_oai_message(message, "user", sender)
+ if logging_enabled():
+ log_event(self, "received_message", message=message, sender=sender.name, valid=valid)
+
if not valid:
raise ValueError(
"Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
@@ -646,13 +856,20 @@ async def a_receive(
if reply is not None:
await self.a_send(reply, sender, silent=silent)
- def _prepare_chat(self, recipient: "ConversableAgent", clear_history: bool, prepare_recipient: bool = True) -> None:
+ def _prepare_chat(
+ self,
+ recipient: "ConversableAgent",
+ clear_history: bool,
+ prepare_recipient: bool = True,
+ reply_at_receive: bool = True,
+ ) -> None:
self.reset_consecutive_auto_reply_counter(recipient)
- self.reply_at_receive[recipient] = True
+ self.reply_at_receive[recipient] = reply_at_receive
if clear_history:
self.clear_history(recipient)
+ self._human_input = []
if prepare_recipient:
- recipient._prepare_chat(self, clear_history, False)
+ recipient._prepare_chat(self, clear_history, False, reply_at_receive)
def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
@@ -660,12 +877,12 @@ def _raise_exception_on_async_reply_functions(self) -> None:
Raises:
RuntimeError: if any async reply functions are registered.
"""
- reply_functions = {f["reply_func"] for f in self._reply_func_list}.difference(
- self._ignore_async_func_in_sync_chat_list
- )
+ reply_functions = {
+ f["reply_func"] for f in self._reply_func_list if not f.get("ignore_async_in_sync_chat", False)
+ }
async_reply_functions = [f for f in reply_functions if inspect.iscoroutinefunction(f)]
- if async_reply_functions != []:
+ if async_reply_functions:
msg = (
"Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: "
+ ", ".join([f.__name__ for f in async_reply_functions])
@@ -676,70 +893,374 @@ def _raise_exception_on_async_reply_functions(self) -> None:
def initiate_chat(
self,
recipient: "ConversableAgent",
- clear_history: Optional[bool] = True,
+ clear_history: bool = True,
silent: Optional[bool] = False,
- cache: Optional[Cache] = None,
- **context,
- ):
+ cache: Optional[AbstractCache] = None,
+ max_turns: Optional[int] = None,
+ summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD,
+ summary_args: Optional[dict] = {},
+ message: Optional[Union[Dict, str, Callable]] = None,
+ **kwargs,
+ ) -> ChatResult:
"""Initiate a chat with the recipient agent.
Reset the consecutive auto reply counter.
If `clear_history` is True, the chat history with the recipient agent will be cleared.
- `generate_init_message` is called to generate the initial message for the agent.
+
Args:
recipient: the recipient agent.
- clear_history (bool): whether to clear the chat history with the agent.
- silent (bool or None): (Experimental) whether to print the messages for this conversation.
- cache (Cache or None): the cache client to be used for this conversation.
- **context: any context information.
- "message" needs to be provided if the `generate_init_message` method is not overridden.
- Otherwise, input() will be called to get the initial message.
+ clear_history (bool): whether to clear the chat history with the agent. Default is True.
+ silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+ cache (AbstractCache or None): the cache client to be used for this conversation. Default is None.
+ max_turns (int or None): the maximum number of turns for the chat between the two agents. One turn means one conversation round trip. Note that this is different from
+ [max_consecutive_auto_reply](#max_consecutive_auto_reply) which is the maximum number of consecutive auto replies; and it is also different from [max_rounds in GroupChat](./groupchat#groupchat-objects) which is the maximum number of rounds in a group chat session.
+ If max_turns is set to None, the chat will continue until a termination condition is met. Default is None.
+ summary_method (str or callable): a method to get a summary from the chat. Default is DEFAULT_SUMMARY_METHOD, i.e., "last_msg".
+
+ Supported strings are "last_msg" and "reflection_with_llm":
+ - when set to "last_msg", it returns the last message of the dialog as the summary.
+ - when set to "reflection_with_llm", it returns a summary extracted using an llm client.
+ `llm_config` must be set in either the recipient or sender.
+
+ A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g.,
+
+ ```python
+ def my_summary_method(
+ sender: ConversableAgent,
+ recipient: ConversableAgent,
+ summary_args: dict,
+ ):
+ return recipient.last_message(sender)["content"]
+ ```
+ summary_args (dict): a dictionary of arguments to be passed to the summary_method.
+ One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect
+ on the conversation and extract a summary when summary_method is "reflection_with_llm".
+ The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
+ Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system".
+ message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message.
+ - If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context.
+ If dict, it may contain the following reserved fields (either content or tool_calls need to be provided).
+
+ 1. "content": content of the message, can be None.
+ 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls")
+ 3. "tool_calls": a list of dictionaries containing the function name and arguments.
+ 4. "role": role of the message, can be "assistant", "user", "function".
+ This field is only needed to distinguish between "function" or "assistant"/"user".
+ 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
+ 6. "context" (dict): the context of the message, which will be passed to
+ [OpenAIWrapper.create](../oai/client#create).
+
+ - If a callable is provided, it will be called to get the initial message in the form of a string or a dict.
+ If the returned type is dict, it may contain the reserved fields mentioned above.
+
+ Example of a callable message (returning a string):
+
+ ```python
+ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]:
+ carryover = context.get("carryover", "")
+ if isinstance(message, list):
+ carryover = carryover[-1]
+ final_msg = "Write a blogpost." + "\\nContext: \\n" + carryover
+ return final_msg
+ ```
+
+ Example of a callable message (returning a dict):
+
+ ```python
+ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]:
+ final_msg = {}
+ carryover = context.get("carryover", "")
+ if isinstance(message, list):
+ carryover = carryover[-1]
+ final_msg["content"] = "Write a blogpost." + "\\nContext: \\n" + carryover
+ final_msg["context"] = {"prefix": "Today I feel"}
+ return final_msg
+ ```
+ **kwargs: any additional information. It has the following reserved fields:
+ - "carryover": a string or a list of string to specify the carryover information to be passed to this chat.
+ If provided, we will combine this carryover (by attaching a "context: " string and the carryover content after the message content) with the "message" content when generating the initial chat
+ message in `generate_init_message`.
+ - "verbose": a boolean to specify whether to print the message and carryover in a chat. Default is False.
Raises:
RuntimeError: if any async reply functions are registered and not ignored in sync chat.
+
+ Returns:
+ ChatResult: an ChatResult object.
"""
+ _chat_info = locals().copy()
+ _chat_info["sender"] = self
+ consolidate_chat_info(_chat_info, uniform_sender=self)
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
agent.previous_cache = agent.client_cache
agent.client_cache = cache
- self._prepare_chat(recipient, clear_history)
- self.send(self.generate_init_message(**context), recipient, silent=silent)
+ if isinstance(max_turns, int):
+ self._prepare_chat(recipient, clear_history, reply_at_receive=False)
+ for _ in range(max_turns):
+ if _ == 0:
+ if isinstance(message, Callable):
+ msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs)
+ else:
+ msg2send = self.generate_init_message(message, **kwargs)
+ else:
+ msg2send = self.generate_reply(messages=self.chat_messages[recipient], sender=recipient)
+ if msg2send is None:
+ break
+ self.send(msg2send, recipient, request_reply=True, silent=silent)
+ else:
+ self._prepare_chat(recipient, clear_history)
+ if isinstance(message, Callable):
+ msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs)
+ else:
+ msg2send = self.generate_init_message(message, **kwargs)
+ self.send(msg2send, recipient, silent=silent)
+ summary = self._summarize_chat(
+ summary_method,
+ summary_args,
+ recipient,
+ cache=cache,
+ )
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None
+ chat_result = ChatResult(
+ chat_history=self.chat_messages[recipient],
+ summary=summary,
+ cost=gather_usage_summary([self, recipient]),
+ human_input=self._human_input,
+ )
+ return chat_result
async def a_initiate_chat(
self,
recipient: "ConversableAgent",
- clear_history: Optional[bool] = True,
+ clear_history: bool = True,
silent: Optional[bool] = False,
- cache: Optional[Cache] = None,
- **context,
- ):
+ cache: Optional[AbstractCache] = None,
+ max_turns: Optional[int] = None,
+ summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD,
+ summary_args: Optional[dict] = {},
+ message: Optional[Union[str, Callable]] = None,
+ **kwargs,
+ ) -> ChatResult:
"""(async) Initiate a chat with the recipient agent.
Reset the consecutive auto reply counter.
If `clear_history` is True, the chat history with the recipient agent will be cleared.
- `generate_init_message` is called to generate the initial message for the agent.
+ `a_generate_init_message` is called to generate the initial message for the agent.
- Args:
- recipient: the recipient agent.
- clear_history (bool): whether to clear the chat history with the agent.
- silent (bool or None): (Experimental) whether to print the messages for this conversation.
- cache (Cache or None): the cache client to be used for this conversation.
- **context: any context information.
- "message" needs to be provided if the `generate_init_message` method is not overridden.
- Otherwise, input() will be called to get the initial message.
+ Args: Please refer to `initiate_chat`.
+
+ Returns:
+ ChatResult: an ChatResult object.
"""
- self._prepare_chat(recipient, clear_history)
+ _chat_info = locals().copy()
+ _chat_info["sender"] = self
+ consolidate_chat_info(_chat_info, uniform_sender=self)
for agent in [self, recipient]:
agent.previous_cache = agent.client_cache
agent.client_cache = cache
- await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
+ if isinstance(max_turns, int):
+ self._prepare_chat(recipient, clear_history, reply_at_receive=False)
+ for _ in range(max_turns):
+ if _ == 0:
+ if isinstance(message, Callable):
+ msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs)
+ else:
+ msg2send = await self.a_generate_init_message(message, **kwargs)
+ else:
+ msg2send = await self.a_generate_reply(messages=self.chat_messages[recipient], sender=recipient)
+ if msg2send is None:
+ break
+ await self.a_send(msg2send, recipient, request_reply=True, silent=silent)
+ else:
+ self._prepare_chat(recipient, clear_history)
+ if isinstance(message, Callable):
+ msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs)
+ else:
+ msg2send = await self.a_generate_init_message(message, **kwargs)
+ await self.a_send(msg2send, recipient, silent=silent)
+ summary = self._summarize_chat(
+ summary_method,
+ summary_args,
+ recipient,
+ cache=cache,
+ )
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None
+ chat_result = ChatResult(
+ chat_history=self.chat_messages[recipient],
+ summary=summary,
+ cost=gather_usage_summary([self, recipient]),
+ human_input=self._human_input,
+ )
+ return chat_result
+
+ def _summarize_chat(
+ self,
+ summary_method,
+ summary_args,
+ recipient: Optional[Agent] = None,
+ cache: Optional[AbstractCache] = None,
+ ) -> str:
+ """Get a chat summary from an agent participating in a chat.
+
+ Args:
+ summary_method (str or callable): the summary_method to get the summary.
+ The callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g,
+ ```python
+ def my_summary_method(
+ sender: ConversableAgent,
+ recipient: ConversableAgent,
+ summary_args: dict,
+ ):
+ return recipient.last_message(sender)["content"]
+ ```
+ summary_args (dict): a dictionary of arguments to be passed to the summary_method.
+ recipient: the recipient agent in a chat.
+ prompt (str): the prompt used to get a summary when summary_method is "reflection_with_llm".
+
+ Returns:
+ str: a chat summary from the agent.
+ """
+ summary = ""
+ if summary_method is None:
+ return summary
+ if "cache" not in summary_args:
+ summary_args["cache"] = cache
+ if summary_method == "reflection_with_llm":
+ summary_method = self._reflection_with_llm_as_summary
+ elif summary_method == "last_msg":
+ summary_method = self._last_msg_as_summary
+
+ if isinstance(summary_method, Callable):
+ summary = summary_method(self, recipient, summary_args)
+ else:
+ raise ValueError(
+ "If not None, the summary_method must be a string from [`reflection_with_llm`, `last_msg`] or a callable."
+ )
+ return summary
+
+ @staticmethod
+ def _last_msg_as_summary(sender, recipient, summary_args) -> str:
+ """Get a chat summary from the last message of the recipient."""
+ summary = ""
+ try:
+ content = recipient.last_message(sender)["content"]
+ if isinstance(content, str):
+ summary = content.replace("TERMINATE", "")
+ elif isinstance(content, list):
+ # Remove the `TERMINATE` word in the content list.
+ summary = "\n".join(
+ x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
+ )
+ except (IndexError, AttributeError) as e:
+ warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
+ return summary
+
+ @staticmethod
+ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
+ prompt = summary_args.get("summary_prompt")
+ prompt = ConversableAgent.DEFAULT_SUMMARY_PROMPT if prompt is None else prompt
+ if not isinstance(prompt, str):
+ raise ValueError("The summary_prompt must be a string.")
+ msg_list = recipient.chat_messages_for_summary(sender)
+ agent = sender if recipient is None else recipient
+ role = summary_args.get("summary_role", None)
+ if role and not isinstance(role, str):
+ raise ValueError("The summary_role in summary_arg must be a string.")
+ try:
+ summary = sender._reflection_with_llm(
+ prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role
+ )
+ except BadRequestError as e:
+ warnings.warn(
+ f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning
+ )
+ summary = ""
+ return summary
+
+ def _reflection_with_llm(
+ self,
+ prompt,
+ messages,
+ llm_agent: Optional[Agent] = None,
+ cache: Optional[AbstractCache] = None,
+ role: Union[str, None] = None,
+ ) -> str:
+ """Get a chat summary using reflection with an llm client based on the conversation history.
+
+ Args:
+ prompt (str): The prompt (in this method it is used as system prompt) used to get the summary.
+ messages (list): The messages generated as part of a chat conversation.
+ llm_agent: the agent with an llm client.
+ cache (AbstractCache or None): the cache client to be used for this conversation.
+ role (str): the role of the message, usually "system" or "user". Default is "system".
+ """
+ if not role:
+ role = "system"
+
+ system_msg = [
+ {
+ "role": role,
+ "content": prompt,
+ }
+ ]
+
+ messages = messages + system_msg
+ if llm_agent and llm_agent.client is not None:
+ llm_client = llm_agent.client
+ elif self.client is not None:
+ llm_client = self.client
+ else:
+ raise ValueError("No OpenAIWrapper client is found.")
+ response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
+ return response
+
+ def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Check the chat queue and add the "sender" key if it's missing.
+
+ Args:
+ chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information.
+
+ Returns:
+ List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing.
+ """
+ chat_queue_with_sender = []
+ for chat_info in chat_queue:
+ if chat_info.get("sender") is None:
+ chat_info["sender"] = self
+ chat_queue_with_sender.append(chat_info)
+ return chat_queue_with_sender
+
+ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
+ """(Experimental) Initiate chats with multiple agents.
+
+ Args:
+ chat_queue (List[Dict]): a list of dictionaries containing the information of the chats.
+ Each dictionary should contain the input arguments for [`initiate_chat`](conversable_agent#initiate_chat)
+
+ Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue.
+ """
+ _chat_queue = self._check_chat_queue_for_sender(chat_queue)
+ self._finished_chats = initiate_chats(_chat_queue)
+ return self._finished_chats
+
+ async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
+ _chat_queue = self._check_chat_queue_for_sender(chat_queue)
+ self._finished_chats = await a_initiate_chats(_chat_queue)
+ return self._finished_chats
+
+ def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]:
+ """A summary from the finished chats of particular agents."""
+ if chat_index is not None:
+ return self._finished_chats[chat_index]
+ else:
+ return self._finished_chats
def reset(self):
"""Reset the agent."""
@@ -775,17 +1296,28 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser
recipient: the agent with whom the chat history to clear. If None, clear the chat history with all agents.
nr_messages_to_preserve: the number of newest messages to preserve in the chat history.
"""
+ iostream = IOStream.get_default()
if recipient is None:
if nr_messages_to_preserve:
for key in self._oai_messages:
+ nr_messages_to_preserve_internal = nr_messages_to_preserve
+ # if breaking history between function call and function response, save function call message
+ # additionally, otherwise openai will return error
+ first_msg_to_save = self._oai_messages[key][-nr_messages_to_preserve_internal]
+ if "tool_responses" in first_msg_to_save:
+ nr_messages_to_preserve_internal += 1
+ iostream.print(
+ f"Preserving one more message for {self.name} to not divide history between tool call and "
+ f"tool response."
+ )
# Remove messages from history except last `nr_messages_to_preserve` messages.
- self._oai_messages[key] = self._oai_messages[key][-nr_messages_to_preserve:]
+ self._oai_messages[key] = self._oai_messages[key][-nr_messages_to_preserve_internal:]
else:
self._oai_messages.clear()
else:
self._oai_messages[recipient].clear()
if nr_messages_to_preserve:
- print(
+ iostream.print(
colored(
"WARNING: `nr_preserved_messages` is ignored when clearing chat history with a specific agent.",
"yellow",
@@ -805,7 +1337,12 @@ def generate_oai_reply(
return False, None
if messages is None:
messages = self._oai_messages[sender]
+ extracted_response = self._generate_oai_reply_from_client(
+ client, self._oai_system_message + messages, self.client_cache
+ )
+ return (False, None) if extracted_response is None else (True, extracted_response)
+ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[str, Dict, None]:
# unroll tool_responses
all_messages = []
for message in messages:
@@ -819,16 +1356,16 @@ def generate_oai_reply(
all_messages.append(message)
# TODO: #1143 handle token limit exceeded error
- response = client.create(
- context=messages[-1].pop("context", None),
- messages=self._oai_system_message + all_messages,
- cache=self.client_cache,
+ response = llm_client.create(
+ context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self
)
+ extracted_response = llm_client.extract_text_or_completion_object(response)[0]
- extracted_response = client.extract_text_or_completion_object(response)[0]
-
+ if extracted_response is None:
+ warnings.warn(f"Extracted_response from {response} is None.", UserWarning)
+ return None
# ensure function and tool calls will be accepted when sent back to the LLM
- if not isinstance(extracted_response, str):
+ if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"):
extracted_response = model_dump(extracted_response)
if isinstance(extracted_response, dict):
if extracted_response.get("function_call"):
@@ -837,7 +1374,13 @@ def generate_oai_reply(
)
for tool_call in extracted_response.get("tool_calls") or []:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
- return True, extracted_response
+ # Remove id and type if they are not present.
+ # This is to make the tool call object compatible with Mistral API.
+ if tool_call.get("id") is None:
+ tool_call.pop("id")
+ if tool_call.get("type") is None:
+ tool_call.pop("type")
+ return extracted_response
async def a_generate_oai_reply(
self,
@@ -846,10 +1389,90 @@ async def a_generate_oai_reply(
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai asynchronously."""
+ iostream = IOStream.get_default()
+
+ def _generate_oai_reply(
+ self, iostream: IOStream, *args: Any, **kwargs: Any
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ with IOStream.set_default(iostream):
+ return self.generate_oai_reply(*args, **kwargs)
+
return await asyncio.get_event_loop().run_in_executor(
- None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config)
+ None,
+ functools.partial(
+ _generate_oai_reply, self=self, iostream=iostream, messages=messages, sender=sender, config=config
+ ),
)
+ def _generate_code_execution_reply_using_executor(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[Union[Dict, Literal[False]]] = None,
+ ):
+ """Generate a reply using code executor."""
+ iostream = IOStream.get_default()
+
+ if config is not None:
+ raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.")
+ if self._code_execution_config is False:
+ return False, None
+ if messages is None:
+ messages = self._oai_messages[sender]
+ last_n_messages = self._code_execution_config.get("last_n_messages", "auto")
+
+ if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto":
+ raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.")
+
+ num_messages_to_scan = last_n_messages
+ if last_n_messages == "auto":
+ # Find when the agent last spoke
+ num_messages_to_scan = 0
+ for message in reversed(messages):
+ if "role" not in message:
+ break
+ elif message["role"] != "user":
+ break
+ else:
+ num_messages_to_scan += 1
+ num_messages_to_scan = min(len(messages), num_messages_to_scan)
+ messages_to_scan = messages[-num_messages_to_scan:]
+
+ # iterate through the last n messages in reverse
+ # if code blocks are found, execute the code blocks and return the output
+ # if no code blocks are found, continue
+ for message in reversed(messages_to_scan):
+ if not message["content"]:
+ continue
+ code_blocks = self._code_executor.code_extractor.extract_code_blocks(message["content"])
+ if len(code_blocks) == 0:
+ continue
+
+ num_code_blocks = len(code_blocks)
+ if num_code_blocks == 1:
+ iostream.print(
+ colored(
+ f"\n>>>>>>>> EXECUTING CODE BLOCK (inferred language is {code_blocks[0].language})...",
+ "red",
+ ),
+ flush=True,
+ )
+ else:
+ iostream.print(
+ colored(
+ f"\n>>>>>>>> EXECUTING {num_code_blocks} CODE BLOCKS (inferred languages are [{', '.join([x.language for x in code_blocks])}])...",
+ "red",
+ ),
+ flush=True,
+ )
+
+ # found code blocks, execute code.
+ code_result = self._code_executor.execute_code_blocks(code_blocks)
+ exitcode2str = "execution succeeded" if code_result.exit_code == 0 else "execution failed"
+ return True, f"exitcode: {code_result.exit_code} ({exitcode2str})\nCode output: {code_result.output}"
+
+ return False, None
+
def generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
@@ -986,7 +1609,6 @@ def generate_tool_calls_reply(
message = messages[-1]
tool_returns = []
for tool_call in message.get("tool_calls", []):
- id = tool_call["id"]
function_call = tool_call.get("function", {})
func = self._function_map.get(function_call.get("name", None), None)
if inspect.iscoroutinefunction(func):
@@ -1004,13 +1626,24 @@ def generate_tool_calls_reply(
loop.close()
else:
_, func_return = self.execute_function(function_call)
- tool_returns.append(
- {
- "tool_call_id": id,
+ content = func_return.get("content", "")
+ if content is None:
+ content = ""
+ tool_call_id = tool_call.get("id", None)
+ if tool_call_id is not None:
+ tool_call_response = {
+ "tool_call_id": tool_call_id,
"role": "tool",
- "content": func_return.get("content", ""),
+ "content": content,
}
- )
+ else:
+ # Do not include tool_call_id if it is not present.
+ # This is to make the tool call object compatible with Mistral API.
+ tool_call_response = {
+ "role": "tool",
+ "content": content,
+ }
+ tool_returns.append(tool_call_response)
if tool_returns:
return True, {
"role": "tool",
@@ -1077,18 +1710,19 @@ def check_termination_and_human_reply(
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
- # Function implementation...
+ iostream = IOStream.get_default()
if config is None:
config = self
if messages is None:
- messages = self._oai_messages[sender]
+ messages = self._oai_messages[sender] if sender else []
message = messages[-1]
reply = ""
no_human_input_msg = ""
+ sender_name = "the sender" if sender is None else sender.name
if self.human_input_mode == "ALWAYS":
reply = self.get_human_input(
- f"Provide feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
+ f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1101,9 +1735,9 @@ def check_termination_and_human_reply(
# self.human_input_mode == "TERMINATE":
terminate = self._is_termination_msg(message)
reply = self.get_human_input(
- f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: "
+ f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: "
if terminate
- else f"Please give feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: "
+ else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1114,7 +1748,7 @@ def check_termination_and_human_reply(
else:
# self.human_input_mode == "TERMINATE":
reply = self.get_human_input(
- f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: "
+ f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1122,7 +1756,7 @@ def check_termination_and_human_reply(
# print the no_human_input_msg
if no_human_input_msg:
- print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
+ iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
# stop the conversation
if reply == "exit":
@@ -1162,7 +1796,7 @@ def check_termination_and_human_reply(
# increment the consecutive_auto_reply_counter
self._consecutive_auto_reply_counter[sender] += 1
if self.human_input_mode != "NEVER":
- print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
+ iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
return False, None
@@ -1189,16 +1823,19 @@ async def a_check_termination_and_human_reply(
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
+ iostream = IOStream.get_default()
+
if config is None:
config = self
if messages is None:
- messages = self._oai_messages[sender]
- message = messages[-1]
+ messages = self._oai_messages[sender] if sender else []
+ message = messages[-1] if messages else {}
reply = ""
no_human_input_msg = ""
+ sender_name = "the sender" if sender is None else sender.name
if self.human_input_mode == "ALWAYS":
reply = await self.a_get_human_input(
- f"Provide feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
+ f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1211,9 +1848,9 @@ async def a_check_termination_and_human_reply(
# self.human_input_mode == "TERMINATE":
terminate = self._is_termination_msg(message)
reply = await self.a_get_human_input(
- f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: "
+ f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: "
if terminate
- else f"Please give feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: "
+ else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1224,7 +1861,7 @@ async def a_check_termination_and_human_reply(
else:
# self.human_input_mode == "TERMINATE":
reply = await self.a_get_human_input(
- f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: "
+ f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1232,7 +1869,7 @@ async def a_check_termination_and_human_reply(
# print the no_human_input_msg
if no_human_input_msg:
- print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
+ iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
# stop the conversation
if reply == "exit":
@@ -1272,15 +1909,15 @@ async def a_check_termination_and_human_reply(
# increment the consecutive_auto_reply_counter
self._consecutive_auto_reply_counter[sender] += 1
if self.human_input_mode != "NEVER":
- print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
+ iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
return False, None
def generate_reply(
self,
- messages: Optional[List[Dict]] = None,
- sender: Optional[Agent] = None,
- exclude: Optional[List[Callable]] = None,
+ messages: Optional[List[Dict[str, Any]]] = None,
+ sender: Optional["Agent"] = None,
+ **kwargs: Any,
) -> Union[str, Dict, None]:
"""Reply based on the conversation history and the sender.
@@ -1301,9 +1938,10 @@ def generate_reply(
Args:
messages: a list of messages in the conversation history.
- default_reply (str or dict): default reply.
sender: sender of an Agent instance.
- exclude: a list of functions to exclude.
+
+ Additional keyword arguments:
+ exclude (List[Callable]): a list of reply functions to be excluded.
Returns:
str or dict or None: reply. None if no reply is generated.
@@ -1318,26 +1956,39 @@ def generate_reply(
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
- messages = self.process_last_message(messages)
+ messages = self.process_last_received_message(messages)
+
+ # Call the hookable method that gives registered hooks a chance to process all messages.
+ # Message modifications do not affect the incoming messages or self._oai_messages.
+ messages = self.process_all_messages_before_reply(messages)
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
- if exclude and reply_func in exclude:
+ if "exclude" in kwargs and reply_func in kwargs["exclude"]:
continue
if inspect.iscoroutinefunction(reply_func):
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
+ if logging_enabled():
+ log_event(
+ self,
+ "reply_func_executed",
+ reply_func_module=reply_func.__module__,
+ reply_func_name=reply_func.__name__,
+ final=final,
+ reply=reply,
+ )
if final:
return reply
return self._default_auto_reply
async def a_generate_reply(
self,
- messages: Optional[List[Dict]] = None,
- sender: Optional[Agent] = None,
- exclude: Optional[List[Callable]] = None,
- ) -> Union[str, Dict, None]:
+ messages: Optional[List[Dict[str, Any]]] = None,
+ sender: Optional["Agent"] = None,
+ **kwargs: Any,
+ ) -> Union[str, Dict[str, Any], None]:
"""(async) Reply based on the conversation history and the sender.
Either messages or sender must be provided.
@@ -1357,9 +2008,10 @@ async def a_generate_reply(
Args:
messages: a list of messages in the conversation history.
- default_reply (str or dict): default reply.
sender: sender of an Agent instance.
- exclude: a list of functions to exclude.
+
+ Additional keyword arguments:
+ exclude (List[Callable]): a list of reply functions to be excluded.
Returns:
str or dict or None: reply. None if no reply is generated.
@@ -1372,14 +2024,19 @@ async def a_generate_reply(
if messages is None:
messages = self._oai_messages[sender]
+ # Call the hookable method that gives registered hooks a chance to process all messages.
+ # Message modifications do not affect the incoming messages or self._oai_messages.
+ messages = self.process_all_messages_before_reply(messages)
+
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
- messages = self.process_last_message(messages)
+ messages = self.process_last_received_message(messages)
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
- if exclude and reply_func in exclude:
+ if "exclude" in kwargs and reply_func in kwargs["exclude"]:
continue
+
if self._match_trigger(reply_func_tuple["trigger"], sender):
if inspect.iscoroutinefunction(reply_func):
final, reply = await reply_func(
@@ -1391,7 +2048,7 @@ async def a_generate_reply(
return reply
return self._default_auto_reply
- def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Agent) -> bool:
+ def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Optional[Agent]) -> bool:
"""Check if the sender matches the trigger.
Args:
@@ -1408,6 +2065,8 @@ def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List],
if trigger is None:
return sender is None
elif isinstance(trigger, str):
+ if sender is None:
+ raise SenderRequired()
return trigger == sender.name
elif isinstance(trigger, type):
return isinstance(sender, trigger)
@@ -1416,7 +2075,7 @@ def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List],
return trigger == sender
elif isinstance(trigger, Callable):
rst = trigger(sender)
- assert rst in [True, False], f"trigger {trigger} must return a boolean value."
+ assert isinstance(rst, bool), f"trigger {trigger} must return a boolean value."
return rst
elif isinstance(trigger, list):
return any(self._match_trigger(t, sender) for t in trigger)
@@ -1434,7 +2093,10 @@ def get_human_input(self, prompt: str) -> str:
Returns:
str: human input.
"""
- reply = input(prompt)
+ iostream = IOStream.get_default()
+
+ reply = iostream.input(prompt)
+ self._human_input.append(reply)
return reply
async def a_get_human_input(self, prompt: str) -> str:
@@ -1448,7 +2110,8 @@ async def a_get_human_input(self, prompt: str) -> str:
Returns:
str: human input.
"""
- reply = input(prompt)
+ loop = asyncio.get_running_loop()
+ reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt))
return reply
def run_code(self, code, **kwargs):
@@ -1469,12 +2132,14 @@ def run_code(self, code, **kwargs):
def execute_code_blocks(self, code_blocks):
"""Execute the code blocks and return the result."""
+ iostream = IOStream.get_default()
+
logs_all = ""
for i, code_block in enumerate(code_blocks):
lang, code = code_block
if not lang:
lang = infer_lang(code)
- print(
+ iostream.print(
colored(
f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...",
"red",
@@ -1483,7 +2148,7 @@ def execute_code_blocks(self, code_blocks):
)
if lang in ["bash", "shell", "sh"]:
exitcode, logs, image = self.run_code(code, lang=lang, **self._code_execution_config)
- elif lang in ["python", "Python"]:
+ elif lang in PYTHON_VARIANTS:
if code.startswith("# filename: "):
filename = code[11 : code.find("\n")].strip()
else:
@@ -1555,6 +2220,8 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
"""
+ iostream = IOStream.get_default()
+
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
@@ -1570,7 +2237,7 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict
# Try to execute the function
if arguments is not None:
- print(
+ iostream.print(
colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"),
flush=True,
)
@@ -1583,7 +2250,7 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict
content = f"Error: Function {func_name} not found."
if verbose:
- print(
+ iostream.print(
colored(f"\nInput arguments: {arguments}\nOutput:\n{content}", "magenta"),
flush=True,
)
@@ -1610,6 +2277,8 @@ async def a_execute_function(self, func_call):
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
"""
+ iostream = IOStream.get_default()
+
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
@@ -1625,7 +2294,7 @@ async def a_execute_function(self, func_call):
# Try to execute the function
if arguments is not None:
- print(
+ iostream.print(
colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"),
flush=True,
)
@@ -1647,43 +2316,98 @@ async def a_execute_function(self, func_call):
"content": str(content),
}
- def generate_init_message(self, **context) -> Union[str, Dict]:
+ def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
"""Generate the initial message for the agent.
-
- Override this function to customize the initial message based on user's request.
- If not overridden, "message" needs to be provided in the context.
+ If message is None, input() will be called to get the initial message.
Args:
- **context: any context information, and "message" parameter needs to be provided.
- If message is not given, prompt for it via input()
+ message (str or None): the message to be processed.
+ **kwargs: any additional information. It has the following reserved fields:
+ "carryover": a string or a list of string to specify the carryover information to be passed to this chat. It can be a string or a list of string.
+ If provided, we will combine this carryover with the "message" content when generating the initial chat
+ message.
+ Returns:
+ str or dict: the processed message.
"""
- if "message" not in context:
- context["message"] = self.get_human_input(">")
- return context["message"]
+ if message is None:
+ message = self.get_human_input(">")
- async def a_generate_init_message(self, **context) -> Union[str, Dict]:
- """Generate the initial message for the agent.
+ return self._handle_carryover(message, kwargs)
+
+ def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
+ if not kwargs.get("carryover"):
+ return message
+
+ if isinstance(message, str):
+ return self._process_carryover(message, kwargs)
+
+ elif isinstance(message, dict):
+ if isinstance(message.get("content"), str):
+ # Makes sure the original message is not mutated
+ message = message.copy()
+ message["content"] = self._process_carryover(message["content"], kwargs)
+ elif isinstance(message.get("content"), list):
+ # Makes sure the original message is not mutated
+ message = message.copy()
+ message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
+ else:
+ raise InvalidCarryOverType("Carryover should be a string or a list of strings.")
- Override this function to customize the initial message based on user's request.
- If not overridden, "message" needs to be provided in the context.
+ return message
+
+ def _process_carryover(self, content: str, kwargs: dict) -> str:
+ # Makes sure there's a carryover
+ if not kwargs.get("carryover"):
+ return content
+
+ # if carryover is string
+ if isinstance(kwargs["carryover"], str):
+ content += "\nContext: \n" + kwargs["carryover"]
+ elif isinstance(kwargs["carryover"], list):
+ content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
+ else:
+ raise InvalidCarryOverType(
+ "Carryover should be a string or a list of strings. Not adding carryover to the message."
+ )
+ return content
+
+ def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
+ """Prepends the context to a multimodal message."""
+ # Makes sure there's a carryover
+ if not kwargs.get("carryover"):
+ return content
+
+ return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content
+
+ async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
+ """Generate the initial message for the agent.
+ If message is None, input() will be called to get the initial message.
Args:
- **context: any context information, and "message" parameter needs to be provided.
- If message is not given, prompt for it via input()
+ Please refer to `generate_init_message` for the description of the arguments.
+
+ Returns:
+ str or dict: the processed message.
"""
- if "message" not in context:
- context["message"] = await self.a_get_human_input(">")
- return context["message"]
+ if message is None:
+ message = await self.a_get_human_input(">")
- def register_function(self, function_map: Dict[str, Callable]):
+ return self._handle_carryover(message, kwargs)
+
+ def register_function(self, function_map: Dict[str, Union[Callable, None]]):
"""Register functions to the agent.
Args:
- function_map: a dictionary mapping function names to functions.
+ function_map: a dictionary mapping function names to functions. if function_map[name] is None, the function will be removed from the function_map.
"""
- for name in function_map.keys():
+ for name, func in function_map.items():
self._assert_valid_name(name)
+ if func is None and name not in self._function_map.keys():
+ warnings.warn(f"The function {name} to remove doesn't exist", name)
+ if name in self._function_map:
+ warnings.warn(f"Function '{name}' is being overridden.", UserWarning)
self._function_map.update(function_map)
+ self._function_map = {k: v for k, v in self._function_map.items() if v is not None}
def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None):
"""update a function_signature in the LLM configuration for function_call.
@@ -1711,8 +2435,16 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)
func for func in self.llm_config["functions"] if func["name"] != func_sig
]
else:
+ if not isinstance(func_sig, dict):
+ raise ValueError(
+ f"The function signature must be of the type dict. Received function signature type {type(func_sig)}"
+ )
+
self._assert_valid_name(func_sig["name"])
if "functions" in self.llm_config.keys():
+ if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]):
+ warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning)
+
self.llm_config["functions"] = [
func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"]
] + [func_sig]
@@ -1747,8 +2479,14 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
tool for tool in self.llm_config["tools"] if tool["function"]["name"] != tool_sig
]
else:
+ if not isinstance(tool_sig, dict):
+ raise ValueError(
+ f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}"
+ )
self._assert_valid_name(tool_sig["function"]["name"])
- if "tools" in self.llm_config.keys():
+ if "tools" in self.llm_config:
+ if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]):
+ warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning)
self.llm_config["tools"] = [
tool
for tool in self.llm_config["tools"]
@@ -1788,13 +2526,14 @@ def _wrap_function(self, func: F) -> F:
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
retval = func(*args, **kwargs)
-
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
@load_basemodels_if_needed
@functools.wraps(func)
async def _a_wrapped_func(*args, **kwargs):
retval = await func(*args, **kwargs)
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
@@ -1955,13 +2694,13 @@ def register_model_client(self, model_client_cls: ModelClient, **kwargs):
"""
self.client.register_model_client(model_client_cls, **kwargs)
- def register_hook(self, hookable_method: Callable, hook: Callable):
+ def register_hook(self, hookable_method: str, hook: Callable):
"""
Registers a hook to be called by a hookable method, in order to add a capability to the agent.
Registered hooks are kept in lists (one per hookable method), and are called in their order of registration.
Args:
- hookable_method: A hookable method implemented by ConversableAgent.
+ hookable_method: A hookable method name implemented by ConversableAgent.
hook: A method implemented by a subclass of AgentCapability.
"""
assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method."
@@ -1969,14 +2708,29 @@ def register_hook(self, hookable_method: Callable, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)
- def process_last_message(self, messages):
+ def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
+ """
+ Calls any registered capability hooks to process all messages, potentially modifying the messages.
+ """
+ hook_list = self.hook_lists["process_all_messages_before_reply"]
+ # If no hooks are registered, or if there are no messages to process, return the original message list.
+ if len(hook_list) == 0 or messages is None:
+ return messages
+
+ # Call each hook (in order of registration) to process the messages.
+ processed_messages = messages
+ for hook in hook_list:
+ processed_messages = hook(processed_messages)
+ return processed_messages
+
+ def process_last_received_message(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
"""
# If any required condition is not met, return the original message list.
- hook_list = self.hook_lists[self.process_last_message]
+ hook_list = self.hook_lists["process_last_received_message"]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:
@@ -1990,30 +2744,36 @@ def process_last_message(self, messages):
return messages # Last message contains a context key.
if "content" not in last_message:
return messages # Last message has no content.
- user_text = last_message["content"]
- if not isinstance(user_text, str):
- return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
- if user_text == "exit":
+
+ user_content = last_message["content"]
+ if not isinstance(user_content, str) and not isinstance(user_content, list):
+ # if the user_content is a string, it is for regular LLM
+ # if the user_content is a list, it should follow the multimodal LMM format.
+ return messages
+ if user_content == "exit":
return messages # Last message is an exit command.
# Call each hook (in order of registration) to process the user's message.
- processed_user_text = user_text
+ processed_user_content = user_content
for hook in hook_list:
- processed_user_text = hook(processed_user_text)
- if processed_user_text == user_text:
+ processed_user_content = hook(processed_user_content)
+
+ if processed_user_content == user_content:
return messages # No hooks actually modified the user's message.
# Replace the last user message with the expanded one.
messages = messages.copy()
- messages[-1]["content"] = processed_user_text
+ messages[-1]["content"] = processed_user_content
return messages
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""
+ iostream = IOStream.get_default()
+
if self.client is None:
- print(f"No cost incurred from agent '{self.name}'.")
+ iostream.print(f"No cost incurred from agent '{self.name}'.")
else:
- print(f"Agent '{self.name}':")
+ iostream.print(f"Agent '{self.name}':")
self.client.print_usage_summary(mode)
def get_actual_usage(self) -> Union[None, Dict[str, int]]:
diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py
index 611cd4795d2..48f11d526cc 100644
--- a/autogen/agentchat/groupchat.py
+++ b/autogen/agentchat/groupchat.py
@@ -1,12 +1,20 @@
+import copy
+import json
import logging
import random
import re
import sys
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Union, Tuple
+from dataclasses import dataclass, field
+from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from ..code_utils import content_str
+from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
+from ..formatting_utils import colored
+from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
+from ..io.base import IOStream
+from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent
+from .chat import ChatResult
from .conversable_agent import ConversableAgent
logger = logging.getLogger(__name__)
@@ -24,28 +32,226 @@ class GroupChat:
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
+ - select_speaker_message_template: customize the select speaker message (used in "auto" speaker selection), which appears first in the message context and generally includes the agent descriptions and list of agents. If the string contains "{roles}" it will replaced with the agent's and their role descriptions. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You are in a role play game. The following roles are available:
+ {roles}.
+ Read the following conversation.
+ Then select the next role from {agentlist} to play. Only return the role."
+ - select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
+ To ignore this prompt being used, set this to None. If set to None, ensure your instructions for selecting a speaker are in the select_speaker_message_template string.
+ - select_speaker_auto_multiple_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains multiple agent names. This prompt guides the LLM to return just one agent name. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."
+ - select_speaker_auto_none_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains no agent names. This prompt guides the LLM to return an agent name and provides a list of agent names. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ The only names that are accepted are {agentlist}.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
Could be any of the following (case insensitive), will raise ValueError if not recognized:
- "auto": the next speaker is selected automatically by LLM.
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- - allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True, in which case all speakers are allowed to speak consecutively. If allow_repeat_speaker is a list of Agents, then only those listed agents are allowed to repeat. If set to False, then no speakers are allowed to repeat.
+ - a customized speaker selection function (Callable): the function will be called to select the next speaker.
+ The function should take the last speaker and the group chat as input and return one of the following:
+ 1. an `Agent` class, it must be one of the agents in the group chat.
+ 2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use.
+ 3. None, which would terminate the conversation gracefully.
+ ```python
+ def custom_speaker_selection_func(
+ last_speaker: Agent, groupchat: GroupChat
+ ) -> Union[Agent, str, None]:
+ ```
+ - max_retries_for_selecting_speaker: the maximum number of times the speaker selection requery process will run.
+ If, during speaker selection, multiple agent names or no agent names are returned by the LLM as the next agent, it will be queried again up to the maximum number
+ of times until a single agent is returned or it exhausts the maximum attempts.
+ Applies only to "auto" speaker selection method.
+ Default is 2.
+ - select_speaker_auto_verbose: whether to output the select speaker responses and selections
+ If set to True, the outputs from the two agents in the nested select speaker chat will be output, along with
+ whether the responses were successful, or not, in selecting an agent
+ Applies only to "auto" speaker selection method.
+ - allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
+ Default is True, in which case all speakers are allowed to speak consecutively.
+ If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
+ If set to False, then no speakers are allowed to repeat.
+ `allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive.
+ - allowed_or_disallowed_speaker_transitions: dict.
+ The keys are source agents, and the values are agents that the key agent can/can't transit to,
+ depending on speaker_transitions_type. Default is None, which means all agents can transit to all other agents.
+ `allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive.
+ - speaker_transitions_type: whether the speaker_transitions_type is a dictionary containing lists of allowed agents or disallowed agents.
+ "allowed" means the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of allowed agents.
+ If set to "disallowed", then the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of disallowed agents.
+ Must be supplied if `allowed_or_disallowed_speaker_transitions` is not None.
- enable_clear_history: enable possibility to clear history of messages for agents manually by providing
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
+ - send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False)
+ - role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system')
"""
agents: List[Agent]
messages: List[Dict]
- max_round: Optional[int] = 10
- admin_name: Optional[str] = "Admin"
- func_call_filter: Optional[bool] = True
- speaker_selection_method: Optional[str] = "auto"
- allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = True
- enable_clear_history: Optional[bool] = False
+ max_round: int = 10
+ admin_name: str = "Admin"
+ func_call_filter: bool = True
+ speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto"
+ max_retries_for_selecting_speaker: int = 2
+ allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
+ allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
+ speaker_transitions_type: Literal["allowed", "disallowed", None] = None
+ enable_clear_history: bool = False
+ send_introductions: bool = False
+ select_speaker_message_template: str = """You are in a role play game. The following roles are available:
+ {roles}.
+ Read the following conversation.
+ Then select the next role from {agentlist} to play. Only return the role."""
+ select_speaker_prompt_template: str = (
+ "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
+ )
+ select_speaker_auto_multiple_template: str = """You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."""
+ select_speaker_auto_none_template: str = """You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ The only names that are accepted are {agentlist}.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."""
+ select_speaker_auto_verbose: Optional[bool] = False
+ role_for_select_speaker_messages: Optional[str] = "system"
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
+ _VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None]
+
+ # Define a class attribute for the default introduction message
+ DEFAULT_INTRO_MSG = (
+ "Hello everyone. We have assembled a great team today to answer questions and solve tasks. In attendance are:"
+ )
+
+ allowed_speaker_transitions_dict: Dict = field(init=False)
+
+ def __post_init__(self):
+ # Post init steers clears of the automatically generated __init__ method from dataclass
+
+ if self.allow_repeat_speaker is not None and not isinstance(self.allow_repeat_speaker, (bool, list)):
+ raise ValueError("GroupChat allow_repeat_speaker should be a bool or a list of Agents.")
+
+ # Here, we create allowed_speaker_transitions_dict from the supplied allowed_or_disallowed_speaker_transitions and speaker_transitions_type, and lastly checks for validity.
+
+ # Check input
+ if self.speaker_transitions_type is not None:
+ self.speaker_transitions_type = self.speaker_transitions_type.lower()
+
+ if self.speaker_transitions_type not in self._VALID_SPEAKER_TRANSITIONS_TYPE:
+ raise ValueError(
+ f"GroupChat speaker_transitions_type is set to '{self.speaker_transitions_type}'. "
+ f"It should be one of {self._VALID_SPEAKER_TRANSITIONS_TYPE} (case insensitive). "
+ )
+
+ # If both self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None, set allow_repeat_speaker to True to ensure backward compatibility
+ # Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451541204
+ if self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None:
+ self.allow_repeat_speaker = True
+
+ # self.allowed_or_disallowed_speaker_transitions and self.allow_repeat_speaker are mutually exclusive parameters.
+ # Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451266661
+ if self.allowed_or_disallowed_speaker_transitions is not None and self.allow_repeat_speaker is not None:
+ raise ValueError(
+ "Don't provide both allowed_or_disallowed_speaker_transitions and allow_repeat_speaker in group chat. "
+ "Please set one of them to None."
+ )
+
+ # Asks the user to specify whether the speaker_transitions_type is allowed or disallowed if speaker_transitions_type is supplied
+ # Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451259524
+ if self.allowed_or_disallowed_speaker_transitions is not None and self.speaker_transitions_type is None:
+ raise ValueError(
+ "GroupChat allowed_or_disallowed_speaker_transitions is not None, but speaker_transitions_type is None. "
+ "Please set speaker_transitions_type to either 'allowed' or 'disallowed'."
+ )
+
+ # Inferring self.allowed_speaker_transitions_dict
+ # Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is None, using allow_repeat_speaker
+ if self.allowed_or_disallowed_speaker_transitions is None:
+ self.allowed_speaker_transitions_dict = {}
+
+ # Create a fully connected allowed_speaker_transitions_dict not including self loops
+ for agent in self.agents:
+ self.allowed_speaker_transitions_dict[agent] = [
+ other_agent for other_agent in self.agents if other_agent != agent
+ ]
+
+ # If self.allow_repeat_speaker is True, add self loops to all agents
+ if self.allow_repeat_speaker is True:
+ for agent in self.agents:
+ self.allowed_speaker_transitions_dict[agent].append(agent)
+
+ # Else if self.allow_repeat_speaker is a list of Agents, add self loops to the agents in the list
+ elif isinstance(self.allow_repeat_speaker, list):
+ for agent in self.allow_repeat_speaker:
+ self.allowed_speaker_transitions_dict[agent].append(agent)
+
+ # Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is not None, using allowed_or_disallowed_speaker_transitions
+ else:
+ # Process based on speaker_transitions_type
+ if self.speaker_transitions_type == "allowed":
+ self.allowed_speaker_transitions_dict = self.allowed_or_disallowed_speaker_transitions
+ else:
+ # Logic for processing disallowed allowed_or_disallowed_speaker_transitions to allowed_speaker_transitions_dict
+ self.allowed_speaker_transitions_dict = invert_disallowed_to_allowed(
+ self.allowed_or_disallowed_speaker_transitions, self.agents
+ )
+
+ # Check for validity
+ check_graph_validity(
+ allowed_speaker_transitions_dict=self.allowed_speaker_transitions_dict,
+ agents=self.agents,
+ )
+
+ # Check select speaker messages, prompts, roles, and retries have values
+ if self.select_speaker_message_template is None or len(self.select_speaker_message_template) == 0:
+ raise ValueError("select_speaker_message_template cannot be empty or None.")
+
+ if self.select_speaker_prompt_template is not None and len(self.select_speaker_prompt_template) == 0:
+ self.select_speaker_prompt_template = None
+
+ if self.role_for_select_speaker_messages is None or len(self.role_for_select_speaker_messages) == 0:
+ raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
+
+ if self.select_speaker_auto_multiple_template is None or len(self.select_speaker_auto_multiple_template) == 0:
+ raise ValueError("select_speaker_auto_multiple_template cannot be empty or None.")
+
+ if self.select_speaker_auto_none_template is None or len(self.select_speaker_auto_none_template) == 0:
+ raise ValueError("select_speaker_auto_none_template cannot be empty or None.")
+
+ if self.max_retries_for_selecting_speaker is None or len(self.role_for_select_speaker_messages) == 0:
+ raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
+
+ # Validate max select speakers retries
+ if self.max_retries_for_selecting_speaker is None or not isinstance(
+ self.max_retries_for_selecting_speaker, int
+ ):
+ raise ValueError("max_retries_for_selecting_speaker cannot be None or non-int")
+ elif self.max_retries_for_selecting_speaker < 0:
+ raise ValueError("max_retries_for_selecting_speaker must be greater than or equal to zero")
+
+ # Validate select_speaker_auto_verbose
+ if self.select_speaker_auto_verbose is None or not isinstance(self.select_speaker_auto_verbose, bool):
+ raise ValueError("select_speaker_auto_verbose cannot be None or non-bool")
@property
def agent_names(self) -> List[str]:
@@ -68,15 +274,36 @@ def append(self, message: Dict, speaker: Agent):
message["content"] = content_str(message["content"])
self.messages.append(message)
- def agent_by_name(self, name: str) -> Agent:
- """Returns the agent with a given name."""
- return self.agents[self.agent_names.index(name)]
+ def agent_by_name(
+ self, name: str, recursive: bool = False, raise_on_name_conflict: bool = False
+ ) -> Optional[Agent]:
+ """Returns the agent with a given name. If recursive is True, it will search in nested teams."""
+ agents = self.nested_agents() if recursive else self.agents
+ filtered_agents = [agent for agent in agents if agent.name == name]
+
+ if raise_on_name_conflict and len(filtered_agents) > 1:
+ raise AgentNameConflict()
+
+ return filtered_agents[0] if filtered_agents else None
+
+ def nested_agents(self) -> List[Agent]:
+ """Returns all agents in the group chat manager."""
+ agents = self.agents.copy()
+ for agent in agents:
+ if isinstance(agent, GroupChatManager):
+ # Recursive call for nested teams
+ agents.extend(agent.groupchat.nested_agents())
+ return agents
def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agent:
"""Return the next agent in the list."""
if agents is None:
agents = self.agents
+ # Ensure the provided list of agents is a subset of self.agents
+ if not set(agents).issubset(set(self.agents)):
+ raise UndefinedNextAgent()
+
# What index is the agent? (-1 if not present)
idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1
@@ -89,40 +316,69 @@ def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agen
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]
+ # Explicitly handle cases where no valid next agent exists in the provided subset.
+ raise UndefinedNextAgent()
+
def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the system message for selecting the next speaker. This is always the *first* message in the context."""
if agents is None:
agents = self.agents
- return f"""You are in a role play game. The following roles are available:
-{self._participant_roles(agents)}.
-Read the following conversation.
-Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
+ roles = self._participant_roles(agents)
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return_msg = self.select_speaker_message_template.format(roles=roles, agentlist=agentlist)
+ return return_msg
def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str:
- """Return the floating system prompt selecting the next speaker. This is always the *last* message in the context."""
+ """Return the floating system prompt selecting the next speaker.
+ This is always the *last* message in the context.
+ Will return None if the select_speaker_prompt_template is None."""
+
+ if self.select_speaker_prompt_template is None:
+ return None
+
if agents is None:
agents = self.agents
- return f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."
+
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return_prompt = self.select_speaker_prompt_template.format(agentlist=agentlist)
+ return return_prompt
+
+ def introductions_msg(self, agents: Optional[List[Agent]] = None) -> str:
+ """Return the system message for selecting the next speaker. This is always the *first* message in the context."""
+ if agents is None:
+ agents = self.agents
+
+ # Use the class attribute instead of a hardcoded string
+ intro_msg = self.DEFAULT_INTRO_MSG
+ participant_roles = self._participant_roles(agents)
+
+ return f"{intro_msg}\n\n{participant_roles}"
def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
"""Manually select the next speaker."""
+ iostream = IOStream.get_default()
+
if agents is None:
agents = self.agents
- print("Please select the next speaker from the following list:")
+ iostream.print("Please select the next speaker from the following list:")
_n_agents = len(agents)
for i in range(_n_agents):
- print(f"{i+1}: {agents[i].name}")
+ iostream.print(f"{i+1}: {agents[i].name}")
try_count = 0
# Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking.
while try_count <= 3:
try_count += 1
if try_count >= 3:
- print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
+ iostream.print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
break
try:
- i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ")
+ i = iostream.input(
+ "Enter the number of the next speaker (enter nothing or `q` to use auto selection): "
+ )
if i == "" or i == "q":
break
i = int(i)
@@ -131,24 +387,51 @@ def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A
else:
raise ValueError
except ValueError:
- print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
+ iostream.print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None
+ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
+ """Randomly select the next speaker."""
+ if agents is None:
+ agents = self.agents
+ return random.choice(agents)
+
def _prepare_and_select_agents(
- self, last_speaker: Agent
+ self,
+ last_speaker: Agent,
) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]:
- if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
+ # If self.speaker_selection_method is a callable, call it to get the next speaker.
+ # If self.speaker_selection_method is a string, return it.
+ speaker_selection_method = self.speaker_selection_method
+ if isinstance(self.speaker_selection_method, Callable):
+ selected_agent = self.speaker_selection_method(last_speaker, self)
+ if selected_agent is None:
+ raise NoEligibleSpeaker("Custom speaker selection function returned None. Terminating conversation.")
+ elif isinstance(selected_agent, Agent):
+ if selected_agent in self.agents:
+ return selected_agent, self.agents, None
+ else:
+ raise ValueError(
+ f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat."
+ )
+ elif isinstance(selected_agent, str):
+ # If returned a string, assume it is a speaker selection method
+ speaker_selection_method = selected_agent
+ else:
+ raise ValueError(
+ f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str."
+ )
+
+ if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
- f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
+ f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. "
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
)
- if not isinstance(self.allow_repeat_speaker, (bool, list)):
- raise ValueError("GroupChat allow_repeat_speaker should be a bool or a list of Agents.")
# If provided a list, make sure the agent is in the list
allow_repeat_speaker = (
self.allow_repeat_speaker
- if isinstance(self.allow_repeat_speaker, bool)
+ if isinstance(self.allow_repeat_speaker, bool) or self.allow_repeat_speaker is None
else last_speaker in self.allow_repeat_speaker
)
@@ -160,11 +443,11 @@ def _prepare_and_select_agents(
f"GroupChat is underpopulated with {n_agents} agents. "
"Please add more agents to the GroupChat or use direct communication instead."
)
- elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
+ elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
- "It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
- "Or, use direct communication instead."
+ "Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
+ "or use direct communication, unless repeated speaker is desired."
)
if (
@@ -196,16 +479,41 @@ def _prepare_and_select_agents(
"Please check the function_map of the agents."
)
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
- agents = agents if allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
+ agents = [agent for agent in agents if agent != last_speaker] if allow_repeat_speaker is False else agents
- select_speaker_messages = None
- if self.speaker_selection_method.lower() == "manual":
- selected_agent = self.manual_select_speaker(agents)
- elif self.speaker_selection_method.lower() == "round_robin":
- selected_agent = self.next_agent(last_speaker, agents)
- elif self.speaker_selection_method.lower() == "random":
- selected_agent = random.choice(agents)
+ # Filter agents with allowed_speaker_transitions_dict
+
+ is_last_speaker_in_group = last_speaker in self.agents
+
+ # this condition means last_speaker is a sink in the graph, then no agents are eligible
+ if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group:
+ raise NoEligibleSpeaker(f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict.")
+ # last_speaker is not in the group, so all agents are eligible
+ elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group:
+ graph_eligible_agents = []
else:
+ # Extract agent names from the list of agents
+ graph_eligible_agents = [
+ agent for agent in agents if agent in self.allowed_speaker_transitions_dict[last_speaker]
+ ]
+
+ # If there is only one eligible agent, just return it to avoid the speaker selection prompt
+ if len(graph_eligible_agents) == 1:
+ return graph_eligible_agents[0], graph_eligible_agents, None
+
+ # If there are no eligible agents, return None, which means all agents will be taken into consideration in the next step
+ if len(graph_eligible_agents) == 0:
+ graph_eligible_agents = None
+
+ # Use the selected speaker selection method
+ select_speaker_messages = None
+ if speaker_selection_method.lower() == "manual":
+ selected_agent = self.manual_select_speaker(graph_eligible_agents)
+ elif speaker_selection_method.lower() == "round_robin":
+ selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
+ elif speaker_selection_method.lower() == "random":
+ selected_agent = self.random_select_speaker(graph_eligible_agents)
+ else: # auto
selected_agent = None
select_speaker_messages = self.messages.copy()
# If last message is a tool call or function call, blank the call so the api doesn't throw
@@ -213,32 +521,36 @@ def _prepare_and_select_agents(
select_speaker_messages[-1] = dict(select_speaker_messages[-1], function_call=None)
if select_speaker_messages[-1].get("tool_calls", False):
select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None)
- select_speaker_messages = select_speaker_messages + [
- {"role": "system", "content": self.select_speaker_prompt(agents)}
- ]
- return selected_agent, agents, select_speaker_messages
+ return selected_agent, graph_eligible_agents, select_speaker_messages
+
+ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
+ """Select the next speaker (with requery)."""
- def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
- """Select the next speaker."""
+ # Prepare the list of available agents and select an agent if selection method allows (non-auto)
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
- # auto speaker selection
- selector.update_system_message(self.select_speaker_msg(agents))
- final, name = selector.generate_oai_reply(messages)
- return self._finalize_speaker(last_speaker, final, name, agents)
+ elif self.speaker_selection_method == "manual":
+ # An agent has not been selected while in manual mode, so move to the next agent
+ return self.next_agent(last_speaker)
+
+ # auto speaker selection with 2-agent chat
+ return self._auto_select_speaker(last_speaker, selector, messages, agents)
+
+ async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
+ """Select the next speaker (with requery), asynchronously."""
- async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
- """Select the next speaker."""
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
- # auto speaker selection
- selector.update_system_message(self.select_speaker_msg(agents))
- final, name = await selector.a_generate_oai_reply(messages)
- return self._finalize_speaker(last_speaker, final, name, agents)
+ elif self.speaker_selection_method == "manual":
+ # An agent has not been selected while in manual mode, so move to the next agent
+ return self.next_agent(last_speaker)
- def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: List[Agent]) -> Agent:
+ # auto speaker selection with 2-agent chat
+ return await self.a_auto_select_speaker(last_speaker, selector, messages, agents)
+
+ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: Optional[List[Agent]]) -> Agent:
if not final:
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents)
@@ -253,10 +565,315 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
)
# Return the result
- try:
- return self.agent_by_name(name)
- except ValueError:
- return self.next_agent(last_speaker, agents)
+ agent = self.agent_by_name(name)
+ return agent if agent else self.next_agent(last_speaker, agents)
+
+ def _auto_select_speaker(
+ self,
+ last_speaker: Agent,
+ selector: ConversableAgent,
+ messages: Optional[List[Dict]],
+ agents: Optional[List[Agent]],
+ ) -> Agent:
+ """Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying.
+
+ Speaker selection for "auto" speaker selection method:
+ 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat
+ 2. Inject the group messages into the new chat
+ 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent:
+ - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response
+ 4. Chat continues until a single agent is nominated or there are no more attempts left
+ 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned
+
+ Args:
+ last_speaker Agent: The previous speaker in the group chat
+ selector ConversableAgent:
+ messages Optional[List[Dict]]: Current chat messages
+ agents Optional[List[Agent]]: Valid list of agents for speaker selection
+
+ Returns:
+ Dict: a counter for mentioned agents.
+ """
+
+ # If no agents are passed in, assign all the group chat's agents
+ if agents is None:
+ agents = self.agents
+
+ # The maximum number of speaker selection attempts (including requeries)
+ # is the initial speaker selection attempt plus the maximum number of retries.
+ # We track these and use them in the validation function as we can't
+ # access the max_turns from within validate_speaker_name.
+ max_attempts = 1 + self.max_retries_for_selecting_speaker
+ attempts_left = max_attempts
+ attempt = 0
+
+ # Registered reply function for checking_agent, checks the result of the response for agent names
+ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]:
+ # The number of retries left, starting at max_retries_for_selecting_speaker
+ nonlocal attempts_left
+ nonlocal attempt
+
+ attempt = attempt + 1
+ attempts_left = attempts_left - 1
+
+ return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents)
+
+ # Two-agent chat for speaker selection
+
+ # Agent for checking the response from the speaker_select_agent
+ checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)
+
+ # Register the speaker validation function with the checking agent
+ checking_agent.register_reply(
+ [ConversableAgent, None],
+ reply_func=validate_speaker_name, # Validate each response
+ remove_other_reply_funcs=True,
+ )
+
+ # NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat
+
+ # Agent for selecting a single agent name from the response
+ speaker_selection_agent = ConversableAgent(
+ "speaker_selection_agent",
+ system_message=self.select_speaker_msg(agents),
+ chat_messages=(
+ {checking_agent: messages}
+ if self.select_speaker_prompt_template is not None
+ else {checking_agent: messages[:-1]}
+ ),
+ llm_config=selector.llm_config,
+ human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
+ )
+
+ # Create the starting message
+ if self.select_speaker_prompt_template is not None:
+ start_message = {
+ "content": self.select_speaker_prompt(agents),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ start_message = messages[-1]
+
+ # Run the speaker selection chat
+ result = checking_agent.initiate_chat(
+ speaker_selection_agent,
+ cache=None, # don't use caching for the speaker selection chat
+ message=start_message,
+ max_turns=2
+ * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
+ clear_history=False,
+ silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute
+ )
+
+ return self._process_speaker_selection_result(result, last_speaker, agents)
+
+ async def a_auto_select_speaker(
+ self,
+ last_speaker: Agent,
+ selector: ConversableAgent,
+ messages: Optional[List[Dict]],
+ agents: Optional[List[Agent]],
+ ) -> Agent:
+ """(Asynchronous) Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying.
+
+ Speaker selection for "auto" speaker selection method:
+ 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat
+ 2. Inject the group messages into the new chat
+ 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent:
+ - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response
+ 4. Chat continues until a single agent is nominated or there are no more attempts left
+ 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned
+
+ Args:
+ last_speaker Agent: The previous speaker in the group chat
+ selector ConversableAgent:
+ messages Optional[List[Dict]]: Current chat messages
+ agents Optional[List[Agent]]: Valid list of agents for speaker selection
+
+ Returns:
+ Dict: a counter for mentioned agents.
+ """
+
+ # If no agents are passed in, assign all the group chat's agents
+ if agents is None:
+ agents = self.agents
+
+ # The maximum number of speaker selection attempts (including requeries)
+ # We track these and use them in the validation function as we can't
+ # access the max_turns from within validate_speaker_name
+ max_attempts = 1 + self.max_retries_for_selecting_speaker
+ attempts_left = max_attempts
+ attempt = 0
+
+ # Registered reply function for checking_agent, checks the result of the response for agent names
+ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]:
+ # The number of retries left, starting at max_retries_for_selecting_speaker
+ nonlocal attempts_left
+ nonlocal attempt
+
+ attempt = attempt + 1
+ attempts_left = attempts_left - 1
+
+ return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents)
+
+ # Two-agent chat for speaker selection
+
+ # Agent for checking the response from the speaker_select_agent
+ checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)
+
+ # Register the speaker validation function with the checking agent
+ checking_agent.register_reply(
+ [ConversableAgent, None],
+ reply_func=validate_speaker_name, # Validate each response
+ remove_other_reply_funcs=True,
+ )
+
+ # NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat
+
+ # Agent for selecting a single agent name from the response
+ speaker_selection_agent = ConversableAgent(
+ "speaker_selection_agent",
+ system_message=self.select_speaker_msg(agents),
+ chat_messages={checking_agent: messages},
+ llm_config=selector.llm_config,
+ human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
+ )
+
+ # Create the starting message
+ if self.select_speaker_prompt_template is not None:
+ start_message = {
+ "content": self.select_speaker_prompt(agents),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ start_message = messages[-1]
+
+ # Run the speaker selection chat
+ result = await checking_agent.a_initiate_chat(
+ speaker_selection_agent,
+ cache=None, # don't use caching for the speaker selection chat
+ message=start_message,
+ max_turns=2
+ * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
+ clear_history=False,
+ silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute
+ )
+
+ return self._process_speaker_selection_result(result, last_speaker, agents)
+
+ def _validate_speaker_name(
+ self, recipient, messages, sender, config, attempts_left, attempt, agents
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Validates the speaker response for each round in the internal 2-agent
+ chat within the auto select speaker method.
+
+ Used by auto_select_speaker and a_auto_select_speaker.
+ """
+
+ # Output the query and requery results
+ if self.select_speaker_auto_verbose:
+ iostream = IOStream.get_default()
+
+ # Validate the speaker name selected
+ select_name = messages[-1]["content"].strip()
+
+ mentions = self._mentioned_agents(select_name, agents)
+
+ if len(mentions) == 1:
+ # Success on retry, we have just one name mentioned
+ selected_agent_name = next(iter(mentions))
+
+ # Add the selected agent to the response so we can return it
+ messages.append({"role": "user", "content": f"[AGENT SELECTED]{selected_agent_name}"})
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} successfully selected: {selected_agent_name}",
+ "green",
+ ),
+ flush=True,
+ )
+
+ elif len(mentions) > 1:
+ # More than one name on requery so add additional reminder prompt for next retry
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} failed as it included multiple agent names.",
+ "red",
+ ),
+ flush=True,
+ )
+
+ if attempts_left:
+ # Message to return to the chat for the next attempt
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return True, {
+ "content": self.select_speaker_auto_multiple_template.format(agentlist=agentlist),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ # Final failure, no attempts left
+ messages.append(
+ {
+ "role": "user",
+ "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it returned multiple names.",
+ }
+ )
+
+ else:
+ # No names at all on requery so add additional reminder prompt for next retry
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt #{attempt} failed as it did not include any agent names.",
+ "red",
+ ),
+ flush=True,
+ )
+
+ if attempts_left:
+ # Message to return to the chat for the next attempt
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return True, {
+ "content": self.select_speaker_auto_none_template.format(agentlist=agentlist),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ # Final failure, no attempts left
+ messages.append(
+ {
+ "role": "user",
+ "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it did not include any agent names.",
+ }
+ )
+
+ return True, None
+
+ def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: Optional[List[Agent]]):
+ """Checks the result of the auto_select_speaker function, returning the
+ agent to speak.
+
+ Used by auto_select_speaker and a_auto_select_speaker."""
+ if len(result.chat_history) > 0:
+ # Use the final message, which will have the selected agent or reason for failure
+ final_message = result.chat_history[-1]["content"]
+
+ if "[AGENT SELECTED]" in final_message:
+ # Have successfully selected an agent, return it
+ return self.agent_by_name(final_message.replace("[AGENT SELECTED]", ""))
+
+ else: # "[AGENT SELECTION FAILED]"
+ # Failed to select an agent, so we'll select the next agent in the list
+ next_agent = self.next_agent(last_speaker, agents)
+
+ # No agent, return the failed reason
+ return next_agent
def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
@@ -272,8 +889,12 @@ def _participant_roles(self, agents: List[Agent] = None) -> str:
roles.append(f"{agent.name}: {agent.description}".strip())
return "\n".join(roles)
- def _mentioned_agents(self, message_content: Union[str, List], agents: List[Agent]) -> Dict:
+ def _mentioned_agents(self, message_content: Union[str, List], agents: Optional[List[Agent]]) -> Dict:
"""Counts the number of times each agent is mentioned in the provided message content.
+ Agent names will match under any of the following conditions (all case-sensitive):
+ - Exact name match
+ - If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer')
+ - If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer')
Args:
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
@@ -282,6 +903,9 @@ def _mentioned_agents(self, message_content: Union[str, List], agents: List[Agen
Returns:
Dict: a counter for mentioned agents.
"""
+ if agents is None:
+ agents = self.agents
+
# Cast message content to str
if isinstance(message_content, dict):
message_content = message_content["content"]
@@ -289,9 +913,17 @@ def _mentioned_agents(self, message_content: Union[str, List], agents: List[Agen
mentions = dict()
for agent in agents:
+ # Finds agent mentions, taking word boundaries into account,
+ # accommodates escaping underscores and underscores as spaces
regex = (
- r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
- ) # Finds agent mentions, taking word boundaries into account
+ r"(?<=\W)("
+ + re.escape(agent.name)
+ + r"|"
+ + re.escape(agent.name.replace("_", " "))
+ + r"|"
+ + re.escape(agent.name.replace("_", r"\_"))
+ + r")(?=\W)"
+ )
count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
@@ -307,11 +939,16 @@ def __init__(
name: Optional[str] = "chat_manager",
# unlimited consecutive auto reply by default
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
- human_input_mode: Optional[str] = "NEVER",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
system_message: Optional[Union[str, List]] = "Group chat manager.",
+ silent: bool = False,
**kwargs,
):
- if kwargs.get("llm_config") and (kwargs["llm_config"].get("functions") or kwargs["llm_config"].get("tools")):
+ if (
+ kwargs.get("llm_config")
+ and isinstance(kwargs["llm_config"], dict)
+ and (kwargs["llm_config"].get("functions") or kwargs["llm_config"].get("tools"))
+ ):
raise ValueError(
"GroupChatManager is not allowed to make function/tool calls. Please remove the 'functions' or 'tools' config in 'llm_config' you passed in."
)
@@ -323,9 +960,13 @@ def __init__(
system_message=system_message,
**kwargs,
)
+ if logging_enabled():
+ log_new_agent(self, locals())
# Store groupchat
self._groupchat = groupchat
+ self._silent = silent
+
# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
@@ -338,15 +979,32 @@ def __init__(
ignore_async_in_sync_chat=True,
)
- def _prepare_chat(self, recipient: ConversableAgent, clear_history: bool, prepare_recipient: bool = True) -> None:
- super()._prepare_chat(recipient, clear_history, prepare_recipient)
+ @property
+ def groupchat(self) -> GroupChat:
+ """Returns the group chat managed by the group chat manager."""
+ return self._groupchat
+
+ def chat_messages_for_summary(self, agent: Agent) -> List[Dict]:
+ """The list of messages in the group chat as a conversation to summarize.
+ The agent is ignored.
+ """
+ return self._groupchat.messages
+
+ def _prepare_chat(
+ self,
+ recipient: ConversableAgent,
+ clear_history: bool,
+ prepare_recipient: bool = True,
+ reply_at_receive: bool = True,
+ ) -> None:
+ super()._prepare_chat(recipient, clear_history, prepare_recipient, reply_at_receive)
if clear_history:
self._groupchat.reset()
for agent in self._groupchat.agents:
if (recipient != agent or prepare_recipient) and isinstance(agent, ConversableAgent):
- agent._prepare_chat(self, clear_history, False)
+ agent._prepare_chat(self, clear_history, False, reply_at_receive)
def run_chat(
self,
@@ -360,25 +1018,36 @@ def run_chat(
message = messages[-1]
speaker = sender
groupchat = config
+ send_introductions = getattr(groupchat, "send_introductions", False)
+ silent = getattr(self, "_silent", False)
+
+ if send_introductions:
+ # Broadcast the intro
+ intro = groupchat.introductions_msg()
+ for agent in groupchat.agents:
+ self.send(intro, agent, request_reply=False, silent=True)
+ # NOTE: We do not also append to groupchat.messages,
+ # since groupchat handles its own introductions
+
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
groupchat.append(message, speaker)
- if self._is_termination_msg(message):
- # The conversation is over
- break
# broadcast the message to all agents except the speaker
for agent in groupchat.agents:
if agent != speaker:
self.send(message, agent, request_reply=False, silent=True)
- if i == groupchat.max_round - 1:
- # the last round
+ if self._is_termination_msg(message) or i == groupchat.max_round - 1:
+ # The conversation is over or it's the last round
break
try:
# select the next speaker
speaker = groupchat.select_speaker(speaker, self)
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(colored(f"\nNext speaker: {speaker.name}\n", "green"), flush=True)
# let the speaker speak
reply = speaker.generate_reply(sender=self)
except KeyboardInterrupt:
@@ -390,21 +1059,26 @@ def run_chat(
else:
# admin agent is not found in the participants
raise
+ except NoEligibleSpeaker:
+ # No eligible speaker, terminate the conversation
+ break
+
if reply is None:
+ # no reply is generated, exit the chat
break
# check for "clear history" phrase in reply and activate clear history function if found
if (
groupchat.enable_clear_history
and isinstance(reply, dict)
+ and reply["content"]
and "CLEAR HISTORY" in reply["content"].upper()
):
- reply["content"] = self.clear_agents_history(reply["content"], groupchat)
+ reply["content"] = self.clear_agents_history(reply, groupchat)
+
# The speaker sends the message without requesting a reply
- speaker.send(reply, self, request_reply=False)
+ speaker.send(reply, self, request_reply=False, silent=silent)
message = self.last_message(speaker)
- if i == groupchat.max_round - 1:
- groupchat.append(message, speaker)
if self.client_cache is not None:
for a in groupchat.agents:
a.client_cache = a.previous_cache
@@ -423,6 +1097,17 @@ async def a_run_chat(
message = messages[-1]
speaker = sender
groupchat = config
+ send_introductions = getattr(groupchat, "send_introductions", False)
+ silent = getattr(self, "_silent", False)
+
+ if send_introductions:
+ # Broadcast the intro
+ intro = groupchat.introductions_msg()
+ for agent in groupchat.agents:
+ await self.a_send(intro, agent, request_reply=False, silent=True)
+ # NOTE: We do not also append to groupchat.messages,
+ # since groupchat handles its own introductions
+
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
@@ -458,7 +1143,7 @@ async def a_run_chat(
if reply is None:
break
# The speaker sends the message without requesting a reply
- await speaker.a_send(reply, self, request_reply=False)
+ await speaker.a_send(reply, self, request_reply=False, silent=silent)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
@@ -466,6 +1151,304 @@ async def a_run_chat(
a.previous_cache = None
return True, None
+ def resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: Union[str, Callable[[str], str]] = None,
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
+ If a string is provided, this string will be removed from last message.
+ If a function is provided, the last message will be passed to this function.
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ if agent.name != message["name"]:
+ self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True)
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ async def a_resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: Union[str, Callable[[str], str]],
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
+ If a string is provided, this string will be removed from last message.
+ If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ if agent.name != message["name"]:
+ await self.a_send(
+ message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True
+ )
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ def _valid_resume_messages(self, messages: List[Dict]):
+ """Validates the messages used for resuming
+
+ args:
+ messages (List[Dict]): list of messages to resume with
+
+ returns:
+ - bool: Whether they are valid for resuming
+ """
+ # Must have messages to start with, otherwise they should run run_chat
+ if not messages:
+ raise Exception(
+ "Cannot resume group chat as no messages were provided. Use GroupChatManager.run_chat or ConversableAgent.initiate_chat to start a new chat."
+ )
+
+ # Check that all agents in the chat messages exist in the group chat
+ for message in messages:
+ if message.get("name"):
+ if (
+ not self._groupchat.agent_by_name(message["name"])
+ and not message["name"] == self._groupchat.admin_name # ignore group chat's name
+ and not message["name"] == self.name # ignore group chat manager's name
+ ):
+ raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}")
+
+ def _process_resume_termination(
+ self, remove_termination_string: Union[str, Callable[[str], str]], messages: List[Dict]
+ ):
+ """Removes termination string, if required, and checks if termination may occur.
+
+ args:
+ remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
+ If a string is provided, this string will be removed from last message.
+ If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
+
+ returns:
+ None
+ """
+
+ last_message = messages[-1]
+
+ # Replace any given termination string in the last message
+ if isinstance(remove_termination_string, str):
+
+ def _remove_termination_string(content: str) -> str:
+ return content.replace(remove_termination_string, "")
+
+ else:
+ _remove_termination_string = remove_termination_string
+
+ if _remove_termination_string:
+ if messages[-1].get("content"):
+ messages[-1]["content"] = _remove_termination_string(messages[-1]["content"])
+
+ # Check if the last message meets termination (if it has one)
+ if self._is_termination_msg:
+ if self._is_termination_msg(last_message):
+ logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.")
+
+ def messages_from_string(self, message_string: str) -> List[Dict]:
+ """Reads the saved state of messages in Json format for resume and returns as a messages list
+
+ args:
+ - message_string: Json string, the saved state
+
+ returns:
+ - List[Dict]: List of messages
+ """
+ try:
+ state = json.loads(message_string)
+ except json.JSONDecodeError:
+ raise Exception("Messages string is not a valid JSON string")
+
+ return state
+
+ def messages_to_string(self, messages: List[Dict]) -> str:
+ """Converts the provided messages into a Json string that can be used for resuming the chat.
+ The state is made up of a list of messages
+
+ args:
+ - messages (List[Dict]): set of messages to convert to a string
+
+ returns:
+ - str: Json representation of the messages which can be persisted for resuming later
+ """
+
+ return json.dumps(messages)
+
def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
@@ -477,7 +1460,7 @@ def _raise_exception_on_async_reply_functions(self) -> None:
for agent in self._groupchat.agents:
agent._raise_exception_on_async_reply_functions()
- def clear_agents_history(self, reply: str, groupchat: GroupChat) -> str:
+ def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str:
"""Clears history of messages for all agents or selected one. Can preserve selected number of last messages.
That function is called when user manually provide "clear history" phrase in his reply.
When "clear history" is provided, the history of messages for all agents is cleared.
@@ -489,23 +1472,29 @@ def clear_agents_history(self, reply: str, groupchat: GroupChat) -> str:
Phrase "clear history" and optional arguments are cut out from the reply before it passed to the chat.
Args:
- reply (str): Admin reply to analyse.
+ reply (dict): reply message dict to analyze.
groupchat (GroupChat): GroupChat object.
"""
+ iostream = IOStream.get_default()
+
+ reply_content = reply["content"]
# Split the reply into words
- words = reply.split()
+ words = reply_content.split()
# Find the position of "clear" to determine where to start processing
clear_word_index = next(i for i in reversed(range(len(words))) if words[i].upper() == "CLEAR")
# Extract potential agent name and steps
words_to_check = words[clear_word_index + 2 : clear_word_index + 4]
nr_messages_to_preserve = None
+ nr_messages_to_preserve_provided = False
agent_to_memory_clear = None
for word in words_to_check:
if word.isdigit():
nr_messages_to_preserve = int(word)
+ nr_messages_to_preserve_provided = True
elif word[:-1].isdigit(): # for the case when number of messages is followed by dot or other sign
nr_messages_to_preserve = int(word[:-1])
+ nr_messages_to_preserve_provided = True
else:
for agent in groupchat.agents:
if agent.name == word:
@@ -514,24 +1503,30 @@ def clear_agents_history(self, reply: str, groupchat: GroupChat) -> str:
elif agent.name == word[:-1]: # for the case when agent name is followed by dot or other sign
agent_to_memory_clear = agent
break
+ # preserve last tool call message if clear history called inside of tool response
+ if "tool_responses" in reply and not nr_messages_to_preserve:
+ nr_messages_to_preserve = 1
+ logger.warning(
+ "The last tool call message will be saved to prevent errors caused by tool response without tool call."
+ )
# clear history
if agent_to_memory_clear:
if nr_messages_to_preserve:
- print(
+ iostream.print(
f"Clearing history for {agent_to_memory_clear.name} except last {nr_messages_to_preserve} messages."
)
else:
- print(f"Clearing history for {agent_to_memory_clear.name}.")
+ iostream.print(f"Clearing history for {agent_to_memory_clear.name}.")
agent_to_memory_clear.clear_history(nr_messages_to_preserve=nr_messages_to_preserve)
else:
if nr_messages_to_preserve:
- print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.")
+ iostream.print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.")
# clearing history for groupchat here
temp = groupchat.messages[-nr_messages_to_preserve:]
groupchat.messages.clear()
groupchat.messages.extend(temp)
else:
- print("Clearing history for all agents.")
+ iostream.print("Clearing history for all agents.")
# clearing history for groupchat here
groupchat.messages.clear()
# clearing history for agents
@@ -539,7 +1534,7 @@ def clear_agents_history(self, reply: str, groupchat: GroupChat) -> str:
agent.clear_history(nr_messages_to_preserve=nr_messages_to_preserve)
# Reconstruct the reply without the "clear history" command and parameters
- skip_words_number = 2 + int(bool(agent_to_memory_clear)) + int(bool(nr_messages_to_preserve))
- reply = " ".join(words[:clear_word_index] + words[clear_word_index + skip_words_number :])
+ skip_words_number = 2 + int(bool(agent_to_memory_clear)) + int(nr_messages_to_preserve_provided)
+ reply_content = " ".join(words[:clear_word_index] + words[clear_word_index + skip_words_number :])
- return reply
+ return reply_content
diff --git a/autogen/agentchat/user_proxy_agent.py b/autogen/agentchat/user_proxy_agent.py
index 86b7e1e7b1c..a80296a8355 100644
--- a/autogen/agentchat/user_proxy_agent.py
+++ b/autogen/agentchat/user_proxy_agent.py
@@ -1,5 +1,6 @@
from typing import Callable, Dict, List, Literal, Optional, Union
+from ..runtime_logging import log_new_agent, logging_enabled
from .conversable_agent import ConversableAgent
@@ -13,7 +14,6 @@ class UserProxyAgent(ConversableAgent):
To modify the way to get human input, override `get_human_input` method.
To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`,
`run_code`, and `execute_function` methods respectively.
- To customize the initial message when a conversation starts, override `generate_init_message` method.
"""
# Default UserProxyAgent.description values, based on human_input_mode
@@ -28,9 +28,9 @@ def __init__(
name: str,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "ALWAYS",
+ human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS",
function_map: Optional[Dict[str, Callable]] = None,
- code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
+ code_execution_config: Union[Dict, Literal[False]] = {},
default_auto_reply: Optional[Union[str, Dict, None]] = "",
llm_config: Optional[Union[Dict, Literal[False]]] = False,
system_message: Optional[Union[str, List]] = "",
@@ -70,10 +70,11 @@ def __init__(
- timeout (Optional, int): The maximum execution time in seconds.
- last_n_messages (Experimental, Optional, int): The number of messages to look back for code execution. Default to 1.
default_auto_reply (str or dict or None): the default auto reply message when no code execution or llm based reply is generated.
- llm_config (dict or False): llm inference configuration.
+ llm_config (dict or False or None): llm inference configuration.
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
- Default to false, which disables llm-based auto reply.
+ Default to False, which disables llm-based auto reply.
+ When set to None, will use self.DEFAULT_CONFIG, which defaults to False.
system_message (str or List): system message for ChatCompletion inference.
Only used when llm_config is not False. Use it to reprogram the agent.
description (str): a short description of the agent. This description is used by other agents
@@ -89,7 +90,10 @@ def __init__(
code_execution_config=code_execution_config,
llm_config=llm_config,
default_auto_reply=default_auto_reply,
- description=description
- if description is not None
- else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode],
+ description=(
+ description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode]
+ ),
)
+
+ if logging_enabled():
+ log_new_agent(self, locals())
diff --git a/autogen/agentchat/utils.py b/autogen/agentchat/utils.py
new file mode 100644
index 00000000000..b32c2f5f0a0
--- /dev/null
+++ b/autogen/agentchat/utils.py
@@ -0,0 +1,201 @@
+import re
+from typing import Any, Callable, Dict, List, Union
+
+from .agent import Agent
+
+
+def consolidate_chat_info(chat_info, uniform_sender=None) -> None:
+ if isinstance(chat_info, dict):
+ chat_info = [chat_info]
+ for c in chat_info:
+ if uniform_sender is None:
+ assert "sender" in c, "sender must be provided."
+ sender = c["sender"]
+ else:
+ sender = uniform_sender
+ assert "recipient" in c, "recipient must be provided."
+ summary_method = c.get("summary_method")
+ assert (
+ summary_method is None
+ or isinstance(summary_method, Callable)
+ or summary_method in ("last_msg", "reflection_with_llm")
+ ), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None."
+ if summary_method == "reflection_with_llm":
+ assert (
+ sender.client is not None or c["recipient"].client is not None
+ ), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."
+
+
+def gather_usage_summary(agents: List[Agent]) -> Dict[Dict[str, Dict], Dict[str, Dict]]:
+ r"""Gather usage summary from all agents.
+
+ Args:
+ agents: (list): List of agents.
+
+ Returns:
+ dictionary: A dictionary containing two keys:
+ - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
+ - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
+
+ Example:
+
+ ```python
+ {
+ "usage_including_cached_inference" : {
+ "total_cost": 0.0006090000000000001,
+ "gpt-35-turbo": {
+ "cost": 0.0006090000000000001,
+ "prompt_tokens": 242,
+ "completion_tokens": 123,
+ "total_tokens": 365
+ },
+ },
+
+ "usage_excluding_cached_inference" : {
+ "total_cost": 0.0006090000000000001,
+ "gpt-35-turbo": {
+ "cost": 0.0006090000000000001,
+ "prompt_tokens": 242,
+ "completion_tokens": 123,
+ "total_tokens": 365
+ },
+ }
+ }
+ ```
+
+ Note:
+
+ If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`.
+ """
+
+ def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, Any]) -> None:
+ if agent_summary is None:
+ return
+ usage_summary["total_cost"] += agent_summary.get("total_cost", 0)
+ for model, data in agent_summary.items():
+ if model != "total_cost":
+ if model not in usage_summary:
+ usage_summary[model] = data.copy()
+ else:
+ usage_summary[model]["cost"] += data.get("cost", 0)
+ usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0)
+ usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
+ usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
+
+ usage_including_cached_inference = {"total_cost": 0}
+ usage_excluding_cached_inference = {"total_cost": 0}
+
+ for agent in agents:
+ if getattr(agent, "client", None):
+ aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary)
+ aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary)
+
+ return {
+ "usage_including_cached_inference": usage_including_cached_inference,
+ "usage_excluding_cached_inference": usage_excluding_cached_inference,
+ }
+
+
+def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Dict[str, str]]]:
+ """Parses HTML style tags from message contents.
+
+ The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is
+ specified as an argument to the function. The function looks for this tag in the text and extracts its content. The
+ content of a tag is everything that is inside the tag, between the opening and closing angle brackets. The content
+ can be a single string or a set of attribute-value pairs.
+
+ Examples:
+ -> [{"tag": "img", "attr": {"src": "http://example.com/image.png"}, "match": re.Match}]
+