From ac4ed33572f2e0853cf507a660d6e34a20dd7d1c Mon Sep 17 00:00:00 2001 From: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Date: Fri, 14 Jun 2024 09:45:27 +0200 Subject: [PATCH] feat: add QA on given chunk refs Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --- deepsearch/cps/queries/__init__.py | 5 ++++ deepsearch/cps/queries/results.py | 38 ++++++++++++++++++++---------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/deepsearch/cps/queries/__init__.py b/deepsearch/cps/queries/__init__.py index 6a2abb6f..9ba514cd 100644 --- a/deepsearch/cps/queries/__init__.py +++ b/deepsearch/cps/queries/__init__.py @@ -15,6 +15,7 @@ SemanticBackendResource, ) from deepsearch.cps.client.queries import Query, TaskCoordinates +from deepsearch.cps.queries.results import ChunkRef def Wf(wf_query: Dict[str, Any], kg: TaskCoordinates) -> Query: @@ -110,6 +111,7 @@ class _APISemanticRagParameters(_APISemanticRetrievalParameters): gen_ctx_window_size: int = 5000 gen_ctx_window_lead_weight: float = 0.5 return_prompt: bool = False + chunk_refs: Optional[List[ChunkRef]] = None gen_timeout: Optional[float] = None @@ -129,6 +131,7 @@ def RAGQuery( gen_ctx_window_size: int = 5000, gen_ctx_window_lead_weight: float = 0.5, return_prompt: bool = False, + chunk_refs: Optional[List[ChunkRef]] = None, gen_timeout: Optional[float] = None, ) -> Query: """Create a RAG query @@ -147,6 +150,7 @@ def RAGQuery( gen_ctx_window_size (int, optional): (relevant only if gen_ctx_extr_method=="window") max chars to use for extracted gen context (actual extraction quantized on doc item level); defaults to 5000 gen_ctx_window_lead_weight (float, optional): (relevant only if gen_ctx_extr_method=="window") weight of leading text for distributing remaining window size after extracting the `main_path`; defaults to 0.5 (centered around `main_path`) return_prompt (bool, optional): whether to return the instantiated prompt; defaults to False + chunk_refs (Optional[List[ChunkRef]], optional): list of explicit chunk references to use instead of performing retrieval; defaults to None (i.e. retrieval-mode) gen_timeout (float, optional): timeout for LLM generation; defaults to None, i.e. determined by system """ @@ -181,6 +185,7 @@ def RAGQuery( gen_ctx_window_size=gen_ctx_window_size, gen_ctx_window_lead_weight=gen_ctx_window_lead_weight, return_prompt=return_prompt, + chunk_refs=chunk_refs, gen_timeout=gen_timeout, ) diff --git a/deepsearch/cps/queries/results.py b/deepsearch/cps/queries/results.py index d737f00a..b5077c51 100644 --- a/deepsearch/cps/queries/results.py +++ b/deepsearch/cps/queries/results.py @@ -2,21 +2,24 @@ from typing import List, Optional -from pydantic.v1 import BaseModel, root_validator +from pydantic.v1 import BaseModel from deepsearch.cps.client.components.queries import RunQueryResult -class SearchResultItem(BaseModel): +class ChunkRef(BaseModel): doc_hash: str + main_path: str # the anchor path among the contributing group + path_group: List[str] # the doc paths contributing to the encoding source + + +class SearchResultItem(ChunkRef): chunk: str - main_path: str - path_group: List[str] source_is_text: bool class RAGGroundingInfo(BaseModel): - retr_items: List[SearchResultItem] + retr_items: Optional[List[SearchResultItem]] = None gen_ctx_paths: List[str] @@ -45,26 +48,35 @@ def __init__(self, msg="Search returned no results", *args, **kwargs): class RAGResult(BaseModel): answers: List[RAGAnswerItem] - search_result_items: List[SearchResultItem] + search_result_items: Optional[List[SearchResultItem]] = None @classmethod def from_api_output(cls, data: RunQueryResult, raise_on_error=True): answers: List[RAGAnswerItem] = [] try: - search_result_items = data.outputs["retrieval"]["items"] - if raise_on_error and len(search_result_items) == 0: - raise NoSearchResultsError() + retrieval_part = data.outputs["retrieval"] + if retrieval_part is not None: + search_result_items = retrieval_part["items"] + if raise_on_error and len(search_result_items) == 0: + raise NoSearchResultsError() + else: + search_result_items = None for answer_item in data.outputs["answers"]: if raise_on_error and (gen_err := answer_item.get("gen_err")): raise GenerationError(gen_err) + retr_idxs = answer_item["grounding_info"]["retr_idxs"] answers.append( RAGAnswerItem( answer=answer_item["answer"], grounding=RAGGroundingInfo( - retr_items=[ - SearchResultItem.parse_obj(search_result_items[i]) - for i in answer_item["grounding_info"]["retr_idxs"] - ], + retr_items=( + [ + SearchResultItem.parse_obj(search_result_items[i]) + for i in retr_idxs + ] + if retr_idxs is not None and retrieval_part is not None + else None + ), gen_ctx_paths=answer_item["grounding_info"][ "gen_ctx_paths" ],