Skip to content

Commit

Permalink
fix: Clean up some legacy code and fix Groq provider (#1950)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Oct 28, 2024
1 parent 9302c56 commit c372234
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 76 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ env:
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}

on:
push:
Expand Down
48 changes: 2 additions & 46 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ def get_messages(
) -> List[Message]:
raise NotImplementedError

def list_models(self) -> List[LLMConfig]:
def list_model_configs(self) -> List[LLMConfig]:
raise NotImplementedError

def list_embedding_models(self) -> List[EmbeddingConfig]:
def list_embedding_configs(self) -> List[EmbeddingConfig]:
raise NotImplementedError


Expand Down Expand Up @@ -1234,32 +1234,6 @@ def detach_source(self, source_id: str, agent_id: str):
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return Source(**response.json())

# server configuration commands

def list_models(self):
"""
List available LLM models
Returns:
models (List[LLMConfig]): List of LLM models
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list models: {response.text}")
return [LLMConfig(**model) for model in response.json()]

def list_embedding_models(self):
"""
List available embedding models
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list embedding models: {response.text}")
return [EmbeddingConfig(**model) for model in response.json()]

# tools

def get_tool_id(self, tool_name: str):
Expand Down Expand Up @@ -2572,24 +2546,6 @@ def get_messages(
return_message_object=True,
)

def list_models(self) -> List[LLMConfig]:
"""
List available LLM models
Returns:
models (List[LLMConfig]): List of LLM models
"""
return self.server.list_models()

def list_embedding_models(self) -> List[EmbeddingConfig]:
"""
List available embedding models
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
return [self.server.server_embedding_config]

def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
"""
List available blocks
Expand Down
1 change: 0 additions & 1 deletion letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def create(
stream_interface.stream_start()
try:
# groq uses the openai chat completions API, so this component should be reusable
assert model_settings.groq_api_key is not None, "Groq key is missing"
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.groq_api_key,
Expand Down
2 changes: 1 addition & 1 deletion letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def list_llm_models(self) -> List[LLMConfig]:
continue
configs.append(
LLMConfig(
model=model["id"], model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["context_window"]
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"]
)
)
return configs
Expand Down
3 changes: 3 additions & 0 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
AnthropicProvider,
AzureProvider,
GoogleAIProvider,
GroqProvider,
LettaProvider,
OllamaProvider,
OpenAIProvider,
Expand Down Expand Up @@ -298,6 +299,8 @@ def __init__(
api_version=model_settings.azure_api_version,
)
)
if model_settings.groq_api_key:
self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key))
if model_settings.vllm_api_base:
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(
Expand Down
39 changes: 17 additions & 22 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import time
import uuid
from typing import Union
from typing import List, Union

import pytest
from dotenv import load_dotenv
Expand All @@ -18,6 +18,7 @@
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client

# from tests.utils import create_config
Expand Down Expand Up @@ -262,21 +263,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
assert human.value == "Human text", "Creating human failed"


def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()

models_response = client.list_models()
print("MODELS", models_response)

embeddings_response = client.list_embedding_models()
print("EMBEDDINGS", embeddings_response)

# TODO: add back
# config_response = client.get_config()
# TODO: ensure config is the same as the one in the server
# print("CONFIG", config_response)


def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
tools = client.list_tools()
visited_ids = {t.id: False for t in tools}
Expand Down Expand Up @@ -503,11 +489,20 @@ def test_organization(client: RESTClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")


def test_model_configs(client: Union[LocalClient, RESTClient]):
# _reset_config()
def test_list_llm_models(client: RESTClient):
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""

model_configs = client.list_models()
print("MODEL CONFIGS", model_configs)
def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool:
return any(model.model_endpoint_type == target_type for model in models)

embedding_configs = client.list_embedding_models()
print("EMBEDDING CONFIGS", embedding_configs)
models = client.list_llm_configs()
if model_settings.groq_api_key:
assert has_model_endpoint_type(models, "groq")
if model_settings.azure_api_key:
assert has_model_endpoint_type(models, "azure")
if model_settings.openai_api_key:
assert has_model_endpoint_type(models, "openai")
if model_settings.gemini_api_key:
assert has_model_endpoint_type(models, "google_ai")
if model_settings.anthropic_api_key:
assert has_model_endpoint_type(models, "anthropic")
11 changes: 5 additions & 6 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AnthropicProvider,
AzureProvider,
GoogleAIProvider,
GroqProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
Expand All @@ -27,12 +28,10 @@ def test_anthropic():
print(models)


# def test_groq():
# provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
# models = provider.list_llm_models()
# print(models)
#
#
def test_groq():
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
models = provider.list_llm_models()
print(models)


def test_azure():
Expand Down

0 comments on commit c372234

Please sign in to comment.