Skip to content

Commit

Permalink
feat: make system_message as optional (#1038)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan authored Oct 13, 2024
1 parent 0b6734e commit 1b6856d
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 51 deletions.
53 changes: 36 additions & 17 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class ChatAgent(BaseAgent):
r"""Class for managing conversations of CAMEL Chat Agents.
Args:
system_message (BaseMessage): The system message for the chat agent.
system_message (BaseMessage, optional): The system message for the
chat agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
Expand Down Expand Up @@ -144,7 +145,7 @@ class ChatAgent(BaseAgent):

def __init__(
self,
system_message: BaseMessage,
system_message: Optional[BaseMessage] = None,
model: Optional[BaseModelBackend] = None,
memory: Optional[AgentMemory] = None,
message_window_size: Optional[int] = None,
Expand All @@ -154,10 +155,14 @@ def __init__(
external_tools: Optional[List[FunctionTool]] = None,
response_terminators: Optional[List[ResponseTerminator]] = None,
) -> None:
self.orig_sys_message: BaseMessage = system_message
self.system_message = system_message
self.role_name: str = system_message.role_name
self.role_type: RoleType = system_message.role_type
self.orig_sys_message: Optional[BaseMessage] = system_message
self._system_message: Optional[BaseMessage] = system_message
self.role_name: str = (
getattr(system_message, 'role_name', None) or "assistant"
)
self.role_type: RoleType = (
getattr(system_message, 'role_type', None) or RoleType.ASSISTANT
)
self.model_backend: BaseModelBackend = (
model
if model is not None
Expand Down Expand Up @@ -272,11 +277,12 @@ def reset(self):
terminator.reset()

@property
def system_message(self) -> BaseMessage:
def system_message(self) -> Optional[BaseMessage]:
r"""The getter method for the property :obj:`system_message`.
Returns:
BaseMessage: The system message of this agent.
Optional[BaseMessage]: The system message of this agent if set,
else :obj:`None`.
"""
return self._system_message

Expand Down Expand Up @@ -327,12 +333,22 @@ def set_output_language(self, output_language: str) -> BaseMessage:
BaseMessage: The updated system message object.
"""
self.output_language = output_language
content = self.orig_sys_message.content + (
language_prompt = (
"\nRegardless of the input language, "
f"you must output text in {output_language}."
)
self.system_message = self.system_message.create_new_instance(content)
return self.system_message
if self.orig_sys_message is not None:
content = self.orig_sys_message.content + language_prompt
self._system_message = self.orig_sys_message.create_new_instance(
content
)
return self._system_message
else:
self._system_message = BaseMessage.make_assistant_message(
role_name="Assistant",
content=language_prompt,
)
return self._system_message

def get_info(
self,
Expand Down Expand Up @@ -377,12 +393,15 @@ def init_messages(self) -> None:
r"""Initializes the stored messages list with the initial system
message.
"""
system_record = MemoryRecord(
message=self.system_message,
role_at_backend=OpenAIBackendRole.SYSTEM,
)
self.memory.clear()
self.memory.write_record(system_record)
if self.orig_sys_message is not None:
system_record = MemoryRecord(
message=self.orig_sys_message,
role_at_backend=OpenAIBackendRole.SYSTEM,
)
self.memory.clear()
self.memory.write_record(system_record)
else:
self.memory.clear()

def record_message(self, message: BaseMessage) -> None:
r"""Records the externally provided message into the agent memory as if
Expand Down
9 changes: 6 additions & 3 deletions camel/societies/babyagi_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
)

self.assistant_agent: ChatAgent
self.assistant_sys_msg: BaseMessage
self.assistant_sys_msg: Optional[BaseMessage]
self.task_creation_agent: TaskCreationAgent
self.task_prioritization_agent: TaskPrioritizationAgent
self.init_agents(
Expand Down Expand Up @@ -202,7 +202,8 @@ def init_agents(

self.task_creation_agent = TaskCreationAgent(
objective=self.specified_task_prompt,
role_name=self.assistant_sys_msg.role_name,
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
or "assistant",
output_language=output_language,
message_window_size=message_window_size,
**(task_creation_agent_kwargs or {}),
Expand Down Expand Up @@ -238,7 +239,9 @@ def step(self) -> ChatAgentResponse:

task_name = self.subtasks.popleft()
assistant_msg_msg = BaseMessage.make_user_message(
role_name=self.assistant_sys_msg.role_name, content=f"{task_name}"
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
or "assistant",
content=f"{task_name}",
)

assistant_response = self.assistant_agent.step(assistant_msg_msg)
Expand Down
8 changes: 5 additions & 3 deletions camel/societies/role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def __init__(

self.assistant_agent: ChatAgent
self.user_agent: ChatAgent
self.assistant_sys_msg: BaseMessage
self.user_sys_msg: BaseMessage
self.assistant_sys_msg: Optional[BaseMessage]
self.user_sys_msg: Optional[BaseMessage]
self._init_agents(
init_assistant_sys_msg,
init_user_sys_msg,
Expand Down Expand Up @@ -454,9 +454,11 @@ def init_chat(self, init_msg_content: Optional[str] = None) -> BaseMessage:
)
if init_msg_content is None:
init_msg_content = default_init_msg_content

# Initialize a message sent by the assistant
init_msg = BaseMessage.make_assistant_message(
role_name=self.assistant_sys_msg.role_name,
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
or "assistant",
content=init_msg_content,
)

Expand Down
108 changes: 80 additions & 28 deletions test/agents/test_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,55 +69,77 @@ def test_chat_agent(model):
dict(assistant_role="doctor"),
role_tuple=("doctor", RoleType.ASSISTANT),
)
assistant = ChatAgent(system_msg, model=model)
assistant_with_sys_msg = ChatAgent(system_msg, model=model)
assistant_without_sys_msg = ChatAgent(model=model)

assert str(assistant) == (
assert str(assistant_with_sys_msg) == (
"ChatAgent(doctor, " f"RoleType.ASSISTANT, {ModelType.GPT_4O_MINI})"
)
assert str(assistant_without_sys_msg) == (
"ChatAgent(assistant, " f"RoleType.ASSISTANT, {ModelType.GPT_4O_MINI})"
)

for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]:
assistant.reset()

assistant.reset()
user_msg = BaseMessage(
role_name="Patient",
role_type=RoleType.USER,
meta_dict=dict(),
content="Hello!",
)
assistant_response = assistant.step(user_msg)

assert isinstance(assistant_response.msgs, list)
assert len(assistant_response.msgs) > 0
assert isinstance(assistant_response.terminated, bool)
assert assistant_response.terminated is False
assert isinstance(assistant_response.info, dict)
assert assistant_response.info['id'] is not None
for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]:
response = assistant.step(user_msg)
assert isinstance(response.msgs, list)
assert len(response.msgs) > 0
assert isinstance(response.terminated, bool)
assert response.terminated is False
assert isinstance(response.info, dict)
assert response.info['id'] is not None


@pytest.mark.model_backend
def test_chat_agent_stored_messages():
system_msg = BaseMessage(
role_name="assistant",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a help assistant.",
)
assistant = ChatAgent(system_msg)

assistant_with_sys_msg = ChatAgent(system_msg)
assistant_without_sys_msg = ChatAgent()

expected_context = [system_msg.to_openai_system_message()]
context, _ = assistant.memory.get_context()
assert context == expected_context

context_with_sys_msg, _ = assistant_with_sys_msg.memory.get_context()
assert context_with_sys_msg == expected_context
context_without_sys_msg, _ = assistant_without_sys_msg.memory.get_context()
assert context_without_sys_msg == []

user_msg = BaseMessage(
role_name="User",
role_type=RoleType.USER,
meta_dict=dict(),
content="Tell me a joke.",
)
assistant.update_memory(user_msg, OpenAIBackendRole.USER)
expected_context = [

for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]:
assistant.update_memory(user_msg, OpenAIBackendRole.USER)

expected_context_with_sys_msg = [
system_msg.to_openai_system_message(),
user_msg.to_openai_user_message(),
]
context, _ = assistant.memory.get_context()
assert context == expected_context
expected_context_without_sys_msg = [
user_msg.to_openai_user_message(),
]

context_with_sys_msg, _ = assistant_with_sys_msg.memory.get_context()
assert context_with_sys_msg == expected_context_with_sys_msg
context_without_sys_msg, _ = assistant_without_sys_msg.memory.get_context()
assert context_without_sys_msg == expected_context_without_sys_msg


@pytest.mark.model_backend
Expand Down Expand Up @@ -273,17 +295,27 @@ def test_chat_agent_multiple_return_messages(n):
meta_dict=None,
content="You are a helpful assistant.",
)
assistant = ChatAgent(system_msg, model=model)
assistant.reset()
assistant_with_sys_msg = ChatAgent(system_msg, model=model)
assistant_without_sys_msg = ChatAgent(model=model)

assistant_with_sys_msg.reset()
assistant_without_sys_msg.reset()

user_msg = BaseMessage(
role_name="User",
role_type=RoleType.USER,
meta_dict=dict(),
content="Tell me a joke.",
)
assistant_response = assistant.step(user_msg)
assert assistant_response.msgs is not None
assert len(assistant_response.msgs) == n
assistant_with_sys_msg_response = assistant_with_sys_msg.step(user_msg)
assistant_without_sys_msg_response = assistant_without_sys_msg.step(
user_msg
)

assert assistant_with_sys_msg_response.msgs is not None
assert len(assistant_with_sys_msg_response.msgs) == n
assert assistant_without_sys_msg_response.msgs is not None
assert len(assistant_without_sys_msg_response.msgs) == n


@pytest.mark.model_backend
Expand Down Expand Up @@ -396,21 +428,41 @@ def test_set_multiple_output_language():
meta_dict=None,
content="You are a help assistant.",
)
agent = ChatAgent(system_message=system_message)
agent_with_sys_msg = ChatAgent(system_message=system_message)
agent_without_sys_msg = ChatAgent()

# Verify that the length of the system message is kept constant even when
# multiple set_output_language operations are called
agent.set_output_language("Chinese")
agent.set_output_language("English")
agent.set_output_language("French")
updated_system_message = BaseMessage(
agent_with_sys_msg.set_output_language("Chinese")
agent_with_sys_msg.set_output_language("English")
agent_with_sys_msg.set_output_language("French")
agent_without_sys_msg.set_output_language("Chinese")
agent_without_sys_msg.set_output_language("English")
agent_without_sys_msg.set_output_language("French")

updated_system_message_with_content = BaseMessage(
role_name="assistant",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a help assistant."
"\nRegardless of the input language, you must output text in French.",
)
assert agent.system_message.content == updated_system_message.content
updated_system_message_without_content = BaseMessage(
role_name="assistant",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="\nRegardless of the input language, you must output text "
"in French.",
)

assert (
agent_with_sys_msg.system_message.content
== updated_system_message_with_content.content
)
assert (
agent_without_sys_msg.system_message.content
== updated_system_message_without_content.content
)


@pytest.mark.model_backend
Expand Down

0 comments on commit 1b6856d

Please sign in to comment.