From 2574a4f6ceb0dbb051ead94ee49831333890cd4d Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Fri, 11 Oct 2024 14:26:58 +0200 Subject: [PATCH] feat(chat): chat on project documents (#18) * feat(chat): chat on top of all documents * feat(chat): remove extra print statements * feat[Chat]: chat with references for the answers * feat[chat]: display multiple references of chat * feat(chat): push chat reference data * fix(chat): clean reference ui --- backend/app/api/v1/chat.py | 118 ++++++++++++- backend/app/processing/file_preprocessing.py | 5 +- backend/app/vectorstore/chroma.py | 51 ++++++ .../app/(app)/projects/[projectId]/page.tsx | 2 +- frontend/src/components/ChatBox.tsx | 4 + .../src/components/ChatReferenceDrawer.tsx | 37 ++++ frontend/src/components/ui/ChatBubble.tsx | 167 +++++++++++++++++- frontend/src/interfaces/chat.ts | 16 ++ frontend/src/interfaces/processSteps.ts | 1 + 9 files changed, 386 insertions(+), 15 deletions(-) create mode 100644 frontend/src/components/ChatReferenceDrawer.tsx diff --git a/backend/app/api/v1/chat.py b/backend/app/api/v1/chat.py index 6afe530..a63a410 100644 --- a/backend/app/api/v1/chat.py +++ b/backend/app/api/v1/chat.py @@ -1,3 +1,5 @@ +from collections import defaultdict +import re import traceback from typing import Optional from app.database import get_db @@ -24,22 +26,60 @@ class ChatRequest(BaseModel): logger = Logger() +def group_by_start_end(references): + grouped_references = defaultdict( + lambda: {"start": None, "end": None, "references": []} + ) + + for ref in references: + key = (ref["start"], ref["end"]) + + # Initialize start and end if not already set + if grouped_references[key]["start"] is None: + grouped_references[key]["start"] = ref["start"] + grouped_references[key]["end"] = ref["end"] + + # Check if a reference with the same asset_id already exists + existing_ref = None + for existing in grouped_references[key]["references"]: + if ( + existing["asset_id"] == ref["asset_id"] + and existing["page_number"] == ref["page_number"] + ): + existing_ref = existing + break + + if existing_ref: + # Append the source if asset_id already exists + existing_ref["source"].extend(ref["source"]) + else: + # Otherwise, add the new reference + grouped_references[key]["references"].append(ref) + + return list(grouped_references.values()) + + @chat_router.post("/project/{project_id}", status_code=200) def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_db)): try: vectorstore = ChromaDB(f"panda-etl-{project_id}") - docs = vectorstore.get_relevant_docs( + + docs, doc_ids, _ = vectorstore.get_relevant_segments( chat_request.query, settings.max_relevant_docs ) - file_names = project_repository.get_assets_filename( - db, [metadata["doc_id"] for metadata in docs["metadatas"][0]] - ) - extracted_documents = docs["documents"][0] + unique_doc_ids = list(set(doc_ids)) + file_names = project_repository.get_assets_filename(db, unique_doc_ids) + + doc_id_to_filename = { + doc_id: filename for doc_id, filename in zip(unique_doc_ids, file_names) + } + + ordered_file_names = [doc_id_to_filename[doc_id] for doc_id in doc_ids] docs_formatted = [ {"filename": filename, "quote": quote} - for filename, quote in zip(file_names, extracted_documents) + for filename, quote in zip(ordered_file_names, docs) ] api_key = user_repository.get_user_api_key(db) @@ -63,11 +103,72 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d ) conversation_id = str(conversation.id) + content = response["response"] + text_reference = None + text_references = [] + + for reference in response["references"]: + sentence = reference["sentence"] + + closest_docs = [] + closest_metdatas = [] + reference_contexts = [] + for reference_content in reference["references"]: + + doc_sent, _, doc_metadata = vectorstore.get_relevant_segments( + reference_content["sentence"], k=3, num_surrounding_sentences=0 + ) + closest_docs.extend(doc_sent) + closest_metdatas.extend(doc_metadata) + reference_contexts.extend( + [reference_content["sentence"]] * len(doc_sent) + ) + + if not closest_docs: + continue + + for iter, _ in enumerate(closest_docs): + + metadata = closest_metdatas[iter] + + if ( + text_reference is None + or text_reference["asset_id"] != metadata["asset_id"] + ): + if text_reference is not None: + text_references.append(text_reference) + + index = content.find(sentence) + if index != -1: + text_reference = { + "asset_id": metadata["asset_id"], + "project_id": metadata["project_id"], + "page_number": metadata["page_number"], + "filename": ( + metadata["filename"] + if "filename" in metadata + else project_repository.get_assets_filename( + db, [metadata["asset_id"]] + )[0] + ), + "source": [reference_contexts[iter]], + } + + text_reference["start"] = index + text_reference["end"] = index + len(sentence) + + elif text_reference["asset_id"] == metadata["asset_id"]: + index = content.find(sentence) + text_reference["end"] = index + len(sentence) + + # group text references based on start and end + refs = group_by_start_end(text_references) + conversation_repository.create_conversation_message( db, conversation_id=conversation_id, query=chat_request.query, - response=response["response"], + response=content, ) return { @@ -75,7 +176,8 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d "message": "chat response successfully returned!", "data": { "conversation_id": conversation_id, - "response": response["response"], + "response": content, + "response_references": refs, }, } diff --git a/backend/app/processing/file_preprocessing.py b/backend/app/processing/file_preprocessing.py index 259db33..0d66ba6 100644 --- a/backend/app/processing/file_preprocessing.py +++ b/backend/app/processing/file_preprocessing.py @@ -22,7 +22,7 @@ def process_file(asset_id: int): file_preprocessor.submit(preprocess_file, asset_id) -def process_segmentation(project_id: int, asset_content_id: int, api_key: str): +def process_segmentation(project_id: int, asset_content_id: int, asset_file_name: str): try: with SessionLocal() as db: asset_content = project_repository.get_asset_content(db, asset_content_id) @@ -37,6 +37,7 @@ def process_segmentation(project_id: int, asset_content_id: int, api_key: str): metadatas=[ { "asset_id": asset_content.asset_id, + "filename": asset_file_name, "project_id": project_id, "page_number": asset_content.content["page_number_data"][index], } @@ -117,7 +118,7 @@ def preprocess_file(asset_id: int): process_segmentation, asset.project_id, asset_content.id, - api_key, + asset.filename, ) except Exception as e: diff --git a/backend/app/vectorstore/chroma.py b/backend/app/vectorstore/chroma.py index 9f8fd88..c99a725 100644 --- a/backend/app/vectorstore/chroma.py +++ b/backend/app/vectorstore/chroma.py @@ -141,6 +141,57 @@ def get_relevant_docs( relevant_data, self._similarity_threshold ) + def get_relevant_segments( + self, + question: str, + k: int = None, + num_surrounding_sentences: int = 3, + ) -> list[dict]: + k = k or self._max_samples + + relevant_docs = self.get_relevant_docs( + question, + k=k, + ) + + segments = [] + doc_ids = [] + metadatas = [] + # Iterate over each document's metadata and fetch surrounding sentences + for index, metadata in enumerate(relevant_docs["metadatas"][0]): + pdf_content = "" + segment_data = [relevant_docs["documents"][0][index]] + + # Get previous sentences + prev_id = metadata.get("previous_sentence_id") + for _ in range(num_surrounding_sentences): + if prev_id != -1: + prev_sentence = self.get_relevant_docs_by_id(ids=[prev_id]) + segment_data = [prev_sentence["documents"][0]] + segment_data + prev_id = prev_sentence["metadatas"][0].get( + "previous_sentence_id", -1 + ) + else: + break + + # Get next sentences + next_id = metadata.get("next_sentence_id") + for _ in range(num_surrounding_sentences): + if next_id != -1: + next_sentence = self.get_relevant_docs_by_id(ids=[next_id]) + segment_data.append(next_sentence["documents"][0]) + next_id = next_sentence["metadatas"][0].get("next_sentence_id", -1) + else: + break + + # Add the segment data to the PDF content + pdf_content += "\n" + " ".join(segment_data) + segments.append(pdf_content) + doc_ids.append(metadata["asset_id"]) + metadatas.append(metadata) + + return (segments, doc_ids, metadatas) + def get_relevant_docs_by_id(self, ids: Iterable[str]) -> List[dict]: """ Returns relevant question answers based on ids diff --git a/frontend/src/app/(app)/projects/[projectId]/page.tsx b/frontend/src/app/(app)/projects/[projectId]/page.tsx index b241869..d1780f9 100644 --- a/frontend/src/app/(app)/projects/[projectId]/page.tsx +++ b/frontend/src/app/(app)/projects/[projectId]/page.tsx @@ -134,7 +134,7 @@ export default function Project() { const projectTabs = [ { id: "assets", label: "Docs" }, { id: "processes", label: "Processes" }, - // { id: "chat", label: "Chat", badge: "beta" }, + { id: "chat", label: "Chat", badge: "beta" }, ]; const breadcrumbItems = [ diff --git a/frontend/src/components/ChatBox.tsx b/frontend/src/components/ChatBox.tsx index 9e54ffb..df7231e 100644 --- a/frontend/src/components/ChatBox.tsx +++ b/frontend/src/components/ChatBox.tsx @@ -8,6 +8,7 @@ import ChatLoader from "./ChatLoader"; import { useQuery } from "@tanstack/react-query"; import { motion } from "framer-motion"; import ChatBubble from "@/components/ui/ChatBubble"; +import { ChatReferences } from "@/interfaces/chat"; export const NoChatPlaceholder = ({ isLoading }: { isLoading: boolean }) => { return ( @@ -31,6 +32,7 @@ export const NoChatPlaceholder = ({ isLoading }: { isLoading: boolean }) => { interface ChatMessage { sender: string; text: string; + references?: Array; timestamp: Date; } @@ -81,6 +83,7 @@ const ChatBox = ({ const bot_response = { sender: "bot", text: response.response, + references: response.response_references, timestamp: new Date(), }; setLoading(false); @@ -121,6 +124,7 @@ const ChatBox = ({ diff --git a/frontend/src/components/ChatReferenceDrawer.tsx b/frontend/src/components/ChatReferenceDrawer.tsx new file mode 100644 index 0000000..6c5a6f8 --- /dev/null +++ b/frontend/src/components/ChatReferenceDrawer.tsx @@ -0,0 +1,37 @@ +"use client"; +import React from "react"; +import Drawer from "./ui/Drawer"; +import HighlightPdfViewer from "../ee/components/HighlightPdfViewer"; +import { FlattenedSource, Source } from "@/interfaces/processSteps"; +import { BASE_STORAGE_URL } from "@/constants"; + +interface IProps { + filename: string; + project_id: number; + sources: FlattenedSource[]; + isOpen?: boolean; + onCancel: () => void; +} + +const ChatReferenceDrawer = ({ + isOpen = true, + project_id, + sources, + filename, + onCancel, +}: IProps) => { + let file_url = null; + if (project_id) { + file_url = `${BASE_STORAGE_URL}/${project_id}/${filename}`; + } + + return ( + + {file_url && ( + + )} + + ); +}; + +export default ChatReferenceDrawer; diff --git a/frontend/src/components/ui/ChatBubble.tsx b/frontend/src/components/ui/ChatBubble.tsx index 873f6de..6527c0d 100644 --- a/frontend/src/components/ui/ChatBubble.tsx +++ b/frontend/src/components/ui/ChatBubble.tsx @@ -1,13 +1,17 @@ -import React from "react"; +import React, { useEffect, useMemo, useState } from "react"; import ReactMarkdown from "react-markdown"; import remarkGfm from "remark-gfm"; import rehypeRaw from "rehype-raw"; import { markify_text } from "@/lib/utils"; +import { ChatReference, ChatReferences } from "@/interfaces/chat"; +import ChatReferenceDrawer from "../ChatReferenceDrawer"; +import { FileIcon } from "lucide-react"; interface ChatBubbleProps { message: string; timestamp: Date; sender: "user" | "bot"; + references?: ChatReferences[]; } export const ChatBubbleWrapper: React.FC<{ @@ -28,13 +32,168 @@ export const ChatBubbleWrapper: React.FC<{ const ChatBubble: React.FC = ({ message, timestamp, + references, sender, }) => { + const [selectedReference, setSelectedReference] = useState< + ChatReference | undefined + >(); + const [OpenDrawer, setOpenDrawer] = useState(false); + const [indexMap, setIndexMap] = useState<{ [key: string]: number }>({}); + const [flatChatReferences, setFlatChatReferences] = useState( + [], + ); + + const handleReferenceClick = (reference: ChatReference) => { + setSelectedReference(reference); + setOpenDrawer(true); + }; + + useEffect(() => { + if (references) { + const indexMap: { [key: string]: number } = {}; + let counter = 1; + for (const reference_data of references) { + for (const reference of reference_data["references"]) { + const identifier = `${reference.asset_id}_${reference.page_number}`; + if (identifier in indexMap) { + continue; + } else { + indexMap[identifier] = counter; + counter += 1; + } + } + } + setIndexMap(indexMap); + + // preprocess doc references + const flatChatReferences: ChatReference[] = []; + + for (const reference_data of references) { + for (const reference of reference_data["references"]) { + var exists = false; + for (let i = 0; i < flatChatReferences.length; i++) { + const ref = flatChatReferences[i]; + + if ( + ref.asset_id == reference.asset_id && + ref.page_number == reference.page_number + ) { + if (ref.source.includes(reference.source[0])) { + exists = true; + break; + } + ref.source.push(reference.source[0]); + exists = true; + break; + } + } + + if (!exists) { + flatChatReferences.push(reference); + } + } + } + setFlatChatReferences(flatChatReferences); + } + }, [references]); + + let lastEnd = 0; + + const combinedMarkdown = references?.reduce( + (acc, item: ChatReferences, index: number) => { + const beforeText = message.slice(lastEnd, item.end); + + const referenceSpan = item["references"] + .map((item2: ChatReference, index2: number) => { + return `[${indexMap[`${item2.asset_id}_${item2.page_number}`]}]`; + }) + .join(" "); + + acc += `${beforeText}${referenceSpan}`; + + lastEnd = item.end; + + return acc; + }, + "", + ); + + const finalMarkdown = combinedMarkdown + message.slice(lastEnd); + + // Handle click event on the reference markers + const handleMarkerClick = (event: React.MouseEvent) => { + const target = event.target as HTMLElement; + if (target.classList.contains("reference-marker")) { + const splitted_index = target.dataset.index?.split("_"); + if (splitted_index && references) { + const index0 = parseInt(splitted_index[0]); + const index1 = parseInt(splitted_index[1]); + handleReferenceClick(references[index0]["references"][index1]); + } + } + }; + return ( - - {markify_text(message)} - + {references && references.length > 0 ? ( +
+ + {finalMarkdown} + +
+ ) : ( + + {markify_text(message)} + + )} + + {references && references.length > 0 && ( +
+
References
+ {flatChatReferences.map((item: ChatReference, index: number) => { + return ( +
+
+ {indexMap[`${item.asset_id}_${item.page_number}`]}. +
+
handleReferenceClick(item)} + > + + + {item.filename} + +
Page: {item.page_number}
+
+
+ ); + })} +
+ )} + {selectedReference && ( + { + return { + source: item, + page_number: selectedReference.page_number, + filename: selectedReference.filename, + }; + })} + filename={selectedReference.filename} + project_id={selectedReference.project_id} + onCancel={() => setOpenDrawer(false)} + /> + )}
{timestamp.toLocaleTimeString()}
diff --git a/frontend/src/interfaces/chat.ts b/frontend/src/interfaces/chat.ts index 4fcc338..aa7fca6 100644 --- a/frontend/src/interfaces/chat.ts +++ b/frontend/src/interfaces/chat.ts @@ -11,3 +11,19 @@ export interface ChatResponse { export interface ChatStatusResponse { status: boolean; } + +export interface ChatReference { + asset_id: number; + project_id: number; + filename: string; + start: number; + end: number; + page_number: number; + source: string[]; +} + +export interface ChatReferences { + references: ChatReference[]; + start: number; + end: number; +} diff --git a/frontend/src/interfaces/processSteps.ts b/frontend/src/interfaces/processSteps.ts index dd2385b..b87aaf0 100644 --- a/frontend/src/interfaces/processSteps.ts +++ b/frontend/src/interfaces/processSteps.ts @@ -37,4 +37,5 @@ export interface Source { export interface FlattenedSource { source: string; page_number: number; + filename?: string; }