Skip to content

Commit

Permalink
feat: add GET route to get the breakdown of an agent's context wind…
Browse files Browse the repository at this point in the history
…ow (#1889)
  • Loading branch information
cpacker authored Oct 15, 2024
1 parent 4fd82ee commit cca1cb8
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 2 deletions.
68 changes: 67 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_messages
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore
from letta.persistence_manager import LocalStateManager
from letta.schemas.agent import AgentState, AgentStepResponse
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, OptionState
from letta.schemas.memory import Memory
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import (
Expand Down Expand Up @@ -1421,6 +1422,71 @@ def retry_message(self) -> List[Message]:
assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects"
return messages

def get_context_window(self) -> ContextWindowOverview:
"""Get the context window of the agent"""

system_prompt = self.agent_state.system # TODO is this the current system or the initial system?
num_tokens_system = count_tokens(system_prompt)
core_memory = self.memory.compile()
num_tokens_core_memory = count_tokens(core_memory)

# conversion of messages to OpenAI dict format, which is passed to the token counter
messages_openai_format = self.messages

# Check if there's a summary message in the message queue
if (
len(self._messages) > 1
and self._messages[1].role == MessageRole.user
and isinstance(self._messages[1].text, str)
# TODO remove hardcoding
and "The following is a summary of the previous " in self._messages[1].text
):
# Summary message exists
assert self._messages[1].text is not None
summary_memory = self._messages[1].text
num_tokens_summary_memory = count_tokens(self._messages[1].text)
# with a summary message, the real messages start at index 2
num_tokens_messages = (
num_tokens_from_messages(messages=messages_openai_format[2:], model=self.model) if len(messages_openai_format) > 2 else 0
)

else:
summary_memory = None
num_tokens_summary_memory = 0
# with no summary message, the real messages start at index 1
num_tokens_messages = (
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
)

num_archival_memory = self.persistence_manager.archival_memory.storage.size()
num_recall_memory = self.persistence_manager.recall_memory.storage.size()
external_memory_summary = compile_memory_metadata_block(
memory_edit_timestamp=get_utc_time(), # dummy timestamp
archival_memory=self.persistence_manager.archival_memory,
recall_memory=self.persistence_manager.recall_memory,
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)

return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(self._messages),
num_archival_memory=num_archival_memory,
num_recall_memory=num_recall_memory,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
# top-level information
context_window_size_max=self.agent_state.llm_config.context_window,
context_window_size_current=num_tokens_system + num_tokens_core_memory + num_tokens_summary_memory + num_tokens_messages,
# context window breakdown (in tokens)
num_tokens_system=num_tokens_system,
system_prompt=system_prompt,
num_tokens_core_memory=num_tokens_core_memory,
core_memory=core_memory,
num_tokens_summary_memory=num_tokens_summary_memory,
summary_memory=summary_memory,
num_tokens_messages=num_tokens_messages,
messages=self._messages,
)


def save_agent(agent: Agent, ms: MetadataStore):
"""Save agent to metadata store"""
Expand Down
37 changes: 37 additions & 0 deletions letta/schemas/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,43 @@
from letta.agent import Agent

from letta.schemas.block import Block
from letta.schemas.message import Message


class ContextWindowOverview(BaseModel):
"""
Overview of the context window, including the number of messages and tokens.
"""

# top-level information
context_window_size_max: int = Field(..., description="The maximum amount of tokens the context window can hold.")
context_window_size_current: int = Field(..., description="The current number of tokens in the context window.")

# context window breakdown (in messages)
# (technically not in the context window, but useful to know)
num_messages: int = Field(..., description="The number of messages in the context window.")
num_archival_memory: int = Field(..., description="The number of messages in the archival memory.")
num_recall_memory: int = Field(..., description="The number of messages in the recall memory.")
num_tokens_external_memory_summary: int = Field(
..., description="The number of tokens in the external memory summary (archival + recall metadata)."
)

# context window breakdown (in tokens)
# this should all add up to context_window_size_current

num_tokens_system: int = Field(..., description="The number of tokens in the system prompt.")
system_prompt: str = Field(..., description="The content of the system prompt.")

num_tokens_core_memory: int = Field(..., description="The number of tokens in the core memory.")
core_memory: str = Field(..., description="The content of the core memory.")

num_tokens_summary_memory: int = Field(..., description="The number of tokens in the summary memory.")
summary_memory: Optional[str] = Field(None, description="The content of the summary memory.")

num_tokens_messages: int = Field(..., description="The number of tokens in the messages list.")
# TODO make list of messages?
# messages: List[dict] = Field(..., description="The messages in the context window.")
messages: List[Message] = Field(..., description="The messages in the context window.")


class Memory(BaseModel, validate_assignment=True):
Expand Down
15 changes: 15 additions & 0 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from letta.schemas.memory import (
ArchivalMemorySummary,
BasicBlockMemory,
ContextWindowOverview,
CreateArchivalMemory,
Memory,
RecallMemorySummary,
Expand Down Expand Up @@ -51,6 +52,20 @@ def list_agents(
return server.list_agents(user_id=actor.id)


@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="get_agent_context_window")
def get_agent_context_window(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the context window of a specific agent.
"""
actor = server.get_user_or_default(user_id=user_id)

return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)


@router.post("/", response_model=AgentState, operation_id="create_agent")
def create_agent(
agent: CreateAgent = Body(...),
Expand Down
17 changes: 16 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.memory import (
ArchivalMemorySummary,
ContextWindowOverview,
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage
Expand Down Expand Up @@ -2019,3 +2024,13 @@ def add_llm_model(self, request: LLMConfig) -> LLMConfig:

def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
"""Add a new embedding model"""

def get_agent_context_window(
self,
user_id: str,
agent_id: str,
) -> ContextWindowOverview:

# Get the current message
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.get_context_window()

0 comments on commit cca1cb8

Please sign in to comment.