diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index f2ff5907cf3..b8406c1180b 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -2,7 +2,7 @@ import shutil from overrides import override import pickle -from typing import Dict, List, Optional, Sequence, Set, cast +from typing import Any, Dict, List, Optional, Sequence, Set, cast from chromadb.config import System from chromadb.segment.impl.vector.batch import Batch from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams @@ -40,6 +40,9 @@ class PersistentData: """Stores the data and metadata needed for a PersistentLocalHnswSegment""" dimensionality: Optional[int] + total_elements_added: int + total_elements_updated: int + total_invalid_operations: int max_seq_id: SeqId id_to_label: Dict[str, int] @@ -49,17 +52,29 @@ class PersistentData: def __init__( self, dimensionality: Optional[int], + total_elements_added: int, + total_elements_updated: int, + total_invalid_operations: int, max_seq_id: int, id_to_label: Dict[str, int], label_to_id: Dict[int, str], id_to_seq_id: Dict[str, SeqId], ): self.dimensionality = dimensionality + self.total_elements_added = total_elements_added + self.total_elements_updated = total_elements_updated + self.total_invalid_operations = total_invalid_operations self.max_seq_id = max_seq_id self.id_to_label = id_to_label self.label_to_id = label_to_id self.id_to_seq_id = id_to_seq_id + def __setstate__(self, state: Any) -> None: + # Fields were added after the initial implementation + self.total_elements_updated = 0 + self.total_invalid_operations = 0 + self.__dict__.update(state) + @staticmethod def load_from_file(filename: str) -> "PersistentData": """Load persistent data from a file""" @@ -85,9 +100,6 @@ class PersistentLocalHnswSegment(LocalHnswSegment): _opentelemtry_client: OpenTelemetryClient - _num_log_records_since_last_batch: int = 0 - _num_log_records_since_last_persist: int = 0 - def __init__(self, system: System, segment: Segment): super().__init__(system, segment) @@ -108,6 +120,7 @@ def __init__(self, system: System, segment: Segment): self._get_metadata_file() ) self._dimensionality = self._persist_data.dimensionality + self._total_elements_added = self._persist_data.total_elements_added self._max_seq_id = self._persist_data.max_seq_id self._id_to_label = self._persist_data.id_to_label self._label_to_id = self._persist_data.label_to_id @@ -119,6 +132,9 @@ def __init__(self, system: System, segment: Segment): else: self._persist_data = PersistentData( self._dimensionality, + self._total_elements_added, + self._total_elements_updated, + self._total_invalid_operations, self._max_seq_id, self._id_to_label, self._label_to_id, @@ -192,6 +208,8 @@ def _persist(self) -> None: # Persist the metadata self._persist_data.dimensionality = self._dimensionality + self._persist_data.total_elements_added = self._total_elements_added + self._persist_data.total_elements_updated = self._total_elements_updated self._persist_data.max_seq_id = self._max_seq_id # TODO: This should really be stored in sqlite, the index itself, or a better @@ -203,18 +221,29 @@ def _persist(self) -> None: with open(self._get_metadata_file(), "wb") as metadata_file: pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL) - self._num_log_records_since_last_persist = 0 - @trace_method( "PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL ) @override def _apply_batch(self, batch: Batch) -> None: super()._apply_batch(batch) - if self._num_log_records_since_last_persist >= self._sync_threshold: - self._persist() + num_elements_added_since_last_persist = ( + self._total_elements_added - self._persist_data.total_elements_added + ) + num_elements_updated_since_last_persist = ( + self._total_elements_updated - self._persist_data.total_elements_updated + ) + num_invalid_operations_since_last_persist = ( + self._total_invalid_operations - self._persist_data.total_invalid_operations + ) - self._num_log_records_since_last_batch = 0 + if ( + num_elements_added_since_last_persist + + num_elements_updated_since_last_persist + + num_invalid_operations_since_last_persist + >= self._sync_threshold + ): + self._persist() @trace_method( "PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL @@ -226,9 +255,6 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: raise RuntimeError("Cannot add embeddings to stopped component") with WriteRWLock(self._lock): for record in records: - self._num_log_records_since_last_batch += 1 - self._num_log_records_since_last_persist += 1 - if record["record"]["embedding"] is not None: self._ensure_index(len(records), len(record["record"]["embedding"])) if not self._index_initialized: @@ -279,7 +305,15 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: self._curr_batch.apply(record, exists_in_index) self._brute_force_index.upsert([record]) - if self._num_log_records_since_last_batch >= self._batch_size: + num_invalid_operations_since_last_persist = ( + self._total_invalid_operations + - self._persist_data.total_invalid_operations + ) + + if ( + len(self._curr_batch) + num_invalid_operations_since_last_persist + >= self._batch_size + ): self._apply_batch(self._curr_batch) self._curr_batch = Batch() self._brute_force_index.clear()