-
Notifications
You must be signed in to change notification settings - Fork 7.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: chromadb max batch size (#1087)
- Loading branch information
Showing
5 changed files
with
142 additions
and
68 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Any | ||
|
||
from llama_index.schema import BaseNode, MetadataMode | ||
from llama_index.vector_stores import ChromaVectorStore | ||
from llama_index.vector_stores.chroma import chunk_list | ||
from llama_index.vector_stores.utils import node_to_metadata_dict | ||
|
||
|
||
class BatchedChromaVectorStore(ChromaVectorStore): | ||
"""Chroma vector store, batching additions to avoid reaching the max batch limit. | ||
In this vector store, embeddings are stored within a ChromaDB collection. | ||
During query time, the index uses ChromaDB to query for the top | ||
k most similar nodes. | ||
Args: | ||
chroma_client (from chromadb.api.API): | ||
API instance | ||
chroma_collection (chromadb.api.models.Collection.Collection): | ||
ChromaDB collection instance | ||
""" | ||
|
||
chroma_client: Any | None | ||
|
||
def __init__( | ||
self, | ||
chroma_client: Any, | ||
chroma_collection: Any, | ||
host: str | None = None, | ||
port: str | None = None, | ||
ssl: bool = False, | ||
headers: dict[str, str] | None = None, | ||
collection_kwargs: dict[Any, Any] | None = None, | ||
) -> None: | ||
super().__init__( | ||
chroma_collection=chroma_collection, | ||
host=host, | ||
port=port, | ||
ssl=ssl, | ||
headers=headers, | ||
collection_kwargs=collection_kwargs or {}, | ||
) | ||
self.chroma_client = chroma_client | ||
|
||
def add(self, nodes: list[BaseNode]) -> list[str]: | ||
"""Add nodes to index, batching the insertion to avoid issues. | ||
Args: | ||
nodes: List[BaseNode]: list of nodes with embeddings | ||
""" | ||
if not self.chroma_client: | ||
raise ValueError("Client not initialized") | ||
|
||
if not self._collection: | ||
raise ValueError("Collection not initialized") | ||
|
||
max_chunk_size = self.chroma_client.max_batch_size | ||
node_chunks = chunk_list(nodes, max_chunk_size) | ||
|
||
all_ids = [] | ||
for node_chunk in node_chunks: | ||
embeddings = [] | ||
metadatas = [] | ||
ids = [] | ||
documents = [] | ||
for node in node_chunk: | ||
embeddings.append(node.get_embedding()) | ||
metadatas.append( | ||
node_to_metadata_dict( | ||
node, remove_text=True, flat_metadata=self.flat_metadata | ||
) | ||
) | ||
ids.append(node.node_id) | ||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) | ||
|
||
self._collection.add( | ||
embeddings=embeddings, | ||
ids=ids, | ||
metadatas=metadatas, | ||
documents=documents, | ||
) | ||
all_ids.extend(ids) | ||
|
||
return all_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from unittest.mock import PropertyMock, patch | ||
|
||
from llama_index import Document | ||
|
||
from private_gpt.server.ingest.ingest_service import IngestService | ||
from tests.fixtures.mock_injector import MockInjector | ||
|
||
|
||
def test_save_many_nodes(injector: MockInjector) -> None: | ||
"""This is a specific test for a local Chromadb Vector Database setup. | ||
Extend it when we add support for other vector databases in VectorStoreComponent. | ||
""" | ||
with patch( | ||
"chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock | ||
) as max_batch_size: | ||
# Make max batch size of Chromadb very small | ||
max_batch_size.return_value = 10 | ||
|
||
ingest_service = injector.get(IngestService) | ||
|
||
documents = [] | ||
for _i in range(100): | ||
documents.append(Document(text="This is a sentence.")) | ||
|
||
ingested_docs = ingest_service._save_docs(documents) | ||
assert len(ingested_docs) == len(documents) |