Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add QA on given chunk refs #181

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepsearch/cps/queries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SemanticBackendResource,
)
from deepsearch.cps.client.queries import Query, TaskCoordinates
from deepsearch.cps.queries.results import ChunkRef


def Wf(wf_query: Dict[str, Any], kg: TaskCoordinates) -> Query:
Expand Down Expand Up @@ -110,6 +111,7 @@ class _APISemanticRagParameters(_APISemanticRetrievalParameters):
gen_ctx_window_size: int = 5000
gen_ctx_window_lead_weight: float = 0.5
return_prompt: bool = False
chunk_refs: Optional[List[ChunkRef]] = None
gen_timeout: Optional[float] = None


Expand All @@ -129,6 +131,7 @@ def RAGQuery(
gen_ctx_window_size: int = 5000,
gen_ctx_window_lead_weight: float = 0.5,
return_prompt: bool = False,
chunk_refs: Optional[List[ChunkRef]] = None,
gen_timeout: Optional[float] = None,
) -> Query:
"""Create a RAG query
Expand All @@ -147,6 +150,7 @@ def RAGQuery(
gen_ctx_window_size (int, optional): (relevant only if gen_ctx_extr_method=="window") max chars to use for extracted gen context (actual extraction quantized on doc item level); defaults to 5000
gen_ctx_window_lead_weight (float, optional): (relevant only if gen_ctx_extr_method=="window") weight of leading text for distributing remaining window size after extracting the `main_path`; defaults to 0.5 (centered around `main_path`)
return_prompt (bool, optional): whether to return the instantiated prompt; defaults to False
chunk_refs (Optional[List[ChunkRef]], optional): list of explicit chunk references to use instead of performing retrieval; defaults to None (i.e. retrieval-mode)
gen_timeout (float, optional): timeout for LLM generation; defaults to None, i.e. determined by system
"""

Expand Down Expand Up @@ -181,6 +185,7 @@ def RAGQuery(
gen_ctx_window_size=gen_ctx_window_size,
gen_ctx_window_lead_weight=gen_ctx_window_lead_weight,
return_prompt=return_prompt,
chunk_refs=chunk_refs,
gen_timeout=gen_timeout,
)

Expand Down
38 changes: 25 additions & 13 deletions deepsearch/cps/queries/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

from typing import List, Optional

from pydantic.v1 import BaseModel, root_validator
from pydantic.v1 import BaseModel

from deepsearch.cps.client.components.queries import RunQueryResult


class SearchResultItem(BaseModel):
class ChunkRef(BaseModel):
doc_hash: str
main_path: str # the anchor path among the contributing group
path_group: List[str] # the doc paths contributing to the encoding source


class SearchResultItem(ChunkRef):
chunk: str
main_path: str
path_group: List[str]
source_is_text: bool


class RAGGroundingInfo(BaseModel):
retr_items: List[SearchResultItem]
retr_items: Optional[List[SearchResultItem]] = None
gen_ctx_paths: List[str]


Expand Down Expand Up @@ -45,26 +48,35 @@ def __init__(self, msg="Search returned no results", *args, **kwargs):

class RAGResult(BaseModel):
answers: List[RAGAnswerItem]
search_result_items: List[SearchResultItem]
search_result_items: Optional[List[SearchResultItem]] = None

@classmethod
def from_api_output(cls, data: RunQueryResult, raise_on_error=True):
answers: List[RAGAnswerItem] = []
try:
search_result_items = data.outputs["retrieval"]["items"]
if raise_on_error and len(search_result_items) == 0:
raise NoSearchResultsError()
retrieval_part = data.outputs["retrieval"]
if retrieval_part is not None:
search_result_items = retrieval_part["items"]
if raise_on_error and len(search_result_items) == 0:
raise NoSearchResultsError()
else:
search_result_items = None
for answer_item in data.outputs["answers"]:
if raise_on_error and (gen_err := answer_item.get("gen_err")):
raise GenerationError(gen_err)
retr_idxs = answer_item["grounding_info"]["retr_idxs"]
answers.append(
RAGAnswerItem(
answer=answer_item["answer"],
grounding=RAGGroundingInfo(
retr_items=[
SearchResultItem.parse_obj(search_result_items[i])
for i in answer_item["grounding_info"]["retr_idxs"]
],
retr_items=(
[
SearchResultItem.parse_obj(search_result_items[i])
for i in retr_idxs
]
if retr_idxs is not None and retrieval_part is not None
else None
),
gen_ctx_paths=answer_item["grounding_info"][
"gen_ctx_paths"
],
Expand Down
Loading