Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error in agent when use_memory is False #74

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/agentscope/agents/dialog_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class DialogAgent(AgentBase):
def __init__(
self,
name: str,
sys_prompt: Optional[str] = None,
model_config_name: str = None,
sys_prompt: str,
model_config_name: str,
use_memory: bool = True,
memory_config: Optional[dict] = None,
prompt_type: Optional[PromptType] = PromptType.LIST,
Expand All @@ -29,7 +29,7 @@ def __init__(
sys_prompt (`Optional[str]`):
The system prompt of the agent, which can be passed by args
or hard-coded in the agent.
model_config_name (`str`, defaults to None):
model_config_name (`str`):
The name of the model config, which is used to load model from
configuration.
use_memory (`bool`, defaults to `True`):
Expand Down Expand Up @@ -68,13 +68,13 @@ def reply(self, x: dict = None) -> dict:
response to the user's input.
"""
# record the input if needed
if x is not None:
if not self.memory:
DavdGao marked this conversation as resolved.
Show resolved Hide resolved
self.memory.add(x)

# prepare prompt
prompt = self.engine.join(
self.sys_prompt,
self.memory.get_memory(),
self.memory and self.memory.get_memory(),
)

# call llm and generate response
Expand All @@ -85,6 +85,7 @@ def reply(self, x: dict = None) -> dict:
self.speak(msg)

# Record the message in memory
self.memory.add(msg)
if not self.memory:
self.memory.add(msg)

return msg
11 changes: 6 additions & 5 deletions src/agentscope/agents/dict_dialog_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class DictDialogAgent(AgentBase):
def __init__(
self,
name: str,
sys_prompt: Optional[str] = None,
model_config_name: str = None,
sys_prompt: str,
model_config_name: str,
use_memory: bool = True,
memory_config: Optional[dict] = None,
parse_func: Optional[Callable[..., Any]] = parse_dict,
Expand Down Expand Up @@ -127,13 +127,13 @@ def reply(self, x: dict = None) -> dict:
it defaults to treating the response as plain text.
"""
# record the input if needed
if x is not None:
if not self.memory:
self.memory.add(x)

# prepare prompt
prompt = self.engine.join(
self.sys_prompt,
self.memory.get_memory(),
self.memory and self.memory.get_memory(),
)

# call llm
Expand All @@ -158,6 +158,7 @@ def reply(self, x: dict = None) -> dict:
self.speak(msg)

# record to memory
self.memory.add(msg)
if not self.memory:
self.memory.add(msg)

return msg
12 changes: 8 additions & 4 deletions src/agentscope/agents/text_to_image_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class TextToImageAgent(AgentBase):
def __init__(
self,
name: str,
sys_prompt: Optional[str] = None,
model_config_name: str = None,
sys_prompt: str,
model_config_name: str,
use_memory: bool = True,
memory_config: Optional[dict] = None,
) -> None:
Expand Down Expand Up @@ -44,8 +44,9 @@ def __init__(
)

def reply(self, x: dict = None) -> dict:
if x is not None:
if not self.memory and x is not None:
self.memory.add(x)

image_urls = self.model(x.content).image_urls
# TODO: optimize the construction of content
msg = Msg(
Expand All @@ -54,5 +55,8 @@ def reply(self, x: dict = None) -> dict:
url=image_urls,
)
logger.chat(msg)
self.memory.add(msg)

if not self.memory:
self.memory.add(msg)

return msg
5 changes: 3 additions & 2 deletions src/agentscope/agents/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def reply(
the user's input and any additional details. This is also
stored in the object's memory.
"""
if x is not None:
if not self.memory and x is not None:
self.memory.add(x)

# TODO: To avoid order confusion, because `input` print much quicker
Expand Down Expand Up @@ -91,6 +91,7 @@ def reply(
)

# Add to memory
self.memory.add(msg)
if not self.memory:
self.memory.add(msg)

return msg
2 changes: 1 addition & 1 deletion src/agentscope/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_memory(
"""

@abstractmethod
def add(self, memories: Union[list[dict], dict]) -> None:
def add(self, memories: Union[list[dict], dict, None]) -> None:
"""
Adding new memory fragment, depending on how the memory are stored
"""
Expand Down
5 changes: 4 additions & 1 deletion src/agentscope/memory/temporary_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ def __init__(

def add(
self,
memories: Union[Sequence[dict], dict],
memories: Union[Sequence[dict], dict, None],
embed: bool = False,
) -> None:
if memories is None:
return

if not isinstance(memories, list):
record_memories = [memories]
else:
Expand Down
4 changes: 4 additions & 0 deletions src/agentscope/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def join(
converted to `Msg` from `system`.
"""
# TODO: achieve the summarize function

# Filter `None`
args = [_ for _ in args if _ is not None]

if self.prompt_type == PromptType.STRING:
return self.join_to_str(*args, format_map=format_map)
elif self.prompt_type == PromptType.LIST:
Expand Down