From 39cdfcfe2d64f76074b6716a3a1276eb1097c7bd Mon Sep 17 00:00:00 2001 From: jiangzhijie Date: Mon, 25 Nov 2024 11:00:21 +0800 Subject: [PATCH 1/3] fix the wrong LINDORM_PASSWORD variable name in docker-compose.yaml --- docker/docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e316ec0c5afb5c..285f576b0e83c9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -177,7 +177,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} - LINDORM_PASSWORD: ${LINDORM_USERNAME:-lindorm } + LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm } KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} From fc5fe4d397ba45ac022663e31441755fad2afc7f Mon Sep 17 00:00:00 2001 From: jiangzhijie Date: Wed, 11 Dec 2024 15:47:02 +0800 Subject: [PATCH 2/3] lindorm vdb add ugc feature --- api/.env.example | 1 + api/configs/middleware/vdb/lindorm_config.py | 12 + .../datasource/vdb/lindorm/lindorm_vector.py | 286 +++++++++--------- .../vdb/lindorm/test_lindorm.py | 29 +- 4 files changed, 188 insertions(+), 140 deletions(-) diff --git a/api/.env.example b/api/.env.example index 391fc99cff2f39..c2e3f33fc40be0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -294,6 +294,7 @@ VIKINGDB_SOCKET_TIMEOUT=30 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_USERNAME=admin LINDORM_PASSWORD=admin +USING_UGC_INDEX=False # OceanBase Vector configuration OCEANBASE_VECTOR_HOST=127.0.0.1 diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py index 0f6c6528066747..0de34ab8815849 100644 --- a/api/configs/middleware/vdb/lindorm_config.py +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -21,3 +21,15 @@ class LindormConfig(BaseSettings): description="Lindorm password", default=None, ) + DEFAULT_INDEX_TYPE: Optional[str] = Field( + description="Lindorm Vector Index Type, hnsw or flat is available in dify", + default="hnsw", + ) + DEFAULT_DISTANCE_TYPE: Optional[str] = Field( + description="Vector Distance Type, support l2, cosinesimil, innerproduct", + default="l2" + ) + USING_UGC_INDEX: Optional[bool] = Field( + description="Using UGC index will store the same type of Index in a single index but can retrieve separately.", + default=False, + ) \ No newline at end of file diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 60a1a89f1a0b1e..657e82760ec44f 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -1,13 +1,10 @@ import copy import json import logging -from collections.abc import Iterable from typing import Any, Optional from opensearchpy import OpenSearch -from opensearchpy.helpers import bulk from pydantic import BaseModel, model_validator -from tenacity import retry, stop_after_attempt, wait_fixed from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -23,11 +20,15 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger("lindorm").setLevel(logging.WARN) +ROUTING_FIELD = "routing_field" +UGC_INDEX_PREFIX = "ugc_index" + class LindormVectorStoreConfig(BaseModel): hosts: str username: Optional[str] = None password: Optional[str] = None + using_ugc: Optional[bool] = False @model_validator(mode="before") @classmethod @@ -42,7 +43,7 @@ def validate_config(cls, values: dict) -> dict: def to_opensearch_params(self) -> dict[str, Any]: params = { - "hosts": self.hosts, + "hosts": self.hosts } if self.username and self.password: params["http_auth"] = (self.username, self.password) @@ -51,9 +52,21 @@ def to_opensearch_params(self) -> dict[str, Any]: class LindormVectorStore(BaseVector): def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): - super().__init__(collection_name.lower()) + self._routing = None + self._routing_field = None + if config.using_ugc: + routing_value: str = kwargs.get("routing_value") + if routing_value is None: + raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") + self._routing = routing_value.lower() + self._routing_field = ROUTING_FIELD + ugc_index_name = collection_name + super().__init__(ugc_index_name.lower()) + else: + super().__init__(collection_name.lower()) self._client_config = config self._client = OpenSearch(**config.to_opensearch_params()) + self._using_ugc = config.using_ugc self.kwargs = kwargs def get_type(self) -> str: @@ -66,89 +79,41 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def refresh(self): self._client.indices.refresh(index=self._collection_name) - def __filter_existed_ids( - self, - texts: list[str], - metadatas: list[dict], - ids: list[str], - bulk_size: int = 1024, - ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]: - @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) - def __fetch_existing_ids(batch_ids: list[str]) -> set[str]: - try: - existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) - return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} - except Exception as e: - logger.exception(f"Error fetching batch {batch_ids}") - return set() - - @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) - def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]: - try: - existing_docs = self._client.mget( - body={ - "docs": [ - {"_index": self._collection_name, "_id": id, "routing": routing} - for id, routing in zip(batch_ids, route_ids) - ] - }, - _source=False, - ) - return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} - except Exception as e: - logger.exception(f"Error fetching batch ids: {batch_ids}") - return set() - - if ids is None: - return texts, metadatas, ids - - if len(texts) != len(ids): - raise RuntimeError(f"texts {len(texts)} != {ids}") - - filtered_texts = [] - filtered_metadatas = [] - filtered_ids = [] - - def batch(iterable, n): - length = len(iterable) - for idx in range(0, length, n): - yield iterable[idx : min(idx + n, length)] - - for ids_batch, texts_batch, metadatas_batch in zip( - batch(ids, bulk_size), - batch(texts, bulk_size), - batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size), - ): - existing_ids_set = __fetch_existing_ids(ids_batch) - for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch): - if doc_id not in existing_ids_set: - filtered_texts.append(text) - filtered_ids.append(doc_id) - if metadatas is not None: - filtered_metadatas.append(metadata) - - return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): actions = [] uuids = self._get_uuids(documents) for i in range(len(documents)): - action = { - "_op_type": "index", - "_index": self._collection_name.lower(), - "_id": uuids[i], - "_source": { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], # Make sure you pass an array here - Field.METADATA_KEY.value: documents[i].metadata, - }, + action_header = { + "index": { + "_index": self.collection_name.lower(), + "_id": uuids[i], + } } - actions.append(action) - bulk(self._client, actions) - self.refresh() + action_values = { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY.value: documents[i].metadata, + } + if self._using_ugc: + action_header["index"]["routing"] = self._routing + action_values[self._routing_field] = self._routing + actions.append(action_header) + actions.append(action_values) + response = self._client.bulk(actions) + if response["errors"]: + for item in response["items"]: + print(f"{item['index']['status']}: {item['index']['error']['type']}") + else: + self.refresh() def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} + query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} + if self._using_ugc: + query["query"]["bool"]["must"].append( + { + "term": {f"{self._routing_field}.keyword": self._routing} + } + ) response = self._client.search(index=self._collection_name, body=query) if response["hits"]["hits"]: return [hit["_id"] for hit in response["hits"]["hits"]] @@ -156,50 +121,66 @@ def get_ids_by_metadata_field(self, key: str, value: str): return None def delete_by_metadata_field(self, key: str, value: str): - query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} - results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit["_id"] for hit in results["hits"]["hits"]] + ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_ids(ids) def delete_by_ids(self, ids: list[str]) -> None: + params = {} + if self._using_ugc: + params["routing"] = self._routing for id in ids: - if self._client.exists(index=self._collection_name, id=id): - self._client.delete(index=self._collection_name, id=id) + if self._client.exists(index=self._collection_name, id=id, params=params): + params = {} + if self._using_ugc: + params["routing"] = self._routing + self._client.delete(index=self._collection_name, id=id, params=params) + self.refresh() else: logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") def delete(self) -> None: - try: + if self._using_ugc: + routing_filter_query = { + "query": {"bool": { + "must": [{ + "term": {f"{self._routing_field}.keyword": self._routing + }}] + }} + } + self._client.delete_by_query(self._collection_name, body=routing_filter_query) + self.refresh() + else: if self._client.indices.exists(index=self._collection_name): self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) logger.info("Delete index success") else: logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") - except Exception as e: - logger.exception(f"Error occurred while deleting the index: {self._collection_name}") - raise e def text_exists(self, id: str) -> bool: try: - self._client.get(index=self._collection_name, id=id) + params = {} + if self._using_ugc: + params["routing"] = self._routing + self._client.get(index=self._collection_name, id=id, params=params) return True except: return False def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Make sure query_vector is a list if not isinstance(query_vector, list): raise ValueError("query_vector should be a list of floats") - # Check whether query_vector is a floating-point number list if not all(isinstance(x, float) for x in query_vector): raise ValueError("All elements in query_vector should be floats") top_k = kwargs.get("top_k", 10) query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) try: - response = self._client.search(index=self._collection_name, body=query) + params = {} + if self._using_ugc: + params["routing"] = self._routing + response = self._client.search(index=self._collection_name, body=query, params=params) except Exception as e: logger.exception(f"Error executing vector search, query: {query}") raise @@ -232,7 +213,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: minimum_should_match = kwargs.get("minimum_should_match", 0) top_k = kwargs.get("top_k", 10) filters = kwargs.get("filter") - routing = kwargs.get("routing") + routing = self._routing full_text_query = default_text_search_query( query_text=query, k=top_k, @@ -243,6 +224,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: minimum_should_match=minimum_should_match, filters=filters, routing=routing, + routing_field=self._routing_field, ) response = self._client.search(index=self._collection_name, body=full_text_query) docs = [] @@ -265,17 +247,18 @@ def create_collection(self, dimension: int, **kwargs): logger.info(f"Collection {self._collection_name} already exists.") return if self._client.indices.exists(index=self._collection_name): - logger.info("{self._collection_name.lower()} already exists.") + logger.info(f"{self._collection_name.lower()} already exists.") + redis_client.set(collection_exist_cache_key, 1, ex=3600) return if len(self.kwargs) == 0 and len(kwargs) != 0: self.kwargs = copy.deepcopy(kwargs) vector_field = kwargs.pop("vector_field", Field.VECTOR.value) - shards = kwargs.pop("shards", 2) + shards = kwargs.pop("shards", 4) engine = kwargs.pop("engine", "lvector") - method_name = kwargs.pop("method_name", "hnsw") + method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE) + space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE) data_type = kwargs.pop("data_type", "float") - space_type = kwargs.pop("space_type", "cosinesimil") hnsw_m = kwargs.pop("hnsw_m", 24) hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) @@ -288,10 +271,10 @@ def create_collection(self, dimension: int, **kwargs): mapping = default_text_mapping( dimension, method_name, + space_type=space_type, shards=shards, engine=engine, data_type=data_type, - space_type=space_type, vector_field=vector_field, hnsw_m=hnsw_m, hnsw_ef_construction=hnsw_ef_construction, @@ -301,6 +284,7 @@ def create_collection(self, dimension: int, **kwargs): centroids_hnsw_m=centroids_hnsw_m, centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, centroids_hnsw_ef_search=centroids_hnsw_ef_search, + using_ugc=self._using_ugc, **kwargs, ) self._client.indices.create(index=self._collection_name.lower(), body=mapping) @@ -309,15 +293,20 @@ def create_collection(self, dimension: int, **kwargs): def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: - routing_field = kwargs.get("routing_field") excludes_from_source = kwargs.get("excludes_from_source") analyzer = kwargs.get("analyzer", "ik_max_word") text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) engine = kwargs["engine"] shard = kwargs["shards"] - space_type = kwargs["space_type"] + space_type = kwargs.get("space_type") + if space_type is None: + if method_name == "hnsw": + space_type = "l2" + else: + space_type = "cosine" data_type = kwargs["data_type"] vector_field = kwargs.get("vector_field", Field.VECTOR.value) + using_ugc = kwargs.get("using_ugc", False) if method_name == "ivfpq": ivfpq_m = kwargs["ivfpq_m"] @@ -366,33 +355,32 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic if excludes_from_source: mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} - if method_name == "ivfpq" and routing_field is not None: + if using_ugc and method_name == "ivfpq": mapping["settings"]["index"]["knn_routing"] = True mapping["settings"]["index"]["knn.offline.construction"] = True - - if method_name == "flat" and routing_field is not None: + elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat": mapping["settings"]["index"]["knn_routing"] = True - return mapping def default_text_search_query( - query_text: str, - k: int = 4, - text_field: str = Field.CONTENT_KEY.value, - must: Optional[list[dict]] = None, - must_not: Optional[list[dict]] = None, - should: Optional[list[dict]] = None, - minimum_should_match: int = 0, - filters: Optional[list[dict]] = None, - routing: Optional[str] = None, - **kwargs, + query_text: str, + k: int = 4, + text_field: str = Field.CONTENT_KEY.value, + must: Optional[list[dict]] = None, + must_not: Optional[list[dict]] = None, + should: Optional[list[dict]] = None, + minimum_should_match: int = 0, + filters: Optional[list[dict]] = None, + routing: Optional[str] = None, + routing_field: Optional[str] = None, + **kwargs, ) -> dict: if routing is not None: - routing_field = kwargs.get("routing_field", "routing_field") query_clause = { "bool": { - "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}] + "must": [{"match": {text_field: query_text}}, + {"term": {f"{routing_field}.keyword": routing}}] } } else: @@ -435,17 +423,17 @@ def default_text_search_query( def default_vector_search_query( - query_vector: list[float], - k: int = 4, - min_score: str = "0.0", - ef_search: Optional[str] = None, # only for hnsw - nprobe: Optional[str] = None, # "2000" - reorder_factor: Optional[str] = None, # "20" - client_refactor: Optional[str] = None, # "true" - vector_field: str = Field.VECTOR.value, - filters: Optional[list[dict]] = None, - filter_type: Optional[str] = None, - **kwargs, + query_vector: list[float], + k: int = 4, + min_score: str = "0.0", + ef_search: Optional[str] = None, # only for hnsw + nprobe: Optional[str] = None, # "2000" + reorder_factor: Optional[str] = None, # "20" + client_refactor: Optional[str] = None, # "true" + vector_field: str = Field.VECTOR.value, + filters: Optional[list[dict]] = None, + filter_type: Optional[str] = None, + **kwargs, ) -> dict: if filters is not None: filter_type = "post_filter" if filter_type is None else filter_type @@ -483,16 +471,40 @@ def default_vector_search_query( class LindormVectorStoreFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] - collection_name = class_prefix - else: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name)) lindorm_config = LindormVectorStoreConfig( hosts=dify_config.LINDORM_URL, username=dify_config.LINDORM_USERNAME, password=dify_config.LINDORM_PASSWORD, + using_ugc=dify_config.USING_UGC_INDEX, ) - return LindormVectorStore(collection_name, lindorm_config) + using_ugc = dify_config.USING_UGC_INDEX + routing_value = None + if dataset.index_struct: + if using_ugc: + dimension = dataset.index_struct_dict["dimension"] + index_type = dataset.index_struct_dict["index_type"] + distance_type = dataset.index_struct_dict["distance_type"] + index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" + routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"] + else: + index_name = dataset.index_struct_dict["vector_store"]["class_prefix"] + else: + embedding_vector = embeddings.embed_query("hello word") + dimension = len(embedding_vector) + index_type = dify_config.DEFAULT_INDEX_TYPE + distance_type = dify_config.DEFAULT_DISTANCE_TYPE + class_prefix = Dataset.gen_collection_name_by_id(dataset.id) + index_struct_dict = { + "type": VectorType.LINDORM, + "vector_store": {"class_prefix": class_prefix}, + "index_type": index_type, + "dimension": dimension, + "distance_type": distance_type, + } + dataset.index_struct = json.dumps(index_struct_dict) + if using_ugc: + index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" + routing_value = class_prefix + else: + index_name = class_prefix + return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value) \ No newline at end of file diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py index f8f43ba6ef8ab3..16aec2cc9941cf 100644 --- a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py +++ b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py @@ -7,9 +7,10 @@ class Config: - SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070") + SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070") SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") - SEARCH_PWD = env.str("SEARCH_PWD", "PWD") + SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN") + USING_UGC = env.bool("USING_UGC", True) class TestLindormVectorStore(AbstractVectorTest): @@ -31,5 +32,27 @@ def get_ids_by_metadata_field(self): assert ids[0] == self.example_doc_id -def test_lindorm_vector(setup_mock_redis): +class TestLindormVectorStoreUGC(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = LindormVectorStore( + collection_name="ugc_index_test", + config=LindormVectorStoreConfig( + hosts=Config.SEARCH_ENDPOINT, + username=Config.SEARCH_USERNAME, + password=Config.SEARCH_PWD, + using_ugc=Config.USING_UGC, + ), + routing_value=self.collection_name, + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + assert ids[0] == self.example_doc_id + + +def test_lindorm_vector_ugc(setup_mock_redis): TestLindormVectorStore().run_all_tests() + TestLindormVectorStoreUGC().run_all_tests() \ No newline at end of file From c5543e43410ba2cde2bfeb10c56329c3b3ca08be Mon Sep 17 00:00:00 2001 From: jiangzhijie Date: Thu, 12 Dec 2024 09:34:03 +0800 Subject: [PATCH 3/3] reformat and lint checks --- api/configs/middleware/vdb/lindorm_config.py | 5 +- .../datasource/vdb/lindorm/lindorm_vector.py | 67 ++++++++----------- .../vdb/lindorm/test_lindorm.py | 2 +- 3 files changed, 30 insertions(+), 44 deletions(-) diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py index 0de34ab8815849..95e1d1cfca4b80 100644 --- a/api/configs/middleware/vdb/lindorm_config.py +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -26,10 +26,9 @@ class LindormConfig(BaseSettings): default="hnsw", ) DEFAULT_DISTANCE_TYPE: Optional[str] = Field( - description="Vector Distance Type, support l2, cosinesimil, innerproduct", - default="l2" + description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2" ) USING_UGC_INDEX: Optional[bool] = Field( description="Using UGC index will store the same type of Index in a single index but can retrieve separately.", default=False, - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 657e82760ec44f..cccc2c4abc59c6 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -42,9 +42,7 @@ def validate_config(cls, values: dict) -> dict: return values def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts - } + params = {"hosts": self.hosts} if self.username and self.password: params["http_auth"] = (self.username, self.password) return params @@ -109,11 +107,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} if self._using_ugc: - query["query"]["bool"]["must"].append( - { - "term": {f"{self._routing_field}.keyword": self._routing} - } - ) + query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) response = self._client.search(index=self._collection_name, body=query) if response["hits"]["hits"]: return [hit["_id"] for hit in response["hits"]["hits"]] @@ -142,11 +136,7 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete(self) -> None: if self._using_ugc: routing_filter_query = { - "query": {"bool": { - "must": [{ - "term": {f"{self._routing_field}.keyword": self._routing - }}] - }} + "query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}} } self._client.delete_by_query(self._collection_name, body=routing_filter_query) self.refresh() @@ -364,24 +354,21 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic def default_text_search_query( - query_text: str, - k: int = 4, - text_field: str = Field.CONTENT_KEY.value, - must: Optional[list[dict]] = None, - must_not: Optional[list[dict]] = None, - should: Optional[list[dict]] = None, - minimum_should_match: int = 0, - filters: Optional[list[dict]] = None, - routing: Optional[str] = None, - routing_field: Optional[str] = None, - **kwargs, + query_text: str, + k: int = 4, + text_field: str = Field.CONTENT_KEY.value, + must: Optional[list[dict]] = None, + must_not: Optional[list[dict]] = None, + should: Optional[list[dict]] = None, + minimum_should_match: int = 0, + filters: Optional[list[dict]] = None, + routing: Optional[str] = None, + routing_field: Optional[str] = None, + **kwargs, ) -> dict: if routing is not None: query_clause = { - "bool": { - "must": [{"match": {text_field: query_text}}, - {"term": {f"{routing_field}.keyword": routing}}] - } + "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} } else: query_clause = {"match": {text_field: query_text}} @@ -423,17 +410,17 @@ def default_text_search_query( def default_vector_search_query( - query_vector: list[float], - k: int = 4, - min_score: str = "0.0", - ef_search: Optional[str] = None, # only for hnsw - nprobe: Optional[str] = None, # "2000" - reorder_factor: Optional[str] = None, # "20" - client_refactor: Optional[str] = None, # "true" - vector_field: str = Field.VECTOR.value, - filters: Optional[list[dict]] = None, - filter_type: Optional[str] = None, - **kwargs, + query_vector: list[float], + k: int = 4, + min_score: str = "0.0", + ef_search: Optional[str] = None, # only for hnsw + nprobe: Optional[str] = None, # "2000" + reorder_factor: Optional[str] = None, # "20" + client_refactor: Optional[str] = None, # "true" + vector_field: str = Field.VECTOR.value, + filters: Optional[list[dict]] = None, + filter_type: Optional[str] = None, + **kwargs, ) -> dict: if filters is not None: filter_type = "post_filter" if filter_type is None else filter_type @@ -507,4 +494,4 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings routing_value = class_prefix else: index_name = class_prefix - return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value) \ No newline at end of file + return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value) diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py index 16aec2cc9941cf..0a26d3ea1c9987 100644 --- a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py +++ b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py @@ -55,4 +55,4 @@ def get_ids_by_metadata_field(self): def test_lindorm_vector_ugc(setup_mock_redis): TestLindormVectorStore().run_all_tests() - TestLindormVectorStoreUGC().run_all_tests() \ No newline at end of file + TestLindormVectorStoreUGC().run_all_tests()