Skip to content

Commit

Permalink
Fix/multi thread parameter (#1604)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Nov 22, 2023
1 parent f704094 commit a5b80c9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
4 changes: 2 additions & 2 deletions api/core/tool/dataset_multi_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document
'search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'top_k': self.top_k,
'score_threshold': self.score_threshold,
Expand All @@ -210,7 +210,7 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'search_method': 'hybrid_search',
'embeddings': embeddings,
Expand Down
4 changes: 2 additions & 2 deletions api/core/tool/dataset_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _run(self, query: str) -> str:
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
Expand All @@ -124,7 +124,7 @@ def _run(self, query: str) -> str:
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
Expand Down
4 changes: 2 additions & 2 deletions api/services/hit_testing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
Expand All @@ -77,7 +77,7 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
Expand Down
11 changes: 9 additions & 2 deletions api/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset

default_retrieval_model = {
Expand All @@ -21,10 +22,13 @@
class RetrievalService:

@classmethod
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()

vector_index = VectorIndex(
dataset=dataset,
Expand Down Expand Up @@ -56,10 +60,13 @@ def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
all_documents.extend(documents)

@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()

vector_index = VectorIndex(
dataset=dataset,
Expand Down

0 comments on commit a5b80c9

Please sign in to comment.