Skip to content

Commit

Permalink
Recommendation api update + local mode (#314)
Browse files Browse the repository at this point in the history
* WIP: recommendation api update

* implement local new reco + add new tests

* refactor calculate_best_scores, normalize score assertion

* fix raw vectors in new recommend

* fix mypy errors, add new reco on group recommend

* fix more lints from ci

* fix pyright lint

* formattng

* Add descriptions to enums

* no mutable default argument

* edit score comparison precision based on magnitude

* Address review comments

* remove custom message for type ignore

* add coverage test

* improve coverage test a little

* add more fixtures for conversion test

---------

Co-authored-by: Luis Cossío <luis.cossio@outlook.com>
  • Loading branch information
generall and coszio authored Oct 5, 2023
1 parent 8d5ed2c commit bb8c442
Show file tree
Hide file tree
Showing 24 changed files with 809 additions and 318 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ venv
__pycache__
.pytest_cache
.idea
.vscode
.devcontainer
.coverage
htmlcov
*.iml
Expand Down
12 changes: 7 additions & 5 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def recommend_batch(
def recommend(
self,
collection_name: str,
positive: Sequence[types.PointId],
negative: Optional[Sequence[types.PointId]] = None,
positive: Optional[Sequence[types.RecommendExample]] = None,
negative: Optional[Sequence[types.RecommendExample]] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
Expand All @@ -79,7 +79,8 @@ def recommend(
with_vectors: Union[bool, List[str]] = False,
score_threshold: Optional[float] = None,
using: Optional[str] = None,
lookup_from: Optional[models.LookupLocation] = None,
lookup_from: Optional[types.LookupLocation] = None,
strategy: Optional[types.RecommendStrategy] = None,
**kwargs: Any,
) -> List[types.ScoredPoint]:
raise NotImplementedError()
Expand All @@ -88,8 +89,8 @@ def recommend_groups(
self,
collection_name: str,
group_by: str,
positive: Sequence[types.PointId],
negative: Optional[Sequence[types.PointId]] = None,
positive: Optional[Sequence[types.RecommendExample]] = None,
negative: Optional[Sequence[types.RecommendExample]] = None,
query_filter: Optional[models.Filter] = None,
search_params: Optional[models.SearchParams] = None,
limit: int = 10,
Expand All @@ -100,6 +101,7 @@ def recommend_groups(
using: Optional[str] = None,
lookup_from: Optional[models.LookupLocation] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
strategy: Optional[types.RecommendStrategy] = None,
**kwargs: Any,
) -> types.GroupsResult:
raise NotImplementedError()
Expand Down
3 changes: 3 additions & 0 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def get_args_subscribed(tp: type): # type: ignore
List[PointId], rest.Filter, grpc.Filter, rest.PointsSelector, grpc.PointsSelector
]
LookupLocation = Union[rest.LookupLocation, grpc.LookupLocation]
RecommendStrategy: TypeAlias = rest.RecommendStrategy
RecommendExample: TypeAlias = rest.RecommendExample

AliasOperations = Union[
rest.CreateAliasOperation,
Expand All @@ -84,6 +86,7 @@ def get_args_subscribed(tp: type): # type: ignore
InitFrom: TypeAlias = Union[rest.InitFrom, str]
UpdateOperation: TypeAlias = rest.UpdateOperation


SearchRequest = Union[rest.SearchRequest, grpc.SearchPoints]
RecommendRequest = Union[rest.RecommendRequest, grpc.RecommendPoints]

Expand Down
90 changes: 85 additions & 5 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, get_args
from typing import Any, Dict, List, Optional, Sequence, Tuple, get_args

from google.protobuf.json_format import MessageToDict
from google.protobuf.timestamp_pb2 import Timestamp

from qdrant_client._pydantic_compat import construct
from qdrant_client.conversions.common_types import get_args_subscribed

try:
from google.protobuf.pyext._message import MessageMapContainer # type: ignore
Expand Down Expand Up @@ -518,6 +519,9 @@ def convert_collection_params(cls, model: grpc.CollectionParams) -> rest.Collect
replication_factor=model.replication_factor
if model.HasField("replication_factor")
else None,
read_fan_out_factor=model.read_fan_out_factor
if model.HasField("read_fan_out_factor")
else None,
write_consistency_factor=model.write_consistency_factor
if model.HasField("write_consistency_factor")
else None,
Expand Down Expand Up @@ -741,9 +745,15 @@ def convert_search_points(cls, model: grpc.SearchPoints) -> rest.SearchRequest:

@classmethod
def convert_recommend_points(cls, model: grpc.RecommendPoints) -> rest.RecommendRequest:
positive_ids = [cls.convert_point_id(point_id) for point_id in model.positive]
negative_ids = [cls.convert_point_id(point_id) for point_id in model.negative]

positive_vectors = [cls.convert_vector(vector) for vector in model.positive_vectors]
negative_vectors = [cls.convert_vector(vector) for vector in model.negative_vectors]

return rest.RecommendRequest(
positive=[cls.convert_point_id(point_id) for point_id in model.positive],
negative=[cls.convert_point_id(point_id) for point_id in model.negative],
positive=positive_ids + positive_vectors,
negative=negative_ids + negative_vectors,
filter=cls.convert_filter(model.filter) if model.HasField("filter") else None,
limit=model.limit,
with_payload=cls.convert_with_payload_interface(model.with_payload)
Expand All @@ -759,6 +769,9 @@ def convert_recommend_points(cls, model: grpc.RecommendPoints) -> rest.Recommend
lookup_from=cls.convert_lookup_location(model.lookup_from)
if model.HasField("lookup_from")
else None,
strategy=cls.convert_recommend_strategy(model.strategy)
if model.HasField("strategy")
else None,
)

@classmethod
Expand Down Expand Up @@ -794,6 +807,9 @@ def convert_collection_params_diff(
write_consistency_factor=model.write_consistency_factor
if model.HasField("write_consistency_factor")
else None,
read_fan_out_factor=model.read_fan_out_factor
if model.HasField("read_fan_out_factor")
else None,
on_disk_payload=model.on_disk_payload if model.HasField("on_disk_payload") else None,
)

Expand Down Expand Up @@ -1094,6 +1110,14 @@ def convert_init_from(cls, model: str) -> rest.InitFrom:
return rest.InitFrom(collection=model)
raise ValueError(f"Invalid InitFrom model: {model}")

@classmethod
def convert_recommend_strategy(cls, model: grpc.RecommendStrategy) -> rest.RecommendStrategy:
if model == grpc.RecommendStrategy.AverageVector:
return rest.RecommendStrategy.AVERAGE_VECTOR
if model == grpc.RecommendStrategy.BestScore:
return rest.RecommendStrategy.BEST_SCORE
raise ValueError(f"invalid RecommendStrategy model: {model}")


# ----------------------------------------
#
Expand Down Expand Up @@ -1426,6 +1450,7 @@ def convert_collection_params(cls, model: rest.CollectionParams) -> grpc.Collect
on_disk_payload=model.on_disk_payload or False,
write_consistency_factor=model.write_consistency_factor,
replication_factor=model.replication_factor,
read_fan_out_factor=model.read_fan_out_factor,
)

@classmethod
Expand Down Expand Up @@ -1531,6 +1556,40 @@ def convert_alias_description(cls, model: rest.AliasDescription) -> grpc.AliasDe
collection_name=model.collection_name,
)

@classmethod
def convert_recommend_examples_to_ids(
cls, examples: Sequence[rest.RecommendExample]
) -> List[grpc.PointId]:
ids: List[grpc.PointId] = []
for example in examples:
if isinstance(example, get_args_subscribed(rest.ExtendedPointId)):
id_ = cls.convert_extended_point_id(example)
elif isinstance(example, grpc.PointId):
id_ = example
else:
continue

ids.append(id_)

return ids

@classmethod
def convert_recommend_examples_to_vectors(
cls, examples: Sequence[rest.RecommendExample]
) -> List[grpc.Vector]:
vectors: List[grpc.Vector] = []
for example in examples:
if isinstance(example, grpc.Vector):
vector = example
elif isinstance(example, list):
vector = grpc.Vector(data=example)
else:
continue

vectors.append(vector)

return vectors

@classmethod
def convert_extended_point_id(cls, model: rest.ExtendedPointId) -> grpc.PointId:
if isinstance(model, int):
Expand Down Expand Up @@ -1735,10 +1794,16 @@ def convert_search_points(
def convert_recommend_request(
cls, model: rest.RecommendRequest, collection_name: str
) -> grpc.RecommendPoints:
positive_ids = cls.convert_recommend_examples_to_ids(model.positive)
negative_ids = cls.convert_recommend_examples_to_ids(model.negative)

positive_vectors = cls.convert_recommend_examples_to_vectors(model.positive)
negative_vectors = cls.convert_recommend_examples_to_vectors(model.negative)

return grpc.RecommendPoints(
collection_name=collection_name,
positive=[cls.convert_extended_point_id(point_id) for point_id in model.positive],
negative=[cls.convert_extended_point_id(point_id) for point_id in model.negative],
positive=positive_ids,
negative=negative_ids,
filter=cls.convert_filter(model.filter) if model.filter is not None else None,
limit=model.limit,
with_payload=cls.convert_with_payload_interface(model.with_payload)
Expand All @@ -1754,6 +1819,11 @@ def convert_recommend_request(
lookup_from=cls.convert_lookup_location(model.lookup_from)
if model.lookup_from is not None
else None,
strategy=cls.convert_recommend_strategy(model.strategy)
if model.strategy is not None
else None,
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
)

@classmethod
Expand Down Expand Up @@ -1794,6 +1864,7 @@ def convert_collection_params_diff(
replication_factor=model.replication_factor,
write_consistency_factor=model.write_consistency_factor,
on_disk_payload=model.on_disk_payload,
read_fan_out_factor=model.read_fan_out_factor,
)

@classmethod
Expand Down Expand Up @@ -2136,3 +2207,12 @@ def convert_init_from(cls, model: rest.InitFrom) -> str:
return model.collection
else:
raise ValueError(f"invalid InitFrom model: {model}")

@classmethod
def convert_recommend_strategy(cls, model: rest.RecommendStrategy) -> grpc.RecommendStrategy:
if model == rest.RecommendStrategy.AVERAGE_VECTOR:
return grpc.RecommendStrategy.AverageVector
elif model == rest.RecommendStrategy.BEST_SCORE:
return grpc.RecommendStrategy.BestScore
else:
raise ValueError(f"invalid RecommendStrategy model: {model}")
Loading

0 comments on commit bb8c442

Please sign in to comment.