Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Sep 12, 2024
1 parent 265f64b commit 37932cc
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
4 changes: 4 additions & 0 deletions backend/danswer/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class CitationInfo(BaseModel):
document_id: str


class AllCitations(BaseModel):
citations: list[CitationInfo]


# This is a mapping of the citation number to the document index within
# the result search doc set
class MessageSpecificCitations(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import Session

from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
Expand Down Expand Up @@ -245,6 +246,7 @@ def _get_force_search_settings(
| FinalUsedContextDocsResponse
| ChatMessageDetail
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
Expand Down Expand Up @@ -758,7 +760,7 @@ def stream_chat_message_objects(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield message_specific_citations
yield AllCitations(citations=answer.citations)

# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
Expand Down
7 changes: 5 additions & 2 deletions backend/danswer/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ def default(self, obj: Any) -> Any:
return super().default(obj)


def get_json_line(json_dict: dict[str, Any]) -> str:
def get_json_line(
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
) -> str:
"""
Convert a dictionary to a JSON string with datetime handling, and add a newline.
Args:
json_dict: The dictionary to be converted to JSON.
encoder: JSON encoder class to use, defaults to DateTimeEncoder.
Returns:
A JSON string representation of the input dictionary with a newline character.
"""
return json.dumps(json_dict, cls=DateTimeEncoder) + "\n"
return json.dumps(json_dict, cls=encoder) + "\n"


def mask_string(sensitive_str: str) -> str:
Expand Down
28 changes: 15 additions & 13 deletions backend/ee/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import AllCitations
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.process_message import stream_chat_message_objects
Expand Down Expand Up @@ -74,14 +74,10 @@ def _get_final_context_doc_indices(
if final_context_docs is None or simple_search_docs is None:
return None

final_context_indices: list[int] = []
for i, simple_doc in enumerate(simple_search_docs):
for doc in final_context_docs:
if doc.document_id == simple_doc.id:
final_context_indices.append(i)
break

return final_context_indices
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
return [
i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids
]


def remove_answer_citations(answer: str) -> str:
Expand Down Expand Up @@ -158,8 +154,11 @@ def handle_simplified_chat_message(
response.message_id = packet.message_id
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, MessageSpecificCitations):
response.cited_documents = packet.citation_map
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}

response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs
Expand Down Expand Up @@ -297,8 +296,11 @@ def handle_send_message_simple_with_history(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, MessageSpecificCitations):
response.cited_documents = packet.citation_map
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}

response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs
Expand Down
3 changes: 2 additions & 1 deletion backend/ee/danswer/server/query_and_chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ class ChatBasicResponse(BaseModel):
message_id: int | None = None
llm_selected_doc_indices: list[int] | None = None
final_context_doc_indices: list[int] | None = None
cited_documents: dict[int, int] | None = None
# this is a map of the citation number to the document id
cited_documents: dict[int, str] | None = None

0 comments on commit 37932cc

Please sign in to comment.