Skip to content

Commit

Permalink
simple apis now cited/context doc indices (onyx-dot-app#2419)
Browse files Browse the repository at this point in the history
* simple apis now cited/context doc indices

* minor fixes
  • Loading branch information
hagen-danswer authored and rajiv chodisetti committed Oct 2, 2024
1 parent c6afe7a commit 35420b4
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 36 deletions.
18 changes: 16 additions & 2 deletions backend/danswer/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]:


class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
llm_selected_doc_indices: list[int]


class FinalUsedContextDocsResponse(BaseModel):
final_context_docs: list[LlmDoc]


class RelevanceAnalysis(BaseModel):
Expand Down Expand Up @@ -93,6 +97,16 @@ class CitationInfo(BaseModel):
document_id: str


class AllCitations(BaseModel):
citations: list[CitationInfo]


# This is a mapping of the citation number to the document index within
# the result search doc set
class MessageSpecificCitations(BaseModel):
citation_map: dict[int, int]


class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
Expand Down Expand Up @@ -138,7 +152,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_chunks_indices: list[int] | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None


Expand Down
28 changes: 21 additions & 7 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from sqlalchemy.orm import Session

from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
Expand Down Expand Up @@ -85,6 +88,7 @@
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
Expand All @@ -100,9 +104,9 @@
logger = setup_logger()


def translate_citations(
def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
) -> MessageSpecificCitations:
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
Expand All @@ -117,7 +121,7 @@ def translate_citations(
citation.citation_num
] = doc_id_to_saved_doc_id_map[citation.document_id]

return citation_to_saved_doc_id_map
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)


def _handle_search_tool_response_summary(
Expand Down Expand Up @@ -239,11 +243,14 @@ def _get_force_search_settings(
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
)
ChatPacketStream = Iterator[ChatPacket]
Expand Down Expand Up @@ -688,9 +695,13 @@ def stream_chat_message_objects(
)

yield LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_indices
llm_selected_doc_indices=llm_indices
)

elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
Expand Down Expand Up @@ -743,12 +754,13 @@ def stream_chat_message_objects(

# Post-LLM answer processing
try:
db_citations = None
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
db_citations = translate_citations(
message_specific_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield AllCitations(citations=answer.citations)

# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
Expand All @@ -765,7 +777,9 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non
)

# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
Expand Down
17 changes: 8 additions & 9 deletions backend/danswer/llm/answering/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
Expand Down Expand Up @@ -433,7 +433,7 @@ def _raw_output_for_non_explicit_tool_calling_llms(
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS:
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response

Expand Down Expand Up @@ -501,12 +501,10 @@ def _process_stream(
message = None

# special things we need to keep track of for the SearchTool
search_results: list[
LlmDoc
] | None = None # raw results that will be displayed to the user
final_context_docs: list[
LlmDoc
] | None = None # processed docs to feed into the LLM
# raw results that will be displayed to the user
search_results: list[LlmDoc] | None = None
# processed docs to feed into the LLM
final_context_docs: list[LlmDoc] | None = None

for message in stream:
if isinstance(message, ToolCallKickoff) or isinstance(
Expand All @@ -525,8 +523,9 @@ def _process_stream(
SearchResponseSummary, message.response
).top_sections
]
elif message.id == FINAL_CONTEXT_DOCUMENTS:
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_docs = cast(list[LlmDoc], message.response)
yield message

elif (
message.id == SEARCH_DOC_CONTENT_ID
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def get_search_answer(
elif isinstance(packet, QADocsResponse):
qa_response.docs = packet
elif isinstance(packet, LLMRelevanceFilterResponse):
qa_response.llm_chunks_indices = packet.relevant_chunk_indices
qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, DanswerQuotes):
qa_response.quotes = packet
elif isinstance(packet, CitationInfo):
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class OneShotQAResponse(BaseModel):
quotes: DanswerQuotes | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_chunks_indices: list[int] | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
chat_message_id: int | None = None
Expand Down
26 changes: 24 additions & 2 deletions backend/danswer/server/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
import json
from datetime import datetime
from typing import Any


def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings."""

def default(self, obj: Any) -> Any:
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)


def get_json_line(
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
) -> str:
"""
Convert a dictionary to a JSON string with datetime handling, and add a newline.
Args:
json_dict: The dictionary to be converted to JSON.
encoder: JSON encoder class to use, defaults to DateTimeEncoder.
Returns:
A JSON string representation of the input dictionary with a newline character.
"""
return json.dumps(json_dict, cls=encoder) + "\n"


def mask_string(sensitive_str: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/tools/internet_search/internet_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.internet_search.models import InternetSearchResponse
from danswer.tools.internet_search.models import InternetSearchResult
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -224,7 +224,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
]

yield ToolResponse(
id=FINAL_CONTEXT_DOCUMENTS,
id=FINAL_CONTEXT_DOCUMENTS_ID,
response=llm_docs,
)

Expand Down
10 changes: 5 additions & 5 deletions backend/danswer/tools/search/search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
FINAL_CONTEXT_DOCUMENTS = "final_context_documents"
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
SEARCH_EVALUATION_ID = "llm_doc_eval"


Expand Down Expand Up @@ -179,7 +179,7 @@ def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
final_context_docs_response = next(
response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS
response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS_ID
)
final_context_docs = cast(list[LlmDoc], final_context_docs_response.response)

Expand Down Expand Up @@ -260,7 +260,7 @@ def _build_response_for_specified_sections(
for section in final_context_sections
]

yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
Expand Down Expand Up @@ -343,12 +343,12 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
llm_doc_from_inference_section(section) for section in pruned_sections
]

yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

def final_result(self, *args: ToolResponse) -> JSON_ro:
final_docs = cast(
list[LlmDoc],
next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS),
next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS_ID),
)
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
# subfields that are not serializable by default (datetime)
Expand Down
Loading

0 comments on commit 35420b4

Please sign in to comment.