Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simple apis now cited/context doc indices #2419

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions backend/danswer/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]:


class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
llm_selected_doc_indices: list[int]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed this everywhere



class FinalUsedContextDocsResponse(BaseModel):
final_context_docs: list[LlmDoc]


class RelevanceAnalysis(BaseModel):
Expand Down Expand Up @@ -93,6 +97,16 @@ class CitationInfo(BaseModel):
document_id: str


class AllCitations(BaseModel):
citations: list[CitationInfo]


# This is a mapping of the citation number to the document index within
# the result search doc set
class MessageSpecificCitations(BaseModel):
citation_map: dict[int, int]


class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
Expand Down Expand Up @@ -138,7 +152,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_chunks_indices: list[int] | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None


Expand Down
28 changes: 21 additions & 7 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from sqlalchemy.orm import Session

from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
Expand Down Expand Up @@ -85,6 +88,7 @@
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
Expand All @@ -100,9 +104,9 @@
logger = setup_logger()


def translate_citations(
def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
) -> MessageSpecificCitations:
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
Expand All @@ -117,7 +121,7 @@ def translate_citations(
citation.citation_num
] = doc_id_to_saved_doc_id_map[citation.document_id]

return citation_to_saved_doc_id_map
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)


def _handle_search_tool_response_summary(
Expand Down Expand Up @@ -239,11 +243,14 @@ def _get_force_search_settings(
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
)
ChatPacketStream = Iterator[ChatPacket]
Expand Down Expand Up @@ -688,9 +695,13 @@ def stream_chat_message_objects(
)

yield LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_indices
llm_selected_doc_indices=llm_indices
)

elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
Expand Down Expand Up @@ -743,12 +754,13 @@ def stream_chat_message_objects(

# Post-LLM answer processing
try:
db_citations = None
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
db_citations = translate_citations(
message_specific_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield AllCitations(citations=answer.citations)

# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
Expand All @@ -765,7 +777,9 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non
)

# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
Expand Down
17 changes: 8 additions & 9 deletions backend/danswer/llm/answering/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
Expand Down Expand Up @@ -433,7 +433,7 @@ def _raw_output_for_non_explicit_tool_calling_llms(
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS:
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response

Expand Down Expand Up @@ -501,12 +501,10 @@ def _process_stream(
message = None

# special things we need to keep track of for the SearchTool
search_results: list[
LlmDoc
] | None = None # raw results that will be displayed to the user
final_context_docs: list[
LlmDoc
] | None = None # processed docs to feed into the LLM
# raw results that will be displayed to the user
search_results: list[LlmDoc] | None = None
# processed docs to feed into the LLM
final_context_docs: list[LlmDoc] | None = None

for message in stream:
if isinstance(message, ToolCallKickoff) or isinstance(
Expand All @@ -525,8 +523,9 @@ def _process_stream(
SearchResponseSummary, message.response
).top_sections
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we yield the final context docs

elif message.id == FINAL_CONTEXT_DOCUMENTS:
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_docs = cast(list[LlmDoc], message.response)
yield message

elif (
message.id == SEARCH_DOC_CONTENT_ID
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def get_search_answer(
elif isinstance(packet, QADocsResponse):
qa_response.docs = packet
elif isinstance(packet, LLMRelevanceFilterResponse):
qa_response.llm_chunks_indices = packet.relevant_chunk_indices
qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, DanswerQuotes):
qa_response.quotes = packet
elif isinstance(packet, CitationInfo):
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class OneShotQAResponse(BaseModel):
quotes: DanswerQuotes | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_chunks_indices: list[int] | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
chat_message_id: int | None = None
Expand Down
26 changes: 24 additions & 2 deletions backend/danswer/server/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
import json
from datetime import datetime
from typing import Any


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made this function able to encode datetime objects

def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings."""

def default(self, obj: Any) -> Any:
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)


def get_json_line(
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
) -> str:
"""
Convert a dictionary to a JSON string with datetime handling, and add a newline.

Args:
json_dict: The dictionary to be converted to JSON.
encoder: JSON encoder class to use, defaults to DateTimeEncoder.

Returns:
A JSON string representation of the input dictionary with a newline character.
"""
return json.dumps(json_dict, cls=encoder) + "\n"


def mask_string(sensitive_str: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/tools/internet_search/internet_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.internet_search.models import InternetSearchResponse
from danswer.tools.internet_search.models import InternetSearchResult
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -224,7 +224,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
]

yield ToolResponse(
id=FINAL_CONTEXT_DOCUMENTS,
id=FINAL_CONTEXT_DOCUMENTS_ID,
response=llm_docs,
)

Expand Down
10 changes: 5 additions & 5 deletions backend/danswer/tools/search/search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
FINAL_CONTEXT_DOCUMENTS = "final_context_documents"
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
SEARCH_EVALUATION_ID = "llm_doc_eval"


Expand Down Expand Up @@ -179,7 +179,7 @@ def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
final_context_docs_response = next(
response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS
response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS_ID
)
final_context_docs = cast(list[LlmDoc], final_context_docs_response.response)

Expand Down Expand Up @@ -260,7 +260,7 @@ def _build_response_for_specified_sections(
for section in final_context_sections
]

yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
Expand Down Expand Up @@ -343,12 +343,12 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
llm_doc_from_inference_section(section) for section in pruned_sections
]

yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

def final_result(self, *args: ToolResponse) -> JSON_ro:
final_docs = cast(
list[LlmDoc],
next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS),
next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS_ID),
)
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
# subfields that are not serializable by default (datetime)
Expand Down
Loading
Loading