Skip to content

Commit

Permalink
feat(chat): chat on project documents (#18)
Browse files Browse the repository at this point in the history
* feat(chat): chat on top of all documents

* feat(chat): remove extra print statements

* feat[Chat]: chat with references for the answers

* feat[chat]: display multiple references of chat

* feat(chat): push chat reference data

* fix(chat): clean reference ui
  • Loading branch information
ArslanSaleem authored Oct 11, 2024
1 parent 9591060 commit 2574a4f
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 15 deletions.
118 changes: 110 additions & 8 deletions backend/app/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict
import re
import traceback
from typing import Optional
from app.database import get_db
Expand All @@ -24,22 +26,60 @@ class ChatRequest(BaseModel):
logger = Logger()


def group_by_start_end(references):
grouped_references = defaultdict(
lambda: {"start": None, "end": None, "references": []}
)

for ref in references:
key = (ref["start"], ref["end"])

# Initialize start and end if not already set
if grouped_references[key]["start"] is None:
grouped_references[key]["start"] = ref["start"]
grouped_references[key]["end"] = ref["end"]

# Check if a reference with the same asset_id already exists
existing_ref = None
for existing in grouped_references[key]["references"]:
if (
existing["asset_id"] == ref["asset_id"]
and existing["page_number"] == ref["page_number"]
):
existing_ref = existing
break

if existing_ref:
# Append the source if asset_id already exists
existing_ref["source"].extend(ref["source"])
else:
# Otherwise, add the new reference
grouped_references[key]["references"].append(ref)

return list(grouped_references.values())


@chat_router.post("/project/{project_id}", status_code=200)
def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_db)):
try:
vectorstore = ChromaDB(f"panda-etl-{project_id}")
docs = vectorstore.get_relevant_docs(

docs, doc_ids, _ = vectorstore.get_relevant_segments(
chat_request.query, settings.max_relevant_docs
)

file_names = project_repository.get_assets_filename(
db, [metadata["doc_id"] for metadata in docs["metadatas"][0]]
)
extracted_documents = docs["documents"][0]
unique_doc_ids = list(set(doc_ids))
file_names = project_repository.get_assets_filename(db, unique_doc_ids)

doc_id_to_filename = {
doc_id: filename for doc_id, filename in zip(unique_doc_ids, file_names)
}

ordered_file_names = [doc_id_to_filename[doc_id] for doc_id in doc_ids]

docs_formatted = [
{"filename": filename, "quote": quote}
for filename, quote in zip(file_names, extracted_documents)
for filename, quote in zip(ordered_file_names, docs)
]

api_key = user_repository.get_user_api_key(db)
Expand All @@ -63,19 +103,81 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
)
conversation_id = str(conversation.id)

content = response["response"]
text_reference = None
text_references = []

for reference in response["references"]:
sentence = reference["sentence"]

closest_docs = []
closest_metdatas = []
reference_contexts = []
for reference_content in reference["references"]:

doc_sent, _, doc_metadata = vectorstore.get_relevant_segments(
reference_content["sentence"], k=3, num_surrounding_sentences=0
)
closest_docs.extend(doc_sent)
closest_metdatas.extend(doc_metadata)
reference_contexts.extend(
[reference_content["sentence"]] * len(doc_sent)
)

if not closest_docs:
continue

for iter, _ in enumerate(closest_docs):

metadata = closest_metdatas[iter]

if (
text_reference is None
or text_reference["asset_id"] != metadata["asset_id"]
):
if text_reference is not None:
text_references.append(text_reference)

index = content.find(sentence)
if index != -1:
text_reference = {
"asset_id": metadata["asset_id"],
"project_id": metadata["project_id"],
"page_number": metadata["page_number"],
"filename": (
metadata["filename"]
if "filename" in metadata
else project_repository.get_assets_filename(
db, [metadata["asset_id"]]
)[0]
),
"source": [reference_contexts[iter]],
}

text_reference["start"] = index
text_reference["end"] = index + len(sentence)

elif text_reference["asset_id"] == metadata["asset_id"]:
index = content.find(sentence)
text_reference["end"] = index + len(sentence)

# group text references based on start and end
refs = group_by_start_end(text_references)

conversation_repository.create_conversation_message(
db,
conversation_id=conversation_id,
query=chat_request.query,
response=response["response"],
response=content,
)

return {
"status": "success",
"message": "chat response successfully returned!",
"data": {
"conversation_id": conversation_id,
"response": response["response"],
"response": content,
"response_references": refs,
},
}

Expand Down
5 changes: 3 additions & 2 deletions backend/app/processing/file_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def process_file(asset_id: int):
file_preprocessor.submit(preprocess_file, asset_id)


def process_segmentation(project_id: int, asset_content_id: int, api_key: str):
def process_segmentation(project_id: int, asset_content_id: int, asset_file_name: str):
try:
with SessionLocal() as db:
asset_content = project_repository.get_asset_content(db, asset_content_id)
Expand All @@ -37,6 +37,7 @@ def process_segmentation(project_id: int, asset_content_id: int, api_key: str):
metadatas=[
{
"asset_id": asset_content.asset_id,
"filename": asset_file_name,
"project_id": project_id,
"page_number": asset_content.content["page_number_data"][index],
}
Expand Down Expand Up @@ -117,7 +118,7 @@ def preprocess_file(asset_id: int):
process_segmentation,
asset.project_id,
asset_content.id,
api_key,
asset.filename,
)

except Exception as e:
Expand Down
51 changes: 51 additions & 0 deletions backend/app/vectorstore/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,57 @@ def get_relevant_docs(
relevant_data, self._similarity_threshold
)

def get_relevant_segments(
self,
question: str,
k: int = None,
num_surrounding_sentences: int = 3,
) -> list[dict]:
k = k or self._max_samples

relevant_docs = self.get_relevant_docs(
question,
k=k,
)

segments = []
doc_ids = []
metadatas = []
# Iterate over each document's metadata and fetch surrounding sentences
for index, metadata in enumerate(relevant_docs["metadatas"][0]):
pdf_content = ""
segment_data = [relevant_docs["documents"][0][index]]

# Get previous sentences
prev_id = metadata.get("previous_sentence_id")
for _ in range(num_surrounding_sentences):
if prev_id != -1:
prev_sentence = self.get_relevant_docs_by_id(ids=[prev_id])
segment_data = [prev_sentence["documents"][0]] + segment_data
prev_id = prev_sentence["metadatas"][0].get(
"previous_sentence_id", -1
)
else:
break

# Get next sentences
next_id = metadata.get("next_sentence_id")
for _ in range(num_surrounding_sentences):
if next_id != -1:
next_sentence = self.get_relevant_docs_by_id(ids=[next_id])
segment_data.append(next_sentence["documents"][0])
next_id = next_sentence["metadatas"][0].get("next_sentence_id", -1)
else:
break

# Add the segment data to the PDF content
pdf_content += "\n" + " ".join(segment_data)
segments.append(pdf_content)
doc_ids.append(metadata["asset_id"])
metadatas.append(metadata)

return (segments, doc_ids, metadatas)

def get_relevant_docs_by_id(self, ids: Iterable[str]) -> List[dict]:
"""
Returns relevant question answers based on ids
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/app/(app)/projects/[projectId]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ export default function Project() {
const projectTabs = [
{ id: "assets", label: "Docs" },
{ id: "processes", label: "Processes" },
// { id: "chat", label: "Chat", badge: "beta" },
{ id: "chat", label: "Chat", badge: "beta" },
];

const breadcrumbItems = [
Expand Down
4 changes: 4 additions & 0 deletions frontend/src/components/ChatBox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import ChatLoader from "./ChatLoader";
import { useQuery } from "@tanstack/react-query";
import { motion } from "framer-motion";
import ChatBubble from "@/components/ui/ChatBubble";
import { ChatReferences } from "@/interfaces/chat";

export const NoChatPlaceholder = ({ isLoading }: { isLoading: boolean }) => {
return (
Expand All @@ -31,6 +32,7 @@ export const NoChatPlaceholder = ({ isLoading }: { isLoading: boolean }) => {
interface ChatMessage {
sender: string;
text: string;
references?: Array<ChatReferences>;
timestamp: Date;
}

Expand Down Expand Up @@ -81,6 +83,7 @@ const ChatBox = ({
const bot_response = {
sender: "bot",
text: response.response,
references: response.response_references,
timestamp: new Date(),
};
setLoading(false);
Expand Down Expand Up @@ -121,6 +124,7 @@ const ChatBox = ({
<ChatBubble
message={message.text}
timestamp={message.timestamp}
references={message.references}
sender={message.sender as "user" | "bot"}
/>
</motion.div>
Expand Down
37 changes: 37 additions & 0 deletions frontend/src/components/ChatReferenceDrawer.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"use client";
import React from "react";
import Drawer from "./ui/Drawer";
import HighlightPdfViewer from "../ee/components/HighlightPdfViewer";
import { FlattenedSource, Source } from "@/interfaces/processSteps";
import { BASE_STORAGE_URL } from "@/constants";

interface IProps {
filename: string;
project_id: number;
sources: FlattenedSource[];
isOpen?: boolean;
onCancel: () => void;
}

const ChatReferenceDrawer = ({
isOpen = true,
project_id,
sources,
filename,
onCancel,
}: IProps) => {
let file_url = null;
if (project_id) {
file_url = `${BASE_STORAGE_URL}/${project_id}/${filename}`;
}

return (
<Drawer isOpen={isOpen} onClose={onCancel} title="Chat Reference">
{file_url && (
<HighlightPdfViewer file={file_url} highlightSources={sources} />
)}
</Drawer>
);
};

export default ChatReferenceDrawer;
Loading

0 comments on commit 2574a4f

Please sign in to comment.