From 37932ccf24c52a5a23642d612224fe73985d5f7a Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 12 Sep 2024 13:38:25 -0700 Subject: [PATCH] minor fixes --- backend/danswer/chat/models.py | 4 +++ backend/danswer/chat/process_message.py | 4 ++- backend/danswer/server/utils.py | 7 +++-- .../server/query_and_chat/chat_backend.py | 28 ++++++++++--------- .../danswer/server/query_and_chat/models.py | 3 +- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index a5beca116ab..97d5b9e7275 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -97,6 +97,10 @@ class CitationInfo(BaseModel): document_id: str +class AllCitations(BaseModel): + citations: list[CitationInfo] + + # This is a mapping of the citation number to the document index within # the result search doc set class MessageSpecificCitations(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index b240558d0a3..fa13f245ccb 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.models import AllCitations from danswer.chat.models import CitationInfo from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece @@ -245,6 +246,7 @@ def _get_force_search_settings( | FinalUsedContextDocsResponse | ChatMessageDetail | DanswerAnswerPiece + | AllCitations | CitationInfo | ImageGenerationDisplay | CustomToolResponse @@ -758,7 +760,7 @@ def stream_chat_message_objects( citations_list=answer.citations, db_docs=reference_db_search_docs, ) - yield message_specific_citations + yield AllCitations(citations=answer.citations) # Saving Gen AI answer and responding with message info tool_name_to_tool_id: dict[str, int] = {} diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index dc979054cba..53ed5b426ba 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -12,17 +12,20 @@ def default(self, obj: Any) -> Any: return super().default(obj) -def get_json_line(json_dict: dict[str, Any]) -> str: +def get_json_line( + json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder +) -> str: """ Convert a dictionary to a JSON string with datetime handling, and add a newline. Args: json_dict: The dictionary to be converted to JSON. + encoder: JSON encoder class to use, defaults to DateTimeEncoder. Returns: A JSON string representation of the input dictionary with a newline character. """ - return json.dumps(json_dict, cls=DateTimeEncoder) + "\n" + return json.dumps(json_dict, cls=encoder) + "\n" def mask_string(sensitive_str: str) -> str: diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index b5f3909e081..143f81a57d4 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -7,11 +7,11 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.models import AllCitations from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import FinalUsedContextDocsResponse from danswer.chat.models import LlmDoc from danswer.chat.models import LLMRelevanceFilterResponse -from danswer.chat.models import MessageSpecificCitations from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.chat.process_message import stream_chat_message_objects @@ -74,14 +74,10 @@ def _get_final_context_doc_indices( if final_context_docs is None or simple_search_docs is None: return None - final_context_indices: list[int] = [] - for i, simple_doc in enumerate(simple_search_docs): - for doc in final_context_docs: - if doc.document_id == simple_doc.id: - final_context_indices.append(i) - break - - return final_context_indices + final_context_doc_ids = {doc.document_id for doc in final_context_docs} + return [ + i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids + ] def remove_answer_citations(answer: str) -> str: @@ -158,8 +154,11 @@ def handle_simplified_chat_message( response.message_id = packet.message_id elif isinstance(packet, FinalUsedContextDocsResponse): final_context_docs = packet.final_context_docs - elif isinstance(packet, MessageSpecificCitations): - response.cited_documents = packet.citation_map + elif isinstance(packet, AllCitations): + response.cited_documents = { + citation.citation_num: citation.document_id + for citation in packet.citations + } response.final_context_doc_indices = _get_final_context_doc_indices( final_context_docs, response.simple_search_docs @@ -297,8 +296,11 @@ def handle_send_message_simple_with_history( response.llm_selected_doc_indices = packet.llm_selected_doc_indices elif isinstance(packet, FinalUsedContextDocsResponse): final_context_docs = packet.final_context_docs - elif isinstance(packet, MessageSpecificCitations): - response.cited_documents = packet.citation_map + elif isinstance(packet, AllCitations): + response.cited_documents = { + citation.citation_num: citation.document_id + for citation in packet.citations + } response.final_context_doc_indices = _get_final_context_doc_indices( final_context_docs, response.simple_search_docs diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index 19fc65b7273..e2a32d3c56d 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -76,4 +76,5 @@ class ChatBasicResponse(BaseModel): message_id: int | None = None llm_selected_doc_indices: list[int] | None = None final_context_doc_indices: list[int] | None = None - cited_documents: dict[int, int] | None = None + # this is a map of the citation number to the document id + cited_documents: dict[int, str] | None = None