Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch errors with OllamaProvider #1875

Merged
merged 2 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions letta/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def get_chat_completion(
elif wrapper is None:
# Warn the user that we're using the fallback
if not has_shown_warning:
print(
f"{CLI_WARNING_PREFIX}no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model-wrapper)"
)
print(f"{CLI_WARNING_PREFIX}no prompt formatter specified for local LLM, using the default formatter")
has_shown_warning = True

llm_wrapper = DEFAULT_WRAPPER()
Expand Down
23 changes: 22 additions & 1 deletion letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,17 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:


class OllamaProvider(OpenAIProvider):
"""Ollama provider that uses the native /api/generate endpoint

See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
"""

name: str = "ollama"
base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field(
..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
)

def list_llm_models(self) -> List[LLMConfig]:
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
Expand All @@ -156,11 +164,15 @@ def list_llm_models(self) -> List[LLMConfig]:
configs = []
for model in response_json["models"]:
context_window = self.get_model_context_window(model["name"])
if context_window is None:
print(f"Ollama model {model['name']} has no context window")
continue
configs.append(
LLMConfig(
model=model["name"],
model_endpoint_type="ollama",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
)
)
Expand Down Expand Up @@ -192,6 +204,10 @@ def get_model_context_window(self, model_name: str) -> Optional[int]:
# ]
# max_position_embeddings
# parse model cards: nous, dolphon, llama
if "model_info" not in response_json:
if "error" in response_json:
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
return None
for key, value in response_json["model_info"].items():
if "context_length" in key:
return value
Expand All @@ -202,6 +218,10 @@ def get_model_embedding_dim(self, model_name: str):

response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()
if "model_info" not in response_json:
if "error" in response_json:
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
return None
for key, value in response_json["model_info"].items():
if "embedding_length" in key:
return value
Expand All @@ -220,6 +240,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
for model in response_json["models"]:
embedding_dim = self.get_model_embedding_dim(model["name"])
if not embedding_dim:
print(f"Ollama model {model['name']} has no embedding dimension")
continue
configs.append(
EmbeddingConfig(
Expand Down Expand Up @@ -420,7 +441,7 @@ class VLLMCompletionsProvider(Provider):
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")

def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM
Expand Down
35 changes: 29 additions & 6 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class SyncServer(Server):
def __init__(
self,
chaining: bool = True,
max_chaining_steps: bool = None,
max_chaining_steps: Optional[bool] = None,
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
# default_interface: AgentInterface = CLIInterface(),
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
Expand Down Expand Up @@ -241,13 +241,32 @@ def __init__(
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
if model_settings.openai_api_key:
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base))
self._enabled_providers.append(
OpenAIProvider(
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
)
if model_settings.anthropic_api_key:
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
self._enabled_providers.append(
AnthropicProvider(
api_key=model_settings.anthropic_api_key,
)
)
if model_settings.ollama_base_url:
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url, api_key=None))
self._enabled_providers.append(
OllamaProvider(
base_url=model_settings.ollama_base_url,
api_key=None,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
)
if model_settings.gemini_api_key:
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
self._enabled_providers.append(
GoogleAIProvider(
api_key=model_settings.gemini_api_key,
)
)
if model_settings.azure_api_key and model_settings.azure_base_url:
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append(
Expand All @@ -268,7 +287,11 @@ def __init__(
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
# see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append(VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base))
self._enabled_providers.append(
VLLMChatCompletionsProvider(
base_url=model_settings.vllm_api_base,
)
)

def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
Expand Down
20 changes: 13 additions & 7 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
OllamaProvider,
OpenAIProvider,
)
from letta.settings import model_settings


def test_openai():

provider = OpenAIProvider(api_key=os.getenv("OPENAI_API_KEY"))
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
models = provider.list_llm_models()
print(models)


def test_anthropic():
if os.getenv("ANTHROPIC_API_KEY") is None:
return
provider = AnthropicProvider(api_key=os.getenv("ANTHROPIC_API_KEY"))
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(api_key=api_key)
models = provider.list_llm_models()
print(models)

Expand All @@ -38,7 +40,9 @@ def test_azure():


def test_ollama():
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
models = provider.list_llm_models()
print(models)

Expand All @@ -47,7 +51,9 @@ def test_ollama():


def test_googleai():
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
api_key = os.getenv("GEMINI_API_KEY")
assert api_key is not None
provider = GoogleAIProvider(api_key=api_key)
models = provider.list_llm_models()
print(models)

Expand Down
Loading