From 113aeb1147ca95490b8438dca740efae28eb7e35 Mon Sep 17 00:00:00 2001 From: NamCaoHai Date: Mon, 4 Nov 2024 16:19:13 +0000 Subject: [PATCH 1/6] normalize vector data to a standard form during insertion (#27469) Signed-off-by: NamCaoHai --- examples/example_normalization_fields.py | 112 +++++++++++++++++++++++ pymilvus/client/utils.py | 23 +++++ pymilvus/exceptions.py | 2 + pymilvus/orm/collection.py | 31 +++++++ 4 files changed, 168 insertions(+) create mode 100644 examples/example_normalization_fields.py diff --git a/examples/example_normalization_fields.py b/examples/example_normalization_fields.py new file mode 100644 index 000000000..ed5ac7b83 --- /dev/null +++ b/examples/example_normalization_fields.py @@ -0,0 +1,112 @@ +import time + +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, + MilvusClient +) + +fmt = "\n=== {:30} ===\n" +search_latency_fmt = "search latency = {:.4f}s" +num_entities, dim = 3000, 8 + + +print(fmt.format("start connecting to Milvus")) +# this is milvus standalone +connection = connections.connect( + alias="default", + host='localhost', # or '0.0.0.0' or 'localhost' + port='19530' +) + +client = MilvusClient(connections=connection) + +has = utility.has_collection("hello_milvus") +print(f"Does collection hello_milvus exist in Milvus: {has}") +if has: + utility.drop_collection("hello_milvus") + +fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings1", dtype=DataType.FLOAT_VECTOR, dim=dim), + FieldSchema(name="embeddings2", dtype=DataType.FLOAT_VECTOR, dim=dim) +] + +schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs") + +print(fmt.format("Create collection `hello_milvus`")) + +print(fmt.format("Message for handling an invalid format in the normalization_fields value")) # you can try with other value like: dict,... +try: + hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields='embeddings1') +except BaseException as e: + print(e) + + +print(fmt.format("Message for handling the invalid vector fields")) +try: + hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embddings']) +except BaseException as e: + print(e) + +print(fmt.format("Insert data, with conversion to standard form")) + +hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embeddings1']) + +print(fmt.format("Start inserting a row")) +rng = np.random.default_rng(seed=19530) + +row = { + "pk": "19530", + "random": 0.5, + "embeddings1": rng.random((1, dim), np.float32)[0], + "embeddings2": rng.random((1, dim), np.float32)[0] +} +_row = row.copy() +hello_milvus.insert(row) + +index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}} +hello_milvus.create_index("embeddings1", index_param) +hello_milvus.create_index("embeddings2", index_param) +hello_milvus.load() + +original_vector = _row['embeddings1'] +insert_vector = hello_milvus.query( + expr="pk == '19530'", + output_fields=["embeddings1"], +)[0]['embeddings1'] + +print(fmt.format("Mean and standard deviation before normalization.")) +print("Mean: ", np.mean(original_vector)) +print("Std: ", np.std(original_vector)) + +print(fmt.format("Mean and standard deviation after normalization.")) +print("Mean: ", np.mean(insert_vector)) +print("Std: ", np.std(insert_vector)) + + +print(fmt.format("Start inserting entities")) + +entities = [ + [str(i) for i in range(num_entities)], + rng.random(num_entities).tolist(), + rng.random((num_entities, dim), np.float32), + rng.random((num_entities, dim), np.float32), +] + +insert_result = hello_milvus.insert(entities) + +insert_vector = hello_milvus.query( + expr="pk == '1'", + output_fields=["embeddings1"], +)[0]['embeddings1'] + +print(fmt.format("Mean and standard deviation after normalization.")) +print("Mean: ", np.mean(insert_vector)) +print("Std: ", np.std(insert_vector)) + +utility.drop_collection("hello_milvus") \ No newline at end of file diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 46bc8173f..7c9080ad5 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -10,6 +10,7 @@ from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK from .types import DataType +import numpy as np MILVUS = "milvus" ZILLIZ = "zilliz" @@ -375,3 +376,25 @@ def is_scipy_sparse(cls, data: Any): "csr_array", "spmatrix", ] + + +def convert_to_standard_form(vector_data): + if len(vector_data.shape) == 1: + # Calculate the mean and standard deviation of the vector + mean = np.mean(vector_data) + std_dev = np.std(vector_data) + + # Standardize the vector + standardized_vector = (vector_data - mean) / std_dev if std_dev != 0 else vector_data + return standardized_vector + + else: + # Calculate mean and standard deviation for each row + row_means = np.mean(vector_data, axis=1, keepdims=True) + row_stds = np.std(vector_data, axis=1, keepdims=True) + + # Standardize each row independently + standardized_matrix = np.where( + row_stds != 0, (vector_data - row_means) / row_stds, vector_data + ) + return standardized_matrix diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 6abd40860..467af5e3e 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -257,3 +257,5 @@ class ExceptionsMessage: DefaultValueInvalid = ( "Default value cannot be None for a field that is defined as nullable == false." ) + InvalidVectorFields = "%s is not a valid vector field; expected %s" + InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s." diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index a31d374cf..d9f062b23 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -34,6 +34,7 @@ IndexNotExistException, PartitionAlreadyExistException, SchemaNotReadyException, + MilvusException, ) from pymilvus.grpc_gen import schema_pb2 from pymilvus.settings import Config @@ -156,6 +157,30 @@ def __init__( self._schema_dict = self._schema.to_dict() self._schema_dict["consistency_level"] = self._consistency_level + self._normalization_fields = kwargs.get("normalization_fields", None) + if self._normalization_fields: + self._vector_fields = self._get_vector_fields() + if self._normalization_fields == "all": + self._normalization_fields = self._vector_fields + elif isinstance(self._normalization_fields, list): + for field in self._normalization_fields: + if field not in self._vector_fields: + raise MilvusException( + ExceptionsMessage.InvalidVectorFields + % (field, ", ".join(self._vector_fields)) + ) + else: + raise MilvusException( + ExceptionsMessage.InvalidNormalizationParam % (self._normalization_fields) + ) + + def _get_vector_fields(self): + vector_fields = [] + for field in self._schema_dict.get("fields", []): + if field.get("params", {}).get("dim", None): + vector_fields.append(field.get("name")) + return vector_fields + def __repr__(self) -> str: _dict = { "name": self.name, @@ -504,6 +529,9 @@ def insert( conn = self._get_connection() if is_row_based(data): + if self._normalization_fields: + for norm_fld in self._normalization_fields: + data[norm_fld] = utils.convert_to_standard_form(data[norm_fld]) return conn.insert_rows( collection_name=self._name, entities=data, @@ -513,6 +541,9 @@ def insert( **kwargs, ) + for idx, fld in enumerate(self._schema_dict["fields"]): + if fld["name"] in self._normalization_fields: + data[idx] = utils.convert_to_standard_form(data[idx]) check_insert_schema(self.schema, data) entities = Prepare.prepare_data(data, self.schema) return conn.batch_insert( From 301586235f5b7b5320c04694854f021c1d779e5e Mon Sep 17 00:00:00 2001 From: NamCaoHai Date: Mon, 4 Nov 2024 16:19:13 +0000 Subject: [PATCH 2/6] normalize vector data to a standard form during insertion (#27469) Signed-off-by: NamCaoHai --- examples/example_normalization_fields.py | 16 +++++++++++++++ pymilvus/client/utils.py | 25 ++++++++++-------------- pymilvus/exceptions.py | 2 +- pymilvus/orm/collection.py | 15 +++++++------- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/examples/example_normalization_fields.py b/examples/example_normalization_fields.py index ed5ac7b83..e8f66f958 100644 --- a/examples/example_normalization_fields.py +++ b/examples/example_normalization_fields.py @@ -52,6 +52,22 @@ hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embddings']) except BaseException as e: print(e) + +print(fmt.format("Insert data, without conversion to standard form")) + +hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong") + +print(fmt.format("Start inserting a row")) +rng = np.random.default_rng(seed=19530) + +row = { + "pk": "19530", + "random": 0.5, + "embeddings1": rng.random((1, dim), np.float32)[0], + "embeddings2": rng.random((1, dim), np.float32)[0] +} +hello_milvus.insert(row) +utility.drop_collection("hello_milvus") print(fmt.format("Insert data, with conversion to standard form")) diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 7c9080ad5..d400d28a4 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -3,6 +3,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +import numpy as np import ujson from pymilvus.exceptions import MilvusException, ParamError @@ -10,7 +11,6 @@ from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK from .types import DataType -import numpy as np MILVUS = "milvus" ZILLIZ = "zilliz" @@ -378,23 +378,18 @@ def is_scipy_sparse(cls, data: Any): ] -def convert_to_standard_form(vector_data): +def convert_to_standard_form(vector_data: Any) -> Any: if len(vector_data.shape) == 1: # Calculate the mean and standard deviation of the vector mean = np.mean(vector_data) std_dev = np.std(vector_data) # Standardize the vector - standardized_vector = (vector_data - mean) / std_dev if std_dev != 0 else vector_data - return standardized_vector - - else: - # Calculate mean and standard deviation for each row - row_means = np.mean(vector_data, axis=1, keepdims=True) - row_stds = np.std(vector_data, axis=1, keepdims=True) - - # Standardize each row independently - standardized_matrix = np.where( - row_stds != 0, (vector_data - row_means) / row_stds, vector_data - ) - return standardized_matrix + return (vector_data - mean) / std_dev if std_dev != 0 else vector_data + + # Calculate mean and standard deviation for each row + row_means = np.mean(vector_data, axis=1, keepdims=True) + row_stds = np.std(vector_data, axis=1, keepdims=True) + + # Standardize each row independently + return np.where(row_stds != 0, (vector_data - row_means) / row_stds, vector_data) \ No newline at end of file diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 467af5e3e..ac8ef4717 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -258,4 +258,4 @@ class ExceptionsMessage: "Default value cannot be None for a field that is defined as nullable == false." ) InvalidVectorFields = "%s is not a valid vector field; expected %s" - InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s." + InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s." \ No newline at end of file diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index d9f062b23..7252fcd0b 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -32,9 +32,9 @@ DataTypeNotSupportException, ExceptionsMessage, IndexNotExistException, + MilvusException, PartitionAlreadyExistException, SchemaNotReadyException, - MilvusException, ) from pymilvus.grpc_gen import schema_pb2 from pymilvus.settings import Config @@ -114,6 +114,7 @@ def __init__( self._using = using self._kwargs = kwargs self._num_shards = None + self._normalization_fields = None conn = self._get_connection() has = conn.has_collection(self._name, **kwargs) @@ -157,7 +158,7 @@ def __init__( self._schema_dict = self._schema.to_dict() self._schema_dict["consistency_level"] = self._consistency_level - self._normalization_fields = kwargs.get("normalization_fields", None) + self._normalization_fields = kwargs.get("normalization_fields") if self._normalization_fields: self._vector_fields = self._get_vector_fields() if self._normalization_fields == "all": @@ -540,10 +541,10 @@ def insert( schema=self._schema_dict, **kwargs, ) - - for idx, fld in enumerate(self._schema_dict["fields"]): - if fld["name"] in self._normalization_fields: - data[idx] = utils.convert_to_standard_form(data[idx]) + if self._normalization_fields: + for idx, fld in enumerate(self._schema_dict["fields"]): + if fld["name"] in self._normalization_fields: + data[idx] = utils.convert_to_standard_form(data[idx]) check_insert_schema(self.schema, data) entities = Prepare.prepare_data(data, self.schema) return conn.batch_insert( @@ -1622,4 +1623,4 @@ def get_replicas(self, timeout: Optional[float] = None, **kwargs) -> Replica: def describe(self, timeout: Optional[float] = None): conn = self._get_connection() - return conn.describe_collection(self.name, timeout=timeout) + return conn.describe_collection(self.name, timeout=timeout) \ No newline at end of file From dee71bc26f10e1ca1b1a839b9797f7050678a8c4 Mon Sep 17 00:00:00 2001 From: NamCaoHai Date: Thu, 7 Nov 2024 09:12:17 +0000 Subject: [PATCH 3/6] refactor code Signed-off-by: NamCaoHai --- pymilvus/exceptions.py | 2 +- pymilvus/orm/collection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index ac8ef4717..467af5e3e 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -258,4 +258,4 @@ class ExceptionsMessage: "Default value cannot be None for a field that is defined as nullable == false." ) InvalidVectorFields = "%s is not a valid vector field; expected %s" - InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s." \ No newline at end of file + InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s." diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 7252fcd0b..50ec3bb3e 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -1623,4 +1623,4 @@ def get_replicas(self, timeout: Optional[float] = None, **kwargs) -> Replica: def describe(self, timeout: Optional[float] = None): conn = self._get_connection() - return conn.describe_collection(self.name, timeout=timeout) \ No newline at end of file + return conn.describe_collection(self.name, timeout=timeout) From 951730386ef6fba4d59d8f92e112dc01c52f81be Mon Sep 17 00:00:00 2001 From: NamCaoHai Date: Thu, 7 Nov 2024 09:35:36 +0000 Subject: [PATCH 4/6] refactor code Signed-off-by: NamCaoHai --- pymilvus/client/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index d400d28a4..aed8ee735 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -392,4 +392,4 @@ def convert_to_standard_form(vector_data: Any) -> Any: row_stds = np.std(vector_data, axis=1, keepdims=True) # Standardize each row independently - return np.where(row_stds != 0, (vector_data - row_means) / row_stds, vector_data) \ No newline at end of file + return np.where(row_stds != 0, (vector_data - row_means) / row_stds, vector_data) From 31ffe31f7ad9e6d1a1f9fc4de88cc875c2e8fc43 Mon Sep 17 00:00:00 2001 From: NamCaoHai Date: Fri, 8 Nov 2024 07:21:42 +0000 Subject: [PATCH 5/6] Refactor: simplify code Signed-off-by: NamCaoHai --- pymilvus/orm/collection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 50ec3bb3e..b4604cbb3 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -114,7 +114,6 @@ def __init__( self._using = using self._kwargs = kwargs self._num_shards = None - self._normalization_fields = None conn = self._get_connection() has = conn.has_collection(self._name, **kwargs) @@ -158,7 +157,7 @@ def __init__( self._schema_dict = self._schema.to_dict() self._schema_dict["consistency_level"] = self._consistency_level - self._normalization_fields = kwargs.get("normalization_fields") + self._normalization_fields = self._kwargs.get("normalization_fields", None) if self._normalization_fields: self._vector_fields = self._get_vector_fields() if self._normalization_fields == "all": From ad80aa41b633bc72bb8f8afa229634452b2cd8b1 Mon Sep 17 00:00:00 2001 From: NamCaoHai Date: Sat, 16 Nov 2024 03:19:38 +0000 Subject: [PATCH 6/6] Add a description for the normalization_fields parameter. Signed-off-by: NamCaoHai --- pymilvus/orm/collection.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index b4604cbb3..27f701c0b 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -94,6 +94,9 @@ def __init__( If timeout is not set, the client keeps waiting until the server responds or an error occurs. + * *normalization_fields* (``str/list``, optional) + Fields are selected to apply standard normalization. + Raises: SchemaNotReadyException: if the schema is wrong.