Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Ollama embeddings #38

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions hybridagi/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from hybridagi.embeddings.ollama import OllamaEmbeddings
from .embeddings import Embeddings
from .fake import FakeEmbeddings
from .sentence_transformer import SentenceTransformerEmbeddings

__all__ = [
Embeddings,
FakeEmbeddings,
SentenceTransformerEmbeddings,
'Embeddings',
'FakeEmbeddings',
'SentenceTransformerEmbeddings',
'OllamaEmbeddings'
]
89 changes: 89 additions & 0 deletions hybridagi/embeddings/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
from typing import Union, List
from hybridagi.embeddings.embeddings import Embeddings

class OllamaEmbeddings(Embeddings):
def __init__(
self,
model_name: str = "mxbai-embed-large:latest",
dim: int = 1024, # Adjust default dimension based on your model
batch_size: int = 32,
):
"""Initialize the Ollama embeddings class.

Args:
model_name: Name of the Ollama model to use for embeddings
dim: Dimension of the embedding vectors
batch_size: Number of texts to process at once
"""
super().__init__(dim=dim)
self.model_name = model_name
self.batch_size = batch_size

try:
import ollama
except ImportError:
raise ImportError(
"You need to install ollama library to use Ollama embeddings. Obviously, you also need an Ollama server running."
"Please run `pip install ollama`"
)
self.ollama = ollama

def _batch_embed(self, texts: List[str]) -> np.ndarray:
"""Embed a batch of texts using Ollama.

Args:
texts: List of strings to embed

Returns:
numpy.ndarray: Array of embeddings
"""
all_embeddings = []

# Process in batches
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i + self.batch_size]
# Get embeddings for the entire batch at once
response = self.ollama.embed(
model=self.model_name,
input=batch
)
# Convert the embeddings list to numpy array
batch_embeddings = np.array(response['embeddings'], dtype=np.float32)
all_embeddings.append(batch_embeddings)

# Concatenate all batches
return np.concatenate(all_embeddings) if len(all_embeddings) > 1 else all_embeddings[0]

def embed_text(self, query_or_queries: Union[str, List[str]]) -> np.ndarray:
"""Embed text or list of texts using Ollama.

Args:
query_or_queries: Single string or list of strings to embed

Returns:
numpy.ndarray: Array of embeddings

Raises:
ValueError: If input is an empty string
"""
if isinstance(query_or_queries, str):
if query_or_queries == "":
raise ValueError("Input cannot be an empty string.")

# Single string case
response = self.ollama.embed(
model=self.model_name,
input=query_or_queries
)
return np.array(response['embeddings'][0], dtype=np.float32)
else:
# List of strings case
if not query_or_queries: # Empty list check
raise ValueError("Input cannot be an empty list.")

return self._batch_embed(query_or_queries)

def embed_image(self, image_or_images: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
"""Not implemented for Ollama embeddings."""
raise NotImplementedError("Ollama embeddings do not support image embeddings")
20 changes: 18 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ homepage = "https://github.com/SynaLinks/HybridAGI"
[tool.poetry.dependencies]
python = ">=3.10,<3.12"
sentence-transformers = ">=2.6.0"
ollama = ">=0.3.3"
falkordb = ">=1.0.7"
dspy-ai = "==2.4.10"
colorama = ">=0.4.6"
Expand Down
86 changes: 86 additions & 0 deletions tests/embeddings/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import numpy as np
from unittest.mock import MagicMock

# Assuming that hybridagi.embeddings.ollama.OllamaEmbeddings can be imported without needing the server
from hybridagi.embeddings.ollama import OllamaEmbeddings

@pytest.fixture
def ollama_embeddings():
# Create a mock instance of OllamaEmbeddings
mock = MagicMock(spec=OllamaEmbeddings)

mock.dim = 1024
mock.batch_size = 10 # Arbitrary

def mock_embed_text(query):
if isinstance(query, str):
if not query.strip():
raise ValueError("Input text is empty")
return np.random.rand(mock.dim).astype(np.float32)
elif isinstance(query, list):
if not query:
raise ValueError("Input list is empty")
for q in query:
if not isinstance(q, str) or not q.strip():
raise ValueError("One or more input texts are empty")
return [np.random.rand(mock.dim).astype(np.float32) for _ in query]
else:
raise ValueError("Invalid input type")

mock.embed_text.side_effect = mock_embed_text

mock.embed_image.side_effect = NotImplementedError("Embedding images is not implemented.")

return mock

def test_embed_text_single(ollama_embeddings):
query = "Hello"
embedding = ollama_embeddings.embed_text(query)
assert embedding.shape[0] == ollama_embeddings.dim
assert isinstance(embedding, np.ndarray)
assert embedding.dtype == np.float32

def test_embed_text_multiple(ollama_embeddings):
queries = ["Hello", "World"]
embeddings = ollama_embeddings.embed_text(queries)
assert len(embeddings) == 2
for emb in embeddings:
assert emb.shape[0] == ollama_embeddings.dim
assert isinstance(emb, np.ndarray)
assert emb.dtype == np.float32

def test_embed_image_not_implemented(ollama_embeddings):
with pytest.raises(NotImplementedError):
ollama_embeddings.embed_image(np.random.random((ollama_embeddings.dim,)))

def test_embed_text_empty_input(ollama_embeddings):
with pytest.raises(ValueError) as exc_info:
ollama_embeddings.embed_text("")
assert "Input text is empty" in str(exc_info.value)

def test_embed_text_empty_list(ollama_embeddings):
with pytest.raises(ValueError) as exc_info:
ollama_embeddings.embed_text([])
assert "Input list is empty" in str(exc_info.value)

def test_batch_processing(ollama_embeddings):
# Test with a number of queries that exceeds the batch size
batch_size = ollama_embeddings.batch_size
queries = [f"Query {i}" for i in range(batch_size + 5)]
embeddings = ollama_embeddings.embed_text(queries)

assert len(embeddings) == len(queries)
for emb in embeddings:
assert emb.shape[0] == ollama_embeddings.dim
assert isinstance(emb, np.ndarray)
assert emb.dtype == np.float32

def test_ollama_import_error(monkeypatch):
# Simulate ollama import error
import sys
monkeypatch.setitem(sys.modules, 'ollama', None)

with pytest.raises(ImportError) as exc_info:
OllamaEmbeddings()
assert "You need to install ollama library" in str(exc_info.value)