Skip to content

Commit

Permalink
fix: patch errors with OllamaProvider (#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Oct 13, 2024
1 parent a91c94c commit 1842fd1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 17 deletions.
4 changes: 1 addition & 3 deletions letta/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def get_chat_completion(
elif wrapper is None:
# Warn the user that we're using the fallback
if not has_shown_warning:
print(
f"{CLI_WARNING_PREFIX}no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model-wrapper)"
)
print(f"{CLI_WARNING_PREFIX}no prompt formatter specified for local LLM, using the default formatter")
has_shown_warning = True

llm_wrapper = DEFAULT_WRAPPER()
Expand Down
23 changes: 22 additions & 1 deletion letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,17 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:


class OllamaProvider(OpenAIProvider):
"""Ollama provider that uses the native /api/generate endpoint
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
"""

name: str = "ollama"
base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field(
..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
)

def list_llm_models(self) -> List[LLMConfig]:
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
Expand All @@ -156,11 +164,15 @@ def list_llm_models(self) -> List[LLMConfig]:
configs = []
for model in response_json["models"]:
context_window = self.get_model_context_window(model["name"])
if context_window is None:
print(f"Ollama model {model['name']} has no context window")
continue
configs.append(
LLMConfig(
model=model["name"],
model_endpoint_type="ollama",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
)
)
Expand Down Expand Up @@ -192,6 +204,10 @@ def get_model_context_window(self, model_name: str) -> Optional[int]:
# ]
# max_position_embeddings
# parse model cards: nous, dolphon, llama
if "model_info" not in response_json:
if "error" in response_json:
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
return None
for key, value in response_json["model_info"].items():
if "context_length" in key:
return value
Expand All @@ -202,6 +218,10 @@ def get_model_embedding_dim(self, model_name: str):

response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()
if "model_info" not in response_json:
if "error" in response_json:
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
return None
for key, value in response_json["model_info"].items():
if "embedding_length" in key:
return value
Expand All @@ -220,6 +240,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
for model in response_json["models"]:
embedding_dim = self.get_model_embedding_dim(model["name"])
if not embedding_dim:
print(f"Ollama model {model['name']} has no embedding dimension")
continue
configs.append(
EmbeddingConfig(
Expand Down Expand Up @@ -420,7 +441,7 @@ class VLLMCompletionsProvider(Provider):
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")

def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM
Expand Down
35 changes: 29 additions & 6 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class SyncServer(Server):
def __init__(
self,
chaining: bool = True,
max_chaining_steps: bool = None,
max_chaining_steps: Optional[bool] = None,
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
# default_interface: AgentInterface = CLIInterface(),
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
Expand Down Expand Up @@ -241,13 +241,32 @@ def __init__(
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
if model_settings.openai_api_key:
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base))
self._enabled_providers.append(
OpenAIProvider(
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
)
if model_settings.anthropic_api_key:
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
self._enabled_providers.append(
AnthropicProvider(
api_key=model_settings.anthropic_api_key,
)
)
if model_settings.ollama_base_url:
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url, api_key=None))
self._enabled_providers.append(
OllamaProvider(
base_url=model_settings.ollama_base_url,
api_key=None,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
)
if model_settings.gemini_api_key:
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
self._enabled_providers.append(
GoogleAIProvider(
api_key=model_settings.gemini_api_key,
)
)
if model_settings.azure_api_key and model_settings.azure_base_url:
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append(
Expand All @@ -268,7 +287,11 @@ def __init__(
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
# see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append(VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base))
self._enabled_providers.append(
VLLMChatCompletionsProvider(
base_url=model_settings.vllm_api_base,
)
)

def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
Expand Down
20 changes: 13 additions & 7 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
OllamaProvider,
OpenAIProvider,
)
from letta.settings import model_settings


def test_openai():

provider = OpenAIProvider(api_key=os.getenv("OPENAI_API_KEY"))
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
models = provider.list_llm_models()
print(models)


def test_anthropic():
if os.getenv("ANTHROPIC_API_KEY") is None:
return
provider = AnthropicProvider(api_key=os.getenv("ANTHROPIC_API_KEY"))
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(api_key=api_key)
models = provider.list_llm_models()
print(models)

Expand All @@ -38,7 +40,9 @@ def test_azure():


def test_ollama():
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
models = provider.list_llm_models()
print(models)

Expand All @@ -47,7 +51,9 @@ def test_ollama():


def test_googleai():
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
api_key = os.getenv("GEMINI_API_KEY")
assert api_key is not None
provider = GoogleAIProvider(api_key=api_key)
models = provider.list_llm_models()
print(models)

Expand Down

0 comments on commit 1842fd1

Please sign in to comment.