Skip to content

Commit

Permalink
Revert "[ENH] simplify logic for when to persist index changes" (#2544)
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb authored Jul 19, 2024
1 parent 17c3f1d commit f879f9e
Showing 1 changed file with 47 additions and 13 deletions.
60 changes: 47 additions & 13 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from overrides import override
import pickle
from typing import Dict, List, Optional, Sequence, Set, cast
from typing import Any, Dict, List, Optional, Sequence, Set, cast
from chromadb.config import System
from chromadb.segment.impl.vector.batch import Batch
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
Expand Down Expand Up @@ -40,6 +40,9 @@ class PersistentData:
"""Stores the data and metadata needed for a PersistentLocalHnswSegment"""

dimensionality: Optional[int]
total_elements_added: int
total_elements_updated: int
total_invalid_operations: int
max_seq_id: SeqId

id_to_label: Dict[str, int]
Expand All @@ -49,17 +52,29 @@ class PersistentData:
def __init__(
self,
dimensionality: Optional[int],
total_elements_added: int,
total_elements_updated: int,
total_invalid_operations: int,
max_seq_id: int,
id_to_label: Dict[str, int],
label_to_id: Dict[int, str],
id_to_seq_id: Dict[str, SeqId],
):
self.dimensionality = dimensionality
self.total_elements_added = total_elements_added
self.total_elements_updated = total_elements_updated
self.total_invalid_operations = total_invalid_operations
self.max_seq_id = max_seq_id
self.id_to_label = id_to_label
self.label_to_id = label_to_id
self.id_to_seq_id = id_to_seq_id

def __setstate__(self, state: Any) -> None:
# Fields were added after the initial implementation
self.total_elements_updated = 0
self.total_invalid_operations = 0
self.__dict__.update(state)

@staticmethod
def load_from_file(filename: str) -> "PersistentData":
"""Load persistent data from a file"""
Expand All @@ -85,9 +100,6 @@ class PersistentLocalHnswSegment(LocalHnswSegment):

_opentelemtry_client: OpenTelemetryClient

_num_log_records_since_last_batch: int = 0
_num_log_records_since_last_persist: int = 0

def __init__(self, system: System, segment: Segment):
super().__init__(system, segment)

Expand All @@ -108,6 +120,7 @@ def __init__(self, system: System, segment: Segment):
self._get_metadata_file()
)
self._dimensionality = self._persist_data.dimensionality
self._total_elements_added = self._persist_data.total_elements_added
self._max_seq_id = self._persist_data.max_seq_id
self._id_to_label = self._persist_data.id_to_label
self._label_to_id = self._persist_data.label_to_id
Expand All @@ -119,6 +132,9 @@ def __init__(self, system: System, segment: Segment):
else:
self._persist_data = PersistentData(
self._dimensionality,
self._total_elements_added,
self._total_elements_updated,
self._total_invalid_operations,
self._max_seq_id,
self._id_to_label,
self._label_to_id,
Expand Down Expand Up @@ -192,6 +208,8 @@ def _persist(self) -> None:

# Persist the metadata
self._persist_data.dimensionality = self._dimensionality
self._persist_data.total_elements_added = self._total_elements_added
self._persist_data.total_elements_updated = self._total_elements_updated
self._persist_data.max_seq_id = self._max_seq_id

# TODO: This should really be stored in sqlite, the index itself, or a better
Expand All @@ -203,18 +221,29 @@ def _persist(self) -> None:
with open(self._get_metadata_file(), "wb") as metadata_file:
pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL)

self._num_log_records_since_last_persist = 0

@trace_method(
"PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL
)
@override
def _apply_batch(self, batch: Batch) -> None:
super()._apply_batch(batch)
if self._num_log_records_since_last_persist >= self._sync_threshold:
self._persist()
num_elements_added_since_last_persist = (
self._total_elements_added - self._persist_data.total_elements_added
)
num_elements_updated_since_last_persist = (
self._total_elements_updated - self._persist_data.total_elements_updated
)
num_invalid_operations_since_last_persist = (
self._total_invalid_operations - self._persist_data.total_invalid_operations
)

self._num_log_records_since_last_batch = 0
if (
num_elements_added_since_last_persist
+ num_elements_updated_since_last_persist
+ num_invalid_operations_since_last_persist
>= self._sync_threshold
):
self._persist()

@trace_method(
"PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL
Expand All @@ -226,9 +255,6 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
raise RuntimeError("Cannot add embeddings to stopped component")
with WriteRWLock(self._lock):
for record in records:
self._num_log_records_since_last_batch += 1
self._num_log_records_since_last_persist += 1

if record["record"]["embedding"] is not None:
self._ensure_index(len(records), len(record["record"]["embedding"]))
if not self._index_initialized:
Expand Down Expand Up @@ -279,7 +305,15 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
self._curr_batch.apply(record, exists_in_index)
self._brute_force_index.upsert([record])

if self._num_log_records_since_last_batch >= self._batch_size:
num_invalid_operations_since_last_persist = (
self._total_invalid_operations
- self._persist_data.total_invalid_operations
)

if (
len(self._curr_batch) + num_invalid_operations_since_last_persist
>= self._batch_size
):
self._apply_batch(self._curr_batch)
self._curr_batch = Batch()
self._brute_force_index.clear()
Expand Down

0 comments on commit f879f9e

Please sign in to comment.