diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 1f1f15ea700..97d5b9e7275 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -60,7 +60,11 @@ def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: class LLMRelevanceFilterResponse(BaseModel): - relevant_chunk_indices: list[int] + llm_selected_doc_indices: list[int] + + +class FinalUsedContextDocsResponse(BaseModel): + final_context_docs: list[LlmDoc] class RelevanceAnalysis(BaseModel): @@ -93,6 +97,16 @@ class CitationInfo(BaseModel): document_id: str +class AllCitations(BaseModel): + citations: list[CitationInfo] + + +# This is a mapping of the citation number to the document index within +# the result search doc set +class MessageSpecificCitations(BaseModel): + citation_map: dict[int, int] + + class MessageResponseIDInfo(BaseModel): user_message_id: int | None reserved_assistant_message_id: int @@ -138,7 +152,7 @@ class QAResponse(SearchResponse, DanswerAnswer): predicted_flow: QueryFlow predicted_search: SearchType eval_res_valid: bool | None = None - llm_chunks_indices: list[int] | None = None + llm_selected_doc_indices: list[int] | None = None error_msg: str | None = None diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 2eea2cfc20f..fa13f245ccb 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -7,12 +7,15 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.models import AllCitations from danswer.chat.models import CitationInfo from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import FinalUsedContextDocsResponse from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import MessageResponseIDInfo +from danswer.chat.models import MessageSpecificCitations from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import BING_API_KEY @@ -85,6 +88,7 @@ ) from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse from danswer.tools.internet_search.internet_search_tool import InternetSearchTool +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary from danswer.tools.search.search_tool import SearchTool @@ -100,9 +104,9 @@ logger = setup_logger() -def translate_citations( +def _translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] -) -> dict[int, int]: +) -> MessageSpecificCitations: """Always cites the first instance of the document_id, assumes the db_docs are sorted in the order displayed in the UI""" doc_id_to_saved_doc_id_map: dict[str, int] = {} @@ -117,7 +121,7 @@ def translate_citations( citation.citation_num ] = doc_id_to_saved_doc_id_map[citation.document_id] - return citation_to_saved_doc_id_map + return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map) def _handle_search_tool_response_summary( @@ -239,11 +243,14 @@ def _get_force_search_settings( StreamingError | QADocsResponse | LLMRelevanceFilterResponse + | FinalUsedContextDocsResponse | ChatMessageDetail | DanswerAnswerPiece + | AllCitations | CitationInfo | ImageGenerationDisplay | CustomToolResponse + | MessageSpecificCitations | MessageResponseIDInfo ) ChatPacketStream = Iterator[ChatPacket] @@ -688,9 +695,13 @@ def stream_chat_message_objects( ) yield LLMRelevanceFilterResponse( - relevant_chunk_indices=llm_indices + llm_selected_doc_indices=llm_indices ) + elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: + yield FinalUsedContextDocsResponse( + final_context_docs=packet.response + ) elif packet.id == IMAGE_GENERATION_RESPONSE_ID: img_generation_response = cast( list[ImageGenerationResponse], packet.response @@ -743,12 +754,13 @@ def stream_chat_message_objects( # Post-LLM answer processing try: - db_citations = None + message_specific_citations: MessageSpecificCitations | None = None if reference_db_search_docs: - db_citations = translate_citations( + message_specific_citations = _translate_citations( citations_list=answer.citations, db_docs=reference_db_search_docs, ) + yield AllCitations(citations=answer.citations) # Saving Gen AI answer and responding with message info tool_name_to_tool_id: dict[str, int] = {} @@ -765,7 +777,9 @@ def stream_chat_message_objects( reference_docs=reference_db_search_docs, files=ai_message_files, token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), - citations=db_citations, + citations=message_specific_citations.citation_map + if message_specific_citations + else None, error=None, tool_calls=[ ToolCall( diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index 09ea4e05332..7057d7c2e4b 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -419,7 +419,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non ) # Get the chunks fed to the LLM only, then fill with other docs - llm_doc_inds = answer.llm_chunks_indices or [] + llm_doc_inds = answer.llm_selected_doc_indices or [] llm_docs = [top_docs[i] for i in llm_doc_inds] remaining_docs = [ doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 0a0a1c52afa..a036421da2e 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -52,7 +52,7 @@ from danswer.tools.internet_search.internet_search_tool import InternetSearchTool from danswer.tools.message import build_tool_message from danswer.tools.message import ToolCallSummary -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary @@ -433,7 +433,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: final_context_documents = None for response in tool_runner.tool_responses(): - if response.id == FINAL_CONTEXT_DOCUMENTS: + if response.id == FINAL_CONTEXT_DOCUMENTS_ID: final_context_documents = cast(list[LlmDoc], response.response) yield response @@ -501,12 +501,10 @@ def _process_stream( message = None # special things we need to keep track of for the SearchTool - search_results: list[ - LlmDoc - ] | None = None # raw results that will be displayed to the user - final_context_docs: list[ - LlmDoc - ] | None = None # processed docs to feed into the LLM + # raw results that will be displayed to the user + search_results: list[LlmDoc] | None = None + # processed docs to feed into the LLM + final_context_docs: list[LlmDoc] | None = None for message in stream: if isinstance(message, ToolCallKickoff) or isinstance( @@ -525,8 +523,9 @@ def _process_stream( SearchResponseSummary, message.response ).top_sections ] - elif message.id == FINAL_CONTEXT_DOCUMENTS: + elif message.id == FINAL_CONTEXT_DOCUMENTS_ID: final_context_docs = cast(list[LlmDoc], message.response) + yield message elif ( message.id == SEARCH_DOC_CONTENT_ID diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index a5a0fe0dad5..3f83ad19551 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -371,7 +371,7 @@ def get_search_answer( elif isinstance(packet, QADocsResponse): qa_response.docs = packet elif isinstance(packet, LLMRelevanceFilterResponse): - qa_response.llm_chunks_indices = packet.relevant_chunk_indices + qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices elif isinstance(packet, DanswerQuotes): qa_response.quotes = packet elif isinstance(packet, CitationInfo): diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index d7e81975630..fceb78de7aa 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -62,7 +62,7 @@ class OneShotQAResponse(BaseModel): quotes: DanswerQuotes | None = None citations: list[CitationInfo] | None = None docs: QADocsResponse | None = None - llm_chunks_indices: list[int] | None = None + llm_selected_doc_indices: list[int] | None = None error_msg: str | None = None answer_valid: bool = True # Reflexion result, default True if Reflexion not run chat_message_id: int | None = None diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index bf535661878..53ed5b426ba 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -1,9 +1,31 @@ import json +from datetime import datetime from typing import Any -def get_json_line(json_dict: dict) -> str: - return json.dumps(json_dict) + "\n" +class DateTimeEncoder(json.JSONEncoder): + """Custom JSON encoder that converts datetime objects to ISO format strings.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + +def get_json_line( + json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder +) -> str: + """ + Convert a dictionary to a JSON string with datetime handling, and add a newline. + + Args: + json_dict: The dictionary to be converted to JSON. + encoder: JSON encoder class to use, defaults to DateTimeEncoder. + + Returns: + A JSON string representation of the input dictionary with a newline character. + """ + return json.dumps(json_dict, cls=encoder) + "\n" def mask_string(sensitive_str: str) -> str: diff --git a/backend/danswer/tools/internet_search/internet_search_tool.py b/backend/danswer/tools/internet_search/internet_search_tool.py index 2640afcdf83..3012eb465f4 100644 --- a/backend/danswer/tools/internet_search/internet_search_tool.py +++ b/backend/danswer/tools/internet_search/internet_search_tool.py @@ -20,7 +20,7 @@ from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.tools.internet_search.models import InternetSearchResponse from danswer.tools.internet_search.models import InternetSearchResult -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse from danswer.utils.logger import setup_logger @@ -224,7 +224,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: ] yield ToolResponse( - id=FINAL_CONTEXT_DOCUMENTS, + id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs, ) diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 13d3a304b06..cbfaf4f3d92 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -45,7 +45,7 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary" SEARCH_DOC_CONTENT_ID = "search_doc_content" SECTION_RELEVANCE_LIST_ID = "section_relevance_list" -FINAL_CONTEXT_DOCUMENTS = "final_context_documents" +FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" SEARCH_EVALUATION_ID = "llm_doc_eval" @@ -179,7 +179,7 @@ def build_tool_message_content( self, *args: ToolResponse ) -> str | list[str | dict[str, Any]]: final_context_docs_response = next( - response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS + response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS_ID ) final_context_docs = cast(list[LlmDoc], final_context_docs_response.response) @@ -260,7 +260,7 @@ def _build_response_for_specified_sections( for section in final_context_sections ] - yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs) + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["query"]) @@ -343,12 +343,12 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: llm_doc_from_inference_section(section) for section in pruned_sections ] - yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs) + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) def final_result(self, *args: ToolResponse) -> JSON_ro: final_docs = cast( list[LlmDoc], - next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS), + next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS_ID), ) # NOTE: need to do this json.loads(doc.json()) stuff because there are some # subfields that are not serializable by default (datetime) diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index 0d5d1987f34..143f81a57d4 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -7,7 +7,10 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.models import AllCitations from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import FinalUsedContextDocsResponse +from danswer.chat.models import LlmDoc from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError @@ -41,7 +44,7 @@ router = APIRouter(prefix="/chat") -def translate_doc_response_to_simple_doc( +def _translate_doc_response_to_simple_doc( doc_response: QADocsResponse, ) -> list[SimpleDoc]: return [ @@ -60,6 +63,23 @@ def translate_doc_response_to_simple_doc( ] +def _get_final_context_doc_indices( + final_context_docs: list[LlmDoc] | None, + simple_search_docs: list[SimpleDoc] | None, +) -> list[int] | None: + """ + this function returns a list of indices of the simple search docs + that were actually fed to the LLM. + """ + if final_context_docs is None or simple_search_docs is None: + return None + + final_context_doc_ids = {doc.document_id for doc in final_context_docs} + return [ + i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids + ] + + def remove_answer_citations(answer: str) -> str: pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)" @@ -120,17 +140,29 @@ def handle_simplified_chat_message( ) response = ChatBasicResponse() + final_context_docs: list[LlmDoc] = [] answer = "" for packet in packets: if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: answer += packet.answer_piece elif isinstance(packet, QADocsResponse): - response.simple_search_docs = translate_doc_response_to_simple_doc(packet) + response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): response.message_id = packet.message_id + elif isinstance(packet, FinalUsedContextDocsResponse): + final_context_docs = packet.final_context_docs + elif isinstance(packet, AllCitations): + response.cited_documents = { + citation.citation_num: citation.document_id + for citation in packet.citations + } + + response.final_context_doc_indices = _get_final_context_doc_indices( + final_context_docs, response.simple_search_docs + ) response.answer = answer if answer: @@ -152,6 +184,8 @@ def handle_send_message_simple_with_history( if len(req.messages) == 0: raise HTTPException(status_code=400, detail="Messages cannot be zero length") + # This is a sanity check to make sure the chat history is valid + # It must start with a user message and alternate between user and assistant expected_role = MessageType.USER for msg in req.messages: if not msg.message: @@ -246,19 +280,31 @@ def handle_send_message_simple_with_history( ) response = ChatBasicResponse() + final_context_docs: list[LlmDoc] = [] answer = "" for packet in packets: if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: answer += packet.answer_piece elif isinstance(packet, QADocsResponse): - response.simple_search_docs = translate_doc_response_to_simple_doc(packet) + response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): response.message_id = packet.message_id elif isinstance(packet, LLMRelevanceFilterResponse): - response.llm_chunks_indices = packet.relevant_chunk_indices + response.llm_selected_doc_indices = packet.llm_selected_doc_indices + elif isinstance(packet, FinalUsedContextDocsResponse): + final_context_docs = packet.final_context_docs + elif isinstance(packet, AllCitations): + response.cited_documents = { + citation.citation_num: citation.document_id + for citation in packet.citations + } + + response.final_context_doc_indices = _get_final_context_doc_indices( + final_context_docs, response.simple_search_docs + ) response.answer = answer if answer: diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index b0ce553ebe0..e2a32d3c56d 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -74,4 +74,7 @@ class ChatBasicResponse(BaseModel): simple_search_docs: list[SimpleDoc] | None = None error_msg: str | None = None message_id: int | None = None - llm_chunks_indices: list[int] | None = None + llm_selected_doc_indices: list[int] | None = None + final_context_doc_indices: list[int] | None = None + # this is a map of the citation number to the document id + cited_documents: dict[int, str] | None = None diff --git a/backend/tests/api/test_api.py b/backend/tests/api/test_api.py index 059c40824d5..9a3571ef585 100644 --- a/backend/tests/api/test_api.py +++ b/backend/tests/api/test_api.py @@ -101,4 +101,4 @@ def test_handle_send_message_simple_with_history(client: TestClient) -> None: resp_json = response.json() # persona must have LLM relevance enabled for this to pass - assert len(resp_json["llm_chunks_indices"]) > 0 + assert len(resp_json["llm_selected_doc_indices"]) > 0