Skip to content

Commit

Permalink
feat(agent): Multi agents v0.1 (eosphoros-ai#1044)
Browse files Browse the repository at this point in the history
Co-authored-by: qidanrui <qidanrui@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
  • Loading branch information
4 people authored and Hopshine committed Sep 10, 2024
1 parent 9365a14 commit 3804d8e
Show file tree
Hide file tree
Showing 41 changed files with 1,424 additions and 380 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ __pycache__/

# C extensions
*.so

message/
dbgpt/util/extensions/
.env*
Expand Down
28 changes: 25 additions & 3 deletions dbgpt/agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from dbgpt.core import LLMClient
from dbgpt.core.interface.llm import ModelMetadata
from dbgpt.util.annotations import PublicAPI

from ..memory.gpts_memory import GptsMemory

Expand Down Expand Up @@ -44,6 +45,10 @@ def describe(self) -> str:
"""Get the name of the agent."""
return self._describe

@property
def is_terminal_agent(self) -> bool:
return False

async def a_send(
self,
message: Union[Dict, str],
Expand Down Expand Up @@ -88,6 +93,7 @@ async def a_generate_reply(
sender: Agent,
reviewer: Agent,
silent: Optional[bool] = False,
rely_messages: Optional[List[Dict]] = None,
**kwargs,
) -> Union[str, Dict, None]:
"""(Abstract async method) Generate a reply based on the received messages.
Expand All @@ -102,10 +108,9 @@ async def a_generate_reply(
async def a_reasoning_reply(
self, messages: Optional[List[Dict]]
) -> Union[str, Dict, None]:
"""
Based on the requirements of the current agent, reason about the current task goal through LLM
"""Based on the requirements of the current agent, reason about the current task goal through LLM
Args:
message:
messages:
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
Expand Down Expand Up @@ -187,3 +192,20 @@ class AgentContext:

def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)


@dataclasses.dataclass
@PublicAPI(stability="beta")
class AgentGenerateContext:
"""A class to represent the input of a Agent."""

message: Optional[Dict]
sender: Agent
reviewer: Agent
silent: Optional[bool] = False

rely_messages: List[Dict] = dataclasses.field(default_factory=list)
final: Optional[bool] = True

def to_dict(self) -> Dict:
return dataclasses.asdict(self)
44 changes: 43 additions & 1 deletion dbgpt/agent/agents/agents_mange.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import re
from collections import defaultdict
from typing import Optional, Type
from typing import Dict, List, Optional, Type

from .agent import Agent
from .expand.code_assistant_agent import CodeAssistantAgent
from .expand.dashboard_assistant_agent import DashboardAssistantAgent
from .expand.data_scientist_agent import DataScientistAgent
from .expand.plugin_assistant_agent import PluginAssistantAgent
from .expand.sql_assistant_agent import SQLAssistantAgent
from .expand.summary_assistant_agent import SummaryAssistantAgent

logger = logging.getLogger(__name__)


def get_all_subclasses(cls):
Expand All @@ -18,6 +24,40 @@ def get_all_subclasses(cls):
return all_subclasses


def participant_roles(agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
agents = agents

roles = []
for agent in agents:
if agent.system_message.strip() == "":
logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
)
roles.append(f"{agent.name}: {agent.describe}")
return "\n".join(roles)


def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict:
"""
Finds and counts agent mentions in the string message_content, taking word boundaries into account.
Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur)
"""
mentions = dict()
for agent in agents:
regex = (
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
) # Finds agent mentions, taking word boundaries into account
count = len(
re.findall(regex, " " + message_content + " ")
) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
return mentions


class AgentsMange:
def __init__(self):
self._agents = defaultdict()
Expand Down Expand Up @@ -46,3 +86,5 @@ def all_agents(self):
agent_mange.register_agent(DashboardAssistantAgent)
agent_mange.register_agent(DataScientistAgent)
agent_mange.register_agent(SQLAssistantAgent)
agent_mange.register_agent(SummaryAssistantAgent)
agent_mange.register_agent(PluginAssistantAgent)
75 changes: 53 additions & 22 deletions dbgpt/agent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "TERMINATE",
default_auto_reply: Optional[Union[str, Dict, None]] = "",
is_terminal_agent: bool = False,
):
super().__init__(name, memory, describe)

Expand All @@ -57,8 +58,9 @@ def __init__(
else self.MAX_CONSECUTIVE_AUTO_REPLY
)
self.consecutive_auto_reply_counter: int = 0

self._current_retry_counter: int = 0
self._max_retry_count: int = 5
self._is_terminal_agent = is_terminal_agent

## By default, the memory of 4 rounds of dialogue is retained.
self.dialogue_memory_rounds = 5
Expand Down Expand Up @@ -91,6 +93,10 @@ def register_reply(
},
)

@property
def is_terminal_agent(self):
return self._is_terminal_agent

@property
def system_message(self):
"""Return the system message."""
Expand Down Expand Up @@ -197,7 +203,6 @@ def append_message(self, message: Optional[Dict], role, sender: Agent) -> bool:
"""
Put the received message content into the collective message memory
Args:
conv_id:
message:
role:
sender:
Expand Down Expand Up @@ -381,17 +386,32 @@ def _gpts_message_to_ai_message(
)
return oai_messages

def process_now_message(self, sender, current_gogal: Optional[str] = None):
# Convert and tailor the information in collective memory into contextual memory available to the current Agent
def process_now_message(
self,
current_message: Optional[Dict],
sender,
rely_messages: Optional[List[Dict]] = None,
):
current_gogal = current_message.get("current_gogal", None)
### Convert and tailor the information in collective memory into contextual memory available to the current Agent
current_gogal_messages = self._gpts_message_to_ai_message(
self.memory.message_memory.get_between_agents(
self.agent_context.conv_id, self.name, sender.name, current_gogal
)
)

# relay messages
if current_gogal_messages is None or len(current_gogal_messages) <= 0:
current_message["role"] = ModelMessageRoleType.HUMAN
current_gogal_messages = [current_message]
### relay messages
cut_messages = []
cut_messages.extend(self._rely_messages)
if rely_messages:
for rely_message in rely_messages:
action_report = rely_message.get("action_report", None)
if action_report:
rely_message["content"] = action_report["content"]
cut_messages.extend(rely_messages)
else:
cut_messages.extend(self._rely_messages)

if len(current_gogal_messages) < self.dialogue_memory_rounds:
cut_messages.extend(current_gogal_messages)
Expand All @@ -409,8 +429,9 @@ async def a_generate_reply(
self,
message: Optional[Dict],
sender: Agent,
reviewer: "Agent",
reviewer: Agent,
silent: Optional[bool] = False,
rely_messages: Optional[List[Dict]] = None,
):
## 0.New message build
new_message = {}
Expand All @@ -420,11 +441,7 @@ async def a_generate_reply(
## 1.LLM Reasonging
await self.a_system_fill_param()
await asyncio.sleep(5) ##TODO Rate limit reached for gpt-3.5-turbo
current_messages = self.process_now_message(
sender, message.get("current_gogal", None)
)
if current_messages is None or len(current_messages) <= 0:
current_messages = [message]
current_messages = self.process_now_message(message, sender, rely_messages)
ai_reply, model = await self.a_reasoning_reply(messages=current_messages)
new_message["content"] = ai_reply
new_message["model_name"] = model
Expand Down Expand Up @@ -466,6 +483,9 @@ async def a_receive(
if request_reply is False or request_reply is None:
logger.info("Messages that do not require a reply")
return
if self._is_termination_msg(message) or sender.is_terminal_agent:
logger.info(f"TERMINATE!")
return

verify_paas, reply = await self.a_generate_reply(
message=message, sender=sender, reviewer=reviewer, silent=silent
Expand All @@ -476,14 +496,26 @@ async def a_receive(
message=reply, recipient=sender, reviewer=reviewer, silent=silent
)
else:
self._current_retry_counter += 1
logger.info(
"The generated answer failed to verify, so send it to yourself for optimization."
)
# TODO: Exit after the maximum number of rounds of self-optimization
await sender.a_send(
message=reply, recipient=self, reviewer=reviewer, silent=silent
)
# Exit after the maximum number of rounds of self-optimization
if self._current_retry_counter >= self._max_retry_count:
# If the maximum number of retries is exceeded, the abnormal answer will be returned directly.
logger.warning(
f"More than {self._current_retry_counter} times and still no valid answer is output."
)
reply[
"content"
] = f"After n optimizations, the following problems still exist:{reply['content']}"
await self.a_send(
message=reply, recipient=sender, reviewer=reviewer, silent=silent
)
else:
self._current_retry_counter += 1
logger.info(
"The generated answer failed to verify, so send it to yourself for optimization."
)
await sender.a_send(
message=reply, recipient=self, reviewer=reviewer, silent=silent
)

async def a_verify(self, message: Optional[Dict]):
return True, message
Expand Down Expand Up @@ -547,7 +579,6 @@ async def a_verify_reply(
async def a_retry_chat(
self,
recipient: "ConversableAgent",
agent_map: dict,
reviewer: "Agent" = None,
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
Expand Down
Loading

0 comments on commit 3804d8e

Please sign in to comment.