diff --git a/qdrant_client/local/local_collection.py b/qdrant_client/local/local_collection.py index f9e77b29..b2619497 100644 --- a/qdrant_client/local/local_collection.py +++ b/qdrant_client/local/local_collection.py @@ -554,6 +554,9 @@ def _update_point(self, point: models.PointStruct) -> None: for vector_name, named_vectors in self.vectors.items(): vector = vectors.get(vector_name) if vector is not None: + params = self.get_vector_params(vector_name) + if params.distance == models.Distance.COSINE: + vector = np.array(vector) / np.linalg.norm(vector) self.vectors[vector_name][idx] = vector self.deleted_per_vector[vector_name][idx] = 0 else: @@ -587,6 +590,9 @@ def _add_point(self, point: models.PointStruct) -> None: ) else: vector_np = np.array(vector) + params = self.get_vector_params(vector_name) + if params.distance == models.Distance.COSINE: + vector_np = vector_np / np.linalg.norm(vector_np) named_vectors[idx] = vector_np self.vectors[vector_name] = named_vectors self.deleted_per_vector[vector_name] = np.append( diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index ad1157b8..cf85e441 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -1,4 +1,3 @@ -import warnings from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union from qdrant_client import grpc as grpc diff --git a/tests/congruence_tests/test_common.py b/tests/congruence_tests/test_common.py index b4df0efd..b3d1e9b0 100644 --- a/tests/congruence_tests/test_common.py +++ b/tests/congruence_tests/test_common.py @@ -4,6 +4,7 @@ from qdrant_client import QdrantClient from qdrant_client.client_base import QdrantBase +from qdrant_client.conversions import common_types as types from qdrant_client.http import models from qdrant_client.http.models import VectorStruct from qdrant_client.local.qdrant_local import QdrantLocal @@ -84,7 +85,7 @@ def compare_collections( compare_client_results( client_1, client_2, - lambda client: client.scroll(COLLECTION_NAME, limit=num_vectors * 2), + lambda client: client.scroll(COLLECTION_NAME, with_vectors=True, limit=num_vectors * 2), ) @@ -96,11 +97,13 @@ def compare_vectors(vec1: Optional[VectorStruct], vec2: Optional[VectorStruct], if isinstance(vec1, dict): for key, value in vec1.items(): - assert np.allclose(vec1[key], vec2[key]), ( + assert np.allclose(vec1[key], vec2[key], atol=1.0e-3), ( f"res1[{i}].vectors[{key}] = {value}, " f"res2[{i}].vectors[{key}] = {vec2[key]}" ) else: - assert np.allclose(vec1, vec2), f"res1[{i}].vectors = {vec1}, res2[{i}].vectors = {vec2}" + assert np.allclose( + vec1, vec2, atol=1.0e-3 + ), f"res1[{i}].vectors = {vec1}, res2[{i}].vectors = {vec2}" def compare_scored_record( @@ -149,6 +152,13 @@ def compare_client_results( res1 = foo(client1, **kwargs) res2 = foo(client2, **kwargs) + # compare scroll results + if isinstance(res1, tuple) and len(res1) == 2: + if isinstance(res1[0], list) and (res1[1] is None or isinstance(res1[1], types.PointId)): + res1, offset1 = res1 + res2, offset2 = res2 + assert offset1 == offset2, f"offset1 = {offset1}, offset2 = {offset2}" + if isinstance(res1, list): compare_records(res1, res2) elif isinstance(res1, models.GroupsResult):