Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose more params for semantic operations #159

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 110 additions & 16 deletions deepsearch/cps/queries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Dict, List, Optional, Union

from pydantic.v1 import Field, validate_arguments
from typing_extensions import Annotated

from deepsearch.cps.client.components.elastic import ElasticSearchQuery
from deepsearch.cps.client.components.projects import Project, SemanticBackendResource
from deepsearch.cps.client.queries import Query, TaskCoordinates
Expand Down Expand Up @@ -77,17 +80,38 @@ def DataQuery(
return query


ConstrainedWeight = Annotated[
float, Field(strict=True, ge=0.0, le=1.0, multiple_of=0.1)
]


def CorpusRAGQuery(
question: str,
*,
project: Union[str, Project],
index_key: str,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_rag_query(
"""Create a RAG query against a collection

Args:
question (str): the natural-language query
project (Union[str, Project]): project to use
index_key (str): index key of target private collection (must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_rag_query(
question=question,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


Expand All @@ -96,29 +120,56 @@ def DocumentRAGQuery(
*,
document_hash: str,
project: Union[str, Project],
index_key: Optional[str] = None, # set in case of private collection
index_key: Optional[str] = None,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_rag_query(
"""Create a RAG query against a specific document

Args:
question (str): the natural-language query
document_hash (str): hash of target document
project (Union[str, Project]): project to use
index_key (str, optional): index key of target private collection (must already be semantically indexed) in case doc within one; defaults to None (doc must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_rag_query(
question=question,
document_hash=document_hash,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


def _get_rag_query(
@validate_arguments
def _create_rag_query(
question: str,
*,
document_hash: Optional[str] = None,
project: Union[str, Project],
index_key: Optional[str] = None,
index_key: Optional[str],
retr_k: int,
rerank: bool,
text_weight: ConstrainedWeight,
) -> Query:
proj_key = project.key if isinstance(project, Project) else project
idx_key = index_key or "__project__"

query = Query()
q_params = {"question": question}

q_params = {
"question": question,
"retr_k": retr_k,
"use_reranker": rerank,
"hybrid_search_text_weight": text_weight,
}
if document_hash:
q_params["doc_id"] = document_hash
task = query.add(
Expand All @@ -138,12 +189,28 @@ def CorpusSemanticQuery(
*,
project: Union[str, Project],
index_key: str,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_semantic_query(
"""Create a semantic retrieval query against a collection

Args:
question (str): the natural-language query
project (Union[str, Project]): project to use
index_key (str): index key of target private collection (must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_semantic_query(
question=question,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


Expand All @@ -152,29 +219,56 @@ def DocumentSemanticQuery(
*,
document_hash: str,
project: Union[str, Project],
index_key: Optional[str] = None, # set in case of private collection
index_key: Optional[str] = None,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_semantic_query(
"""Create a semantic retrieval query against a specific document

Args:
question (str): the natural-language query
document_hash (str): hash of target document
project (Union[str, Project]): project to use
index_key (str, optional): index key of target private collection (must already be semantically indexed) in case doc within one; defaults to None (doc must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_semantic_query(
question=question,
document_hash=document_hash,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


def _get_semantic_query(
@validate_arguments
def _create_semantic_query(
question: str,
*,
document_hash: Optional[str] = None,
project: Union[str, Project],
index_key: Optional[str] = None,
index_key: Optional[str],
retr_k: int,
rerank: bool,
text_weight: ConstrainedWeight,
) -> Query:
proj_key = project.key if isinstance(project, Project) else project
idx_key = index_key or "__project__"

query = Query()
q_params = {"question": question}

q_params = {
"question": question,
"retr_k": retr_k,
"use_reranker": rerank,
"hybrid_search_text_weight": text_weight,
}
if document_hash:
q_params["doc_id"] = document_hash
task = query.add(
Expand Down
Loading