Skip to content

Commit

Permalink
feat: add support for openai embedding (#23)
Browse files Browse the repository at this point in the history
* feat: add support for openai embeddings

* fix: install openai
  • Loading branch information
gventuri committed Oct 15, 2024
1 parent 5bd5032 commit 59aae28
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 25 deletions.
6 changes: 5 additions & 1 deletion backend/.env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
SQLALCHEMY_DATABASE_URL=sqlite:///instance/app.db
PANDAETL_SERVER_URL="https://api.panda-etl.ai/" # optional
API_SERVER_URL="https://api.domer.ai" # optional
API_SERVER_URL="https://api.domer.ai" # optional
USE_OPENAI_EMBEDDINGS=false # optional
OPENAI_API_KEY=sk-xxxxxxxxxxxx # optional
CHROMA_BATCH_SIZE=5 # optional
MAX_FILE_SIZE=20971520 # optional
8 changes: 4 additions & 4 deletions backend/Makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
.PHONY: run migrate upgrade

run:
poetry run uvicorn app.main:app --reload --port 5328
poetry run uvicorn app.main:app --reload --port 5328 --log-level info

migrate:
poetry run alembic upgrade head

.PHONY: generate-migration
generate-migration:
.PHONY: generate-migration
generate-migration:
poetry run alembic revision --autogenerate -m "$$message"

test:
poetry run pytest
poetry run pytest
2 changes: 1 addition & 1 deletion backend/app/api/v1/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def upload_files(
)

# Check if the file size is greater than 20MB
if file.size > settings.MAX_FILE_SIZE:
if file.size > settings.max_file_size:
raise HTTPException(
status_code=400,
detail=f"The file '{file.filename}' exceeds the maximum allowed size of 20MB. Please upload a smaller file.",
Expand Down
8 changes: 7 additions & 1 deletion backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ class Settings(BaseSettings):
log_file_path: str = os.path.join(os.path.dirname(__file__), "..", "pandaetl.log")
max_retries: int = 3
max_relevant_docs: int = 10
MAX_FILE_SIZE: int = 20 * 1024 * 1024
max_file_size: int = 20 * 1024 * 1024
chroma_batch_size: int = 5

# OpenAI embeddings config
use_openai_embeddings: bool = False
openai_api_key: str = ""
openai_embedding_model: str = "text-embedding-ada-002"

class Config:
env_file = ".env"
Expand Down
5 changes: 2 additions & 3 deletions backend/app/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from sqlalchemy.sql import Select
from app.database.query import SoftDeleteQuery
from app.config import settings
from app.logger import Logger

import logging

logger = logging.getLogger(__name__)
logger = Logger(verbose=True)

engine = create_engine(
settings.sqlalchemy_database_url,
Expand Down
22 changes: 22 additions & 0 deletions backend/app/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ def log(self, message: str, level: int = logging.INFO):
}
)

def info(self, message: str):
self._logger.info(message)
self._logs.append(
{
"msg": message,
"level": logging.getLevelName(logging.INFO),
"time": self._calculate_time_diff(),
"source": self._invoked_from(),
}
)

def debug(self, message: str):
self._logger.debug(message)
self._logs.append(
{
"msg": message,
"level": logging.getLevelName(logging.DEBUG),
"time": self._calculate_time_diff(),
"source": self._invoked_from(),
}
)

def error(self, message: str):
self._logger.error(message)
self._logs.append(
Expand Down
60 changes: 46 additions & 14 deletions backend/app/vectorstore/chroma.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import uuid
from typing import Callable, Iterable, List, Optional, Tuple
from pydantic_settings import BaseSettings

import chromadb
from app.config import settings
from app.config import settings as default_settings
from app.vectorstore import VectorStore
from app.logger import Logger
from chromadb import config
from chromadb.utils import embedding_functions
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

DEFAULT_EMBEDDING_FUNCTION = embedding_functions.DefaultEmbeddingFunction()
logger = Logger(verbose=True)

DEFAULT_EMBEDDING_FUNCTION = embedding_functions.DefaultEmbeddingFunction()

class ChromaDB(VectorStore):
"""
Expand All @@ -23,36 +27,43 @@ def __init__(
client_settings: Optional[config.Settings] = None,
max_samples: int = 3,
similary_threshold: int = 1.5,
batch_size: Optional[int] = None,
settings: Optional[BaseSettings] = None,
) -> None:
self.settings = settings or default_settings
self._max_samples = max_samples
self._similarity_threshold = similary_threshold
self._batch_size = batch_size or self.settings.chroma_batch_size

# Initialize Chromadb Client
# initialize from client settings if exists
if client_settings:
client_settings.persist_directory = (
persist_path or client_settings.persist_directory
)
_client_settings = client_settings

# use persist path if exists
elif persist_path:
_client_settings = config.Settings(
is_persistent=True, anonymized_telemetry=False
)
_client_settings.persist_directory = persist_path
# else use root as default path
else:
_client_settings = config.Settings(
is_persistent=True, anonymized_telemetry=False
)
_client_settings.persist_directory = settings.chromadb_url
_client_settings.persist_directory = self.settings.chromadb_url

self._client_settings = _client_settings
self._client = chromadb.Client(_client_settings)
self._persist_directory = _client_settings.persist_directory

self._embedding_function = embedding_function or DEFAULT_EMBEDDING_FUNCTION
# Use the embedding function from config
if self.settings.use_openai_embeddings and self.settings.openai_api_key:
self._embedding_function = OpenAIEmbeddingFunction(
api_key=self.settings.openai_api_key,
model_name=self.settings.openai_embedding_model
)
else:
self._embedding_function = embedding_function or DEFAULT_EMBEDDING_FUNCTION

self._docs_collection = self._client.get_or_create_collection(
name=collection_name, embedding_function=self._embedding_function
Expand All @@ -63,33 +74,54 @@ def add_docs(
docs: Iterable[str],
ids: Optional[Iterable[str]] = None,
metadatas: Optional[List[dict]] = None,
batch_size=5,
batch_size: Optional[int] = None,
) -> List[str]:
"""
Add docs to the training set
Args:
docs: Iterable of strings to add to the vectorstore.
ids: Optional Iterable of ids associated with the texts.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters
batch_size: Optional batch size for adding documents. If not provided, uses the instance's batch size.
Returns:
List of ids from adding the texts into the vectorstore.
"""
if ids is None:
ids = [f"{str(uuid.uuid4())}-docs" for _ in docs]

if metadatas is None:
metadatas = [{}] * len(docs)

# Add previous_id and next_id to metadatas
for idx, metadata in enumerate(metadatas):
metadata["previous_sentence_id"] = ids[idx - 1] if idx > 0 else -1
metadata["next_sentence_id"] = ids[idx + 1] if idx < len(ids) - 1 else -1

for i in range(0, len(docs), batch_size):
filename = metadatas[0].get('filename', 'unknown')
logger.info(f"Adding {len(docs)} sentences to the vector store for file {filename}")

# If using OpenAI embeddings, add all documents at once
if self.settings.use_openai_embeddings and self.settings.openai_api_key:
logger.info("Using OpenAI embeddings")
self._docs_collection.add(
documents=docs[i : i + batch_size],
metadatas=metadatas[i : i + batch_size],
ids=ids[i : i + batch_size],
documents=list(docs),
metadatas=metadatas,
ids=ids,
)
else:
logger.info("Using default embedding function")
batch_size = batch_size or self._batch_size

for i in range(0, len(docs), batch_size):
logger.info(f"Processing batch {i} to {i + batch_size}")
self._docs_collection.add(
documents=docs[i : i + batch_size],
metadatas=metadatas[i : i + batch_size],
ids=ids[i : i + batch_size],
)

return list(ids)

def delete_docs(
self, ids: Optional[List[str]] = None, where: Optional[dict] = None
Expand Down
Loading

0 comments on commit 59aae28

Please sign in to comment.