From ef4df39c5c8b70776fb2be2b6df069f61967ae11 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Fri, 23 Aug 2024 18:20:13 +0800 Subject: [PATCH 1/2] fix:graph retrieve bug --- dbgpt/app/scene/chat_knowledge/v1/chat.py | 1 + dbgpt/serve/rag/retriever/knowledge_space.py | 23 +++++++++++++++----- dbgpt/serve/rag/retriever/retriever_chain.py | 11 ++++++---- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py index dd552ddc9..445972f01 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/chat.py +++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py @@ -98,6 +98,7 @@ def __init__(self, chat_param: Dict): top_k=retriever_top_k, query_rewrite=query_rewrite, rerank=reranker, + llm_model=self.llm_model, ) self.prompt_template.template_is_strict = False diff --git a/dbgpt/serve/rag/retriever/knowledge_space.py b/dbgpt/serve/rag/retriever/knowledge_space.py index 92fd8cb41..d51a42647 100644 --- a/dbgpt/serve/rag/retriever/knowledge_space.py +++ b/dbgpt/serve/rag/retriever/knowledge_space.py @@ -4,6 +4,8 @@ from dbgpt.component import ComponentType from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.core import Chunk +from dbgpt.model import DefaultLLMClient +from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker from dbgpt.rag.retriever.base import BaseRetriever @@ -26,6 +28,7 @@ def __init__( top_k: Optional[int] = 4, query_rewrite: Optional[QueryRewrite] = None, rerank: Optional[Ranker] = None, + llm_model: Optional[str] = None, ): """ Args: @@ -40,6 +43,7 @@ def __init__( self._top_k = top_k self._query_rewrite = query_rewrite self._rerank = rerank + self._llm_model = llm_model embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) @@ -50,9 +54,19 @@ def __init__( space_dao = KnowledgeSpaceDao() space = space_dao.get_one({"id": space_id}) - config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn) + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + llm_client = DefaultLLMClient(worker_manager=worker_manager) + config = VectorStoreConfig( + name=space.name, + embedding_fn=embedding_fn, + llm_client=llm_client, + llm_model=self._llm_model, + ) + self._vector_store_connector = VectorStoreConnector( - vector_store_type=CFG.VECTOR_STORE_TYPE, + vector_store_type=space.vector_type, vector_store_config=config, ) self._executor = CFG.SYSTEM_APP.get_component( @@ -141,7 +155,6 @@ async def _aretrieve_with_score( Return: List[Chunk]: list of chunks with score. """ - candidates_with_score = await blocking_func_to_async( - self._executor, self._retrieve_with_score, query, score_threshold, filters + return await self._retriever_chain.aretrieve_with_scores( + query, score_threshold, filters ) - return candidates_with_score diff --git a/dbgpt/serve/rag/retriever/retriever_chain.py b/dbgpt/serve/rag/retriever/retriever_chain.py index 6ef435594..69e5b74ec 100644 --- a/dbgpt/serve/rag/retriever/retriever_chain.py +++ b/dbgpt/serve/rag/retriever/retriever_chain.py @@ -85,7 +85,10 @@ async def _aretrieve_with_score( Return: List[Chunk]: list of chunks with score """ - candidates_with_score = await blocking_func_to_async( - self._executor, self._retrieve_with_score, query, score_threshold, filters - ) - return candidates_with_score + for retriever in self._retrievers: + candidates_with_scores = await retriever.aretrieve_with_scores( + query=query, score_threshold=score_threshold, filters=filters + ) + if candidates_with_scores: + return candidates_with_scores + return [] From 696223f60a0faa27bf1fe0d5c95cf635f6463c7d Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Mon, 26 Aug 2024 11:58:36 +0800 Subject: [PATCH 2/2] fix:retrieve chain --- dbgpt/serve/rag/retriever/retriever_chain.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dbgpt/serve/rag/retriever/retriever_chain.py b/dbgpt/serve/rag/retriever/retriever_chain.py index 69e5b74ec..88c8b8a0b 100644 --- a/dbgpt/serve/rag/retriever/retriever_chain.py +++ b/dbgpt/serve/rag/retriever/retriever_chain.py @@ -38,17 +38,20 @@ def _retrieve( async def _aretrieve( self, query: str, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: - """Retrieve knowledge chunks. + """Async retrieve knowledge chunks. Args: query (str): query text filters: (Optional[MetadataFilters]) metadata filters. Return: List[Chunk]: list of chunks """ - candidates = await blocking_func_to_async( - self._executor, self._retrieve, query, filters - ) - return candidates + for retriever in self._retrievers: + candidates = await retriever.aretrieve( + query=query, filters=filters + ) + if candidates: + return candidates + return [] def _retrieve_with_score( self,