Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update block label also updates the BlocksAgents table #2106

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion letta/orm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from letta.schemas.block import Human, Persona

if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm import BlocksAgents, Organization


class Block(OrganizationMixin, SqlalchemyBase):
Expand All @@ -35,6 +35,7 @@ class Block(OrganizationMixin, SqlalchemyBase):

# relationships
organization: Mapped[Optional["Organization"]] = relationship("Organization")
blocks_agents: Mapped[list["BlocksAgents"]] = relationship("BlocksAgents", back_populates="block", cascade="all, delete")

def to_pydantic(self) -> Type:
match self.label:
Expand Down
5 changes: 4 additions & 1 deletion letta/orm/blocks_agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
Expand Down Expand Up @@ -27,3 +27,6 @@ class BlocksAgents(SqlalchemyBase):
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
block_id: Mapped[str] = mapped_column(String, primary_key=True)
block_label: Mapped[str] = mapped_column(String, primary_key=True)

# relationships
block: Mapped["Block"] = relationship("Block", back_populates="blocks_agents")
13 changes: 13 additions & 0 deletions letta/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,19 @@ def _handle_dbapi_error(cls, e: DBAPIError):
"""Handle database errors and raise appropriate custom exceptions."""
orig = e.orig # Extract the original error from the DBAPIError
error_code = None
error_message = str(orig) if orig else str(e)
logger.info(f"Handling DBAPIError: {error_message}")

# Handle SQLite-specific errors
if "UNIQUE constraint failed" in error_message:
raise UniqueConstraintViolationError(
f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}"
) from e

if "FOREIGN KEY constraint failed" in error_message:
raise ForeignKeyConstraintViolationError(
f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}"
) from e

# For psycopg2
if hasattr(orig, "pgcode"):
Expand Down
2 changes: 1 addition & 1 deletion letta/schemas/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BaseBlock(LettaBase, validate_assignment=True):

@model_validator(mode="after")
def verify_char_limit(self) -> Self:
if len(self.value) > self.limit:
if self.value and len(self.value) > self.limit:
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}."
raise ValueError(error_msg)

Expand Down
18 changes: 17 additions & 1 deletion letta/services/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, Human, Persona
from letta.schemas.user import User as PydanticUser
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.utils import enforce_types, list_human_files, list_persona_files


Expand Down Expand Up @@ -38,13 +39,28 @@ def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticB
@enforce_types
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
blocks_agents_manager = BlocksAgentsManager()
agent_ids = []
if block_update.label:
agent_ids = blocks_agents_manager.list_agent_ids_with_block(block_id=block_id)
for agent_id in agent_ids:
blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id)

with self.session_maker() as session:
# Update block
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(block, key, value)
block.update(db_session=session, actor=actor)
return block.to_pydantic()

# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
if block_update.label:
for agent_id in agent_ids:
blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block_update.label)

return block.to_pydantic()

@enforce_types
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
Expand Down
9 changes: 8 additions & 1 deletion letta/services/blocks_agents_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,18 @@ def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_i

@enforce_types
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
"""List all blocks associated with a specific agent."""
"""List all block ids associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_id for record in blocks_agents_record]

@enforce_types
def list_block_labels_for_agent(self, agent_id: str) -> List[str]:
"""List all block labels associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_label for record in blocks_agents_record]

@enforce_types
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
"""List all agents associated with a specific block."""
Expand Down
19 changes: 18 additions & 1 deletion tests/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,6 @@ def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user):

# Assertions
assert e2b_config.timeout == 5 * 60
assert e2b_config.template
assert e2b_config.template == tool_settings.e2b_sandbox_template_id


Expand Down Expand Up @@ -1063,6 +1062,24 @@ def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
assert block_association.block_label == default_block.label


def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agent, default_user, default_block):
# Add the block
block_association = server.blocks_agents_manager.add_block_to_agent(
agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label
)
assert block_association.block_label == default_block.label

# Change the block label
new_label = "banana"
block = server.block_manager.update_block(block_id=default_block.id, block_update=BlockUpdate(label=new_label), actor=default_user)
assert block.label == new_label

# Get the association
labels = server.blocks_agents_manager.list_block_labels_for_agent(agent_id=sarah_agent.id)
assert new_label in labels
assert default_block.label not in labels


def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.blocks_agents_manager.add_block_to_agent(
Expand Down
Loading