From d417af60b719a95fd1ad1c6e1d31a1242c1e2554 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 25 Nov 2024 18:18:24 -0800 Subject: [PATCH 1/3] wip --- letta/orm/block.py | 7 ++++++- letta/orm/blocks_agents.py | 5 ++++- letta/orm/sqlalchemy_base.py | 13 +++++++++++++ letta/schemas/block.py | 18 +++++++++--------- letta/services/block_manager.py | 9 +++++++++ letta/services/blocks_agents_manager.py | 9 ++++++++- tests/test_managers.py | 18 ++++++++++++++++++ 7 files changed, 67 insertions(+), 12 deletions(-) diff --git a/letta/orm/block.py b/letta/orm/block.py index ab7e40802e..8242e4dc3a 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -4,13 +4,17 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT +from letta.log import get_logger from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import Human, Persona if TYPE_CHECKING: - from letta.orm.organization import Organization + from letta.orm import BlocksAgents, Organization + + +logger = get_logger(__name__) class Block(OrganizationMixin, SqlalchemyBase): @@ -35,6 +39,7 @@ class Block(OrganizationMixin, SqlalchemyBase): # relationships organization: Mapped[Optional["Organization"]] = relationship("Organization") + blocks_agents: Mapped[list["BlocksAgents"]] = relationship("BlocksAgents", back_populates="block", cascade="all, delete") def to_pydantic(self) -> Type: match self.label: diff --git a/letta/orm/blocks_agents.py b/letta/orm/blocks_agents.py index 31f0fa9d34..a344964690 100644 --- a/letta/orm/blocks_agents.py +++ b/letta/orm/blocks_agents.py @@ -1,5 +1,5 @@ from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents @@ -27,3 +27,6 @@ class BlocksAgents(SqlalchemyBase): agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) block_id: Mapped[str] = mapped_column(String, primary_key=True) block_label: Mapped[str] = mapped_column(String, primary_key=True) + + # relationships + block: Mapped["Block"] = relationship("Block", back_populates="blocks_agents") diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 8eeefac8da..84de1ec303 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -180,6 +180,19 @@ def _handle_dbapi_error(cls, e: DBAPIError): """Handle database errors and raise appropriate custom exceptions.""" orig = e.orig # Extract the original error from the DBAPIError error_code = None + error_message = str(orig) if orig else str(e) + logger.info(f"Handling DBAPIError: {error_message}") + + # Handle SQLite-specific errors + if "UNIQUE constraint failed" in error_message: + raise UniqueConstraintViolationError( + f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}" + ) from e + + if "FOREIGN KEY constraint failed" in error_message: + raise ForeignKeyConstraintViolationError( + f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}" + ) from e # For psycopg2 if hasattr(orig, "pgcode"): diff --git a/letta/schemas/block.py b/letta/schemas/block.py index b3acc8666e..472b2d6abd 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -1,7 +1,6 @@ from typing import Optional -from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self +from pydantic import BaseModel, Field from letta.schemas.letta_base import LettaBase @@ -28,13 +27,14 @@ class BaseBlock(LettaBase, validate_assignment=True): description: Optional[str] = Field(None, description="Description of the block.") metadata_: Optional[dict] = Field({}, description="Metadata of the block.") - @model_validator(mode="after") - def verify_char_limit(self) -> Self: - if len(self.value) > self.limit: - error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}." - raise ValueError(error_msg) - - return self + # @model_validator(mode="after") + # def verify_char_limit(self) -> Self: + # + # if len(self.value) > self.limit: + # error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}." + # raise ValueError(error_msg) + # + # return self # def __len__(self): # return len(self.value) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index c559d05ac9..e2a69b3fe3 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -1,6 +1,7 @@ import os from typing import List, Optional +from letta.orm import BlocksAgents as BlocksAgentsModel from letta.orm.block import Block as BlockModel from letta.orm.errors import NoResultFound from letta.schemas.block import Block @@ -39,11 +40,19 @@ def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticB def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" with self.session_maker() as session: + # Update block block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) update_data = block_update.model_dump(exclude_unset=True, exclude_none=True) for key, value in update_data.items(): setattr(block, key, value) block.update(db_session=session, actor=actor) + + # TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents + if block_update.label: + blocks_agents_record = BlocksAgentsModel.read(db_session=session, block_id=block_id) + setattr(blocks_agents_record, "block_label", block_update.label) + blocks_agents_record.update(db_session=session) + return block.to_pydantic() @enforce_types diff --git a/letta/services/blocks_agents_manager.py b/letta/services/blocks_agents_manager.py index bbc5bfc042..586a581aa1 100644 --- a/letta/services/blocks_agents_manager.py +++ b/letta/services/blocks_agents_manager.py @@ -71,11 +71,18 @@ def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_i @enforce_types def list_block_ids_for_agent(self, agent_id: str) -> List[str]: - """List all blocks associated with a specific agent.""" + """List all block ids associated with a specific agent.""" with self.session_maker() as session: blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id) return [record.block_id for record in blocks_agents_record] + @enforce_types + def list_block_labels_for_agent(self, agent_id: str) -> List[str]: + """List all block labels associated with a specific agent.""" + with self.session_maker() as session: + blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id) + return [record.block_label for record in blocks_agents_record] + @enforce_types def list_agent_ids_with_block(self, block_id: str) -> List[str]: """List all agents associated with a specific block.""" diff --git a/tests/test_managers.py b/tests/test_managers.py index 05e785917d..a986996480 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1063,6 +1063,24 @@ def test_add_block_to_agent(server, sarah_agent, default_user, default_block): assert block_association.block_label == default_block.label +def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agent, default_user, default_block): + # Add the block + block_association = server.blocks_agents_manager.add_block_to_agent( + agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label + ) + assert block_association.block_label == default_block.label + + # Change the block label + new_label = "banana" + block = server.block_manager.update_block(block_id=default_block.id, block_update=BlockUpdate(label=new_label), actor=default_user) + assert block.label == new_label + + # Get the association + labels = server.blocks_agents_manager.list_block_labels_for_agent(agent_id=sarah_agent.id) + assert new_label in labels + assert default_block.label not in labels + + def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user): with pytest.raises(ForeignKeyConstraintViolationError): server.blocks_agents_manager.add_block_to_agent( From fd3b1dd06cc29499ff127efc3a020da3715673e5 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 25 Nov 2024 18:26:43 -0800 Subject: [PATCH 2/3] Finish --- letta/services/block_manager.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index e2a69b3fe3..dcae5f5cd6 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -1,13 +1,13 @@ import os from typing import List, Optional -from letta.orm import BlocksAgents as BlocksAgentsModel from letta.orm.block import Block as BlockModel from letta.orm.errors import NoResultFound from letta.schemas.block import Block from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, Human, Persona from letta.schemas.user import User as PydanticUser +from letta.services.blocks_agents_manager import BlocksAgentsManager from letta.utils import enforce_types, list_human_files, list_persona_files @@ -39,6 +39,14 @@ def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticB @enforce_types def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" + # TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents + blocks_agents_manager = BlocksAgentsManager() + agent_ids = [] + if block_update.label: + agent_ids = blocks_agents_manager.list_agent_ids_with_block(block_id=block_id) + for agent_id in agent_ids: + blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id) + with self.session_maker() as session: # Update block block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) @@ -47,13 +55,12 @@ def update_block(self, block_id: str, block_update: BlockUpdate, actor: Pydantic setattr(block, key, value) block.update(db_session=session, actor=actor) - # TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents - if block_update.label: - blocks_agents_record = BlocksAgentsModel.read(db_session=session, block_id=block_id) - setattr(blocks_agents_record, "block_label", block_update.label) - blocks_agents_record.update(db_session=session) + # TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents + if block_update.label: + for agent_id in agent_ids: + blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block_update.label) - return block.to_pydantic() + return block.to_pydantic() @enforce_types def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock: From bc4a29fdcf6e87630a8d654b40a8f9fc9216b260 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 25 Nov 2024 18:29:24 -0800 Subject: [PATCH 3/3] Finish --- letta/orm/block.py | 4 ---- letta/schemas/block.py | 18 +++++++++--------- tests/test_managers.py | 1 - 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/letta/orm/block.py b/letta/orm/block.py index 8242e4dc3a..84bbdb7e71 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -4,7 +4,6 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT -from letta.log import get_logger from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.block import Block as PydanticBlock @@ -14,9 +13,6 @@ from letta.orm import BlocksAgents, Organization -logger = get_logger(__name__) - - class Block(OrganizationMixin, SqlalchemyBase): """Blocks are sections of the LLM context, representing a specific part of the total Memory""" diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 472b2d6abd..6679d50357 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -1,6 +1,7 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self from letta.schemas.letta_base import LettaBase @@ -27,14 +28,13 @@ class BaseBlock(LettaBase, validate_assignment=True): description: Optional[str] = Field(None, description="Description of the block.") metadata_: Optional[dict] = Field({}, description="Metadata of the block.") - # @model_validator(mode="after") - # def verify_char_limit(self) -> Self: - # - # if len(self.value) > self.limit: - # error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}." - # raise ValueError(error_msg) - # - # return self + @model_validator(mode="after") + def verify_char_limit(self) -> Self: + if self.value and len(self.value) > self.limit: + error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}." + raise ValueError(error_msg) + + return self # def __len__(self): # return len(self.value) diff --git a/tests/test_managers.py b/tests/test_managers.py index a986996480..218296ec80 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -925,7 +925,6 @@ def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user): # Assertions assert e2b_config.timeout == 5 * 60 - assert e2b_config.template assert e2b_config.template == tool_settings.e2b_sandbox_template_id