Skip to content

Commit

Permalink
new: add collection exists interface (#518)
Browse files Browse the repository at this point in the history
* new: add collection exists interface

* tests: fix version comparison

* tests: add collection exists type stub, add async test
  • Loading branch information
joein committed Mar 4, 2024
1 parent 0849671 commit ee67ccf
Show file tree
Hide file tree
Showing 12 changed files with 192 additions and 60 deletions.
3 changes: 3 additions & 0 deletions qdrant_client/async_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
async def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
raise NotImplementedError()

async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()

async def update_collection(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()

Expand Down
12 changes: 12 additions & 0 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,18 @@ async def get_collection(self, collection_name: str, **kwargs: Any) -> types.Col
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.get_collection(collection_name=collection_name, **kwargs)

async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
"""Check whether collection already exists
Args:
collection_name: Name of the collection
Returns:
True if collection exists, False if not
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.collection_exists(collection_name=collection_name, **kwargs)

async def update_collection(
self,
collection_name: str,
Expand Down
14 changes: 14 additions & 0 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,20 @@ async def get_collection(self, collection_name: str, **kwargs: Any) -> types.Col
assert result is not None, "Get collection returned None"
return result

async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
if self._prefer_grpc:
return (
await self.grpc_collections.CollectionExists(
grpc.CollectionExistsRequest(collection_name=collection_name),
timeout=self._timeout,
)
).result.exists
result: Optional[models.CollectionExistence] = (
await self.http.collections_api.collection_exists(collection_name=collection_name)
).result
assert result is not None, "Collection exists returned None"
return result.exists

async def update_collection(
self,
collection_name: str,
Expand Down
3 changes: 3 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
raise NotImplementedError()

def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()

def update_collection(
self,
collection_name: str,
Expand Down
7 changes: 7 additions & 0 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,13 @@ async def get_collection(self, collection_name: str, **kwargs: Any) -> types.Col
collection = self._get_collection(collection_name)
return collection.info()

async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
try:
self._get_collection(collection_name)
return True
except ValueError:
return False

async def update_collection(self, collection_name: str, **kwargs: Any) -> bool:
self._get_collection(collection_name)
return False
Expand Down
7 changes: 7 additions & 0 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,13 @@ def get_collection(self, collection_name: str, **kwargs: Any) -> types.Collectio
collection = self._get_collection(collection_name)
return collection.info()

def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
try:
self._get_collection(collection_name)
return True
except ValueError:
return False

def update_collection(self, collection_name: str, **kwargs: Any) -> bool:
_collection = self._get_collection(collection_name)
return False
Expand Down
13 changes: 13 additions & 0 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,19 @@ def get_collection(self, collection_name: str, **kwargs: Any) -> types.Collectio

return self._client.get_collection(collection_name=collection_name, **kwargs)

def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
"""Check whether collection already exists
Args:
collection_name: Name of the collection
Returns:
True if collection exists, False if not
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"

return self._client.collection_exists(collection_name=collection_name, **kwargs)

def update_collection(
self,
collection_name: str,
Expand Down
159 changes: 103 additions & 56 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,11 @@ def search_batch(
) -> List[List[types.ScoredPoint]]:
if self._prefer_grpc:
requests = [
RestToGrpc.convert_search_request(r, collection_name)
if isinstance(r, models.SearchRequest)
else r
(
RestToGrpc.convert_search_request(r, collection_name)
if isinstance(r, models.SearchRequest)
else r
)
for r in requests
]

Expand Down Expand Up @@ -655,9 +657,11 @@ def recommend_batch(
) -> List[List[types.ScoredPoint]]:
if self._prefer_grpc:
requests = [
RestToGrpc.convert_recommend_request(r, collection_name)
if isinstance(r, models.RecommendRequest)
else r
(
RestToGrpc.convert_recommend_request(r, collection_name)
if isinstance(r, models.RecommendRequest)
else r
)
for r in requests
]

Expand All @@ -679,9 +683,11 @@ def recommend_batch(
]
else:
requests = [
GrpcToRest.convert_recommend_points(r)
if isinstance(r, grpc.RecommendPoints)
else r
(
GrpcToRest.convert_recommend_points(r)
if isinstance(r, grpc.RecommendPoints)
else r
)
for r in requests
]
http_res: List[List[models.ScoredPoint]] = self.http.points_api.recommend_batch_points(
Expand Down Expand Up @@ -775,16 +781,20 @@ def recommend(
return [GrpcToRest.convert_scored_point(hit) for hit in res.result]
else:
positive = [
GrpcToRest.convert_point_id(example)
if isinstance(example, grpc.PointId)
else example
(
GrpcToRest.convert_point_id(example)
if isinstance(example, grpc.PointId)
else example
)
for example in positive
]

negative = [
GrpcToRest.convert_point_id(example)
if isinstance(example, grpc.PointId)
else example
(
GrpcToRest.convert_point_id(example)
if isinstance(example, grpc.PointId)
else example
)
for example in negative
]

Expand Down Expand Up @@ -918,16 +928,20 @@ def recommend_groups(
with_lookup = GrpcToRest.convert_with_lookup(with_lookup)

positive = [
GrpcToRest.convert_point_id(point_id)
if isinstance(point_id, grpc.PointId)
else point_id
(
GrpcToRest.convert_point_id(point_id)
if isinstance(point_id, grpc.PointId)
else point_id
)
for point_id in positive
]

negative = [
GrpcToRest.convert_point_id(point_id)
if isinstance(point_id, grpc.PointId)
else point_id
(
GrpcToRest.convert_point_id(point_id)
if isinstance(point_id, grpc.PointId)
else point_id
)
for point_id in negative
]

Expand Down Expand Up @@ -1000,9 +1014,11 @@ def discover(
)

context = [
RestToGrpc.convert_context_example_pair(pair)
if isinstance(pair, models.ContextExamplePair)
else pair
(
RestToGrpc.convert_context_example_pair(pair)
if isinstance(pair, models.ContextExamplePair)
else pair
)
for pair in context
]

Expand Down Expand Up @@ -1056,9 +1072,11 @@ def discover(
)

context = [
GrpcToRest.convert_context_example_pair(pair)
if isinstance(pair, grpc.ContextExamplePair)
else pair
(
GrpcToRest.convert_context_example_pair(pair)
if isinstance(pair, grpc.ContextExamplePair)
else pair
)
for pair in context
]

Expand Down Expand Up @@ -1105,9 +1123,11 @@ def discover_batch(
) -> List[List[types.ScoredPoint]]:
if self._prefer_grpc:
requests = [
RestToGrpc.convert_discover_request(r, collection_name)
if isinstance(r, models.DiscoverRequest)
else r
(
RestToGrpc.convert_discover_request(r, collection_name)
if isinstance(r, models.DiscoverRequest)
else r
)
for r in requests
]

Expand Down Expand Up @@ -1187,11 +1207,11 @@ def scroll(
timeout=self._timeout,
)

return [
GrpcToRest.convert_retrieved_point(point) for point in res.result
], GrpcToRest.convert_point_id(res.next_page_offset) if res.HasField(
"next_page_offset"
) else None
return [GrpcToRest.convert_retrieved_point(point) for point in res.result], (
GrpcToRest.convert_point_id(res.next_page_offset)
if res.HasField("next_page_offset")
else None
)
else:
if isinstance(offset, grpc.PointId):
offset = GrpcToRest.convert_point_id(offset)
Expand Down Expand Up @@ -1282,17 +1302,21 @@ def upsert(
grpc.PointStruct(
id=RestToGrpc.convert_extended_point_id(points.ids[idx]),
vectors=vectors_batch[idx],
payload=RestToGrpc.convert_payload(points.payloads[idx])
if points.payloads is not None
else None,
payload=(
RestToGrpc.convert_payload(points.payloads[idx])
if points.payloads is not None
else None
),
)
for idx in range(len(points.ids))
]
if isinstance(points, list):
points = [
RestToGrpc.convert_point_struct(point)
if isinstance(point, models.PointStruct)
else point
(
RestToGrpc.convert_point_struct(point)
if isinstance(point, models.PointStruct)
else point
)
for point in points
]

Expand All @@ -1318,9 +1342,11 @@ def upsert(
else:
if isinstance(points, list):
points = [
GrpcToRest.convert_point_struct(point)
if isinstance(point, grpc.PointStruct)
else point
(
GrpcToRest.convert_point_struct(point)
if isinstance(point, grpc.PointStruct)
else point
)
for point in points
]

Expand Down Expand Up @@ -1444,9 +1470,11 @@ def retrieve(
with_payload = RestToGrpc.convert_with_payload_interface(with_payload)

ids = [
RestToGrpc.convert_extended_point_id(idx)
if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
else idx
(
RestToGrpc.convert_extended_point_id(idx)
if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
else idx
)
for idx in ids
]

Expand Down Expand Up @@ -1505,9 +1533,11 @@ def _try_argument_to_grpc_selector(
points_selector = grpc.PointsSelector(
points=grpc.PointsIdsList(
ids=[
RestToGrpc.convert_extended_point_id(idx)
if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
else idx
(
RestToGrpc.convert_extended_point_id(idx)
if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
else idx
)
for idx in points
]
)
Expand Down Expand Up @@ -1893,9 +1923,11 @@ def update_collection_aliases(
) -> bool:
if self._prefer_grpc:
change_aliases_operation = [
RestToGrpc.convert_alias_operations(operation)
if not isinstance(operation, grpc.AliasOperations)
else operation
(
RestToGrpc.convert_alias_operations(operation)
if not isinstance(operation, grpc.AliasOperations)
else operation
)
for operation in change_aliases_operations
]
return self.grpc_collections.UpdateAliases(
Expand All @@ -1907,9 +1939,11 @@ def update_collection_aliases(
).result

change_aliases_operation = [
GrpcToRest.convert_alias_operations(operation)
if isinstance(operation, grpc.AliasOperations)
else operation
(
GrpcToRest.convert_alias_operations(operation)
if isinstance(operation, grpc.AliasOperations)
else operation
)
for operation in change_aliases_operations
]
result: Optional[bool] = self.http.collections_api.update_aliases(
Expand Down Expand Up @@ -1991,6 +2025,19 @@ def get_collection(self, collection_name: str, **kwargs: Any) -> types.Collectio
assert result is not None, "Get collection returned None"
return result

def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
if self._prefer_grpc:
return self.grpc_collections.CollectionExists(
grpc.CollectionExistsRequest(collection_name=collection_name),
timeout=self._timeout,
).result.exists

result: Optional[models.CollectionExistence] = self.http.collections_api.collection_exists(
collection_name=collection_name
).result
assert result is not None, "Collection exists returned None"
return result.exists

def update_collection(
self,
collection_name: str,
Expand Down
Loading

0 comments on commit ee67ccf

Please sign in to comment.