Skip to content

Commit

Permalink
[ENH] Add retry on mismatch (#2843)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Adds version mismatch error in python and retries on it from the FE.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
Added tests for version mismatches post a compaction wait, if the
version matches the request should succeed. If we set it to an erroneous
value, we expect a `VersionMismatchError`.
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Sep 25, 2024
1 parent e3da324 commit c515df8
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 19 deletions.
33 changes: 31 additions & 2 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tenacity import retry, stop_after_attempt, retry_if_exception, wait_fixed
from chromadb.api import ServerAPI
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
Expand All @@ -15,8 +16,11 @@
from chromadb.ingest import Producer
from chromadb.types import Collection as CollectionModel
from chromadb import __version__
from chromadb.errors import InvalidDimensionException, InvalidCollectionException

from chromadb.errors import (
InvalidDimensionException,
InvalidCollectionException,
VersionMismatchError,
)
from chromadb.api.types import (
URI,
CollectionMetadata,
Expand Down Expand Up @@ -443,6 +447,12 @@ def _upsert(
return True

@trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION)
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@rate_limit(subject="collection_id", resource=Resource.GET_PER_MINUTE)
@override
def _get(
Expand Down Expand Up @@ -631,6 +641,12 @@ def _delete(
return ids_to_delete

@trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION)
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@override
def _count(self, collection_id: UUID) -> int:
add_attributes_to_current_span({"collection_id": str(collection_id)})
Expand All @@ -644,6 +660,19 @@ def _count(self, collection_id: UUID) -> int:
return metadata_segment.count(request_version_context)

@trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
# We retry on version mismatch errors because the version of the collection
# may have changed between the time we got the version and the time we
# actually query the collection on the FE. We are fine with fixed
# wait time because the version mismatch error is not a error due to
# network issues or other transient issues. It is a result of the
# collection being updated between the time we got the version and
# the time we actually query the collection on the FE.
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@rate_limit(subject="collection_id", resource=Resource.QUERY_PER_MINUTE)
@override
def _query(
Expand Down
12 changes: 12 additions & 0 deletions chromadb/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def name(cls) -> str:
return "BatchSizeExceededError"


class VersionMismatchError(ChromaError):
@overrides
def code(self) -> int:
return 500

@classmethod
@overrides
def name(cls) -> str:
return "VersionMismatchError"


error_types: Dict[str, Type[ChromaError]] = {
"InvalidDimension": InvalidDimensionException,
"InvalidCollection": InvalidCollectionException,
Expand All @@ -135,4 +146,5 @@ def name(cls) -> str:
"AuthorizationError": AuthorizationError,
"NotFoundError": NotFoundError,
"BatchSizeExceededError": BatchSizeExceededError,
"VersionMismatchError": VersionMismatchError,
}
35 changes: 26 additions & 9 deletions chromadb/segment/impl/metadata/grpc_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.segment import MetadataReader
from chromadb.config import System
from chromadb.errors import InvalidArgumentError
from chromadb.errors import InvalidArgumentError, VersionMismatchError
from chromadb.types import Segment, RequestVersionContext
from overrides import override
from chromadb.telemetry.opentelemetry import (
Expand Down Expand Up @@ -55,10 +55,18 @@ def count(self, request_version_context: RequestVersionContext) -> int:
collection_id=self._segment["collection"].hex,
version_context=to_proto_request_version_context(request_version_context),
)
response: pb.CountRecordsResponse = self._metadata_reader_stub.CountRecords(
request,
timeout=self._request_timeout_seconds,
)

try:
response: pb.CountRecordsResponse = self._metadata_reader_stub.CountRecords(
request,
timeout=self._request_timeout_seconds,
)
except grpc.RpcError as rpc_error:
message = rpc_error.details()
if "Collection version mismatch" in message:
raise VersionMismatchError()
raise rpc_error

return response.count

@override
Expand Down Expand Up @@ -110,10 +118,19 @@ def get_metadata(
version_context=to_proto_request_version_context(request_version_context),
)

response: pb.QueryMetadataResponse = self._metadata_reader_stub.QueryMetadata(
request,
timeout=self._request_timeout_seconds,
)
try:
response: pb.QueryMetadataResponse = (
self._metadata_reader_stub.QueryMetadata(
request,
timeout=self._request_timeout_seconds,
)
)
except grpc.RpcError as rpc_error:
message = rpc_error.details()
if "Collection version mismatch" in message:
raise VersionMismatchError()
raise rpc_error

results: List[MetadataEmbeddingRecord] = []
for record in response.records:
result = self._from_proto(record)
Expand Down
33 changes: 25 additions & 8 deletions chromadb/segment/impl/vector/grpc_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.segment import VectorReader
from chromadb.errors import VersionMismatchError
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
Expand Down Expand Up @@ -67,10 +68,18 @@ def get_vectors(
collection_id=self._segment["collection"].hex,
version_context=to_proto_request_version_context(request_version_context),
)
response: GetVectorsResponse = self._vector_reader_stub.GetVectors(
request,
timeout=self._request_timeout_seconds,
)

try:
response: GetVectorsResponse = self._vector_reader_stub.GetVectors(
request,
timeout=self._request_timeout_seconds,
)
except grpc.RpcError as rpc_error:
message = rpc_error.details()
if "Collection version mismatch" in message:
raise VersionMismatchError()
raise rpc_error

results: List[VectorEmbeddingRecord] = []
for vector in response.records:
result = from_proto_vector_embedding_record(vector)
Expand All @@ -96,10 +105,18 @@ def query_vectors(
query["request_version_context"]
),
)
response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors(
request,
timeout=self._request_timeout_seconds,
)

try:
response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors(
request,
timeout=self._request_timeout_seconds,
)
except grpc.RpcError as rpc_error:
message = rpc_error.details()
if "Collection version mismatch" in message:
raise VersionMismatchError()
raise rpc_error

results: List[List[VectorQueryResult]] = []
for result in response.results:
curr_result: List[VectorQueryResult] = []
Expand Down
Loading

0 comments on commit c515df8

Please sign in to comment.