diff --git a/api/core/tool/dataset_multi_retriever_tool.py b/api/core/tool/dataset_multi_retriever_tool.py index 07174b1d71be80..5cf120b63b81f6 100644 --- a/api/core/tool/dataset_multi_retriever_tool.py +++ b/api/core/tool/dataset_multi_retriever_tool.py @@ -192,7 +192,7 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document 'search_method'] == 'hybrid_search': embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'top_k': self.top_k, 'score_threshold': self.score_threshold, @@ -210,7 +210,7 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'search_method': 'hybrid_search', 'embeddings': embeddings, diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index cc8b8e1386d8f9..822a6562be511c 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -106,7 +106,7 @@ def _run(self, query: str) -> str: if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'top_k': self.top_k, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ @@ -124,7 +124,7 @@ def _run(self, query: str) -> str: if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'search_method': retrieval_model['search_method'], 'embeddings': embeddings, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index d9725a66d8fa57..831a37d670aa45 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -61,7 +61,7 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'top_k': retrieval_model['top_k'], 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, @@ -77,7 +77,7 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'search_method': retrieval_model['search_method'], 'embeddings': embeddings, diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py index 3e6b93f862d43d..f12533f2b024bb 100644 --- a/api/services/retrieval_service.py +++ b/api/services/retrieval_service.py @@ -4,6 +4,7 @@ from langchain.embeddings.base import Embeddings from core.index.vector_index.vector_index import VectorIndex from core.model_providers.model_factory import ModelFactory +from extensions.ext_database import db from models.dataset import Dataset default_retrieval_model = { @@ -21,10 +22,13 @@ class RetrievalService: @classmethod - def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, + def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], all_documents: list, search_method: str, embeddings: Embeddings): with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() vector_index = VectorIndex( dataset=dataset, @@ -56,10 +60,13 @@ def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, all_documents.extend(documents) @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, + def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], all_documents: list, search_method: str, embeddings: Embeddings): with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() vector_index = VectorIndex( dataset=dataset,