Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[chat]: vectorize extraction result for improved chat content #45

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions backend/app/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d

ordered_file_names = [doc_id_to_filename[doc_id] for doc_id in doc_ids]

extract_vectorstore = ChromaDB(f"panda-etl-extraction-{project_id}",
similarity_threshold=settings.chat_extraction_doc_threshold)

# Extract reference documents from the extraction results from db
extraction_docs = extract_vectorstore.get_relevant_docs(
chat_request.query,
k=settings.chat_extraction_max_docs
)

# Append text from single documents together
for extraction_doc in extraction_docs["metadatas"][0]:
gventuri marked this conversation as resolved.
Show resolved Hide resolved
index = next((i for i, item in enumerate(ordered_file_names) if item == extraction_doc["filename"]), None)
if index is None:
ordered_file_names.append(extraction_doc["filename"])
docs.append(extraction_doc["reference"])
else:
docs[index] = f'{extraction_doc["reference"]}\n\n{docs[index]}'
gventuri marked this conversation as resolved.
Show resolved Hide resolved

docs_formatted = [
{"filename": filename, "quote": quote}
for filename, quote in zip(ordered_file_names, docs)
Expand Down
4 changes: 4 additions & 0 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class Settings(BaseSettings):
openai_api_key: str = ""
openai_embedding_model: str = "text-embedding-ada-002"

# Extraction References for chat
chat_extraction_doc_threshold: float = 0.5
chat_extraction_max_docs: int = 50
gventuri marked this conversation as resolved.
Show resolved Hide resolved

class Config:
env_file = ".env"

Expand Down
48 changes: 47 additions & 1 deletion backend/app/processing/process_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def process_step_task(
# Initial DB operations (open and fetch relevant data)
with SessionLocal() as db:
process = process_repository.get_process(db, process_id)
project_id = process.project_id
process_step = process_repository.get_process_step(db, process_step_id)

filename = process_step.asset.filename
if process.status == ProcessStatus.STOPPED:
return False # Stop processing if the process is stopped

Expand Down Expand Up @@ -84,6 +85,15 @@ def process_step_task(
output_references=data["context"],
)

# vectorize extraction result
try:
vectorize_extraction_process_step(project_id=project_id,
process_step_id=process_step_id,
filename=filename,
references=data["context"])
except Exception :
logger.error(f"Failed to vectorize extraction results for chat {traceback.format_exc()}")

success = True

except CreditLimitExceededException:
Expand Down Expand Up @@ -361,3 +371,39 @@ def update_process_step_status(
process_repository.update_process_step_status(
db, process_step, status, output=output, output_references=output_references
)

def vectorize_extraction_process_step(project_id: int, process_step_id: int, filename: str, references: dict) -> None:
# Vectorize extraction result and dump in database
field_references = {}

# Loop to concatenate sources for each reference
for extraction_references in references:
for extraction_reference in extraction_references:
sources = extraction_reference.get("sources", [])
if sources:
sources_catenated = "\n".join(sources)
field_references.setdefault(extraction_reference["name"], "")
field_references[extraction_reference["name"]] += (
"\n" + sources_catenated if field_references[extraction_reference["name"]] else sources_catenated
)

# Only proceed if there are references to add
if not field_references:
return
gventuri marked this conversation as resolved.
Show resolved Hide resolved

# Initialize Vectorstore
vectorstore = ChromaDB(f"panda-etl-extraction-{project_id}")

docs = [f"{filename} {key}" for key in field_references]
metadatas = [
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": reference
}
for reference in field_references.values()
]

# Add documents to vectorstore
vectorstore.add_docs(docs=docs, metadatas=metadatas)
gventuri marked this conversation as resolved.
Show resolved Hide resolved
107 changes: 106 additions & 1 deletion backend/tests/processing/test_process_queue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from app.requests.schemas import ExtractFieldsResponse
import pytest
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
from app.processing.process_queue import (
handle_exceptions,
extract_process,
update_process_step_status,
find_best_match_for_short_reference,
vectorize_extraction_process_step,
)
from app.exceptions import CreditLimitExceededException
from app.models import ProcessStepStatus
Expand Down Expand Up @@ -180,3 +181,107 @@ def test_chroma_db_initialization(mock_extract_data, mock_chroma):

mock_chroma.assert_called_with(f"panda-etl-{process.project_id}", similarity_threshold=3)
assert mock_chroma.call_count >= 1

@patch('app.processing.process_queue.ChromaDB')
def test_vectorize_extraction_process_step_single_reference(mock_chroma_db):
# Mock ChromaDB instance
mock_vectorstore = MagicMock()
mock_chroma_db.return_value = mock_vectorstore

# Inputs
project_id = 123
process_step_id = 1
filename = "sample_file"
references = [
[
{"name": "field1", "sources": ["source1", "source2"]}
]
]

# Call function
vectorize_extraction_process_step(project_id, process_step_id, filename, references)

# Expected docs and metadata to add to ChromaDB
expected_docs = ["sample_file field1"]
expected_metadatas = [
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": "source1\nsource2"
}
]

# Assertions
mock_vectorstore.add_docs.assert_called_once_with(
docs=expected_docs,
metadatas=expected_metadatas
)

@patch('app.processing.process_queue.ChromaDB')
def test_vectorize_extraction_process_step_multiple_references_concatenation(mock_chroma_db):
# Mock ChromaDB instance
mock_vectorstore = MagicMock()
mock_chroma_db.return_value = mock_vectorstore

# Inputs
project_id = 456
process_step_id = 2
filename = "test_file"
references = [
[
{"name": "field1", "sources": ["source1", "source2"]},
{"name": "field1", "sources": ["source3"]}
],
[
{"name": "field2", "sources": ["source4"]}
]
]

# Call function
vectorize_extraction_process_step(project_id, process_step_id, filename, references)

# Expected docs and metadata to add to ChromaDB
expected_docs = ["test_file field1", "test_file field2"]
expected_metadatas = [
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": "source1\nsource2\nsource3"
},
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": "source4"
}
]

# Assertions
mock_vectorstore.add_docs.assert_called_once_with(
docs=expected_docs,
metadatas=expected_metadatas
)

@patch('app.processing.process_queue.ChromaDB') # Replace with the correct module path
def test_vectorize_extraction_process_step_empty_sources(mock_chroma_db):
# Mock ChromaDB instance
mock_vectorstore = MagicMock()
mock_chroma_db.return_value = mock_vectorstore

# Inputs
project_id = 789
process_step_id = 3
filename = "empty_sources_file"
references = [
[
{"name": "field1", "sources": []}
]
]

# Call function
vectorize_extraction_process_step(project_id, process_step_id, filename, references)

# Expected no calls to add_docs due to empty sources
mock_vectorstore.add_docs.assert_not_called()
Loading