diff --git a/api/.env.example b/api/.env.example index 391fc99cff2f39..c2e3f33fc40be0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -294,6 +294,7 @@ VIKINGDB_SOCKET_TIMEOUT=30 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_USERNAME=admin LINDORM_PASSWORD=admin +USING_UGC_INDEX=False # OceanBase Vector configuration OCEANBASE_VECTOR_HOST=127.0.0.1 diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py index 0f6c6528066747..95e1d1cfca4b80 100644 --- a/api/configs/middleware/vdb/lindorm_config.py +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -21,3 +21,14 @@ class LindormConfig(BaseSettings): description="Lindorm password", default=None, ) + DEFAULT_INDEX_TYPE: Optional[str] = Field( + description="Lindorm Vector Index Type, hnsw or flat is available in dify", + default="hnsw", + ) + DEFAULT_DISTANCE_TYPE: Optional[str] = Field( + description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2" + ) + USING_UGC_INDEX: Optional[bool] = Field( + description="Using UGC index will store the same type of Index in a single index but can retrieve separately.", + default=False, + ) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 60a1a89f1a0b1e..cccc2c4abc59c6 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -1,13 +1,10 @@ import copy import json import logging -from collections.abc import Iterable from typing import Any, Optional from opensearchpy import OpenSearch -from opensearchpy.helpers import bulk from pydantic import BaseModel, model_validator -from tenacity import retry, stop_after_attempt, wait_fixed from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -23,11 +20,15 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger("lindorm").setLevel(logging.WARN) +ROUTING_FIELD = "routing_field" +UGC_INDEX_PREFIX = "ugc_index" + class LindormVectorStoreConfig(BaseModel): hosts: str username: Optional[str] = None password: Optional[str] = None + using_ugc: Optional[bool] = False @model_validator(mode="before") @classmethod @@ -41,9 +42,7 @@ def validate_config(cls, values: dict) -> dict: return values def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts, - } + params = {"hosts": self.hosts} if self.username and self.password: params["http_auth"] = (self.username, self.password) return params @@ -51,9 +50,21 @@ def to_opensearch_params(self) -> dict[str, Any]: class LindormVectorStore(BaseVector): def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): - super().__init__(collection_name.lower()) + self._routing = None + self._routing_field = None + if config.using_ugc: + routing_value: str = kwargs.get("routing_value") + if routing_value is None: + raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") + self._routing = routing_value.lower() + self._routing_field = ROUTING_FIELD + ugc_index_name = collection_name + super().__init__(ugc_index_name.lower()) + else: + super().__init__(collection_name.lower()) self._client_config = config self._client = OpenSearch(**config.to_opensearch_params()) + self._using_ugc = config.using_ugc self.kwargs = kwargs def get_type(self) -> str: @@ -66,89 +77,37 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def refresh(self): self._client.indices.refresh(index=self._collection_name) - def __filter_existed_ids( - self, - texts: list[str], - metadatas: list[dict], - ids: list[str], - bulk_size: int = 1024, - ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]: - @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) - def __fetch_existing_ids(batch_ids: list[str]) -> set[str]: - try: - existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) - return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} - except Exception as e: - logger.exception(f"Error fetching batch {batch_ids}") - return set() - - @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) - def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]: - try: - existing_docs = self._client.mget( - body={ - "docs": [ - {"_index": self._collection_name, "_id": id, "routing": routing} - for id, routing in zip(batch_ids, route_ids) - ] - }, - _source=False, - ) - return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} - except Exception as e: - logger.exception(f"Error fetching batch ids: {batch_ids}") - return set() - - if ids is None: - return texts, metadatas, ids - - if len(texts) != len(ids): - raise RuntimeError(f"texts {len(texts)} != {ids}") - - filtered_texts = [] - filtered_metadatas = [] - filtered_ids = [] - - def batch(iterable, n): - length = len(iterable) - for idx in range(0, length, n): - yield iterable[idx : min(idx + n, length)] - - for ids_batch, texts_batch, metadatas_batch in zip( - batch(ids, bulk_size), - batch(texts, bulk_size), - batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size), - ): - existing_ids_set = __fetch_existing_ids(ids_batch) - for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch): - if doc_id not in existing_ids_set: - filtered_texts.append(text) - filtered_ids.append(doc_id) - if metadatas is not None: - filtered_metadatas.append(metadata) - - return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): actions = [] uuids = self._get_uuids(documents) for i in range(len(documents)): - action = { - "_op_type": "index", - "_index": self._collection_name.lower(), - "_id": uuids[i], - "_source": { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], # Make sure you pass an array here - Field.METADATA_KEY.value: documents[i].metadata, - }, + action_header = { + "index": { + "_index": self.collection_name.lower(), + "_id": uuids[i], + } + } + action_values = { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY.value: documents[i].metadata, } - actions.append(action) - bulk(self._client, actions) - self.refresh() + if self._using_ugc: + action_header["index"]["routing"] = self._routing + action_values[self._routing_field] = self._routing + actions.append(action_header) + actions.append(action_values) + response = self._client.bulk(actions) + if response["errors"]: + for item in response["items"]: + print(f"{item['index']['status']}: {item['index']['error']['type']}") + else: + self.refresh() def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} + query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} + if self._using_ugc: + query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) response = self._client.search(index=self._collection_name, body=query) if response["hits"]["hits"]: return [hit["_id"] for hit in response["hits"]["hits"]] @@ -156,50 +115,62 @@ def get_ids_by_metadata_field(self, key: str, value: str): return None def delete_by_metadata_field(self, key: str, value: str): - query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} - results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit["_id"] for hit in results["hits"]["hits"]] + ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_ids(ids) def delete_by_ids(self, ids: list[str]) -> None: + params = {} + if self._using_ugc: + params["routing"] = self._routing for id in ids: - if self._client.exists(index=self._collection_name, id=id): - self._client.delete(index=self._collection_name, id=id) + if self._client.exists(index=self._collection_name, id=id, params=params): + params = {} + if self._using_ugc: + params["routing"] = self._routing + self._client.delete(index=self._collection_name, id=id, params=params) + self.refresh() else: logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") def delete(self) -> None: - try: + if self._using_ugc: + routing_filter_query = { + "query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}} + } + self._client.delete_by_query(self._collection_name, body=routing_filter_query) + self.refresh() + else: if self._client.indices.exists(index=self._collection_name): self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) logger.info("Delete index success") else: logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") - except Exception as e: - logger.exception(f"Error occurred while deleting the index: {self._collection_name}") - raise e def text_exists(self, id: str) -> bool: try: - self._client.get(index=self._collection_name, id=id) + params = {} + if self._using_ugc: + params["routing"] = self._routing + self._client.get(index=self._collection_name, id=id, params=params) return True except: return False def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Make sure query_vector is a list if not isinstance(query_vector, list): raise ValueError("query_vector should be a list of floats") - # Check whether query_vector is a floating-point number list if not all(isinstance(x, float) for x in query_vector): raise ValueError("All elements in query_vector should be floats") top_k = kwargs.get("top_k", 10) query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) try: - response = self._client.search(index=self._collection_name, body=query) + params = {} + if self._using_ugc: + params["routing"] = self._routing + response = self._client.search(index=self._collection_name, body=query, params=params) except Exception as e: logger.exception(f"Error executing vector search, query: {query}") raise @@ -232,7 +203,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: minimum_should_match = kwargs.get("minimum_should_match", 0) top_k = kwargs.get("top_k", 10) filters = kwargs.get("filter") - routing = kwargs.get("routing") + routing = self._routing full_text_query = default_text_search_query( query_text=query, k=top_k, @@ -243,6 +214,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: minimum_should_match=minimum_should_match, filters=filters, routing=routing, + routing_field=self._routing_field, ) response = self._client.search(index=self._collection_name, body=full_text_query) docs = [] @@ -265,17 +237,18 @@ def create_collection(self, dimension: int, **kwargs): logger.info(f"Collection {self._collection_name} already exists.") return if self._client.indices.exists(index=self._collection_name): - logger.info("{self._collection_name.lower()} already exists.") + logger.info(f"{self._collection_name.lower()} already exists.") + redis_client.set(collection_exist_cache_key, 1, ex=3600) return if len(self.kwargs) == 0 and len(kwargs) != 0: self.kwargs = copy.deepcopy(kwargs) vector_field = kwargs.pop("vector_field", Field.VECTOR.value) - shards = kwargs.pop("shards", 2) + shards = kwargs.pop("shards", 4) engine = kwargs.pop("engine", "lvector") - method_name = kwargs.pop("method_name", "hnsw") + method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE) + space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE) data_type = kwargs.pop("data_type", "float") - space_type = kwargs.pop("space_type", "cosinesimil") hnsw_m = kwargs.pop("hnsw_m", 24) hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) @@ -288,10 +261,10 @@ def create_collection(self, dimension: int, **kwargs): mapping = default_text_mapping( dimension, method_name, + space_type=space_type, shards=shards, engine=engine, data_type=data_type, - space_type=space_type, vector_field=vector_field, hnsw_m=hnsw_m, hnsw_ef_construction=hnsw_ef_construction, @@ -301,6 +274,7 @@ def create_collection(self, dimension: int, **kwargs): centroids_hnsw_m=centroids_hnsw_m, centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, centroids_hnsw_ef_search=centroids_hnsw_ef_search, + using_ugc=self._using_ugc, **kwargs, ) self._client.indices.create(index=self._collection_name.lower(), body=mapping) @@ -309,15 +283,20 @@ def create_collection(self, dimension: int, **kwargs): def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: - routing_field = kwargs.get("routing_field") excludes_from_source = kwargs.get("excludes_from_source") analyzer = kwargs.get("analyzer", "ik_max_word") text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) engine = kwargs["engine"] shard = kwargs["shards"] - space_type = kwargs["space_type"] + space_type = kwargs.get("space_type") + if space_type is None: + if method_name == "hnsw": + space_type = "l2" + else: + space_type = "cosine" data_type = kwargs["data_type"] vector_field = kwargs.get("vector_field", Field.VECTOR.value) + using_ugc = kwargs.get("using_ugc", False) if method_name == "ivfpq": ivfpq_m = kwargs["ivfpq_m"] @@ -366,13 +345,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic if excludes_from_source: mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} - if method_name == "ivfpq" and routing_field is not None: + if using_ugc and method_name == "ivfpq": mapping["settings"]["index"]["knn_routing"] = True mapping["settings"]["index"]["knn.offline.construction"] = True - - if method_name == "flat" and routing_field is not None: + elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat": mapping["settings"]["index"]["knn_routing"] = True - return mapping @@ -386,14 +363,12 @@ def default_text_search_query( minimum_should_match: int = 0, filters: Optional[list[dict]] = None, routing: Optional[str] = None, + routing_field: Optional[str] = None, **kwargs, ) -> dict: if routing is not None: - routing_field = kwargs.get("routing_field", "routing_field") query_clause = { - "bool": { - "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}] - } + "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} } else: query_clause = {"match": {text_field: query_text}} @@ -483,16 +458,40 @@ def default_vector_search_query( class LindormVectorStoreFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] - collection_name = class_prefix - else: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name)) lindorm_config = LindormVectorStoreConfig( hosts=dify_config.LINDORM_URL, username=dify_config.LINDORM_USERNAME, password=dify_config.LINDORM_PASSWORD, + using_ugc=dify_config.USING_UGC_INDEX, ) - return LindormVectorStore(collection_name, lindorm_config) + using_ugc = dify_config.USING_UGC_INDEX + routing_value = None + if dataset.index_struct: + if using_ugc: + dimension = dataset.index_struct_dict["dimension"] + index_type = dataset.index_struct_dict["index_type"] + distance_type = dataset.index_struct_dict["distance_type"] + index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" + routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"] + else: + index_name = dataset.index_struct_dict["vector_store"]["class_prefix"] + else: + embedding_vector = embeddings.embed_query("hello word") + dimension = len(embedding_vector) + index_type = dify_config.DEFAULT_INDEX_TYPE + distance_type = dify_config.DEFAULT_DISTANCE_TYPE + class_prefix = Dataset.gen_collection_name_by_id(dataset.id) + index_struct_dict = { + "type": VectorType.LINDORM, + "vector_store": {"class_prefix": class_prefix}, + "index_type": index_type, + "dimension": dimension, + "distance_type": distance_type, + } + dataset.index_struct = json.dumps(index_struct_dict) + if using_ugc: + index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" + routing_value = class_prefix + else: + index_name = class_prefix + return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value) diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py index f8f43ba6ef8ab3..0a26d3ea1c9987 100644 --- a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py +++ b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py @@ -7,9 +7,10 @@ class Config: - SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070") + SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070") SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") - SEARCH_PWD = env.str("SEARCH_PWD", "PWD") + SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN") + USING_UGC = env.bool("USING_UGC", True) class TestLindormVectorStore(AbstractVectorTest): @@ -31,5 +32,27 @@ def get_ids_by_metadata_field(self): assert ids[0] == self.example_doc_id -def test_lindorm_vector(setup_mock_redis): +class TestLindormVectorStoreUGC(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = LindormVectorStore( + collection_name="ugc_index_test", + config=LindormVectorStoreConfig( + hosts=Config.SEARCH_ENDPOINT, + username=Config.SEARCH_USERNAME, + password=Config.SEARCH_PWD, + using_ugc=Config.USING_UGC, + ), + routing_value=self.collection_name, + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + assert ids[0] == self.example_doc_id + + +def test_lindorm_vector_ugc(setup_mock_redis): TestLindormVectorStore().run_all_tests() + TestLindormVectorStoreUGC().run_all_tests()