Skip to content

Commit

Permalink
feat: allow passing in tags to client.create_agent(tags=[..]) (#2073)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Nov 20, 2024
1 parent 746efc4 commit 37d700a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
12 changes: 11 additions & 1 deletion letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def create_agent(
include_base_tools: Optional[bool] = True,
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> AgentState:
raise NotImplementedError

Expand All @@ -94,6 +95,7 @@ def update_agent(
embedding_config: Optional[EmbeddingConfig] = None,
message_ids: Optional[List[str]] = None,
memory: Optional[Memory] = None,
tags: Optional[List[str]] = None,
):
raise NotImplementedError

Expand Down Expand Up @@ -381,6 +383,7 @@ def create_agent(
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
tags: Optional[List[str]] = None,
) -> AgentState:
"""Create an agent
Expand All @@ -394,6 +397,7 @@ def create_agent(
include_base_tools (bool): Include base tools
metadata (Dict): Metadata
description (str): Description
tags (List[str]): Tags for filtering agents
Returns:
agent_state (AgentState): State of the created agent
Expand Down Expand Up @@ -435,6 +439,7 @@ def create_agent(
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
tags=tags,
)

# Use model_dump_json() instead of model_dump()
Expand Down Expand Up @@ -482,12 +487,12 @@ def update_agent(
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
message_ids: Optional[List[str]] = None,
memory: Optional[Memory] = None,
tags: Optional[List[str]] = None,
):
"""
Update an existing agent
Expand All @@ -503,6 +508,7 @@ def update_agent(
embedding_config (EmbeddingConfig): Embedding configuration
message_ids (List[str]): List of message IDs
memory (Memory): Memory configuration
tags (List[str]): Tags for filtering agents
Returns:
agent_state (AgentState): State of the updated agent
Expand Down Expand Up @@ -1669,6 +1675,7 @@ def create_agent(
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
tags: Optional[List[str]] = None,
) -> AgentState:
"""Create an agent
Expand All @@ -1683,6 +1690,7 @@ def create_agent(
include_base_tools (bool): Include base tools
metadata (Dict): Metadata
description (str): Description
tags (List[str]): Tags for filtering agents
Returns:
agent_state (AgentState): State of the created agent
Expand Down Expand Up @@ -1724,6 +1732,7 @@ def create_agent(
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
tags=tags,
),
actor=self.user,
)
Expand Down Expand Up @@ -1780,6 +1789,7 @@ def update_agent(
embedding_config (EmbeddingConfig): Embedding configuration
message_ids (List[str]): List of message IDs
memory (Memory): Memory configuration
tags (List[str]): Tags for filtering agents
Returns:
agent_state (AgentState): State of the updated agent
Expand Down
10 changes: 9 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,7 @@ def create_agent(
memory=request.memory,
description=request.description,
metadata_=request.metadata_,
tags=request.tags,
)
if request.agent_type == AgentType.memgpt_agent:
agent = Agent(
Expand Down Expand Up @@ -904,8 +905,15 @@ def create_agent(
save_agent(agent, self.ms)
logger.debug(f"Created new agent from config: {agent}")

# TODO: move this into save_agent. save_agent should be moved to server.py
if request.tags:
for tag in request.tags:
self.agents_tags_manager.add_tag_to_agent(agent_id=agent.agent_state.id, tag=tag, actor=actor)

assert isinstance(agent.agent_state.memory, Memory), f"Invalid memory type: {type(agent_state.memory)}"
# return AgentState

# TODO: remove (hacky)
agent.agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent.agent_state.id, actor=actor)

return agent.agent_state

Expand Down
6 changes: 5 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,13 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a
"""
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
"""
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]

# Step 0: create an agent with tags
tagged_agent = client.create_agent(tags=tags_to_add)
assert set(tagged_agent.tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {tagged_agent.tags}"

# Step 1: Add multiple tags to the agent
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]
client.update_agent(agent_id=agent.id, tags=tags_to_add)

# Step 2: Retrieve tags for the agent and verify they match the added tags
Expand Down

0 comments on commit 37d700a

Please sign in to comment.