Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable base constructs to automatically populate "created_by" and "last_updated_by" fields for relevant objects #1944

Merged
merged 8 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions letta/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def run(
)

# create agent
tools = [
server.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=client.user_id) for tool_name in agent_state.tools
]
tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tools]
letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)

else: # create new agent
Expand Down Expand Up @@ -300,7 +298,7 @@ def run(
)
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
tools = [server.tool_manager.get_tool_by_name_and_user_id(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
tools = [server.tool_manager.get_tool_by_name(tool_name, actor=client.user) for tool_name in agent_state.tools]

letta_agent = Agent(
interface=interface(),
Expand Down
30 changes: 15 additions & 15 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,9 @@ def __init__(
# get default user
self.user_id = self.server.user_manager.DEFAULT_USER_ID

self.user = self.server.get_user_or_default(self.user_id)
self.organization = self.server.get_organization_or_default(self.org_id)

# agents
def list_agents(self) -> List[AgentState]:
self.interface.clear()
Expand Down Expand Up @@ -1648,7 +1651,7 @@ def create_agent(
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
),
user_id=self.user_id,
actor=self.user,
)
return agent_state

Expand Down Expand Up @@ -1720,7 +1723,7 @@ def update_agent(
message_ids=message_ids,
memory=memory,
),
user_id=self.user_id,
actor=self.user,
)
return agent_state

Expand Down Expand Up @@ -2198,24 +2201,22 @@ def delete_human(self, id: str):
def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
tool_create = ToolCreate.from_langchain(
langchain_tool=langchain_tool,
user_id=self.user_id,
organization_id=self.org_id,
additional_imports_module_attr_map=additional_imports_module_attr_map,
)
return self.server.tool_manager.create_or_update_tool(tool_create)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)

def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
tool_create = ToolCreate.from_crewai(
crewai_tool=crewai_tool,
additional_imports_module_attr_map=additional_imports_module_attr_map,
user_id=self.user_id,
organization_id=self.org_id,
)
return self.server.tool_manager.create_or_update_tool(tool_create)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)

def load_composio_tool(self, action: "ActionType") -> Tool:
tool_create = ToolCreate.from_composio(action=action, user_id=self.user_id, organization_id=self.org_id)
return self.server.tool_manager.create_or_update_tool(tool_create)
tool_create = ToolCreate.from_composio(action=action, organization_id=self.org_id)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)

# TODO: Use the above function `add_tool` here as there is duplicate logic
def create_tool(
Expand Down Expand Up @@ -2250,14 +2251,13 @@ def create_tool(
# call server function
return self.server.tool_manager.create_or_update_tool(
ToolCreate(
user_id=self.user_id,
organization_id=self.org_id,
source_type=source_type,
source_code=source_code,
name=name,
tags=tags,
terminal=terminal,
),
actor=self.user,
)

def update_tool(
Expand Down Expand Up @@ -2289,7 +2289,7 @@ def update_tool(
# Filter out any None values from the dictionary
update_data = {key: value for key, value in update_data.items() if value is not None}

return self.server.tool_manager.update_tool_by_id(id, ToolUpdate(**update_data))
return self.server.tool_manager.update_tool_by_id(tool_id=id, tool_update=ToolUpdate(**update_data), actor=self.user)

def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
"""
Expand All @@ -2298,7 +2298,7 @@ def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) ->
Returns:
tools (List[Tool]): List of tools
"""
return self.server.tool_manager.list_tools_for_org(cursor=cursor, limit=limit, organization_id=self.org_id)
return self.server.tool_manager.list_tools(cursor=cursor, limit=limit, actor=self.user)

def get_tool(self, id: str) -> Optional[Tool]:
"""
Expand All @@ -2310,7 +2310,7 @@ def get_tool(self, id: str) -> Optional[Tool]:
Returns:
tool (Tool): Tool
"""
return self.server.tool_manager.get_tool_by_id(id)
return self.server.tool_manager.get_tool_by_id(id, actor=self.user)

def delete_tool(self, id: str):
"""
Expand All @@ -2319,7 +2319,7 @@ def delete_tool(self, id: str):
Args:
id (str): ID of the tool
"""
return self.server.tool_manager.delete_tool_by_id(id)
return self.server.tool_manager.delete_tool_by_id(id, user_id=self.user_id)

def get_tool_id(self, name: str) -> Optional[str]:
"""
Expand All @@ -2331,7 +2331,7 @@ def get_tool_id(self, name: str) -> Optional[str]:
Returns:
id (str): ID of the tool (`None` if not found)
"""
tool = self.server.tool_manager.get_tool_by_name_and_org_id(tool_name=name, organization_id=self.org_id)
tool = self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user)
return tool.id

def load_data(self, connector: DataConnector, source_name: str):
Expand Down
15 changes: 10 additions & 5 deletions letta/orm/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from datetime import datetime
from typing import Optional
from uuid import UUID

from sqlalchemy import UUID as SQLUUID
from sqlalchemy import Boolean, DateTime, func, text
from sqlalchemy import Boolean, DateTime, String, func, text
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Expand All @@ -25,6 +23,13 @@ class CommonSqlalchemyMetaMixins(Base):
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now(), server_onupdate=func.now())
is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE"))

def _set_created_and_updated_by_fields(self, actor_id: str) -> None:
"""Populate created_by_id and last_updated_by_id based on actor."""
if not self.created_by_id:
self.created_by_id = actor_id
# Always set the last_updated_by_id when updating
self.last_updated_by_id = actor_id

@declared_attr
def _created_by_id(cls):
return cls._user_by_id()
Expand All @@ -38,7 +43,7 @@ def _user_by_id(cls):
"""a flexible non-constrained record of a user.
This way users can get added, deleted etc without history freaking out
"""
return mapped_column(SQLUUID(), nullable=True)
return mapped_column(String, nullable=True)

@property
def last_updated_by_id(self) -> Optional[str]:
Expand Down Expand Up @@ -72,4 +77,4 @@ def _user_id_setter(self, prop: str, value: str) -> None:
return
prefix, id_ = value.split("-", 1)
assert prefix == "user", f"{prefix} is not a valid id prefix for a user id"
setattr(self, full_prop, UUID(id_))
setattr(self, full_prop, id_)
113 changes: 50 additions & 63 deletions letta/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
from uuid import UUID, uuid4
from typing import TYPE_CHECKING, List, Literal, Optional, Type
from uuid import uuid4

from humps import depascalize
from sqlalchemy import Boolean, String, select
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_uid_from_identifier(cls, identifier: str, indifferent: Optional[bool] =
def read(
cls,
db_session: "Session",
identifier: Union[str, UUID],
identifier: Optional[str] = None,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
**kwargs,
Expand All @@ -105,71 +105,76 @@ def read(
Raises:
NoResultFound: if the object is not found
"""
del kwargs # arity for more complex reads
identifier = cls.get_uid_from_identifier(identifier)
query = select(cls).where(cls._id == identifier)
# if actor:
# query = cls.apply_access_predicate(query, actor, access)
# Start the query
query = select(cls)

# If an identifier is provided, add it to the query conditions
if identifier is not None:
identifier = cls.get_uid_from_identifier(identifier)
query = query.where(cls._id == identifier)

if kwargs:
query = query.filter_by(**kwargs)

if actor:
query = cls.apply_access_predicate(query, actor, access)

if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
if found := db_session.execute(query).scalar():
return found
raise NoResultFound(f"{cls.__name__} with id {identifier} not found")

def create(self, db_session: "Session") -> Type["SqlalchemyBase"]:
# self._infer_organization(db_session)
def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
if actor:
self._set_created_and_updated_by_fields(actor.id)

with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self

def delete(self, db_session: "Session") -> Type["SqlalchemyBase"]:
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
if actor:
self._set_created_and_updated_by_fields(actor.id)

self.is_deleted = True
return self.update(db_session)

def update(self, db_session: "Session") -> Type["SqlalchemyBase"]:
def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
if actor:
self._set_created_and_updated_by_fields(actor.id)

with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self

@classmethod
def read_or_create(cls, *, db_session: "Session", **kwargs) -> Type["SqlalchemyBase"]:
"""get an instance by search criteria or create it if it doesn't exist"""
try:
return cls.read(db_session=db_session, identifier=kwargs.get("id", None))
except NoResultFound:
clean_kwargs = {k: v for k, v in kwargs.items() if k in cls.__table__.columns}
return cls(**clean_kwargs).create(db_session=db_session)

# TODO: Add back later when access predicates are actually important
# The idea behind this is that you can add a WHERE clause restricting the actions you can take, e.g. R/W
# @classmethod
# def apply_access_predicate(
# cls,
# query: "Select",
# actor: "User",
# access: List[Literal["read", "write", "admin"]],
# ) -> "Select":
# """applies a WHERE clause restricting results to the given actor and access level
# Args:
# query: The initial sqlalchemy select statement
# actor: The user acting on the query. **Note**: this is called 'actor' to identify the
# person or system acting. Users can act on users, making naming very sticky otherwise.
# access:
# what mode of access should the query restrict to? This will be used with granular permissions,
# but because of how it will impact every query we want to be explicitly calling access ahead of time.
# Returns:
# the sqlalchemy select statement restricted to the given access.
# """
# del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
# org_uid = getattr(actor, "_organization_id", getattr(actor.organization, "_id", None))
# if not org_uid:
# raise ValueError("object %s has no organization accessor", actor)
# return query.where(cls._organization_id == org_uid, cls.is_deleted == False)
def apply_access_predicate(
cls,
query: "Select",
actor: "User",
access: List[Literal["read", "write", "admin"]],
) -> "Select":
"""applies a WHERE clause restricting results to the given actor and access level
Args:
query: The initial sqlalchemy select statement
actor: The user acting on the query. **Note**: this is called 'actor' to identify the
person or system acting. Users can act on users, making naming very sticky otherwise.
access:
what mode of access should the query restrict to? This will be used with granular permissions,
but because of how it will impact every query we want to be explicitly calling access ahead of time.
Returns:
the sqlalchemy select statement restricted to the given access.
"""
del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
org_id = getattr(actor, "organization_id", None)
if not org_id:
raise ValueError(f"object {actor} has no organization accessor")
return query.where(cls._organization_id == cls.get_uid_from_identifier(org_id, indifferent=True), cls.is_deleted == False)

@property
def __pydantic_model__(self) -> Type["BaseModel"]:
Expand All @@ -183,21 +188,3 @@ def to_record(self) -> Type["BaseModel"]:
"""Deprecated accessor for to_pydantic"""
logger.warning("to_record is deprecated, use to_pydantic instead.")
return self.to_pydantic()

def _infer_organization(self, db_session: "Session") -> None:
"""🪄 MAGIC ALERT! 🪄
Because so much of the original API is centered around user scopes,
this allows us to continue with that scope and then infer the org from the creating user.

IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
If not do nothing to the object. Mutates in place.
"""
if self.created_by_id and hasattr(self, "_organization_id"):
try:
from letta.orm.user import User # to avoid circular import

created_by = User.read(db_session, self.created_by_id)
except NoResultFound:
logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
return
self._organization_id = created_by._organization_id
15 changes: 3 additions & 12 deletions letta/orm/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@

# TODO everything in functions should live in this model
from letta.orm.enums import ToolSourceType
from letta.orm.mixins import OrganizationMixin, UserMixin
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.tool import Tool as PydanticTool

if TYPE_CHECKING:
pass

from letta.orm.organization import Organization
from letta.orm.user import User


class Tool(SqlalchemyBase, OrganizationMixin, UserMixin):
class Tool(SqlalchemyBase, OrganizationMixin):
"""Represents an available tool that the LLM can invoke.

NOTE: polymorphic inheritance makes more sense here as a TODO. We want a superset of tools
Expand All @@ -29,10 +26,7 @@ class Tool(SqlalchemyBase, OrganizationMixin, UserMixin):

# Add unique constraint on (name, _organization_id)
# An organization should not have multiple tools with the same name
__table_args__ = (
UniqueConstraint("name", "_organization_id", name="uix_name_organization"),
UniqueConstraint("name", "_user_id", name="uix_name_user"),
)
__table_args__ = (UniqueConstraint("name", "_organization_id", name="uix_name_organization"),)

name: Mapped[str] = mapped_column(doc="The display name of the tool.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
Expand All @@ -48,7 +42,4 @@ class Tool(SqlalchemyBase, OrganizationMixin, UserMixin):
# This was an intentional decision by Sarah

# relationships
# TODO: Possibly add in user in the future
# This will require some more thought and justification to add this in.
user: Mapped["User"] = relationship("User", back_populates="tools", lazy="selectin")
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
4 changes: 1 addition & 3 deletions letta/orm/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING

from sqlalchemy.orm import Mapped, mapped_column, relationship

Expand All @@ -8,7 +8,6 @@

if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.tool import Tool


class User(SqlalchemyBase, OrganizationMixin):
Expand All @@ -21,7 +20,6 @@ class User(SqlalchemyBase, OrganizationMixin):

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="user", cascade="all, delete-orphan")

# TODO: Add this back later potentially
# agents: Mapped[List["Agent"]] = relationship(
Expand Down
Loading
Loading