From 59aae28511c3e0ece4baa63822f32757f3723f6e Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Tue, 15 Oct 2024 15:37:00 +0200 Subject: [PATCH] feat: add support for openai embedding (#23) * feat: add support for openai embeddings * fix: install openai --- backend/.env.example | 6 +- backend/Makefile | 8 +- backend/app/api/v1/projects.py | 2 +- backend/app/config.py | 8 +- backend/app/database/__init__.py | 5 +- backend/app/logger.py | 22 ++++ backend/app/vectorstore/chroma.py | 60 +++++++--- backend/poetry.lock | 119 +++++++++++++++++++- backend/pyproject.toml | 1 + backend/tests/vectorstore/test_chroma_db.py | 64 +++++++++++ 10 files changed, 270 insertions(+), 25 deletions(-) create mode 100644 backend/tests/vectorstore/test_chroma_db.py diff --git a/backend/.env.example b/backend/.env.example index 9f599d3..aaefd3c 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -1,3 +1,7 @@ SQLALCHEMY_DATABASE_URL=sqlite:///instance/app.db PANDAETL_SERVER_URL="https://api.panda-etl.ai/" # optional -API_SERVER_URL="https://api.domer.ai" # optional \ No newline at end of file +API_SERVER_URL="https://api.domer.ai" # optional +USE_OPENAI_EMBEDDINGS=false # optional +OPENAI_API_KEY=sk-xxxxxxxxxxxx # optional +CHROMA_BATCH_SIZE=5 # optional +MAX_FILE_SIZE=20971520 # optional diff --git a/backend/Makefile b/backend/Makefile index c12d1a2..ed82a1e 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -1,14 +1,14 @@ .PHONY: run migrate upgrade run: - poetry run uvicorn app.main:app --reload --port 5328 + poetry run uvicorn app.main:app --reload --port 5328 --log-level info migrate: poetry run alembic upgrade head -.PHONY: generate-migration -generate-migration: +.PHONY: generate-migration +generate-migration: poetry run alembic revision --autogenerate -m "$$message" test: - poetry run pytest \ No newline at end of file + poetry run pytest diff --git a/backend/app/api/v1/projects.py b/backend/app/api/v1/projects.py index 2a95586..96791c0 100644 --- a/backend/app/api/v1/projects.py +++ b/backend/app/api/v1/projects.py @@ -167,7 +167,7 @@ async def upload_files( ) # Check if the file size is greater than 20MB - if file.size > settings.MAX_FILE_SIZE: + if file.size > settings.max_file_size: raise HTTPException( status_code=400, detail=f"The file '{file.filename}' exceeds the maximum allowed size of 20MB. Please upload a smaller file.", diff --git a/backend/app/config.py b/backend/app/config.py index b77f6b6..ed49d66 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -18,7 +18,13 @@ class Settings(BaseSettings): log_file_path: str = os.path.join(os.path.dirname(__file__), "..", "pandaetl.log") max_retries: int = 3 max_relevant_docs: int = 10 - MAX_FILE_SIZE: int = 20 * 1024 * 1024 + max_file_size: int = 20 * 1024 * 1024 + chroma_batch_size: int = 5 + + # OpenAI embeddings config + use_openai_embeddings: bool = False + openai_api_key: str = "" + openai_embedding_model: str = "text-embedding-ada-002" class Config: env_file = ".env" diff --git a/backend/app/database/__init__.py b/backend/app/database/__init__.py index 6d21a01..8d12cb7 100644 --- a/backend/app/database/__init__.py +++ b/backend/app/database/__init__.py @@ -4,10 +4,9 @@ from sqlalchemy.sql import Select from app.database.query import SoftDeleteQuery from app.config import settings +from app.logger import Logger -import logging - -logger = logging.getLogger(__name__) +logger = Logger(verbose=True) engine = create_engine( settings.sqlalchemy_database_url, diff --git a/backend/app/logger.py b/backend/app/logger.py index 6484813..aa108f5 100644 --- a/backend/app/logger.py +++ b/backend/app/logger.py @@ -69,6 +69,28 @@ def log(self, message: str, level: int = logging.INFO): } ) + def info(self, message: str): + self._logger.info(message) + self._logs.append( + { + "msg": message, + "level": logging.getLevelName(logging.INFO), + "time": self._calculate_time_diff(), + "source": self._invoked_from(), + } + ) + + def debug(self, message: str): + self._logger.debug(message) + self._logs.append( + { + "msg": message, + "level": logging.getLevelName(logging.DEBUG), + "time": self._calculate_time_diff(), + "source": self._invoked_from(), + } + ) + def error(self, message: str): self._logger.error(message) self._logs.append( diff --git a/backend/app/vectorstore/chroma.py b/backend/app/vectorstore/chroma.py index d1742db..7a2768f 100644 --- a/backend/app/vectorstore/chroma.py +++ b/backend/app/vectorstore/chroma.py @@ -1,14 +1,18 @@ import uuid from typing import Callable, Iterable, List, Optional, Tuple +from pydantic_settings import BaseSettings import chromadb -from app.config import settings +from app.config import settings as default_settings from app.vectorstore import VectorStore +from app.logger import Logger from chromadb import config from chromadb.utils import embedding_functions +from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction -DEFAULT_EMBEDDING_FUNCTION = embedding_functions.DefaultEmbeddingFunction() +logger = Logger(verbose=True) +DEFAULT_EMBEDDING_FUNCTION = embedding_functions.DefaultEmbeddingFunction() class ChromaDB(VectorStore): """ @@ -23,36 +27,43 @@ def __init__( client_settings: Optional[config.Settings] = None, max_samples: int = 3, similary_threshold: int = 1.5, + batch_size: Optional[int] = None, + settings: Optional[BaseSettings] = None, ) -> None: + self.settings = settings or default_settings self._max_samples = max_samples self._similarity_threshold = similary_threshold + self._batch_size = batch_size or self.settings.chroma_batch_size # Initialize Chromadb Client - # initialize from client settings if exists if client_settings: client_settings.persist_directory = ( persist_path or client_settings.persist_directory ) _client_settings = client_settings - - # use persist path if exists elif persist_path: _client_settings = config.Settings( is_persistent=True, anonymized_telemetry=False ) _client_settings.persist_directory = persist_path - # else use root as default path else: _client_settings = config.Settings( is_persistent=True, anonymized_telemetry=False ) - _client_settings.persist_directory = settings.chromadb_url + _client_settings.persist_directory = self.settings.chromadb_url self._client_settings = _client_settings self._client = chromadb.Client(_client_settings) self._persist_directory = _client_settings.persist_directory - self._embedding_function = embedding_function or DEFAULT_EMBEDDING_FUNCTION + # Use the embedding function from config + if self.settings.use_openai_embeddings and self.settings.openai_api_key: + self._embedding_function = OpenAIEmbeddingFunction( + api_key=self.settings.openai_api_key, + model_name=self.settings.openai_embedding_model + ) + else: + self._embedding_function = embedding_function or DEFAULT_EMBEDDING_FUNCTION self._docs_collection = self._client.get_or_create_collection( name=collection_name, embedding_function=self._embedding_function @@ -63,7 +74,7 @@ def add_docs( docs: Iterable[str], ids: Optional[Iterable[str]] = None, metadatas: Optional[List[dict]] = None, - batch_size=5, + batch_size: Optional[int] = None, ) -> List[str]: """ Add docs to the training set @@ -71,7 +82,7 @@ def add_docs( docs: Iterable of strings to add to the vectorstore. ids: Optional Iterable of ids associated with the texts. metadatas: Optional list of metadatas associated with the texts. - kwargs: vectorstore specific parameters + batch_size: Optional batch size for adding documents. If not provided, uses the instance's batch size. Returns: List of ids from adding the texts into the vectorstore. @@ -79,17 +90,38 @@ def add_docs( if ids is None: ids = [f"{str(uuid.uuid4())}-docs" for _ in docs] + if metadatas is None: + metadatas = [{}] * len(docs) + # Add previous_id and next_id to metadatas for idx, metadata in enumerate(metadatas): metadata["previous_sentence_id"] = ids[idx - 1] if idx > 0 else -1 metadata["next_sentence_id"] = ids[idx + 1] if idx < len(ids) - 1 else -1 - for i in range(0, len(docs), batch_size): + filename = metadatas[0].get('filename', 'unknown') + logger.info(f"Adding {len(docs)} sentences to the vector store for file {filename}") + + # If using OpenAI embeddings, add all documents at once + if self.settings.use_openai_embeddings and self.settings.openai_api_key: + logger.info("Using OpenAI embeddings") self._docs_collection.add( - documents=docs[i : i + batch_size], - metadatas=metadatas[i : i + batch_size], - ids=ids[i : i + batch_size], + documents=list(docs), + metadatas=metadatas, + ids=ids, ) + else: + logger.info("Using default embedding function") + batch_size = batch_size or self._batch_size + + for i in range(0, len(docs), batch_size): + logger.info(f"Processing batch {i} to {i + batch_size}") + self._docs_collection.add( + documents=docs[i : i + batch_size], + metadatas=metadatas[i : i + batch_size], + ids=ids[i : i + batch_size], + ) + + return list(ids) def delete_docs( self, ids: Optional[List[str]] = None, where: Optional[dict] = None diff --git a/backend/poetry.lock b/backend/poetry.lock index 5ae7484..bca8e84 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -495,6 +495,17 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "dnspython" version = "2.6.1" @@ -1053,6 +1064,88 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jiter" +version = "0.6.1" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.6.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d08510593cb57296851080018006dfc394070178d238b767b1879dc1013b106c"}, + {file = "jiter-0.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adef59d5e2394ebbad13b7ed5e0306cceb1df92e2de688824232a91588e77aa7"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3e02f7a27f2bcc15b7d455c9df05df8ffffcc596a2a541eeda9a3110326e7a3"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed69a7971d67b08f152c17c638f0e8c2aa207e9dd3a5fcd3cba294d39b5a8d2d"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2019d966e98f7c6df24b3b8363998575f47d26471bfb14aade37630fae836a1"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:36c0b51a285b68311e207a76c385650322734c8717d16c2eb8af75c9d69506e7"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:220e0963b4fb507c525c8f58cde3da6b1be0bfddb7ffd6798fb8f2531226cdb1"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aa25c7a9bf7875a141182b9c95aed487add635da01942ef7ca726e42a0c09058"}, + {file = "jiter-0.6.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e90552109ca8ccd07f47ca99c8a1509ced93920d271bb81780a973279974c5ab"}, + {file = "jiter-0.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:67723a011964971864e0b484b0ecfee6a14de1533cff7ffd71189e92103b38a8"}, + {file = "jiter-0.6.1-cp310-none-win32.whl", hash = "sha256:33af2b7d2bf310fdfec2da0177eab2fedab8679d1538d5b86a633ebfbbac4edd"}, + {file = "jiter-0.6.1-cp310-none-win_amd64.whl", hash = "sha256:7cea41c4c673353799906d940eee8f2d8fd1d9561d734aa921ae0f75cb9732f4"}, + {file = "jiter-0.6.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b03c24e7da7e75b170c7b2b172d9c5e463aa4b5c95696a368d52c295b3f6847f"}, + {file = "jiter-0.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:47fee1be677b25d0ef79d687e238dc6ac91a8e553e1a68d0839f38c69e0ee491"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0d2f6e01a8a0fb0eab6d0e469058dab2be46ff3139ed2d1543475b5a1d8e7"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0b809e39e342c346df454b29bfcc7bca3d957f5d7b60e33dae42b0e5ec13e027"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9ac7c2f092f231f5620bef23ce2e530bd218fc046098747cc390b21b8738a7a"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e51a2d80d5fe0ffb10ed2c82b6004458be4a3f2b9c7d09ed85baa2fbf033f54b"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3343d4706a2b7140e8bd49b6c8b0a82abf9194b3f0f5925a78fc69359f8fc33c"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82521000d18c71e41c96960cb36e915a357bc83d63a8bed63154b89d95d05ad1"}, + {file = "jiter-0.6.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3c843e7c1633470708a3987e8ce617ee2979ee18542d6eb25ae92861af3f1d62"}, + {file = "jiter-0.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a2e861658c3fe849efc39b06ebb98d042e4a4c51a8d7d1c3ddc3b1ea091d0784"}, + {file = "jiter-0.6.1-cp311-none-win32.whl", hash = "sha256:7d72fc86474862c9c6d1f87b921b70c362f2b7e8b2e3c798bb7d58e419a6bc0f"}, + {file = "jiter-0.6.1-cp311-none-win_amd64.whl", hash = "sha256:3e36a320634f33a07794bb15b8da995dccb94f944d298c8cfe2bd99b1b8a574a"}, + {file = "jiter-0.6.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1fad93654d5a7dcce0809aff66e883c98e2618b86656aeb2129db2cd6f26f867"}, + {file = "jiter-0.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4e6e340e8cd92edab7f6a3a904dbbc8137e7f4b347c49a27da9814015cc0420c"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:691352e5653af84ed71763c3c427cff05e4d658c508172e01e9c956dfe004aba"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:defee3949313c1f5b55e18be45089970cdb936eb2a0063f5020c4185db1b63c9"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26d2bdd5da097e624081c6b5d416d3ee73e5b13f1703bcdadbb1881f0caa1933"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18aa9d1626b61c0734b973ed7088f8a3d690d0b7f5384a5270cd04f4d9f26c86"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a3567c8228afa5ddcce950631c6b17397ed178003dc9ee7e567c4c4dcae9fa0"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e5c0507131c922defe3f04c527d6838932fcdfd69facebafd7d3574fa3395314"}, + {file = "jiter-0.6.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:540fcb224d7dc1bcf82f90f2ffb652df96f2851c031adca3c8741cb91877143b"}, + {file = "jiter-0.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e7b75436d4fa2032b2530ad989e4cb0ca74c655975e3ff49f91a1a3d7f4e1df2"}, + {file = "jiter-0.6.1-cp312-none-win32.whl", hash = "sha256:883d2ced7c21bf06874fdeecab15014c1c6d82216765ca6deef08e335fa719e0"}, + {file = "jiter-0.6.1-cp312-none-win_amd64.whl", hash = "sha256:91e63273563401aadc6c52cca64a7921c50b29372441adc104127b910e98a5b6"}, + {file = "jiter-0.6.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:852508a54fe3228432e56019da8b69208ea622a3069458252f725d634e955b31"}, + {file = "jiter-0.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f491cc69ff44e5a1e8bc6bf2b94c1f98d179e1aaf4a554493c171a5b2316b701"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc56c8f0b2a28ad4d8047f3ae62d25d0e9ae01b99940ec0283263a04724de1f3"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:51b58f7a0d9e084a43b28b23da2b09fc5e8df6aa2b6a27de43f991293cab85fd"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5f79ce15099154c90ef900d69c6b4c686b64dfe23b0114e0971f2fecd306ec6c"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:03a025b52009f47e53ea619175d17e4ded7c035c6fbd44935cb3ada11e1fd592"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c74a8d93718137c021d9295248a87c2f9fdc0dcafead12d2930bc459ad40f885"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40b03b75f903975f68199fc4ec73d546150919cb7e534f3b51e727c4d6ccca5a"}, + {file = "jiter-0.6.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:825651a3f04cf92a661d22cad61fc913400e33aa89b3e3ad9a6aa9dc8a1f5a71"}, + {file = "jiter-0.6.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:928bf25eb69ddb292ab8177fe69d3fbf76c7feab5fce1c09265a7dccf25d3991"}, + {file = "jiter-0.6.1-cp313-none-win32.whl", hash = "sha256:352cd24121e80d3d053fab1cc9806258cad27c53cad99b7a3cac57cf934b12e4"}, + {file = "jiter-0.6.1-cp313-none-win_amd64.whl", hash = "sha256:be7503dd6f4bf02c2a9bacb5cc9335bc59132e7eee9d3e931b13d76fd80d7fda"}, + {file = "jiter-0.6.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:31d8e00e1fb4c277df8ab6f31a671f509ebc791a80e5c61fdc6bc8696aaa297c"}, + {file = "jiter-0.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77c296d65003cd7ee5d7b0965f6acbe6cffaf9d1fa420ea751f60ef24e85fed5"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aeeb0c0325ef96c12a48ea7e23e2e86fe4838e6e0a995f464cf4c79fa791ceeb"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a31c6fcbe7d6c25d6f1cc6bb1cba576251d32795d09c09961174fe461a1fb5bd"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59e2b37f3b9401fc9e619f4d4badcab2e8643a721838bcf695c2318a0475ae42"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bae5ae4853cb9644144e9d0755854ce5108d470d31541d83f70ca7ecdc2d1637"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9df588e9c830b72d8db1dd7d0175af6706b0904f682ea9b1ca8b46028e54d6e9"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:15f8395e835cf561c85c1adee72d899abf2733d9df72e9798e6d667c9b5c1f30"}, + {file = "jiter-0.6.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a99d4e0b5fc3b05ea732d67eb2092fe894e95a90e6e413f2ea91387e228a307"}, + {file = "jiter-0.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a311df1fa6be0ccd64c12abcd85458383d96e542531bafbfc0a16ff6feda588f"}, + {file = "jiter-0.6.1-cp38-none-win32.whl", hash = "sha256:81116a6c272a11347b199f0e16b6bd63f4c9d9b52bc108991397dd80d3c78aba"}, + {file = "jiter-0.6.1-cp38-none-win_amd64.whl", hash = "sha256:13f9084e3e871a7c0b6e710db54444088b1dd9fbefa54d449b630d5e73bb95d0"}, + {file = "jiter-0.6.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:f1c53615fcfec3b11527c08d19cff6bc870da567ce4e57676c059a3102d3a082"}, + {file = "jiter-0.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f791b6a4da23238c17a81f44f5b55d08a420c5692c1fda84e301a4b036744eb1"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c97e90fec2da1d5f68ef121444c2c4fa72eabf3240829ad95cf6bbeca42a301"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3cbc1a66b4e41511209e97a2866898733c0110b7245791ac604117b7fb3fedb7"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4e85f9e12cd8418ab10e1fcf0e335ae5bb3da26c4d13a0fd9e6a17a674783b6"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08be33db6dcc374c9cc19d3633af5e47961a7b10d4c61710bd39e48d52a35824"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:677be9550004f5e010d673d3b2a2b815a8ea07a71484a57d3f85dde7f14cf132"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e8bd065be46c2eecc328e419d6557bbc37844c88bb07b7a8d2d6c91c7c4dedc9"}, + {file = "jiter-0.6.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bd95375ce3609ec079a97c5d165afdd25693302c071ca60c7ae1cf826eb32022"}, + {file = "jiter-0.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db459ed22d0208940d87f614e1f0ea5a946d29a3cfef71f7e1aab59b6c6b2afb"}, + {file = "jiter-0.6.1-cp39-none-win32.whl", hash = "sha256:d71c962f0971347bd552940ab96aa42ceefcd51b88c4ced8a27398182efa8d80"}, + {file = "jiter-0.6.1-cp39-none-win_amd64.whl", hash = "sha256:d465db62d2d10b489b7e7a33027c4ae3a64374425d757e963f86df5b5f2e7fc5"}, + {file = "jiter-0.6.1.tar.gz", hash = "sha256:e19cd21221fc139fb032e4112986656cb2739e9fe6d84c13956ab30ccc7d4449"}, +] + [[package]] name = "kubernetes" version = "30.1.0" @@ -1425,6 +1518,30 @@ packaging = "*" protobuf = "*" sympy = "*" +[[package]] +name = "openai" +version = "1.51.2" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.51.2-py3-none-any.whl", hash = "sha256:5c5954711cba931423e471c37ff22ae0fd3892be9b083eee36459865fbbb83fa"}, + {file = "openai-1.51.2.tar.gz", hash = "sha256:c6a51fac62a1ca9df85a522e462918f6bb6bc51a8897032217e453a0730123a6"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.11,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "opentelemetry-api" version = "1.26.0" @@ -3154,4 +3271,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "9e6c182950fed6da9590f4809fb8dd02a7f72e22fd995af1964ead9494fc57d3" +content-hash = "f93ac443a905910a7a021a4163b8f067845c178bc03eaa1f68de01014f5e25c4" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2ba6c07..55c7781 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -16,6 +16,7 @@ starlette = "^0.37.2" dateparser = "^1.2.0" requests = "^2.32.3" chromadb = "^0.5.5" +openai = "^1.51.2" [tool.poetry.group.dev.dependencies] pytest = "^8.3.2" diff --git a/backend/tests/vectorstore/test_chroma_db.py b/backend/tests/vectorstore/test_chroma_db.py new file mode 100644 index 0000000..4ab7960 --- /dev/null +++ b/backend/tests/vectorstore/test_chroma_db.py @@ -0,0 +1,64 @@ +import pytest +from unittest.mock import MagicMock +from app.config import Settings +from app.vectorstore.chroma import ChromaDB + +@pytest.fixture +def mock_settings(monkeypatch): + monkeypatch.setenv("SQLALCHEMY_DATABASE_URL", "sqlite:///test.db") + monkeypatch.setenv("USE_OPENAI_EMBEDDINGS", "false") + monkeypatch.setenv("OPENAI_API_KEY", "") + monkeypatch.setenv("CHROMA_BATCH_SIZE", "10") + return Settings() + +def test_chroma_db_initialization(mock_settings): + chroma_db = ChromaDB(settings=mock_settings) + assert chroma_db._batch_size == mock_settings.chroma_batch_size + assert chroma_db._embedding_function.__class__.__name__ == "ONNXMiniLM_L6_V2" + +def test_chroma_db_with_openai_embeddings(): + settings = Settings(use_openai_embeddings=True, openai_api_key="test_key") + chroma_db = ChromaDB(settings=settings) + assert chroma_db._embedding_function.__class__.__name__ == "OpenAIEmbeddingFunction" + +def test_add_docs_with_custom_batch_size(mock_settings): + chroma_db = ChromaDB(batch_size=3, settings=mock_settings) + docs = ["doc1", "doc2", "doc3", "doc4", "doc5"] + ids = ["id1", "id2", "id3", "id4", "id5"] + metadatas = [{"key": "value", "filename": "test.txt"} for _ in range(5)] + + # Mock the _docs_collection.add method + chroma_db._docs_collection.add = MagicMock() + + # Call add_docs with a custom batch size + chroma_db.add_docs(docs, ids=ids, metadatas=metadatas, batch_size=2) + + # Assert that the _docs_collection.add method was called 3 times + # (2 batches of 2 and 1 batch of 1) + assert chroma_db._docs_collection.add.call_count == 3 + +def test_add_docs_with_openai_embeddings(): + settings = Settings(use_openai_embeddings=True, openai_api_key="test_key") + chroma_db = ChromaDB(settings=settings) + docs = ["doc1", "doc2", "doc3", "doc4", "doc5"] + ids = ["id1", "id2", "id3", "id4", "id5"] + metadatas = [{"key": "value", "filename": "test.txt"} for _ in range(5)] + + # Mock the _docs_collection.add method + chroma_db._docs_collection.add = MagicMock() + + # Call add_docs + chroma_db.add_docs(docs, ids=ids, metadatas=metadatas) + + # Assert that the _docs_collection.add method was called only once + chroma_db._docs_collection.add.assert_called_once_with( + documents=docs, + metadatas=metadatas, + ids=ids + ) + +def test_max_file_size_setting(mock_settings): + assert mock_settings.max_file_size == 20 * 1024 * 1024 # 20 MB + +def test_chroma_batch_size_setting(mock_settings): + assert mock_settings.chroma_batch_size == 10 # Set by the mock_settings fixture