diff --git a/examples/example_normalization_fields.py b/examples/example_normalization_fields.py new file mode 100644 index 000000000..ed5ac7b83 --- /dev/null +++ b/examples/example_normalization_fields.py @@ -0,0 +1,112 @@ +import time + +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, + MilvusClient +) + +fmt = "\n=== {:30} ===\n" +search_latency_fmt = "search latency = {:.4f}s" +num_entities, dim = 3000, 8 + + +print(fmt.format("start connecting to Milvus")) +# this is milvus standalone +connection = connections.connect( + alias="default", + host='localhost', # or '0.0.0.0' or 'localhost' + port='19530' +) + +client = MilvusClient(connections=connection) + +has = utility.has_collection("hello_milvus") +print(f"Does collection hello_milvus exist in Milvus: {has}") +if has: + utility.drop_collection("hello_milvus") + +fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings1", dtype=DataType.FLOAT_VECTOR, dim=dim), + FieldSchema(name="embeddings2", dtype=DataType.FLOAT_VECTOR, dim=dim) +] + +schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs") + +print(fmt.format("Create collection `hello_milvus`")) + +print(fmt.format("Message for handling an invalid format in the normalization_fields value")) # you can try with other value like: dict,... +try: + hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields='embeddings1') +except BaseException as e: + print(e) + + +print(fmt.format("Message for handling the invalid vector fields")) +try: + hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embddings']) +except BaseException as e: + print(e) + +print(fmt.format("Insert data, with conversion to standard form")) + +hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embeddings1']) + +print(fmt.format("Start inserting a row")) +rng = np.random.default_rng(seed=19530) + +row = { + "pk": "19530", + "random": 0.5, + "embeddings1": rng.random((1, dim), np.float32)[0], + "embeddings2": rng.random((1, dim), np.float32)[0] +} +_row = row.copy() +hello_milvus.insert(row) + +index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}} +hello_milvus.create_index("embeddings1", index_param) +hello_milvus.create_index("embeddings2", index_param) +hello_milvus.load() + +original_vector = _row['embeddings1'] +insert_vector = hello_milvus.query( + expr="pk == '19530'", + output_fields=["embeddings1"], +)[0]['embeddings1'] + +print(fmt.format("Mean and standard deviation before normalization.")) +print("Mean: ", np.mean(original_vector)) +print("Std: ", np.std(original_vector)) + +print(fmt.format("Mean and standard deviation after normalization.")) +print("Mean: ", np.mean(insert_vector)) +print("Std: ", np.std(insert_vector)) + + +print(fmt.format("Start inserting entities")) + +entities = [ + [str(i) for i in range(num_entities)], + rng.random(num_entities).tolist(), + rng.random((num_entities, dim), np.float32), + rng.random((num_entities, dim), np.float32), +] + +insert_result = hello_milvus.insert(entities) + +insert_vector = hello_milvus.query( + expr="pk == '1'", + output_fields=["embeddings1"], +)[0]['embeddings1'] + +print(fmt.format("Mean and standard deviation after normalization.")) +print("Mean: ", np.mean(insert_vector)) +print("Std: ", np.std(insert_vector)) + +utility.drop_collection("hello_milvus") \ No newline at end of file diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 46bc8173f..d0bf73fe7 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -10,6 +10,7 @@ from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK from .types import DataType +import numpy as np MILVUS = "milvus" ZILLIZ = "zilliz" @@ -375,3 +376,23 @@ def is_scipy_sparse(cls, data: Any): "csr_array", "spmatrix", ] + +def convert_to_standard_form(vector_data): + + if len(vector_data.shape) == 1: + # Calculate the mean and standard deviation of the vector + mean = np.mean(vector_data) + std_dev = np.std(vector_data) + + # Standardize the vector + standardized_vector = (vector_data - mean) / std_dev if std_dev != 0 else vector_data + return standardized_vector + + else: + # Calculate mean and standard deviation for each row + row_means = np.mean(vector_data, axis=1, keepdims=True) + row_stds = np.std(vector_data, axis=1, keepdims=True) + + # Standardize each row independently + standardized_matrix = np.where(row_stds != 0, (vector_data - row_means) / row_stds, vector_data) + return standardized_matrix diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 6abd40860..59c49e468 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -257,3 +257,7 @@ class ExceptionsMessage: DefaultValueInvalid = ( "Default value cannot be None for a field that is defined as nullable == false." ) + InvalidVectorFields = ( + "%s is not a valid vector field; expected %s" + ) + InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s." \ No newline at end of file diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index a31d374cf..bed9c913f 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -156,6 +156,25 @@ def __init__( self._schema_dict = self._schema.to_dict() self._schema_dict["consistency_level"] = self._consistency_level + self._normalization_fields = kwargs.get("normalization_fields", None) + if self._normalization_fields: + self._vector_fields = self._get_vector_fields() + if self._normalization_fields == 'all': + self._normalization_fields = self._vector_fields + elif isinstance(self._normalization_fields, list): + for field in self._normalization_fields: + if field not in self._vector_fields: + raise BaseException(ExceptionsMessage.InvalidVectorFields % (field, ', '.join(self._vector_fields))) + else: + raise BaseException(ExceptionsMessage.InvalidNormalizationParam % (self._normalization_fields)) + + def _get_vector_fields(self): + vector_fields = [] + for field in self._schema_dict.get("fields", []): + if field.get("params", {}).get("dim", None): + vector_fields.append(field.get("name")) + return vector_fields + def __repr__(self) -> str: _dict = { "name": self.name, @@ -504,6 +523,9 @@ def insert( conn = self._get_connection() if is_row_based(data): + if self._normalization_fields: + for norm_fld in self._normalization_fields: + data[norm_fld] = utils.convert_to_standard_form(data[norm_fld]) return conn.insert_rows( collection_name=self._name, entities=data, @@ -513,6 +535,9 @@ def insert( **kwargs, ) + for idx, fld in enumerate(self._schema_dict['fields']): + if fld['name'] in self._normalization_fields: + data[idx] = utils.convert_to_standard_form(data[idx]) check_insert_schema(self.schema, data) entities = Prepare.prepare_data(data, self.schema) return conn.batch_insert(