From 5fa0963bc07956348a27df4a8b97b3118c357463 Mon Sep 17 00:00:00 2001 From: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:06:23 +0100 Subject: [PATCH] feat: expose sem. params, default to hybrid search (#159) Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --- deepsearch/cps/queries/__init__.py | 126 +++++++++++++++++++++++++---- 1 file changed, 110 insertions(+), 16 deletions(-) diff --git a/deepsearch/cps/queries/__init__.py b/deepsearch/cps/queries/__init__.py index 3dbc4b70..7f78870c 100644 --- a/deepsearch/cps/queries/__init__.py +++ b/deepsearch/cps/queries/__init__.py @@ -1,5 +1,8 @@ from typing import Any, Dict, List, Optional, Union +from pydantic.v1 import Field, validate_arguments +from typing_extensions import Annotated + from deepsearch.cps.client.components.elastic import ElasticSearchQuery from deepsearch.cps.client.components.projects import Project, SemanticBackendResource from deepsearch.cps.client.queries import Query, TaskCoordinates @@ -77,17 +80,38 @@ def DataQuery( return query +ConstrainedWeight = Annotated[ + float, Field(strict=True, ge=0.0, le=1.0, multiple_of=0.1) +] + + def CorpusRAGQuery( question: str, *, project: Union[str, Project], index_key: str, + retr_k: int = 10, + rerank: bool = False, + text_weight: ConstrainedWeight = 0.1, ) -> Query: - - return _get_rag_query( + """Create a RAG query against a collection + + Args: + question (str): the natural-language query + project (Union[str, Project]): project to use + index_key (str): index key of target private collection (must already be semantically indexed) + retr_k (int, optional): num of items to retrieve; defaults to 10 + rerank (bool, optional): whether to rerank retrieval results; defaults to False + text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1 + """ + + return _create_rag_query( question=question, project=project, index_key=index_key, + retr_k=retr_k, + rerank=rerank, + text_weight=text_weight, ) @@ -96,29 +120,56 @@ def DocumentRAGQuery( *, document_hash: str, project: Union[str, Project], - index_key: Optional[str] = None, # set in case of private collection + index_key: Optional[str] = None, + retr_k: int = 10, + rerank: bool = False, + text_weight: ConstrainedWeight = 0.1, ) -> Query: - - return _get_rag_query( + """Create a RAG query against a specific document + + Args: + question (str): the natural-language query + document_hash (str): hash of target document + project (Union[str, Project]): project to use + index_key (str, optional): index key of target private collection (must already be semantically indexed) in case doc within one; defaults to None (doc must already be semantically indexed) + retr_k (int, optional): num of items to retrieve; defaults to 10 + rerank (bool, optional): whether to rerank retrieval results; defaults to False + text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1 + """ + + return _create_rag_query( question=question, document_hash=document_hash, project=project, index_key=index_key, + retr_k=retr_k, + rerank=rerank, + text_weight=text_weight, ) -def _get_rag_query( +@validate_arguments +def _create_rag_query( question: str, *, document_hash: Optional[str] = None, project: Union[str, Project], - index_key: Optional[str] = None, + index_key: Optional[str], + retr_k: int, + rerank: bool, + text_weight: ConstrainedWeight, ) -> Query: proj_key = project.key if isinstance(project, Project) else project idx_key = index_key or "__project__" query = Query() - q_params = {"question": question} + + q_params = { + "question": question, + "retr_k": retr_k, + "use_reranker": rerank, + "hybrid_search_text_weight": text_weight, + } if document_hash: q_params["doc_id"] = document_hash task = query.add( @@ -138,12 +189,28 @@ def CorpusSemanticQuery( *, project: Union[str, Project], index_key: str, + retr_k: int = 10, + rerank: bool = False, + text_weight: ConstrainedWeight = 0.1, ) -> Query: - - return _get_semantic_query( + """Create a semantic retrieval query against a collection + + Args: + question (str): the natural-language query + project (Union[str, Project]): project to use + index_key (str): index key of target private collection (must already be semantically indexed) + retr_k (int, optional): num of items to retrieve; defaults to 10 + rerank (bool, optional): whether to rerank retrieval results; defaults to False + text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1 + """ + + return _create_semantic_query( question=question, project=project, index_key=index_key, + retr_k=retr_k, + rerank=rerank, + text_weight=text_weight, ) @@ -152,29 +219,56 @@ def DocumentSemanticQuery( *, document_hash: str, project: Union[str, Project], - index_key: Optional[str] = None, # set in case of private collection + index_key: Optional[str] = None, + retr_k: int = 10, + rerank: bool = False, + text_weight: ConstrainedWeight = 0.1, ) -> Query: - - return _get_semantic_query( + """Create a semantic retrieval query against a specific document + + Args: + question (str): the natural-language query + document_hash (str): hash of target document + project (Union[str, Project]): project to use + index_key (str, optional): index key of target private collection (must already be semantically indexed) in case doc within one; defaults to None (doc must already be semantically indexed) + retr_k (int, optional): num of items to retrieve; defaults to 10 + rerank (bool, optional): whether to rerank retrieval results; defaults to False + text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1 + """ + + return _create_semantic_query( question=question, document_hash=document_hash, project=project, index_key=index_key, + retr_k=retr_k, + rerank=rerank, + text_weight=text_weight, ) -def _get_semantic_query( +@validate_arguments +def _create_semantic_query( question: str, *, document_hash: Optional[str] = None, project: Union[str, Project], - index_key: Optional[str] = None, + index_key: Optional[str], + retr_k: int, + rerank: bool, + text_weight: ConstrainedWeight, ) -> Query: proj_key = project.key if isinstance(project, Project) else project idx_key = index_key or "__project__" query = Query() - q_params = {"question": question} + + q_params = { + "question": question, + "retr_k": retr_k, + "use_reranker": rerank, + "hybrid_search_text_weight": text_weight, + } if document_hash: q_params["doc_id"] = document_hash task = query.add(