Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add listing llm models and embedding models for Azure endpoint #1846

Merged
merged 6 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion configs/llm_model_configs/azure-gpt-4o-mini.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
"context_window": 128000,
"model": "gpt-4o-mini",
"model_endpoint_type": "azure",
"api_version": "2023-03-15-preview",
"model_wrapper": null
}
90 changes: 38 additions & 52 deletions letta/llm_api/azure_openai.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,62 @@
from typing import Union

import requests

from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completions import ChatCompletionRequest
from letta.schemas.openai.embedding_response import EmbeddingResponse
from letta.settings import ModelSettings
from letta.utils import smart_urljoin

MODEL_TO_AZURE_ENGINE = {
"gpt-4-1106-preview": "gpt-4",
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-3.5": "gpt-35-turbo",
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
"gpt-4o-mini": "gpt-4o-mini",
}

def get_azure_chat_completions_endpoint(base_url: str, model: str, api_version: str):
return f"{base_url}/openai/deployments/{model}/chat/completions?api-version={api_version}"


def get_azure_endpoint(llm_config: LLMConfig, model_settings: ModelSettings):
assert llm_config.api_version, "Missing model version! This field must be provided in the LLM config for Azure."
assert llm_config.model in MODEL_TO_AZURE_ENGINE, f"{llm_config.model} not in supported models: {list(MODEL_TO_AZURE_ENGINE.keys())}"
def get_azure_embeddings_endpoint(base_url: str, model: str, api_version: str):
return f"{base_url}/openai/deployments/{model}/embeddings?api-version={api_version}"

model = MODEL_TO_AZURE_ENGINE[llm_config.model]
return f"{model_settings.azure_base_url}/openai/deployments/{model}/chat/completions?api-version={llm_config.api_version}"

def get_azure_model_list_endpoint(base_url: str, api_version: str):
return f"{base_url}/openai/models?api-version={api_version}"

def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:

def azure_openai_get_model_list(base_url: str, api_key: str, api_version: str) -> list:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
from letta.utils import printd

# https://xxx.openai.azure.com/openai/models?api-version=xxx
url = smart_urljoin(url, "openai")
url = smart_urljoin(url, f"models?api-version={api_version}")

headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["api-key"] = f"{api_key}"

printd(f"Sending request to {url}")
url = get_azure_model_list_endpoint(base_url, api_version)
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(f"Failed to retrieve model list: {e}")

return response.json().get("data", [])


def azure_openai_get_chat_completion_model_list(base_url: str, api_key: str, api_version: str) -> list:
model_list = azure_openai_get_model_list(base_url, api_key, api_version)
# Extract models that support text generation
model_options = [m for m in model_list if m.get("capabilities").get("chat_completion") == True]
return model_options


def azure_openai_get_embeddings_model_list(base_url: str, api_key: str, api_version: str, require_embedding_in_name: bool = True) -> list:
def valid_embedding_model(m: dict):
valid_name = True
if require_embedding_in_name:
valid_name = "embedding" in m["id"]

return m.get("capabilities").get("embeddings") == True and valid_name

model_list = azure_openai_get_model_list(base_url, api_key, api_version)
# Extract models that support embeddings

model_options = [m for m in model_list if valid_embedding_model(m)]
return model_options


def azure_openai_chat_completions_request(
Expand All @@ -93,7 +79,7 @@ def azure_openai_chat_completions_request(
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")

model_endpoint = get_azure_endpoint(llm_config, model_settings)
model_endpoint = get_azure_chat_completions_endpoint(model_settings.azure_base_url, llm_config.model, model_settings.api_version)
printd(f"Sending request to {model_endpoint}")
try:
response = requests.post(model_endpoint, headers=headers, json=data)
Expand Down
10 changes: 10 additions & 0 deletions letta/llm_api/azure_openai_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
AZURE_MODEL_TO_CONTEXT_LENGTH = {
"babbage-002": 16384,
"davinci-002": 16384,
"gpt-35-turbo-0613": 4096,
"gpt-35-turbo-1106": 16385,
"gpt-35-turbo-0125": 16385,
"gpt-4-0613": 8192,
"gpt-4o-mini-2024-07-18": 128000,
"gpt-4o-2024-08-06": 128000,
}
3 changes: 3 additions & 0 deletions letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def create(
if model_settings.azure_base_url is None:
raise ValueError(f"Azure base url is missing. Did you set AZURE_BASE_URL in your env?")

if model_settings.azure_api_version is None:
raise ValueError(f"Azure API version is missing. Did you set AZURE_API_VERSION in your env?")

# Set the llm config model_endpoint from model_settings
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
llm_config.model_endpoint = model_settings.azure_base_url
Expand Down
61 changes: 60 additions & 1 deletion letta/providers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import List, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from letta.constants import LLM_MAX_TOKENS
from letta.llm_api.azure_openai import (
get_azure_chat_completions_endpoint,
get_azure_embeddings_endpoint,
)
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig

Expand Down Expand Up @@ -274,10 +279,64 @@ def get_model_context_window(self, model_name: str):

class AzureProvider(Provider):
name: str = "azure"
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
base_url: str = Field(
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
)
api_key: str = Field(..., description="API key for the Azure API.")
api_version: str = Field(latest_api_version, description="API version for the Azure API")

@model_validator(mode="before")
def set_default_api_version(cls, values):
"""
This ensures that api_version is always set to the default if None is passed in.
"""
if values.get("api_version") is None:
values["api_version"] = cls.model_fields["latest_api_version"].default
return values

def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.azure_openai import (
azure_openai_get_chat_completion_model_list,
)

model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
configs = []
for model_option in model_options:
model_name = model_option["id"]
context_window_size = self.get_model_context_window(model_name)
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
configs.append(
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
)
return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list

model_options = azure_openai_get_embeddings_model_list(
self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
)
configs = []
for model_option in model_options:
model_name = model_option["id"]
model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type="azure",
embedding_endpoint=model_endpoint,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
)
)
return configs

def get_model_context_window(self, model_name: str):
"""
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
"""
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)


class VLLMProvider(OpenAIProvider):
Expand Down
3 changes: 0 additions & 3 deletions letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ class LLMConfig(BaseModel):
"hugging-face",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
api_version: Optional[str] = Field(
None, description="The version for the model API. Used by the Azure provider backend, e.g. 2023-03-15-preview."
)
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")

Expand Down
8 changes: 7 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,13 @@ def __init__(
if model_settings.gemini_api_key:
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
if model_settings.azure_api_key and model_settings.azure_base_url:
self._enabled_providers.append(AzureProvider(api_key=model_settings.azure_api_key, base_url=model_settings.azure_base_url))
self._enabled_providers.append(
AzureProvider(
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
)
)

def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
Expand Down
1 change: 1 addition & 0 deletions letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ModelSettings(BaseSettings):
# azure
azure_api_key: Optional[str] = None
azure_base_url: Optional[str] = None
azure_api_version: Optional[str] = None

# google ai
gemini_api_key: Optional[str] = None
Expand Down
1 change: 0 additions & 1 deletion tests/configs/llm_model_configs/azure-gpt-4o-mini.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
"context_window": 128000,
"model": "gpt-4o-mini",
"model_endpoint_type": "azure",
"api_version": "2023-03-15-preview",
"model_wrapper": null
}
8 changes: 8 additions & 0 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def test_anthropic():
# print(models)
#
#


# TODO: Add this test
# https://linear.app/letta/issue/LET-159/add-tests-for-azure-openai-in-test-providerspy-and-test-endpointspy
def test_azure():
pass


def test_ollama():
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
models = provider.list_llm_models()
Expand Down
Loading