Skip to content

Commit

Permalink
Clean up legacy code
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 committed Oct 28, 2024
1 parent 5501531 commit 9830211
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 50 deletions.
30 changes: 2 additions & 28 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ def get_messages(
) -> List[Message]:
raise NotImplementedError

def list_models(self) -> List[LLMConfig]:
def list_model_configs(self) -> List[LLMConfig]:
raise NotImplementedError

def list_embedding_models(self) -> List[EmbeddingConfig]:
def list_embedding_configs(self) -> List[EmbeddingConfig]:
raise NotImplementedError


Expand Down Expand Up @@ -1234,32 +1234,6 @@ def detach_source(self, source_id: str, agent_id: str):
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return Source(**response.json())

# server configuration commands

def list_models(self):
"""
List available LLM models
Returns:
models (List[LLMConfig]): List of LLM models
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list models: {response.text}")
return [LLMConfig(**model) for model in response.json()]

def list_embedding_models(self):
"""
List available embedding models
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list embedding models: {response.text}")
return [EmbeddingConfig(**model) for model in response.json()]

# tools

def get_tool_id(self, tool_name: str):
Expand Down
22 changes: 21 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import time
import uuid
from typing import Union
from typing import List, Union

import pytest
from dotenv import load_dotenv
Expand All @@ -18,6 +18,7 @@
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client

# from tests.utils import create_config
Expand Down Expand Up @@ -486,3 +487,22 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
def test_organization(client: RESTClient):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")


def test_list_llm_models(client: RESTClient):
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""

def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool:
return any(model.model_endpoint_type == target_type for model in models)

models = client.list_llm_configs()
if model_settings.groq_api_key:
assert has_model_endpoint_type(models, "groq")
if model_settings.azure_api_key:
assert has_model_endpoint_type(models, "azure")
if model_settings.openai_api_key:
assert has_model_endpoint_type(models, "openai")
if model_settings.gemini_api_key:
assert has_model_endpoint_type(models, "google_ai")
if model_settings.anthropic_api_key:
assert has_model_endpoint_type(models, "anthropic")
21 changes: 0 additions & 21 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import json
import uuid
import warnings
from typing import List

import pytest

import letta.utils as utils
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.enums import MessageRole
from letta.settings import model_settings

utils.DEBUG = True
from letta.config import LettaConfig
Expand Down Expand Up @@ -543,22 +541,3 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
+ overview.num_tokens_functions_definitions
+ overview.num_tokens_external_memory_summary
)


def test_list_llm_models(server: SyncServer):
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""

def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool:
return any(model.model_endpoint_type == target_type for model in models)

models = server.list_llm_models()
if model_settings.groq_api_key:
assert has_model_endpoint_type(models, "groq")
if model_settings.azure_api_key:
assert has_model_endpoint_type(models, "azure")
if model_settings.openai_api_key:
assert has_model_endpoint_type(models, "openai")
if model_settings.gemini_api_key:
assert has_model_endpoint_type(models, "google_ai")
if model_settings.anthropic_api_key:
assert has_model_endpoint_type(models, "anthropic")

0 comments on commit 9830211

Please sign in to comment.