Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch_size for vectordb #1449

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ install:

install_all:
poetry install --all-extras
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface psutil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need psutil here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test was failing because it was not installed. It can be seen in the tests of #1448


install_es:
poetry install --extras elasticsearch
Expand Down
4 changes: 0 additions & 4 deletions embedchain/config/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def __init__(
dir: str = "db",
host: Optional[str] = None,
port: Optional[str] = None,
batch_size: Optional[int] = 100,
**kwargs,
):
"""
Expand All @@ -24,16 +23,13 @@ def __init__(
:type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
"""
self.collection_name = collection_name or "embedchain_store"
self.dir = dir
self.host = host
self.port = port
self.batch_size = batch_size
# Assign additional keyword arguments
if kwargs:
for key, value in kwargs.items():
Expand Down
4 changes: 4 additions & 0 deletions embedchain/config/vectordb/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
batch_size: Optional[int] = 100,
allow_reset=False,
chroma_settings: Optional[dict] = None,
):
Expand All @@ -26,6 +27,8 @@ def __init__(
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param allow_reset: Resets the database. defaults to False
:type allow_reset: bool
:param chroma_settings: Chroma settings dict, defaults to None
Expand All @@ -34,4 +37,5 @@ def __init__(

self.chroma_settings = chroma_settings
self.allow_reset = allow_reset
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
7 changes: 7 additions & 0 deletions embedchain/config/vectordb/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Expand All @@ -24,6 +25,10 @@ def __init__(
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
Expand All @@ -46,4 +51,6 @@ def __init__(
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")

self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)
4 changes: 4 additions & 0 deletions embedchain/config/vectordb/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
Expand All @@ -28,10 +29,13 @@ def __init__(
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size

super().__init__(collection_name=collection_name, dir=dir)
2 changes: 2 additions & 0 deletions embedchain/config/vectordb/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
serverless_config: Optional[dict[str, any]] = None,
hybrid_search: bool = False,
bm25_encoder: any = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.metric = metric
Expand All @@ -26,6 +27,7 @@ def __init__(
self.extra_params = extra_params
self.hybrid_search = hybrid_search
self.bm25_encoder = bm25_encoder
self.batch_size = batch_size
if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
Expand Down
4 changes: 4 additions & 0 deletions embedchain/config/vectordb/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
hnsw_config: Optional[dict[str, any]] = None,
quantization_config: Optional[dict[str, any]] = None,
on_disk: Optional[bool] = None,
batch_size: Optional[int] = 10,
**extra_params: dict[str, any],
):
"""
Expand All @@ -36,9 +37,12 @@ def __init__(
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
:type on_disk: bool, optional, defaults to None
:param batch_size: Number of items to insert in one batch, defaults to 10
:type batch_size: Optional[int], optional
"""
self.hnsw_config = hnsw_config
self.quantization_config = quantization_config
self.on_disk = on_disk
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)
2 changes: 2 additions & 0 deletions embedchain/config/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)
10 changes: 6 additions & 4 deletions embedchain/vectordb/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, config: Optional[ChromaDbConfig] = None):

self.settings = Settings(anonymized_telemetry=False)
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
self.batch_size = self.config.batch_size
if self.config.chroma_settings:
for key, value in self.config.chroma_settings.items():
if hasattr(self.settings, key):
Expand Down Expand Up @@ -153,12 +154,13 @@ def add(
" Ids size: {}".format(len(documents), len(metadatas), len(ids))
)

for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"):
for i in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in chromadb"):
self.collection.add(
documents=documents[i : i + self.config.batch_size],
metadatas=metadatas[i : i + self.config.batch_size],
ids=ids[i : i + self.config.batch_size],
documents=documents[i : i + self.batch_size],
metadatas=metadatas[i : i + self.batch_size],
ids=ids[i : i + self.batch_size],
)
self.config

@staticmethod
def _format_result(results: QueryResult) -> list[tuple[Document, float]]:
Expand Down
3 changes: 2 additions & 1 deletion embedchain/vectordb/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
"Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
)

self.batch_size = self.config.batch_size
# Call parent init here because embedder is needed
super().__init__(config=self.config)

Expand Down Expand Up @@ -139,7 +140,7 @@ def add(

for chunk in chunks(
list(zip(ids, documents, metadatas, embeddings)),
self.config.batch_size,
self.batch_size,
desc="Inserting batches in elasticsearch",
): # noqa: E501
ids, docs, metadatas, embeddings = [], [], [], []
Expand Down
7 changes: 3 additions & 4 deletions embedchain/vectordb/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, config: OpenSearchDBConfig):
if config is None:
raise ValueError("OpenSearchDBConfig is required")
self.config = config
self.batch_size = self.config.batch_size
self.client = OpenSearch(
hosts=[self.config.opensearch_url],
http_auth=self.config.http_auth,
Expand Down Expand Up @@ -118,10 +119,8 @@ def add(self, documents: list[str], metadatas: list[object], ids: list[str], **k
"""Adds documents to the opensearch index"""

embeddings = self.embedder.embedding_fn(documents)
for batch_start in tqdm(
range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch"
):
batch_end = batch_start + self.config.batch_size
for batch_start in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in opensearch"):
batch_end = batch_start + self.batch_size
batch_documents = documents[batch_start:batch_end]
batch_embeddings = embeddings[batch_start:batch_end]

Expand Down
7 changes: 4 additions & 3 deletions embedchain/vectordb/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(

# Setup BM25Encoder if sparse vectors are to be used
self.bm25_encoder = None
self.batch_size = self.config.batch_size
if self.config.hybrid_search:
logger.info("Initializing BM25Encoder for sparse vectors..")
self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
Expand Down Expand Up @@ -102,8 +103,8 @@ def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] =
metadatas = []

if ids is not None:
for i in range(0, len(ids), self.config.batch_size):
result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size])
for i in range(0, len(ids), self.batch_size):
result = self.pinecone_index.fetch(ids=ids[i : i + self.batch_size])
vectors = result.get("vectors")
batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids)
Expand Down Expand Up @@ -142,7 +143,7 @@ def add(
},
)

for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"):
for chunk in chunks(docs, self.batch_size, desc="Adding chunks in batches"):
self.pinecone_index.upsert(chunk, **kwargs)

def query(
Expand Down
11 changes: 6 additions & 5 deletions embedchain/vectordb/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, config: QdrantDBConfig = None):
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.batch_size = self.config.batch_size
self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
# Call parent init here because embedder is needed
super().__init__(config=self.config)
Expand Down Expand Up @@ -114,7 +115,7 @@ def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] =
collection_name=self.collection_name,
scroll_filter=models.Filter(must=qdrant_must_filters),
offset=offset,
limit=self.config.batch_size,
limit=self.batch_size,
)
offset = response[1]
for doc in response[0]:
Expand Down Expand Up @@ -146,13 +147,13 @@ def add(
qdrant_ids.append(id)
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})

for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"):
for i in tqdm(range(0, len(qdrant_ids), self.batch_size), desc="Adding data in batches"):
self.client.upsert(
collection_name=self.collection_name,
points=Batch(
ids=qdrant_ids[i : i + self.config.batch_size],
payloads=payloads[i : i + self.config.batch_size],
vectors=embeddings[i : i + self.config.batch_size],
ids=qdrant_ids[i : i + self.batch_size],
payloads=payloads[i : i + self.batch_size],
vectors=embeddings[i : i + self.batch_size],
),
**kwargs,
)
Expand Down
5 changes: 3 additions & 2 deletions embedchain/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.batch_size = self.config.batch_size
self.client = weaviate.Client(
url=os.environ.get("WEAVIATE_ENDPOINT"),
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
Expand Down Expand Up @@ -167,7 +168,7 @@ def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] =
)
.with_where(weaviate_where_clause)
.with_additional(["id"])
.with_limit(limit or self.config.batch_size),
.with_limit(limit or self.batch_size),
offset,
)

Expand Down Expand Up @@ -196,7 +197,7 @@ def add(self, documents: list[str], metadatas: list[object], ids: list[str], **k
:type ids: list[str]
"""
embeddings = self.embedder.embedding_fn(documents)
self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3) # Configure batch
self.client.batch.configure(batch_size=self.batch_size, timeout_retries=3) # Configure batch
with self.client.batch as batch: # Initialize a batch process
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
doc = {"identifier": id, "text": text}
Expand Down
3 changes: 1 addition & 2 deletions tests/vectordb/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,14 @@ def test_add(self, weaviate_mock):
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
db.config.batch_size = 1

documents = ["This is test document"]
metadatas = [None]
ids = ["id_1"]
db.add(documents, metadatas, ids)

# Check if the document was added to the database.
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3)
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
)
Expand Down
Loading