From 94633224a0f4c3d8156ead5fb993e014ac07d78f Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 31 Jul 2024 13:32:04 -0400 Subject: [PATCH 01/30] changes --- src/backend/services/auth/jwt.py | 1 + src/backend/services/auth/utils.py | 1 - src/backend/services/file.py | 85 +++++++++++++++++++++++++++--- 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/src/backend/services/auth/jwt.py b/src/backend/services/auth/jwt.py index 81364a5036..cc5fe294bc 100644 --- a/src/backend/services/auth/jwt.py +++ b/src/backend/services/auth/jwt.py @@ -59,6 +59,7 @@ def decode_jwt(self, token: str) -> dict: dict: Decoded JWT token payload. """ try: + print("secret key", self.secret_key) decoded_payload = jwt.decode( token, self.secret_key, algorithms=[self.ALGORITHM] ) diff --git a/src/backend/services/auth/utils.py b/src/backend/services/auth/utils.py index 4ac9e3a261..aa686d21fd 100644 --- a/src/backend/services/auth/utils.py +++ b/src/backend/services/auth/utils.py @@ -75,7 +75,6 @@ def get_header_user_id(request: Request) -> str: authorization = request.headers.get("Authorization") _, token = authorization.split(" ") decoded = JWTService().decode_jwt(token) - return decoded["context"]["id"] # Auth disabled else: diff --git a/src/backend/services/file.py b/src/backend/services/file.py index c9d708a10c..6d33658465 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -1,4 +1,6 @@ import io +import os +import uuid from copy import deepcopy from typing import Any, Optional @@ -9,6 +11,7 @@ from pypdf import PdfReader from python_calamine.pandas import pandas_monkeypatch +from backend.services.compass import Compass import backend.crud.conversation as conversation_crud import backend.crud.file as file_crud from backend.config.tools import ToolName @@ -47,8 +50,7 @@ def get_file_service(): class FileService: @property def is_compass_enabled(self) -> bool: - # TODO Scott: add compass env variable anc check here - return False + return os.getenv("ENABLE_COMPASS_FILE_STORAGE", "false").lower() == "true" # All these functions will eventually support file operations on Compass async def create_conversation_files( @@ -94,15 +96,19 @@ async def create_conversation_files( ) ) - uploaded_files = file_crud.batch_create_files(session, files_to_upload) - uploaded_file_ids = [file.id for file in uploaded_files] - for file_id in uploaded_file_ids: + uploaded_files = [] + if self.is_compass_enabled: + uploaded_files = await index_files_to_compass(session, files, user_id) + else: + uploaded_files = file_crud.batch_create_files(session, files_to_upload) + + for file in uploaded_files: conversation_crud.create_conversation_file_association( session, ConversationFileAssociation( conversation_id=conversation_id, user_id=user_id, - file_id=file_id, + file_id=file.id, ), ) @@ -279,6 +285,73 @@ def get_files_by_message_id( if message.file_ids is not None: files = file_crud.get_files_by_ids(session, message.file_ids, user_id) return files + +async def index_files_to_compass(session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str) -> None: + uploaded_files = [] + + for file in files: + filename = file.filename.encode("ascii", "ignore").decode("utf-8") + file_bytes = await file.read() + cleaned_content = file_bytes.decode("utf-8").replace("\x00", "") + cleaned_content_bytes = cleaned_content.encode("utf-8") + new_file_id = uuid.uuid4() + + try: + compass = Compass() + except Exception as e: + print(f"Error initializing Compass: {e}") + + file_text = compass.invoke( + action=Compass.ValidActions.PROCESS_FILE, + parameters={ + "file_id": new_file_id, + "file_text": cleaned_content_bytes, + }, + )[0].content["text"] + + # try: + # compass = Compass() + # compass.invoke( + # action=Compass.ValidActions.CREATE_INDEX, + # parameters={ + # "index": new_file_id, + # }, + # ) + # compass.invoke( + # action=Compass.ValidActions.CREATE, + # parameters={ + # "index": new_file_id, + # "file_id": new_file_id, + # "file_text": file_text, + # }, + # ) + # compass.invoke( + # action=Compass.ValidActions.ADD_CONTEXT, + # parameters={ + # "index": new_file_id, + # "file_id": new_file_id, + # "context": { + # "file_name": filename, + # "file_size": file.size, + # }, + # }, + # ) + # compass.invoke( + # action=Compass.ValidActions.REFRESH, + # parameters={"index": new_file_id}, + # ) + + # uploaded_files.append(File( + # id=new_file_id, + # file_name=filename, + # file_size=file.size, + # file_path=filename, + # user_id=user_id, + # )) + # except Exception as e: + # print(f"Error Creating File Index in Compass: {e}") + + return uploaded_files def attach_conversation_id_to_files( From 93209dac3a950a171d55aed01611daa2208762e4 Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 31 Jul 2024 16:18:39 -0400 Subject: [PATCH 02/30] debug --- src/backend/routers/conversation.py | 5 +++ src/backend/services/auth/jwt.py | 1 - src/backend/services/auth/utils.py | 1 + src/backend/services/compass.py | 24 ++++++------ src/backend/services/file.py | 61 ++++++++++++++++++----------- 5 files changed, 57 insertions(+), 35 deletions(-) diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index b00abb1fcd..68c0162284 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -2,6 +2,7 @@ from fastapi import File as RequestFile from fastapi import Form, HTTPException, Request from fastapi import UploadFile as FastAPIUploadFile +from backend.services.compass import Compass from backend.chat.custom.custom import CustomChat from backend.chat.custom.utils import get_deployment @@ -453,6 +454,10 @@ async def batch_upload_file( ) # TODO: check if file already exists in DB once we have files per agents + compass = Compass() + result = compass.invoke( + action=Compass.ValidActions.CREATE_INDEX, parameters={"index": "foobar"} + ) try: uploaded_files = await get_file_service().create_conversation_files( session, files, user_id, conversation.id diff --git a/src/backend/services/auth/jwt.py b/src/backend/services/auth/jwt.py index cc5fe294bc..81364a5036 100644 --- a/src/backend/services/auth/jwt.py +++ b/src/backend/services/auth/jwt.py @@ -59,7 +59,6 @@ def decode_jwt(self, token: str) -> dict: dict: Decoded JWT token payload. """ try: - print("secret key", self.secret_key) decoded_payload = jwt.decode( token, self.secret_key, algorithms=[self.ALGORITHM] ) diff --git a/src/backend/services/auth/utils.py b/src/backend/services/auth/utils.py index aa686d21fd..4ac9e3a261 100644 --- a/src/backend/services/auth/utils.py +++ b/src/backend/services/auth/utils.py @@ -75,6 +75,7 @@ def get_header_user_id(request: Request) -> str: authorization = request.headers.get("Authorization") _, token = authorization.split(" ") decoded = JWTService().decode_jwt(token) + return decoded["context"]["id"] # Auth disabled else: diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index e0e5566d5f..37288503f9 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -234,23 +234,23 @@ def _process_file(self, parameters: dict, **kwargs: Any) -> None: ) # Check if filename is specified for file-related actions - if not parameters.get("filename", None) and not parameters.get( - "file_text", None - ): - logger.error( - event=f"[Compass] Error processing file: No filename or file_text specified in parameters {parameters}" - ) - return None + # if not parameters.get("filename", None) and not parameters.get( + # "file_text", None + # ): + # logger.error( + # event=f"[Compass] Error processing file: No filename or file_text specified in parameters {parameters}" + # ) + # return None file_id = parameters["file_id"] filename = parameters.get("filename", None) file_text = parameters.get("file_text", None) - if filename and not os.path.exists(filename): - logger.error( - event=f"[Compass] Error processing file: Invalid filename {filename} in parameters {parameters}" - ) - return None + # if filename and not os.path.exists(filename): + # logger.error( + # event=f"[Compass] Error processing file: Invalid filename {filename} in parameters {parameters}" + # ) + # return None parser_config = self.parser_config or parameters.get("parser_config", None) metadata_config = metadata_config = self.metadata_config or parameters.get( diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 6d33658465..8a0debae69 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -21,6 +21,7 @@ from backend.database_models.database import DBSessionDep from backend.database_models.file import File as FileModel from backend.schemas.file import File, UpdateFileRequest +from backend.config.settings import Settings MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -50,7 +51,7 @@ def get_file_service(): class FileService: @property def is_compass_enabled(self) -> bool: - return os.getenv("ENABLE_COMPASS_FILE_STORAGE", "false").lower() == "true" + return True # All these functions will eventually support file operations on Compass async def create_conversation_files( @@ -97,8 +98,9 @@ async def create_conversation_files( ) uploaded_files = [] + compass = None if self.is_compass_enabled: - uploaded_files = await index_files_to_compass(session, files, user_id) + pass else: uploaded_files = file_crud.batch_create_files(session, files_to_upload) @@ -288,26 +290,41 @@ def get_files_by_message_id( async def index_files_to_compass(session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str) -> None: uploaded_files = [] - - for file in files: - filename = file.filename.encode("ascii", "ignore").decode("utf-8") - file_bytes = await file.read() - cleaned_content = file_bytes.decode("utf-8").replace("\x00", "") - cleaned_content_bytes = cleaned_content.encode("utf-8") - new_file_id = uuid.uuid4() - - try: - compass = Compass() - except Exception as e: - print(f"Error initializing Compass: {e}") - - file_text = compass.invoke( - action=Compass.ValidActions.PROCESS_FILE, - parameters={ - "file_id": new_file_id, - "file_text": cleaned_content_bytes, - }, - )[0].content["text"] + compass = None + try: + compass = Compass() + compass.invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={"index": "test_index"}, + ) + except Exception as e: + print(f"Error initializing Compass: {e}") + + + # for file in files: + # filename = file.filename.encode("ascii", "ignore").decode("utf-8") + # file_bytes = await file.read() + # cleaned_content = file_bytes.decode("utf-8").replace("\x00", "") + # cleaned_content_bytes = cleaned_content.encode("utf-8") + # new_file_id = str(uuid.uuid4()) + + # compass = None + # try: + # compass = Compass() + # except Exception as e: + # print(f"Error initializing Compass: {e}") + + # compass.invoke( + # action=Compass.ValidActions.CREATE_INDEX, + # parameters={"index": new_file_id}, + # ) + # file_text = compass.invoke( + # action=Compass.ValidActions.PROCESS_FILE, + # parameters={ + # "file_id": new_file_id, + # "file_text": cleaned_content_bytes, + # }, + # )[0].content["text"] # try: # compass = Compass() From 7c57b68f1f3a26e5ed86a99b7d760fe4061d0390 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 1 Aug 2024 14:34:52 -0400 Subject: [PATCH 03/30] upload works, fetching docs work --- src/backend/routers/conversation.py | 7 +- src/backend/schemas/file.py | 2 +- src/backend/services/compass.py | 3 +- src/backend/services/conversation.py | 2 +- src/backend/services/file.py | 219 ++++++++++++++------------- 5 files changed, 122 insertions(+), 111 deletions(-) diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 68c0162284..4139a6c455 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -454,13 +454,10 @@ async def batch_upload_file( ) # TODO: check if file already exists in DB once we have files per agents - compass = Compass() - result = compass.invoke( - action=Compass.ValidActions.CREATE_INDEX, parameters={"index": "foobar"} - ) + try: uploaded_files = await get_file_service().create_conversation_files( - session, files, user_id, conversation.id + session, files, user_id, conversation.id, ctx, ) except Exception as e: raise HTTPException( diff --git a/src/backend/schemas/file.py b/src/backend/schemas/file.py index acf084767b..35e3f7e7c4 100644 --- a/src/backend/schemas/file.py +++ b/src/backend/schemas/file.py @@ -10,7 +10,7 @@ class File(BaseModel): updated_at: datetime.datetime user_id: str - conversation_id: str + conversation_id: Optional[str] = None file_name: str file_path: str file_size: int = Field(default=0, ge=0) diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index 37288503f9..614a1bd49d 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -157,7 +157,7 @@ def _create(self, parameters: dict, **kwargs: Any) -> Dict[str, str]: docs=compass_docs, ) if error is not None: - message = ("[Compass] Error inserting document: {error}",) + message = (f"[Compass] Error inserting document: {error}",) logger.error(event=message) raise Exception(message) @@ -267,6 +267,7 @@ def _process_file(self, parameters: dict, **kwargs: Any) -> None: custom_context=parameters.get("custom_context", None), ) else: + print("raw parsing") return self._raw_parsing( text=file_text, file_id=file_id, diff --git a/src/backend/services/conversation.py b/src/backend/services/conversation.py index 1e41dd9f2d..99c73898d9 100644 --- a/src/backend/services/conversation.py +++ b/src/backend/services/conversation.py @@ -15,8 +15,8 @@ from backend.schemas.file import File from backend.schemas.message import Message from backend.services.chat import generate_chat_response -from backend.services.context import get_context from backend.services.file import attach_conversation_id_to_files, get_file_service +from backend.services.context import get_context DEFAULT_TITLE = "New Conversation" GENERATE_TITLE_PROMPT = """# TASK diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 8a0debae69..47a04a0f4b 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -2,11 +2,12 @@ import os import uuid from copy import deepcopy +from datetime import datetime from typing import Any, Optional import pandas as pd from docx import Document -from fastapi import HTTPException +from fastapi import HTTPException, Depends from fastapi import UploadFile as FastAPIUploadFile from pypdf import PdfReader from python_calamine.pandas import pandas_monkeypatch @@ -22,6 +23,9 @@ from backend.database_models.file import File as FileModel from backend.schemas.file import File, UpdateFileRequest from backend.config.settings import Settings +from backend.schemas.context import Context +from backend.services.context import get_context + MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -60,6 +64,7 @@ async def create_conversation_files( files: list[FastAPIUploadFile], user_id: str, conversation_id: str, + ctx: Context = Depends(get_context), ) -> list[File]: """ Create files and associations with conversation @@ -73,36 +78,11 @@ async def create_conversation_files( Returns: list[File]: The files that were created """ - files_to_upload = [] - for file in files: - content = await get_file_content(file) - cleaned_content = content.replace("\x00", "") - filename = file.filename.encode("ascii", "ignore").decode("utf-8") - conversation = conversation_crud.get_conversation( - session, conversation_id, user_id - ) - if not conversation: - raise HTTPException( - status_code=404, - detail=f"Conversation with ID: {conversation_id} not found.", - ) - - files_to_upload.append( - FileModel( - file_name=filename, - file_size=file.size, - file_path=filename, - file_content=cleaned_content, - user_id=conversation.user_id, - ) - ) - uploaded_files = [] - compass = None if self.is_compass_enabled: - pass + uploaded_files = await insert_files_in_compass(session, files, user_id, conversation_id) else: - uploaded_files = file_crud.batch_create_files(session, files_to_upload) + uploaded_files = await insert_files_in_db(session, files, user_id, conversation_id) for file in uploaded_files: conversation_crud.create_conversation_file_association( @@ -154,7 +134,10 @@ def get_files_by_agent_id( if artifact.get("type") == "local_file" ] - files = file_crud.get_files_by_ids(session, file_ids, user_id) + if self.is_compass_enabled: + files = get_files_in_compass(file_ids, user_id) + else: + files = file_crud.get_files_by_ids(session, file_ids, user_id) return files @@ -184,7 +167,10 @@ def get_files_by_conversation_id( files = [] if file_ids is not None: - files = file_crud.get_files_by_ids(session, file_ids, user_id) + if self.is_compass_enabled: + files = get_files_in_compass(file_ids, user_id) + else: + files = file_crud.get_files_by_ids(session, file_ids, user_id) return files @@ -235,7 +221,10 @@ def get_files_by_ids( Returns: list[File]: The files that were created """ - files = file_crud.get_files_by_ids(session, file_ids, user_id) + if self.is_compass_enabled: + files = get_files_in_compass(file_ids, user_id) + else: + files = file_crud.get_files_by_ids(session, file_ids, user_id) return files def update_file( @@ -285,88 +274,112 @@ def get_files_by_message_id( message = message_crud.get_message(session, message_id, user_id) files = [] if message.file_ids is not None: - files = file_crud.get_files_by_ids(session, message.file_ids, user_id) + if self.is_compass_enabled: + files = get_files_in_compass(message.file_ids, user_id) + else: + files = file_crud.get_files_by_ids(session, message.file_ids, user_id) return files + +def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: + compass = Compass() + files = [] + for file_id in file_ids: + fetched_doc = compass.invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ).result["doc"]["content"] + + files.append(File( + id=file_id, + file_name=fetched_doc["file_name"], + file_size=fetched_doc["file_size"], + file_path=fetched_doc["file_path"], + user_id=user_id, + created_at=datetime.fromisoformat(fetched_doc["created_at"]), + updated_at=datetime.fromisoformat(fetched_doc["updated_at"]), + )) + + return files + +async def insert_files_in_db(session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str, conversation_id: str) -> list[File]: + files_to_upload = [] + for file in files: + content = await get_file_content(file) + cleaned_content = content.replace("\x00", "") + filename = file.filename.encode("ascii", "ignore").decode("utf-8") + + files_to_upload.append( + FileModel( + file_name=filename, + file_size=file.size, + file_path=filename, + file_content=cleaned_content, + user_id=user_id, + ) + ) + + uploaded_files = file_crud.batch_create_files(session, files_to_upload) + return uploaded_files + -async def index_files_to_compass(session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str) -> None: +async def insert_files_in_compass(session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str, conversation_id: str) -> list[File]: uploaded_files = [] compass = None try: compass = Compass() - compass.invoke( - action=Compass.ValidActions.CREATE_INDEX, - parameters={"index": "test_index"}, - ) except Exception as e: print(f"Error initializing Compass: {e}") + for file in files: + filename = file.filename.encode("ascii", "ignore").decode("utf-8") + file_bytes = await file.read() + cleaned_content = file_bytes.decode("utf-8", errors="ignore").replace("\x00", "") + new_file_id = str(uuid.uuid4()) + + # Create new index for file + compass.invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": new_file_id, + }, + ) + compass.invoke( + action=Compass.ValidActions.CREATE, + parameters={ + "index": new_file_id, + "file_id": new_file_id, + "file_text": file_bytes, + }, + ) + compass.invoke( + action=Compass.ValidActions.ADD_CONTEXT, + parameters={ + "index": new_file_id, + "file_id": new_file_id, + "context": { + "file_name": filename, + "file_path": filename, + "file_size": file.size, + "user_id": user_id, + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), + }, + }, + ) + compass.invoke( + action=Compass.ValidActions.REFRESH, + parameters={"index": new_file_id}, + ) - # for file in files: - # filename = file.filename.encode("ascii", "ignore").decode("utf-8") - # file_bytes = await file.read() - # cleaned_content = file_bytes.decode("utf-8").replace("\x00", "") - # cleaned_content_bytes = cleaned_content.encode("utf-8") - # new_file_id = str(uuid.uuid4()) - - # compass = None - # try: - # compass = Compass() - # except Exception as e: - # print(f"Error initializing Compass: {e}") - - # compass.invoke( - # action=Compass.ValidActions.CREATE_INDEX, - # parameters={"index": new_file_id}, - # ) - # file_text = compass.invoke( - # action=Compass.ValidActions.PROCESS_FILE, - # parameters={ - # "file_id": new_file_id, - # "file_text": cleaned_content_bytes, - # }, - # )[0].content["text"] - - # try: - # compass = Compass() - # compass.invoke( - # action=Compass.ValidActions.CREATE_INDEX, - # parameters={ - # "index": new_file_id, - # }, - # ) - # compass.invoke( - # action=Compass.ValidActions.CREATE, - # parameters={ - # "index": new_file_id, - # "file_id": new_file_id, - # "file_text": file_text, - # }, - # ) - # compass.invoke( - # action=Compass.ValidActions.ADD_CONTEXT, - # parameters={ - # "index": new_file_id, - # "file_id": new_file_id, - # "context": { - # "file_name": filename, - # "file_size": file.size, - # }, - # }, - # ) - # compass.invoke( - # action=Compass.ValidActions.REFRESH, - # parameters={"index": new_file_id}, - # ) - - # uploaded_files.append(File( - # id=new_file_id, - # file_name=filename, - # file_size=file.size, - # file_path=filename, - # user_id=user_id, - # )) - # except Exception as e: - # print(f"Error Creating File Index in Compass: {e}") + uploaded_files.append(FileModel( + file_name=filename, + id=new_file_id, + file_size=file.size, + file_path=filename, + user_id=user_id, + created_at=datetime.now(), + updated_at=datetime.now(), + )) return uploaded_files From daf94e2c6260d760891d04436368a491869ed0aa Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 1 Aug 2024 15:44:50 -0400 Subject: [PATCH 04/30] getting chat to work --- src/backend/chat/custom/custom.py | 2 +- src/backend/config/settings.py | 6 ++++++ src/backend/config/tools.py | 12 +++++------ src/backend/schemas/file.py | 1 + src/backend/services/file.py | 11 ++++++---- src/backend/tools/files.py | 36 +++++++++++++++++++------------ 6 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 7656837888..e7fe7089f1 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -302,7 +302,7 @@ def add_files_to_chat_history( num_words = min(25, word_count) preview = " ".join(file.file_content.split()[:num_words]) - files_message += f"Filename: {file.file_name}\nWord Count: {word_count} Preview: {preview}\n\n" + files_message += f"Filename: {file.file_name}\nFile ID: {file.id}\nWord Count: {word_count} Preview: {preview}\n\n" chat_history.append(ChatMessage(message=files_message, role=ChatRole.SYSTEM)) return chat_history diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 5eae1a6608..e543b8bcd7 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -80,6 +80,12 @@ class FeatureFlags(BaseSettings, BaseModel): "USE_COMMUNITY_FEATURES", "use_community_features" ), ) + use_compass_file_storage: Optional[bool] = Field( + default=False, + validation_alias=AliasChoices( + "USE_COMPASS_FILE_STORAGE", "use_compass_file_storage" + ), + ) class PythonToolSettings(BaseSettings, BaseModel): diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 07fd389428..86ac2af727 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -66,9 +66,9 @@ class ToolName(StrEnum): "type": "str", "required": True, }, - "filenames": { - "description": "A list of one or more uploaded filename strings to search over", - "type": "list", + "files": { + "description": "A list of one or more tuples of (filename, file ID) to search over", + "type": "list[tuple[str, str]]", "required": True, }, }, @@ -82,9 +82,9 @@ class ToolName(StrEnum): display_name="Read Document", implementation=ReadFileTool, parameter_definitions={ - "filename": { - "description": "The name of the attached file to read.", - "type": "str", + "files": { + "description": "A list of one or more tuples of (filename, file ID) to read over", + "type": "list[tuple[str, str]]", "required": True, } }, diff --git a/src/backend/schemas/file.py b/src/backend/schemas/file.py index 35e3f7e7c4..08e924a3af 100644 --- a/src/backend/schemas/file.py +++ b/src/backend/schemas/file.py @@ -11,6 +11,7 @@ class File(BaseModel): user_id: str conversation_id: Optional[str] = None + file_content: Optional[str] = None file_name: str file_path: str file_size: int = Field(default=0, ge=0) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 47a04a0f4b..a7aee9dc60 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -15,7 +15,7 @@ from backend.services.compass import Compass import backend.crud.conversation as conversation_crud import backend.crud.file as file_crud -from backend.config.tools import ToolName +# from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.crud import message as message_crud from backend.database_models.conversation import ConversationFileAssociation @@ -25,6 +25,7 @@ from backend.config.settings import Settings from backend.schemas.context import Context from backend.services.context import get_context +from backend.config.settings import Settings MAX_FILE_SIZE = 20_000_000 # 20MB @@ -55,7 +56,7 @@ def get_file_service(): class FileService: @property def is_compass_enabled(self) -> bool: - return True + return Settings().feature_flags.use_compass_file_storage # All these functions will eventually support file operations on Compass async def create_conversation_files( @@ -120,11 +121,12 @@ def get_files_by_agent_id( files = [] agent_tool_metadata = agent.tools_metadata if agent_tool_metadata is not None: + # fix circular import artifacts = next( tool_metadata.artifacts for tool_metadata in agent_tool_metadata - if tool_metadata.tool_name == ToolName.Read_File - or tool_metadata.tool_name == ToolName.Search_File + if tool_metadata.tool_name == "read_document" + or tool_metadata.tool_name == "search_file" ) # TODO scott: enumerate type names (?), different types for local vs. compass? @@ -294,6 +296,7 @@ def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: file_name=fetched_doc["file_name"], file_size=fetched_doc["file_size"], file_path=fetched_doc["file_path"], + file_content=fetched_doc["text"], user_id=user_id, created_at=datetime.fromisoformat(fetched_doc["created_at"]), updated_at=datetime.fromisoformat(fetched_doc["updated_at"]), diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 6994393f64..3215bad59b 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -2,7 +2,8 @@ import backend.crud.file as file_crud from backend.tools.base import BaseTool - +from backend.services.compass import Compass +from backend.services.file import get_file_service class ReadFileTool(BaseTool): """ @@ -20,14 +21,19 @@ def is_available(cls) -> bool: return True async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: - file_name = parameters.get("filename", "") + files = parameters.get("files", []) session = kwargs.get("session") user_id = kwargs.get("user_id") - if not file_name: + if not files: return [] - files = file_crud.get_files_by_file_names(session, [file_name], user_id) + # files = file_crud.get_files_by_file_names(session, [file_name], user_id) + compass = Compass() + fetched_doc = compass.invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": index_name, "file_id": file_id}, + ).result["doc"]["content"] if not files: return [] @@ -35,9 +41,9 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: file = files[0] return [ { - "text": file.file_content, - "title": file.file_name, - "url": file.file_path, + "text": fetched_doc.get("text", ""), + "title": fetched_doc.get("title", ""), + "url": fetched_doc.get("url", ""), } ] @@ -59,19 +65,21 @@ def is_available(cls) -> bool: async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("search_query") - file_names = parameters.get("filenames") + files = parameters.get("files") session = kwargs.get("session") user_id = kwargs.get("user_id") - if not query or not file_names: + if not query or not files: return [] - file_names = [ - file_name.encode("ascii", "ignore").decode("utf-8") - for file_name in file_names - ] + # file_names = [ + # file_name.encode("ascii", "ignore").decode("utf-8") + # for file_name in file_names + # ] - files = file_crud.get_files_by_file_names(session, file_names, user_id) + # files = file_crud.get_files_by_file_names(session, file_names, user_id) + file_ids = [file_id for file_name, file_id in files] + files = get_file_service().get_files_by_ids(session, file_ids, user_id) if not files: return [] From 39ed17e4124f14c0190ae11ee3b92da4b6b02933 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 1 Aug 2024 16:19:05 -0400 Subject: [PATCH 05/30] cleaing up --- src/backend/config/tools.py | 8 +- src/backend/routers/conversation.py | 8 +- src/backend/services/compass.py | 1 - src/backend/services/conversation.py | 2 +- src/backend/services/file.py | 111 ++++++++++++++++----------- src/backend/tools/files.py | 32 ++++---- 6 files changed, 96 insertions(+), 66 deletions(-) diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 86ac2af727..249ad61acf 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -67,7 +67,7 @@ class ToolName(StrEnum): "required": True, }, "files": { - "description": "A list of one or more tuples of (filename, file ID) to search over", + "description": "A list of files represented as tuples of (filename, file ID) to search over", "type": "list[tuple[str, str]]", "required": True, }, @@ -82,9 +82,9 @@ class ToolName(StrEnum): display_name="Read Document", implementation=ReadFileTool, parameter_definitions={ - "files": { - "description": "A list of one or more tuples of (filename, file ID) to read over", - "type": "list[tuple[str, str]]", + "file": { + "description": "A file represented as a tuple of (filename, file ID) to read over", + "type": "tuple[str, str]", "required": True, } }, diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 4139a6c455..3b9b4ddce9 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -2,7 +2,6 @@ from fastapi import File as RequestFile from fastapi import Form, HTTPException, Request from fastapi import UploadFile as FastAPIUploadFile -from backend.services.compass import Compass from backend.chat.custom.custom import CustomChat from backend.chat.custom.utils import get_deployment @@ -27,6 +26,7 @@ UploadFileResponse, ) from backend.schemas.metrics import DEFAULT_METRICS_AGENT, agent_to_metrics_agent +from backend.services.compass import Compass from backend.services.context import get_context from backend.services.conversation import ( DEFAULT_TITLE, @@ -457,7 +457,11 @@ async def batch_upload_file( try: uploaded_files = await get_file_service().create_conversation_files( - session, files, user_id, conversation.id, ctx, + session, + files, + user_id, + conversation.id, + ctx, ) except Exception as e: raise HTTPException( diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index 614a1bd49d..af8cf0da58 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -267,7 +267,6 @@ def _process_file(self, parameters: dict, **kwargs: Any) -> None: custom_context=parameters.get("custom_context", None), ) else: - print("raw parsing") return self._raw_parsing( text=file_text, file_id=file_id, diff --git a/src/backend/services/conversation.py b/src/backend/services/conversation.py index 99c73898d9..1e41dd9f2d 100644 --- a/src/backend/services/conversation.py +++ b/src/backend/services/conversation.py @@ -15,8 +15,8 @@ from backend.schemas.file import File from backend.schemas.message import Message from backend.services.chat import generate_chat_response -from backend.services.file import attach_conversation_id_to_files, get_file_service from backend.services.context import get_context +from backend.services.file import attach_conversation_id_to_files, get_file_service DEFAULT_TITLE = "New Conversation" GENERATE_TITLE_PROMPT = """# TASK diff --git a/src/backend/services/file.py b/src/backend/services/file.py index a7aee9dc60..463eacecb1 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -7,26 +7,25 @@ import pandas as pd from docx import Document -from fastapi import HTTPException, Depends +from fastapi import Depends, HTTPException from fastapi import UploadFile as FastAPIUploadFile from pypdf import PdfReader from python_calamine.pandas import pandas_monkeypatch -from backend.services.compass import Compass import backend.crud.conversation as conversation_crud import backend.crud.file as file_crud +from backend.config.settings import Settings + # from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.crud import message as message_crud from backend.database_models.conversation import ConversationFileAssociation from backend.database_models.database import DBSessionDep from backend.database_models.file import File as FileModel -from backend.schemas.file import File, UpdateFileRequest -from backend.config.settings import Settings from backend.schemas.context import Context +from backend.schemas.file import File, UpdateFileRequest +from backend.services.compass import Compass from backend.services.context import get_context -from backend.config.settings import Settings - MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -44,6 +43,7 @@ pandas_monkeypatch() file_service = None +compass = None def get_file_service(): @@ -53,12 +53,21 @@ def get_file_service(): return file_service +def get_compass(): + global compass + if compass is None: + try: + compass = Compass() + except Exception as e: + print(f"Error initializing Compass: {e}") + return compass + + class FileService: @property def is_compass_enabled(self) -> bool: return Settings().feature_flags.use_compass_file_storage - # All these functions will eventually support file operations on Compass async def create_conversation_files( self, session: DBSessionDep, @@ -81,9 +90,11 @@ async def create_conversation_files( """ uploaded_files = [] if self.is_compass_enabled: - uploaded_files = await insert_files_in_compass(session, files, user_id, conversation_id) + uploaded_files = await insert_files_in_compass(files, user_id) else: - uploaded_files = await insert_files_in_db(session, files, user_id, conversation_id) + uploaded_files = await insert_files_in_db( + session, files, user_id, conversation_id + ) for file in uploaded_files: conversation_crud.create_conversation_file_association( @@ -282,29 +293,40 @@ def get_files_by_message_id( files = file_crud.get_files_by_ids(session, message.file_ids, user_id) return files + def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: - compass = Compass() files = [] for file_id in file_ids: - fetched_doc = compass.invoke( - action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": file_id, "file_id": file_id}, - ).result["doc"]["content"] - - files.append(File( - id=file_id, - file_name=fetched_doc["file_name"], - file_size=fetched_doc["file_size"], - file_path=fetched_doc["file_path"], - file_content=fetched_doc["text"], - user_id=user_id, - created_at=datetime.fromisoformat(fetched_doc["created_at"]), - updated_at=datetime.fromisoformat(fetched_doc["updated_at"]), - )) + fetched_doc = ( + get_compass() + .invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ) + .result["doc"]["content"] + ) + + files.append( + File( + id=file_id, + file_name=fetched_doc["file_name"], + file_size=fetched_doc["file_size"], + file_path=fetched_doc["file_path"], + file_content=fetched_doc["text"], + user_id=user_id, + created_at=datetime.fromisoformat(fetched_doc["created_at"]), + updated_at=datetime.fromisoformat(fetched_doc["updated_at"]), + ) + ) return files -async def insert_files_in_db(session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str, conversation_id: str) -> list[File]: + +async def insert_files_in_db( + session: DBSessionDep, + files: list[FastAPIUploadFile], + user_id: str, +) -> list[File]: files_to_upload = [] for file in files: content = await get_file_content(file) @@ -324,10 +346,12 @@ async def insert_files_in_db(session: DBSessionDep, files: list[FastAPIUploadFil uploaded_files = file_crud.batch_create_files(session, files_to_upload) return uploaded_files - -async def insert_files_in_compass(session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str, conversation_id: str) -> list[File]: + +async def insert_files_in_compass( + files: list[FastAPIUploadFile], + user_id: str, +) -> list[File]: uploaded_files = [] - compass = None try: compass = Compass() except Exception as e: @@ -336,17 +360,16 @@ async def insert_files_in_compass(session: DBSessionDep, files: list[FastAPIUplo for file in files: filename = file.filename.encode("ascii", "ignore").decode("utf-8") file_bytes = await file.read() - cleaned_content = file_bytes.decode("utf-8", errors="ignore").replace("\x00", "") new_file_id = str(uuid.uuid4()) # Create new index for file - compass.invoke( + get_compass().invoke( action=Compass.ValidActions.CREATE_INDEX, parameters={ "index": new_file_id, }, ) - compass.invoke( + get_compass().invoke( action=Compass.ValidActions.CREATE, parameters={ "index": new_file_id, @@ -354,7 +377,7 @@ async def insert_files_in_compass(session: DBSessionDep, files: list[FastAPIUplo "file_text": file_bytes, }, ) - compass.invoke( + get_compass().invoke( action=Compass.ValidActions.ADD_CONTEXT, parameters={ "index": new_file_id, @@ -369,20 +392,22 @@ async def insert_files_in_compass(session: DBSessionDep, files: list[FastAPIUplo }, }, ) - compass.invoke( + get_compass().invoke( action=Compass.ValidActions.REFRESH, parameters={"index": new_file_id}, ) - uploaded_files.append(FileModel( - file_name=filename, - id=new_file_id, - file_size=file.size, - file_path=filename, - user_id=user_id, - created_at=datetime.now(), - updated_at=datetime.now(), - )) + uploaded_files.append( + File( + file_name=filename, + id=new_file_id, + file_size=file.size, + file_path=filename, + user_id=user_id, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + ) return uploaded_files diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 3215bad59b..180cfaa93f 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -1,9 +1,11 @@ from typing import Any, Dict, List import backend.crud.file as file_crud -from backend.tools.base import BaseTool +from backend.config.settings import Settings from backend.services.compass import Compass from backend.services.file import get_file_service +from backend.tools.base import BaseTool + class ReadFileTool(BaseTool): """ @@ -28,12 +30,12 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: if not files: return [] - # files = file_crud.get_files_by_file_names(session, [file_name], user_id) - compass = Compass() - fetched_doc = compass.invoke( - action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": index_name, "file_id": file_id}, - ).result["doc"]["content"] + if Settings().feature_flags.use_compass_file_storage: + file_ids = [file_id for _, file_id in files] + files = get_file_service().get_files_by_ids(session, file_ids, user_id) + else: + # TODO get file by file id not file name + files = file_crud.get_files_by_file_names(session, file_names, user_id) if not files: return [] @@ -72,14 +74,14 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: if not query or not files: return [] - # file_names = [ - # file_name.encode("ascii", "ignore").decode("utf-8") - # for file_name in file_names - # ] - - # files = file_crud.get_files_by_file_names(session, file_names, user_id) - file_ids = [file_id for file_name, file_id in files] - files = get_file_service().get_files_by_ids(session, file_ids, user_id) + files = [] + if Settings().feature_flags.use_compass_file_storage: + file_ids = [file_id for _, file_id in files] + files = get_file_service().get_files_by_ids(session, file_ids, user_id) + else: + # TODO get file by file id not file name + file_names = [file_name for file_name, _ in files] + files = file_crud.get_files_by_file_names(session, file_names, user_id) if not files: return [] From 1ab404877914aad8a468458fbe58e1110d6b8d47 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 1 Aug 2024 18:05:56 -0400 Subject: [PATCH 06/30] fixed --- src/backend/services/file.py | 10 ++---- src/backend/tools/files.py | 68 +++++++++++++++++++++++++----------- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 463eacecb1..4f4208d587 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -297,15 +297,11 @@ def get_files_by_message_id( def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: files = [] for file_id in file_ids: - fetched_doc = ( - get_compass() - .invoke( + fetched_doc = get_compass().invoke( action=Compass.ValidActions.GET_DOCUMENT, parameters={"index": file_id, "file_id": file_id}, - ) - .result["doc"]["content"] - ) - + ).result["doc"]["content"] + files.append( File( id=file_id, diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 180cfaa93f..a71d32bc7b 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -3,7 +3,7 @@ import backend.crud.file as file_crud from backend.config.settings import Settings from backend.services.compass import Compass -from backend.services.file import get_file_service +from backend.services.file import get_file_service, get_compass from backend.tools.base import BaseTool @@ -14,6 +14,7 @@ class ReadFileTool(BaseTool): NAME = "read_document" MAX_NUM_CHUNKS = 10 + SEARCH_LIMIT = 5 def __init__(self): pass @@ -43,9 +44,9 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: file = files[0] return [ { - "text": fetched_doc.get("text", ""), - "title": fetched_doc.get("title", ""), - "url": fetched_doc.get("url", ""), + "text": file.file_content, + "title": file.file_name, + "url": file.file_path, } ] @@ -57,6 +58,7 @@ class SearchFileTool(BaseTool): NAME = "search_file" MAX_NUM_CHUNKS = 10 + SEARCH_LIMIT = 5 def __init__(self): pass @@ -74,26 +76,52 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: if not query or not files: return [] - files = [] - if Settings().feature_flags.use_compass_file_storage: + compass_file_stroage_enabled = Settings().feature_flags.use_compass_file_storage + retrieved_files = [] + if compass_file_stroage_enabled: file_ids = [file_id for _, file_id in files] - files = get_file_service().get_files_by_ids(session, file_ids, user_id) + retrieved_files = get_file_service().get_files_by_ids(session, file_ids, user_id) else: # TODO get file by file id not file name file_names = [file_name for file_name, _ in files] - files = file_crud.get_files_by_file_names(session, file_names, user_id) + retrieved_files = file_crud.get_files_by_file_names(session, file_names, user_id) - if not files: + if not retrieved_files: return [] - results = [] - for file in files: - results.append( - { - "text": file.file_content, - "title": file.file_name, - "url": file.file_path, - } - ) - - return results + if compass_file_stroage_enabled: + results = [] + for file in retrieved_files: + hits = get_compass().invoke( + action=Compass.ValidActions.SEARCH, + parameters={"index": file.id, "query": query, "top_k": self.SEARCH_LIMIT}, + ).result["hits"] + results.extend(hits) + + chunks = sorted( + [ + { + "text": chunk["content"]["text"], + "score": chunk["score"], + "url": result["content"].get("url", ""), + "title": result["content"].get("title", ""), + } + for result in results + for chunk in result["chunks"] + ], + key=lambda x: x["score"], + reverse=True, + )[:self.SEARCH_LIMIT] + + return chunks + else: + results = [] + for file in files: + results.append( + { + "text": file.file_content, + "title": file.file_name, + "url": file.file_path, + } + ) + return results From 2ed5e0ccd38e5a18bbfbddc29a33d4b3646af22d Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 2 Aug 2024 15:47:24 -0400 Subject: [PATCH 07/30] fixed both search and read docs tool --- src/backend/config/tools.py | 2 +- src/backend/services/file.py | 56 +++++++++++++------ src/backend/tools/files.py | 102 +++++++++++++++++------------------ 3 files changed, 92 insertions(+), 68 deletions(-) diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 249ad61acf..8e40f8e78a 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -83,7 +83,7 @@ class ToolName(StrEnum): implementation=ReadFileTool, parameter_definitions={ "file": { - "description": "A file represented as a tuple of (filename, file ID) to read over", + "description": "A file represented as a tuple (filename, file ID) to read over", "type": "tuple[str, str]", "required": True, } diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 4f4208d587..0b82925449 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -1,9 +1,7 @@ import io import os import uuid -from copy import deepcopy from datetime import datetime -from typing import Any, Optional import pandas as pd from docx import Document @@ -26,6 +24,7 @@ from backend.schemas.file import File, UpdateFileRequest from backend.services.compass import Compass from backend.services.context import get_context +from backend.services.logger.utils import get_logger MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -42,6 +41,7 @@ # Monkey patch Pandas to use Calamine for Excel reading because Calamine is faster than Pandas pandas_monkeypatch() +logger = get_logger() file_service = None compass = None @@ -59,7 +59,7 @@ def get_compass(): try: compass = Compass() except Exception as e: - print(f"Error initializing Compass: {e}") + logger.error(f"Error initializing Compass: {e}") return compass @@ -122,6 +122,8 @@ def get_files_by_agent_id( Returns: list[File]: The files that were created """ + from backend.config.tools import ToolName + agent = agent_crud.get_agent_by_id(session, agent_id) if agent is None: raise HTTPException( @@ -132,12 +134,11 @@ def get_files_by_agent_id( files = [] agent_tool_metadata = agent.tools_metadata if agent_tool_metadata is not None: - # fix circular import artifacts = next( tool_metadata.artifacts for tool_metadata in agent_tool_metadata - if tool_metadata.tool_name == "read_document" - or tool_metadata.tool_name == "search_file" + if tool_metadata.tool_name == ToolName.READ_DOCUMENT + or tool_metadata.tool_name == ToolName.SEARCH_FILE ) # TODO scott: enumerate type names (?), different types for local vs. compass? @@ -217,12 +218,15 @@ def get_file_by_id(self, session: DBSessionDep, file_id: str, user_id: str) -> F Returns: File: The file that was created """ - file = file_crud.get_file(session, file_id, user_id) + if self.is_compass_enabled: + file = get_file_in_compass(file_id, user_id) + else: + file = file_crud.get_file(session, file_id, user_id) return file def get_files_by_ids( self, session: DBSessionDep, file_ids: list[str], user_id: str - ) -> list[FileModel]: + ) -> list[File]: """ Get files by IDs @@ -294,14 +298,40 @@ def get_files_by_message_id( return files +def get_file_in_compass(file_id: str, user_id: str) -> File: + fetched_doc = ( + get_compass() + .invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ) + .result["doc"]["content"] + ) + + return File( + id=file_id, + file_name=fetched_doc["file_name"], + file_size=fetched_doc["file_size"], + file_path=fetched_doc["file_path"], + file_content=fetched_doc["text"], + user_id=user_id, + created_at=datetime.fromisoformat(fetched_doc["created_at"]), + updated_at=datetime.fromisoformat(fetched_doc["updated_at"]), + ) + + def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: files = [] for file_id in file_ids: - fetched_doc = get_compass().invoke( + fetched_doc = ( + get_compass() + .invoke( action=Compass.ValidActions.GET_DOCUMENT, parameters={"index": file_id, "file_id": file_id}, - ).result["doc"]["content"] - + ) + .result["doc"]["content"] + ) + files.append( File( id=file_id, @@ -348,10 +378,6 @@ async def insert_files_in_compass( user_id: str, ) -> list[File]: uploaded_files = [] - try: - compass = Compass() - except Exception as e: - print(f"Error initializing Compass: {e}") for file in files: filename = file.filename.encode("ascii", "ignore").decode("utf-8") diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index a71d32bc7b..e83c7e443b 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -2,11 +2,45 @@ import backend.crud.file as file_crud from backend.config.settings import Settings +from backend.schemas.file import File from backend.services.compass import Compass -from backend.services.file import get_file_service, get_compass +from backend.services.file import get_compass, get_file_service from backend.tools.base import BaseTool +def compass_file_search( + file_ids: List[str], query: str, search_limit: int = 5 +) -> List[Dict[str, Any]]: + results = [] + for file_id in file_ids: + hits = ( + get_compass() + .invoke( + action=Compass.ValidActions.SEARCH, + parameters={"index": file_id, "query": query, "top_k": search_limit}, + ) + .result["hits"] + ) + results.extend(hits) + + chunks = sorted( + [ + { + "text": chunk["content"]["text"], + "score": chunk["score"], + "url": result["content"].get("title", ""), + "title": result["content"].get("title", ""), + } + for result in results + for chunk in result["chunks"] + ], + key=lambda x: x["score"], + reverse=True, + )[:search_limit] + + return chunks + + class ReadFileTool(BaseTool): """ This class reads a file from the file system. @@ -24,29 +58,22 @@ def is_available(cls) -> bool: return True async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: - files = parameters.get("files", []) + file = parameters.get("file", []) session = kwargs.get("session") user_id = kwargs.get("user_id") - - if not files: + if not file: return [] - if Settings().feature_flags.use_compass_file_storage: - file_ids = [file_id for _, file_id in files] - files = get_file_service().get_files_by_ids(session, file_ids, user_id) - else: - # TODO get file by file id not file name - files = file_crud.get_files_by_file_names(session, file_names, user_id) - - if not files: + _, file_id = file + retrieved_file = get_file_service().get_file_by_id(session, file_id, user_id) + if not retrieved_file: return [] - file = files[0] return [ { - "text": file.file_content, - "title": file.file_name, - "url": file.file_path, + "text": retrieved_file.file_content, + "title": retrieved_file.file_name, + "url": retrieved_file.file_path, } ] @@ -76,45 +103,16 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: if not query or not files: return [] - compass_file_stroage_enabled = Settings().feature_flags.use_compass_file_storage - retrieved_files = [] - if compass_file_stroage_enabled: + if Settings().feature_flags.use_compass_file_storage: file_ids = [file_id for _, file_id in files] - retrieved_files = get_file_service().get_files_by_ids(session, file_ids, user_id) + return compass_file_search(file_ids, query, search_limit=self.SEARCH_LIMIT) else: - # TODO get file by file id not file name - file_names = [file_name for file_name, _ in files] - retrieved_files = file_crud.get_files_by_file_names(session, file_names, user_id) - - if not retrieved_files: - return [] + retrieved_files = get_file_service().get_files_by_ids( + session, file_ids, user_id + ) + if not retrieved_files: + return [] - if compass_file_stroage_enabled: - results = [] - for file in retrieved_files: - hits = get_compass().invoke( - action=Compass.ValidActions.SEARCH, - parameters={"index": file.id, "query": query, "top_k": self.SEARCH_LIMIT}, - ).result["hits"] - results.extend(hits) - - chunks = sorted( - [ - { - "text": chunk["content"]["text"], - "score": chunk["score"], - "url": result["content"].get("url", ""), - "title": result["content"].get("title", ""), - } - for result in results - for chunk in result["chunks"] - ], - key=lambda x: x["score"], - reverse=True, - )[:self.SEARCH_LIMIT] - - return chunks - else: results = [] for file in files: results.append( From a0932532a6efd8823c1525ea5c8aee2ed059d057 Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 5 Aug 2024 15:40:52 -0400 Subject: [PATCH 08/30] saving --- src/backend/services/file.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 0b82925449..b9f9609eb5 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -203,7 +203,11 @@ def delete_file_from_conversation( conversation_crud.delete_conversation_file_association( session, conversation_id, file_id, user_id ) - file_crud.delete_file(session, file_id, user_id) + + if self.is_compass_enabled: + delete_file_in_compass(file_id, user_id) + else: + file_crud.delete_file(session, file_id, user_id) return def get_file_by_id(self, session: DBSessionDep, file_id: str, user_id: str) -> File: @@ -244,6 +248,7 @@ def get_files_by_ids( files = file_crud.get_files_by_ids(session, file_ids, user_id) return files + # TODO: compass def update_file( self, session: DBSessionDep, file: File, new_file: UpdateFileRequest ) -> File: @@ -261,6 +266,7 @@ def update_file( updated_file = file_crud.update_file(session, file, new_file) return updated_file + # TODO: compass def bulk_delete_files( self, session: DBSessionDep, file_ids: list[str], user_id: str ) -> None: @@ -298,6 +304,13 @@ def get_files_by_message_id( return files +# Compass Operations +def delete_file_in_compass(file_id: str, user_id: str) -> None: + get_compass().invoke( + action=Compass.ValidActions.DELETE_INDEX, + parameters={"index": file_id} + ) + def get_file_in_compass(file_id: str, user_id: str) -> File: fetched_doc = ( get_compass() From 6efe82f5c8fc4fc54fc96fb0284e99f9f8a8f750 Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 5 Aug 2024 16:40:48 -0400 Subject: [PATCH 09/30] remove update, add delete --- src/backend/routers/conversation.py | 44 ++--------------------------- src/backend/services/compass.py | 2 +- src/backend/services/file.py | 30 +++++++------------- 3 files changed, 14 insertions(+), 62 deletions(-) diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 3b9b4ddce9..a647a93af3 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -502,42 +502,6 @@ async def list_files( return files_with_conversation_id -@router.put("/{conversation_id}/files/{file_id}", response_model=FilePublic) -async def update_file( - conversation_id: str, - file_id: str, - new_file: UpdateFileRequest, - session: DBSessionDep, - ctx: Context = Depends(get_context), -) -> FilePublic: - """ - Update a file by ID. - - Args: - conversation_id (str): Conversation ID. - file_id (str): File ID. - new_file (UpdateFileRequest): New file data. - session (DBSessionDep): Database session. - ctx (Context): Context object. - - Returns: - FilePublic: Updated file. - - Raises: - HTTPException: If the conversation with the given ID is not found. - """ - user_id = ctx.get_user_id() - _ = validate_conversation(session, conversation_id, user_id) - _ = validate_file(session, file_id, user_id) - - file = get_file_service().get_file_by_id(session, file_id, user_id) - file = get_file_service().update_file(session, file, new_file) - files_with_conversation_id = attach_conversation_id_to_files( - conversation_id, [file] - ) - return files_with_conversation_id[0] - - @router.delete("/{conversation_id}/files/{file_id}") async def delete_file( conversation_id: str, @@ -560,13 +524,11 @@ async def delete_file( HTTPException: If the conversation with the given ID is not found. """ user_id = ctx.get_user_id() - _ = validate_conversation(session, conversation_id, user_id) - _ = validate_file(session, file_id, user_id) - - file = get_file_service().get_file_by_id(session, file_id, user_id) + # _ = validate_conversation(session, conversation_id, user_id) + # _ = validate_file(session, file_id, user_id) # Delete the File DB object - get_file_service().delete_file_from_conversation( + get_file_service().delete_file_by_id( session, conversation_id, file_id, user_id ) diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index af8cf0da58..25956a3db6 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -111,7 +111,7 @@ def invoke( return self.compass_client.create_index( index_name=parameters["index"] ) - case self.ValidActions.CREATE_INDEX.value: + case self.ValidActions.DELETE_INDEX.value: return self.compass_client.delete_index( index_name=parameters["index"] ) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index b9f9609eb5..5b983ce989 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -188,7 +188,7 @@ def get_files_by_conversation_id( return files - def delete_file_from_conversation( + def delete_file_by_id( self, session: DBSessionDep, conversation_id: str, file_id: str, user_id: str ) -> None: """ @@ -208,6 +208,7 @@ def delete_file_from_conversation( delete_file_in_compass(file_id, user_id) else: file_crud.delete_file(session, file_id, user_id) + return def get_file_by_id(self, session: DBSessionDep, file_id: str, user_id: str) -> File: @@ -248,25 +249,7 @@ def get_files_by_ids( files = file_crud.get_files_by_ids(session, file_ids, user_id) return files - # TODO: compass - def update_file( - self, session: DBSessionDep, file: File, new_file: UpdateFileRequest - ) -> File: - """ - Update file - - Args: - session (DBSessionDep): The database session - file (File): The file to update - new_file (UpdateFileRequest): The new file data - Returns: - File: The updated file - """ - updated_file = file_crud.update_file(session, file, new_file) - return updated_file - - # TODO: compass def bulk_delete_files( self, session: DBSessionDep, file_ids: list[str], user_id: str ) -> None: @@ -278,7 +261,13 @@ def bulk_delete_files( file_ids (list[str]): The file IDs user_id (str): The user ID """ - file_crud.bulk_delete_files(session, file_ids, user_id) + if self.is_compass_enabled: + for file_id in file_ids: + delete_file_in_compass(file_id, user_id) + else: + file_crud.bulk_delete_files(session, file_ids, user_id) + + return def get_files_by_message_id( self, session: DBSessionDep, message_id: str, user_id: str @@ -306,6 +295,7 @@ def get_files_by_message_id( # Compass Operations def delete_file_in_compass(file_id: str, user_id: str) -> None: + # todo: validate all files exists before deleting get_compass().invoke( action=Compass.ValidActions.DELETE_INDEX, parameters={"index": file_id} From e81cd6d148002ac8c4285fef5ab201fb7fb90b6d Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 6 Aug 2024 12:47:00 -0400 Subject: [PATCH 10/30] fixes --- src/backend/routers/conversation.py | 4 +--- src/backend/services/file.py | 7 +++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index a647a93af3..f30a389550 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -528,9 +528,7 @@ async def delete_file( # _ = validate_file(session, file_id, user_id) # Delete the File DB object - get_file_service().delete_file_by_id( - session, conversation_id, file_id, user_id - ) + get_file_service().delete_file_by_id(session, conversation_id, file_id, user_id) return DeleteFileResponse() diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 5b983ce989..1ee6dd8059 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -249,7 +249,6 @@ def get_files_by_ids( files = file_crud.get_files_by_ids(session, file_ids, user_id) return files - def bulk_delete_files( self, session: DBSessionDep, file_ids: list[str], user_id: str ) -> None: @@ -267,7 +266,7 @@ def bulk_delete_files( else: file_crud.bulk_delete_files(session, file_ids, user_id) - return + return def get_files_by_message_id( self, session: DBSessionDep, message_id: str, user_id: str @@ -297,10 +296,10 @@ def get_files_by_message_id( def delete_file_in_compass(file_id: str, user_id: str) -> None: # todo: validate all files exists before deleting get_compass().invoke( - action=Compass.ValidActions.DELETE_INDEX, - parameters={"index": file_id} + action=Compass.ValidActions.DELETE_INDEX, parameters={"index": file_id} ) + def get_file_in_compass(file_id: str, user_id: str) -> File: fetched_doc = ( get_compass() From c8d93b6c5f547ce10ce4197186f528ed8e8ff2a0 Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 6 Aug 2024 13:06:01 -0400 Subject: [PATCH 11/30] Fix tests --- src/backend/routers/conversation.py | 5 +- src/backend/services/file.py | 27 ++++++- .../tests/routers/test_conversation.py | 79 ------------------- src/backend/tests/services/test_file.py | 0 4 files changed, 28 insertions(+), 83 deletions(-) create mode 100644 src/backend/tests/services/test_file.py diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index f30a389550..06d10f6f38 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -42,7 +42,6 @@ attach_conversation_id_to_files, get_file_service, validate_batch_file_size, - validate_file, validate_file_size, ) from backend.services.logger.utils import get_logger @@ -524,8 +523,8 @@ async def delete_file( HTTPException: If the conversation with the given ID is not found. """ user_id = ctx.get_user_id() - # _ = validate_conversation(session, conversation_id, user_id) - # _ = validate_file(session, file_id, user_id) + _ = validate_conversation(session, conversation_id, user_id) + get_file_service().validate_file(session, file_id, user_id) # Delete the File DB object get_file_service().delete_file_by_id(session, conversation_id, file_id, user_id) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 1ee6dd8059..b85256d46d 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -291,6 +291,31 @@ def get_files_by_message_id( files = file_crud.get_files_by_ids(session, message.file_ids, user_id) return files + def validate_file(self, session: DBSessionDep, file_id: str, user_id: str) -> File: + """Validates if a file exists and belongs to the user + + Args: + session (DBSessionDep): Database session + file_id (str): File ID + user_id (str): User ID + + Returns: + File: File object + + Raises: + HTTPException: If the file is not found + """ + if self.is_compass_enabled: + file = get_file_in_compass(file_id, user_id) + else: + file = file_crud.get_file(session, file_id, user_id) + + if not file: + raise HTTPException( + status_code=404, + detail=f"File with ID: {file_id} not found.", + ) + # Compass Operations def delete_file_in_compass(file_id: str, user_id: str) -> None: @@ -605,7 +630,7 @@ def validate_batch_file_size( user_id (str): The user ID files (list[FastAPIUploadFile]): The files to validate - Raises: + Raises:p HTTPException: If the file size is too large """ total_batch_size = 0 diff --git a/src/backend/tests/routers/test_conversation.py b/src/backend/tests/routers/test_conversation.py index dd5fcc9788..0b3faae5bc 100644 --- a/src/backend/tests/routers/test_conversation.py +++ b/src/backend/tests/routers/test_conversation.py @@ -926,85 +926,6 @@ def test_batch_upload_file_nonexistent_conversation_fails_if_user_id_not_provide assert response.json() == {"detail": "User-Id required in request headers."} -def test_update_file_name( - session_client: TestClient, - session: Session, - user: User, -) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) - file = get_factory("File", session).create( - file_name="test_file.txt", - user_id=conversation.user_id, - ) - response = session_client.put( - f"/v1/conversations/{conversation.id}/files/{file.id}", - headers={"User-Id": conversation.user_id}, - json={"file_name": "new name"}, - ) - response_file = response.json() - - assert response.status_code == 200 - assert response_file["file_name"] == "new name" - - # Check if the file was updated - file = ( - session.query(File).filter_by(id=file.id, user_id=conversation.user_id).first() - ) - assert file is not None - assert file.file_name == "new name" - - -def test_fail_update_nonexistent_file( - session_client: TestClient, - session: Session, - user: User, -) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) - response = session_client.put( - f"/v1/conversations/{conversation.id}/files/123", - json={"file_name": "new name"}, - headers={"User-Id": conversation.user_id}, - ) - - assert response.status_code == 404 - assert response.json() == {"detail": f"File with ID: 123 not found."} - - -def test_fail_update_nonexistent_file( - session_client: TestClient, - session: Session, - user: User, -) -> None: - response = session_client.put( - f"/v1/conversations/123/files/123", - json={"file_name": "new name"}, - headers={"User-Id": user.id}, - ) - - assert response.status_code == 404 - assert response.json() == {"detail": f"Conversation with ID: 123 not found."} - - -def test_fail_update_file_missing_user_id( - session_client: TestClient, - session: Session, - user: User, -) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) - file = get_factory("File", session).create( - file_name="test_file.txt", - user_id=conversation.user_id, - ) - - response = session_client.put( - f"/v1/conversations/{conversation.id}/files/{file.id}", - json={"file_name": "new name"}, - ) - - assert response.status_code == 401 - assert response.json() == {"detail": "User-Id required in request headers."} - - def test_delete_file( session_client: TestClient, session: Session, diff --git a/src/backend/tests/services/test_file.py b/src/backend/tests/services/test_file.py new file mode 100644 index 0000000000..e69de29bb2 From 00a9926c77c0643cefb7d13b5bcd3ec07b46b40e Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 6 Aug 2024 16:04:06 -0400 Subject: [PATCH 12/30] fixing tests --- src/backend/routers/conversation.py | 3 +- src/backend/services/file.py | 71 ++++-------- src/backend/tests/conftest.py | 27 +++++ .../tests/routers/test_conversation.py | 106 ++++++------------ 4 files changed, 88 insertions(+), 119 deletions(-) diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 06d10f6f38..d25bd65184 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -42,6 +42,7 @@ attach_conversation_id_to_files, get_file_service, validate_batch_file_size, + validate_file, validate_file_size, ) from backend.services.logger.utils import get_logger @@ -524,7 +525,7 @@ async def delete_file( """ user_id = ctx.get_user_id() _ = validate_conversation(session, conversation_id, user_id) - get_file_service().validate_file(session, file_id, user_id) + validate_file(session, file_id, user_id) # Delete the File DB object get_file_service().delete_file_by_id(session, conversation_id, file_id, user_id) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index b85256d46d..1883375f5c 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -60,6 +60,7 @@ def get_compass(): compass = Compass() except Exception as e: logger.error(f"Error initializing Compass: {e}") + raise e return compass @@ -92,9 +93,7 @@ async def create_conversation_files( if self.is_compass_enabled: uploaded_files = await insert_files_in_compass(files, user_id) else: - uploaded_files = await insert_files_in_db( - session, files, user_id, conversation_id - ) + uploaded_files = await insert_files_in_db(session, files, user_id) for file in uploaded_files: conversation_crud.create_conversation_file_association( @@ -291,30 +290,31 @@ def get_files_by_message_id( files = file_crud.get_files_by_ids(session, message.file_ids, user_id) return files - def validate_file(self, session: DBSessionDep, file_id: str, user_id: str) -> File: - """Validates if a file exists and belongs to the user - Args: - session (DBSessionDep): Database session - file_id (str): File ID - user_id (str): User ID +def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: + """Validates if a file exists and belongs to the user - Returns: - File: File object + Args: + session (DBSessionDep): Database session + file_id (str): File ID + user_id (str): User ID - Raises: - HTTPException: If the file is not found - """ - if self.is_compass_enabled: - file = get_file_in_compass(file_id, user_id) - else: - file = file_crud.get_file(session, file_id, user_id) + Returns: + File: File object - if not file: - raise HTTPException( - status_code=404, - detail=f"File with ID: {file_id} not found.", - ) + Raises: + HTTPException: If the file is not found + """ + if Settings().feature_flags.use_compass_file_storage: + file = get_file_in_compass(file_id, user_id) + else: + file = file_crud.get_file(session, file_id, user_id) + + if not file: + raise HTTPException( + status_code=404, + detail=f"File with ID: {file_id} not found.", + ) # Compass Operations @@ -481,31 +481,6 @@ def attach_conversation_id_to_files( return results -def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: - """Validates if a file exists and belongs to the user - - Args: - session (DBSessionDep): Database session - file_id (str): File ID - user_id (str): User ID - - Returns: - File: File object - - Raises: - HTTPException: If the file is not found - """ - file = file_crud.get_file(session, file_id, user_id) - - if not file: - raise HTTPException( - status_code=404, - detail=f"File with ID: {file_id} not found.", - ) - - return file - - def get_file_extension(file_name: str) -> str: """Returns the file extension diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index a1cff889d6..48833f58aa 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -15,6 +15,7 @@ from backend.schemas.deployment import Deployment from backend.schemas.organization import Organization from backend.schemas.user import User +from backend.services.compass import Compass from backend.tests.factories import get_factory DATABASE_URL = os.environ["DATABASE_URL"] @@ -204,3 +205,29 @@ def mock_available_model_deployments(request): with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: yield mock + + +@pytest.fixture +def mock_compass_settings(): + with patch("backend.services.file.Settings") as MockSettings: + mock_settings = MockSettings.return_value + mock_settings.feature_flags.use_compass_file_storage = os.getenv( + "ENABLE_COMPASS_FILE_STORAGE", False + ) + mock_settings.tools.compass.api_url = os.getenv("COHERE_COMPASS_API_URL") + mock_settings.tools.compass.api_parser_url = os.getenv( + "COHERE_COMPASS_API_PARSER_URL" + ) + mock_settings.tools.compass.username = os.getenv("COHERE_COMPASS_USERNAME") + mock_settings.tools.compass.password = os.getenv("COHERE_COMPASS_PASSWORD") + yield mock_settings + + +@pytest.fixture +def mock_compass(): + with patch("backend.services.file.get_compass") as mock: + try: + compass = Compass() + except Exception as e: + raise e + yield compass diff --git a/src/backend/tests/routers/test_conversation.py b/src/backend/tests/routers/test_conversation.py index 0b3faae5bc..1a41e97314 100644 --- a/src/backend/tests/routers/test_conversation.py +++ b/src/backend/tests/routers/test_conversation.py @@ -16,7 +16,7 @@ ) from backend.schemas.metrics import MetricsData, MetricsMessageType from backend.schemas.user import User -from backend.services.file import MAX_FILE_SIZE, MAX_TOTAL_FILE_SIZE +from backend.services.file import MAX_FILE_SIZE, MAX_TOTAL_FILE_SIZE, get_file_service from backend.tests.factories import get_factory @@ -587,9 +587,7 @@ def test_search_conversations_no_conversations( # FILES def test_list_files( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file = get_factory("File", session).create( file_name="test_file.txt", @@ -616,9 +614,7 @@ def test_list_files( def test_list_files_no_files( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) response = session_client.get( @@ -631,9 +627,7 @@ def test_list_files_no_files( def test_list_files_missing_user_id( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) response = session_client.get(f"/v1/conversations/{conversation.id}/files") @@ -643,12 +637,9 @@ def test_list_files_missing_user_id( def test_upload_file_existing_conversation( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/test_data/Mariana_Trench.pdf" - saved_file_path = "src/backend/data/Mariana_Trench.pdf" conversation = get_factory("Conversation", session).create(user_id=user.id) file_doc = {"file": open(file_path, "rb")} @@ -660,29 +651,25 @@ def test_upload_file_existing_conversation( ) file = response.json() - - file_in_db = session.get(File, file.get("id")) - assert file_in_db is not None assert response.status_code == 200 + files = get_file_service().get_files_by_conversation_id( + session, conversation.user_id, conversation.id + ) + assert len(files) == 1 assert "Mariana_Trench" in file["file_name"] - assert conversation.file_ids == [file_in_db.id] - - # File should not exist in the directory - assert not os.path.exists(saved_file_path) + assert conversation.file_ids == [file["id"]] def test_upload_file_nonexistent_conversation_creates_new_conversation( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/test_data/Mariana_Trench.pdf" - saved_file_path = "src/backend/data/Mariana_Trench.pdf" file_doc = {"file": open(file_path, "rb")} response = session_client.post( "/v1/conversations/upload_file", files=file_doc, headers={"User-Id": user.id} ) + assert response.status_code == 200 file = response.json() @@ -696,22 +683,20 @@ def test_upload_file_nonexistent_conversation_creates_new_conversation( .filter_by(id=conversation_file_association.conversation_id) .first() ) - file_in_db = session.get(File, file.get("id")) - assert file_in_db is not None - assert response.status_code == 200 + + files = get_file_service().get_files_by_conversation_id( + session, created_conversation.user_id, created_conversation.id + ) + assert len(files) == 1 + assert "Mariana_Trench" in file["file_name"] + assert created_conversation.file_ids == [file["id"]] assert created_conversation is not None assert created_conversation.user_id == user.id assert conversation_file_association is not None - assert "Mariana_Trench" in file.get("file_name") - - # File should not exist in the directory - assert not os.path.exists(saved_file_path) def test_upload_file_nonexistent_conversation_fails_if_user_id_not_provided( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/test_data/Mariana_Trench.pdf" file_doc = {"file": open(file_path, "rb")} @@ -723,7 +708,7 @@ def test_upload_file_nonexistent_conversation_fails_if_user_id_not_provided( def test_batch_upload_file_existing_conversation( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: file_paths = { "Mariana_Trench.pdf": "src/backend/tests/test_data/Mariana_Trench.pdf", @@ -731,12 +716,6 @@ def test_batch_upload_file_existing_conversation( "Tapas.pdf": "src/backend/tests/test_data/Tapas.pdf", "Mount_Everest.pdf": "src/backend/tests/test_data/Mount_Everest.pdf", } - saved_file_paths = [ - "src/backend/data/Mariana_Trench.pdf", - "src/backend/data/Cardistry.pdf", - "src/backend/data/Tapas.pdf", - "src/backend/data/Mount_Everest.pdf", - ] files = [ ("files", (file_name, open(file_path, "rb"))) for file_name, file_path in file_paths.items() @@ -767,13 +746,14 @@ def test_batch_upload_file_existing_conversation( ) assert conversation_file_association is not None - # File should not exist in the directory - for saved_file_path in saved_file_paths: - assert not os.path.exists(saved_file_path) + files_stored = get_file_service().get_files_by_conversation_id( + session, conversation.user_id, conversation.id + ) + assert len(files_stored) == len(file_paths) def test_batch_upload_total_files_exceeds_limit( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: _ = get_factory("Conversation", session).create(user_id=user.id) file_paths = { @@ -810,7 +790,7 @@ def test_batch_upload_total_files_exceeds_limit( def test_batch_upload_single_file_exceeds_limit( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: _ = get_factory("Conversation", session).create(user_id=user.id) file_paths = { @@ -840,7 +820,7 @@ def test_batch_upload_single_file_exceeds_limit( def test_batch_upload_file_nonexistent_conversation_creates_new_conversation( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: file_paths = { "Mariana_Trench.pdf": "src/backend/tests/test_data/Mariana_Trench.pdf", @@ -848,12 +828,6 @@ def test_batch_upload_file_nonexistent_conversation_creates_new_conversation( "Tapas.pdf": "src/backend/tests/test_data/Tapas.pdf", "Mount_Everest.pdf": "src/backend/tests/test_data/Mount_Everest.pdf", } - saved_file_paths = [ - "src/backend/data/Mariana_Trench.pdf", - "src/backend/data/Cardistry.pdf", - "src/backend/data/Tapas.pdf", - "src/backend/data/Mount_Everest.pdf", - ] files = [ ("files", (file_name, open(file_path, "rb"))) for file_name, file_path in file_paths.items() @@ -895,13 +869,14 @@ def test_batch_upload_file_nonexistent_conversation_creates_new_conversation( ) assert conversation_file_association is not None - # File should not exist in the directory - for saved_file_path in saved_file_paths: - assert not os.path.exists(saved_file_path) + files_stored = get_file_service().get_files_by_conversation_id( + session, created_conversation.user_id, created_conversation.id + ) + assert len(files_stored) == len(file_paths) def test_batch_upload_file_nonexistent_conversation_fails_if_user_id_not_provided( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_paths = { "Mariana_Trench.pdf": "src/backend/tests/test_data/Mariana_Trench.pdf", @@ -909,12 +884,6 @@ def test_batch_upload_file_nonexistent_conversation_fails_if_user_id_not_provide "Tapas.pdf": "src/backend/tests/test_data/Tapas.pdf", "Mount_Everest.pdf": "src/backend/tests/test_data/Mount_Everest.pdf", } - saved_file_paths = [ - "src/backend/data/Mariana_Trench.pdf", - "src/backend/data/Cardistry.pdf", - "src/backend/data/Tapas.pdf", - "src/backend/data/Mount_Everest.pdf", - ] files = [ ("files", (file_name, open(file_path, "rb"))) for file_name, file_path in file_paths.items() @@ -930,6 +899,7 @@ def test_delete_file( session_client: TestClient, session: Session, user: User, + mock_compass_settings, ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) file = get_factory("File", session).create( @@ -971,9 +941,7 @@ def test_delete_file( def test_fail_delete_nonexistent_file( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) response = session_client.delete( @@ -986,9 +954,7 @@ def test_fail_delete_nonexistent_file( def test_fail_delete_file_missing_user_id( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) file = get_factory("File", session).create( From b51e35cff5c52956b3e4e81e2143b6009b593d5b Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 7 Aug 2024 10:38:34 -0400 Subject: [PATCH 13/30] saving --- src/backend/services/file.py | 70 +++++++++------- src/backend/tests/conftest.py | 10 --- .../tests/routers/test_conversation.py | 80 ++++++++++--------- 3 files changed, 80 insertions(+), 80 deletions(-) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 1883375f5c..73a0296661 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -1,6 +1,7 @@ import io import os import uuid +import asyncio from datetime import datetime import pandas as pd @@ -326,14 +327,18 @@ def delete_file_in_compass(file_id: str, user_id: str) -> None: def get_file_in_compass(file_id: str, user_id: str) -> File: - fetched_doc = ( - get_compass() - .invoke( - action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": file_id, "file_id": file_id}, + fetched_doc = None + try: + fetched_doc = ( + get_compass() + .invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ) + .result["doc"]["content"] ) - .result["doc"]["content"] - ) + except Exception: + return fetched_doc return File( id=file_id, @@ -350,14 +355,17 @@ def get_file_in_compass(file_id: str, user_id: str) -> File: def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: files = [] for file_id in file_ids: - fetched_doc = ( - get_compass() - .invoke( - action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": file_id, "file_id": file_id}, + try: + fetched_doc = ( + get_compass() + .invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ) + .result["doc"]["content"] ) - .result["doc"]["content"] - ) + except Exception as e: + print(e) files.append( File( @@ -410,14 +418,14 @@ async def insert_files_in_compass( filename = file.filename.encode("ascii", "ignore").decode("utf-8") file_bytes = await file.read() new_file_id = str(uuid.uuid4()) - - # Create new index for file + print("CREATE INDEX") get_compass().invoke( action=Compass.ValidActions.CREATE_INDEX, parameters={ "index": new_file_id, }, ) + print("CREATE DOCUMENT") get_compass().invoke( action=Compass.ValidActions.CREATE, parameters={ @@ -426,21 +434,21 @@ async def insert_files_in_compass( "file_text": file_bytes, }, ) - get_compass().invoke( - action=Compass.ValidActions.ADD_CONTEXT, - parameters={ - "index": new_file_id, - "file_id": new_file_id, - "context": { - "file_name": filename, - "file_path": filename, - "file_size": file.size, - "user_id": user_id, - "created_at": datetime.now().isoformat(), - "updated_at": datetime.now().isoformat(), - }, - }, - ) + # get_compass().invoke( + # action=Compass.ValidActions.ADD_CONTEXT, + # parameters={ + # "index": new_file_id, + # "file_id": new_file_id, + # "context": { + # "file_name": filename, + # "file_path": filename, + # "file_size": file.size, + # "user_id": user_id, + # "created_at": datetime.now().isoformat(), + # "updated_at": datetime.now().isoformat(), + # }, + # }, + # ) get_compass().invoke( action=Compass.ValidActions.REFRESH, parameters={"index": new_file_id}, diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index 48833f58aa..b232007c41 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -221,13 +221,3 @@ def mock_compass_settings(): mock_settings.tools.compass.username = os.getenv("COHERE_COMPASS_USERNAME") mock_settings.tools.compass.password = os.getenv("COHERE_COMPASS_PASSWORD") yield mock_settings - - -@pytest.fixture -def mock_compass(): - with patch("backend.services.file.get_compass") as mock: - try: - compass = Compass() - except Exception as e: - raise e - yield compass diff --git a/src/backend/tests/routers/test_conversation.py b/src/backend/tests/routers/test_conversation.py index 1a41e97314..40298e8d0d 100644 --- a/src/backend/tests/routers/test_conversation.py +++ b/src/backend/tests/routers/test_conversation.py @@ -589,16 +589,18 @@ def test_search_conversations_no_conversations( def test_list_files( session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: - file = get_factory("File", session).create( - file_name="test_file.txt", - user_id=user.id, - ) - conversation = get_factory("Conversation", session).create( - user_id=user.id, - ) - _ = get_factory("ConversationFileAssociation", session).create( - conversation_id=conversation.id, user_id=user.id, file_id=file.id + conversation = get_factory("Conversation", session).create(user_id=user.id) + files = [("files", ("Mariana_Trench.pdf", open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb")))] + + response = session_client.post( + "/v1/conversations/batch_upload_file", + headers={"User-Id": conversation.user_id}, + files=files, + data={"conversation_id": conversation.id}, ) + assert response.status_code == 200 + files = response.json() + uploaded_file = files[0] response = session_client.get( f"/v1/conversations/{conversation.id}/files", @@ -609,8 +611,8 @@ def test_list_files( response = response.json() assert len(response) == 1 response_file = response[0] - assert response_file["id"] == file.id - assert response_file["file_name"] == "test_file.txt" + assert response_file["id"] == uploaded_file["id"] + assert response_file["file_name"] == uploaded_file["file_name"] def test_list_files_no_files( @@ -902,42 +904,42 @@ def test_delete_file( mock_compass_settings, ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) - file = get_factory("File", session).create( - id="file_id", - file_name="test_file.txt", - user_id=conversation.user_id, - ) - _ = get_factory("ConversationFileAssociation", session).create( - conversation_id=conversation.id, user_id=user.id, file_id=file.id + files = [("files", ("Mariana_Trench.pdf", open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb")))] + + response = session_client.post( + "/v1/conversations/batch_upload_file", + headers={"User-Id": conversation.user_id}, + files=files, + data={"conversation_id": conversation.id}, ) + assert response.status_code == 200 + files = response.json() + uploaded_file = files[0] response = session_client.delete( - f"/v1/conversations/{conversation.id}/files/{file.id}", + f"/v1/conversations/{conversation.id}/files/{uploaded_file['id']}", headers={"User-Id": conversation.user_id}, ) - assert response.status_code == 200 assert response.json() == {} - # Check if File - db_file = ( - session.query(File) - .filter(File.id == "file_id", File.user_id == user.id) - .first() - ) - assert db_file is None - - conversation_file_association = ( - session.query(ConversationFileAssociation) - .filter(File.id == "file_id", File.user_id == user.id) - .first() - ) - assert conversation_file_association is None - - conversation = ( - session.query(Conversation).filter(Conversation.id == conversation.id).first() - ) - assert conversation.file_ids == [] + # # Check if File + # file = get_file_service().get_file_by_id( + # session, uploaded_file["id"], conversation.user_id + # ) + # assert file is None + + # conversation_file_association = ( + # session.query(ConversationFileAssociation) + # .filter(File.id == uploaded_file["id"], File.user_id == user.id) + # .first() + # ) + # assert conversation_file_association is None + + # conversation = ( + # session.query(Conversation).filter(Conversation.id == conversation.id).first() + # ) + # assert conversation.file_ids == [] def test_fail_delete_nonexistent_file( From 1856e67602f8bf9cdc0ed9555ba76b85ef59e212 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 8 Aug 2024 00:06:22 -0400 Subject: [PATCH 14/30] refactor indexing --- src/backend/chat/custom/custom.py | 3 +- src/backend/routers/agent.py | 64 +++++- src/backend/routers/chat.py | 4 - src/backend/services/chat.py | 31 +-- src/backend/services/file.py | 188 +++++++++++++----- .../tests/routers/test_conversation.py | 20 +- src/backend/tools/files.py | 46 ++++- 7 files changed, 263 insertions(+), 93 deletions(-) diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index e7fe7089f1..b77a9a2370 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -257,8 +257,9 @@ async def call_tools( model_deployment=deployment_model, user_id=ctx.get_user_id(), trace_id=ctx.get_trace_id(), - agent_id=ctx.get_agent_id(), agent_tool_metadata=ctx.get_agent_tool_metadata(), + agent_id=ctx.get_agent_id(), + conversation_id=ctx.get_conversation_id() ) # If the tool returns a list of outputs, append each output to the tool_results list diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index d998996c25..8413834d57 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -1,6 +1,10 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends +from fastapi import File as RequestFile +from fastapi import HTTPException +from fastapi import UploadFile as FastAPIUploadFile from backend.config.routers import RouterName +from backend.config.settings import Settings from backend.crud import agent as agent_crud from backend.crud import agent_tool_metadata as agent_tool_metadata_crud from backend.database_models.agent import Agent as AgentModel @@ -21,6 +25,13 @@ UpdateAgentToolMetadataRequest, ) from backend.schemas.context import Context +from backend.schemas.file import ( + DeleteFileResponse, + FilePublic, + ListFile, + UpdateFileRequest, + UploadFileResponse, +) from backend.schemas.metrics import ( DEFAULT_METRICS_AGENT, GenericResponseMessage, @@ -33,6 +44,14 @@ validate_agent_tool_metadata_exists, ) from backend.services.context import get_context +from backend.services.file import ( + attach_conversation_id_to_files, + get_file_service, + index_agent_files, + validate_batch_file_size, + validate_file, + validate_file_size, +) from backend.services.request_validators import ( validate_create_agent_request, validate_update_agent_request, @@ -98,6 +117,24 @@ async def create_agent( await update_or_create_tool_metadata( created_agent, tool_metadata, session, ctx ) + + # Consolidate agent files into one index in compass + if get_file_service().is_compass_enabled: + if agent_tool_metadata is not None: + artifacts = next( + tool_metadata.artifacts + for tool_metadata in agent_tool_metadata + if tool_metadata.tool_name == ToolName.READ_DOCUMENT + or tool_metadata.tool_name == ToolName.SEARCH_FILE + ) + + file_ids = [ + artifact.get("id") + for artifact in artifacts + if artifact.get("type") == "local_file" + ] + + await index_agent_files(file_ids, create_agent.id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -476,6 +513,31 @@ async def delete_agent_tool_metadata( return DeleteAgentToolMetadata() +@router.post("/batch_upload_file", response_model=list[UploadFileResponse]) +async def batch_upload_file( + session: DBSessionDep, + files: list[FastAPIUploadFile] = RequestFile(...), + ctx: Context = Depends(get_context), +) -> UploadFileResponse: + user_id = ctx.get_user_id() + validate_batch_file_size(session, user_id, files) + + uploaded_files = [] + try: + uploaded_files = await get_file_service().create_agent_files( + session, + files, + user_id, + ctx, + ) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error while uploading file(s): {e}." + ) + + return uploaded_files + + # Default Agent Router default_agent_router = APIRouter( prefix="/v1/default_agent", diff --git a/src/backend/routers/chat.py b/src/backend/routers/chat.py index 30f9f063b9..a4782ad6f5 100644 --- a/src/backend/routers/chat.py +++ b/src/backend/routers/chat.py @@ -78,7 +78,6 @@ async def chat_stream( ( session, chat_request, - file_paths, response_message, should_store, managed_tools, @@ -92,7 +91,6 @@ async def chat_stream( CustomChat().chat( chat_request, stream=True, - file_paths=file_paths, managed_tools=managed_tools, session=session, ctx=ctx, @@ -152,7 +150,6 @@ async def chat( ( session, chat_request, - file_paths, response_message, should_store, managed_tools, @@ -165,7 +162,6 @@ async def chat( CustomChat().chat( chat_request, stream=False, - file_paths=file_paths, managed_tools=managed_tools, ctx=ctx, ), diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index e13347a4f6..bbf679fc0e 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -52,7 +52,6 @@ from backend.schemas.conversation import UpdateConversationRequest from backend.schemas.search_query import SearchQuery from backend.schemas.tool import Tool, ToolCall, ToolCallDelta -from backend.services.file import get_file_service from backend.services.generators import AsyncGeneratorContextManager from backend.services.logger.utils import get_logger @@ -142,9 +141,7 @@ def process_chat( id=str(uuid4()), ) - file_paths = None if isinstance(chat_request, CohereChatRequest): - file_paths = handle_file_retrieval(session, user_id, chat_request.file_ids) if should_store: attach_files_to_messages( session, @@ -169,7 +166,6 @@ def process_chat( return ( session, chat_request, - file_paths, chatbot_message, should_store, managed_tools, @@ -221,8 +217,10 @@ def get_or_create_conversation( """ conversation_id = chat_request.conversation_id or "" conversation = conversation_crud.get_conversation(session, conversation_id, user_id) - + print("CONVERSATION DEBUG", conversation_id) + print(conversation) if conversation is None: + print("HERE DEBUG") # Get the first 5 words of the user message as the title title = " ".join(user_message.split()[:5]) @@ -310,29 +308,6 @@ def create_message( return message -def handle_file_retrieval( - session: DBSessionDep, user_id: str, file_ids: List[str] | None = None -) -> list[str] | None: - """ - Retrieve file paths from the database. - - Args: - session (DBSessionDep): Database session. - user_id (str): User ID. - file_ids (List): List of File IDs. - - Returns: - list[str] | None: List of file paths or None. - """ - file_paths = None - # Use file_ids if provided - if file_ids is not None: - files = get_file_service().get_files_by_ids(session, file_ids, user_id) - file_paths = [file.file_path for file in files] - - return file_paths - - def attach_files_to_messages( session: DBSessionDep, user_id: str, diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 73a0296661..53b6f54bf5 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -1,7 +1,7 @@ +import asyncio import io import os import uuid -import asyncio from datetime import datetime import pandas as pd @@ -90,9 +90,10 @@ async def create_conversation_files( Returns: list[File]: The files that were created """ - uploaded_files = [] if self.is_compass_enabled: - uploaded_files = await insert_files_in_compass(files, user_id) + uploaded_files = await insert_files_in_compass( + files, user_id, conversation_id + ) else: uploaded_files = await insert_files_in_db(session, files, user_id) @@ -108,6 +109,33 @@ async def create_conversation_files( return uploaded_files + async def create_agent_files( + self, + session: DBSessionDep, + files: list[FastAPIUploadFile], + user_id: str, + ctx: Context = Depends(get_context), + ) -> list[File]: + """ + Create files and associations with conversation + + Args: + session (DBSessionDep): The database session + files (list[FastAPIUploadFile]): The files to upload + user_id (str): The user ID + conversation_id (str): The conversation ID + + Returns: + list[File]: The files that were created + """ + uploaded_files = [] + if self.is_compass_enabled: + uploaded_files = await insert_files_in_compass(files, user_id) + else: + uploaded_files = await insert_files_in_db(session, files, user_id) + + return uploaded_files + def get_files_by_agent_id( self, session: DBSessionDep, user_id: str, agent_id: str ) -> list[File]: @@ -131,7 +159,6 @@ def get_files_by_agent_id( detail=f"Agent with ID: {agent_id} not found.", ) - files = [] agent_tool_metadata = agent.tools_metadata if agent_tool_metadata is not None: artifacts = next( @@ -149,7 +176,7 @@ def get_files_by_agent_id( ] if self.is_compass_enabled: - files = get_files_in_compass(file_ids, user_id) + files = get_files_in_compass(agent_id, file_ids, user_id) else: files = file_crud.get_files_by_ids(session, file_ids, user_id) @@ -182,7 +209,7 @@ def get_files_by_conversation_id( files = [] if file_ids is not None: if self.is_compass_enabled: - files = get_files_in_compass(file_ids, user_id) + files = get_files_in_compass(conversation_id, file_ids, user_id) else: files = file_crud.get_files_by_ids(session, file_ids, user_id) @@ -229,26 +256,6 @@ def get_file_by_id(self, session: DBSessionDep, file_id: str, user_id: str) -> F file = file_crud.get_file(session, file_id, user_id) return file - def get_files_by_ids( - self, session: DBSessionDep, file_ids: list[str], user_id: str - ) -> list[File]: - """ - Get files by IDs - - Args: - session (DBSessionDep): The database session - file_ids (list[str]): The file IDs - user_id (str): The user ID - - Returns: - list[File]: The files that were created - """ - if self.is_compass_enabled: - files = get_files_in_compass(file_ids, user_id) - else: - files = file_crud.get_files_by_ids(session, file_ids, user_id) - return files - def bulk_delete_files( self, session: DBSessionDep, file_ids: list[str], user_id: str ) -> None: @@ -352,7 +359,7 @@ def get_file_in_compass(file_id: str, user_id: str) -> File: ) -def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: +def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[File]: files = [] for file_id in file_ids: try: @@ -360,12 +367,14 @@ def get_files_in_compass(file_ids: list[str], user_id: str) -> list[File]: get_compass() .invoke( action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": file_id, "file_id": file_id}, + parameters={"index": index, "file_id": file_id}, ) .result["doc"]["content"] ) except Exception as e: - print(e) + raise HTTPException( + status_code=404, detail=f"File: {file_id} not found in Compass." + ) files.append( File( @@ -408,50 +417,123 @@ async def insert_files_in_db( return uploaded_files +async def index_agent_files( + file_ids, + agent_id, +) -> None: + get_compass().invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": agent_id, + }, + ) + + for file_id in file_ids: + try: + fetched_doc = ( + get_compass() + .invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ) + .result["doc"]["content"] + ) + get_compass().invoke( + action=Compass.ValidActions.CREATE, + parameters={ + "index": agent_id, + "file_id": file_id, + "file_text": fetched_doc["text"], + "custom_context": { + "file_id": file_id + } + }, + ) + get_compass().invoke( + action=Compass.ValidActions.ADD_CONTEXT, + parameters={ + "index": agent_id, + "file_id": file_id, + "context": { + "file_name": fetched_doc["file_name"], + "file_path": fetched_doc["file_path"], + "file_size": fetched_doc["file_size"], + "user_id": fetched_doc["user_id"], + "created_at": fetched_doc["created_at"], + "updated_at": fetched_doc["updated_at"], + }, + }, + ) + get_compass().invoke( + action=Compass.ValidActions.REFRESH, + parameters={"index": agent_id}, + ) + # Remove the temporary file index entry + get_compass.invoke( + action=Compass.ValidActions.DELETE_INDEX, parameters={"index": file_id} + ) + except Exception as e: + print(e) + + async def insert_files_in_compass( files: list[FastAPIUploadFile], user_id: str, + index: str = None, ) -> list[File]: - uploaded_files = [] + if index is not None: + get_compass().invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": index, + }, + ) + uploaded_files = [] for file in files: filename = file.filename.encode("ascii", "ignore").decode("utf-8") file_bytes = await file.read() new_file_id = str(uuid.uuid4()) - print("CREATE INDEX") + + # Create temporary index for individual file. + # Consolidate them under one agent index during agent creation + if index is None: + get_compass().invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": new_file_id, + }, + ) + get_compass().invoke( - action=Compass.ValidActions.CREATE_INDEX, + action=Compass.ValidActions.CREATE, parameters={ - "index": new_file_id, + "index": new_file_id if index is None else index, + "file_id": new_file_id, + "file_text": file_bytes, + "custom_context": { + "file_id": new_file_id + } }, ) - print("CREATE DOCUMENT") get_compass().invoke( - action=Compass.ValidActions.CREATE, + action=Compass.ValidActions.ADD_CONTEXT, parameters={ - "index": new_file_id, + "index": new_file_id if index is None else index, "file_id": new_file_id, - "file_text": file_bytes, + "context": { + "file_name": filename, + "file_path": filename, + "file_size": file.size, + "user_id": user_id, + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), + }, }, ) - # get_compass().invoke( - # action=Compass.ValidActions.ADD_CONTEXT, - # parameters={ - # "index": new_file_id, - # "file_id": new_file_id, - # "context": { - # "file_name": filename, - # "file_path": filename, - # "file_size": file.size, - # "user_id": user_id, - # "created_at": datetime.now().isoformat(), - # "updated_at": datetime.now().isoformat(), - # }, - # }, - # ) get_compass().invoke( action=Compass.ValidActions.REFRESH, - parameters={"index": new_file_id}, + parameters={"index": new_file_id if index is None else index}, ) uploaded_files.append( diff --git a/src/backend/tests/routers/test_conversation.py b/src/backend/tests/routers/test_conversation.py index 40298e8d0d..56262a27b7 100644 --- a/src/backend/tests/routers/test_conversation.py +++ b/src/backend/tests/routers/test_conversation.py @@ -590,7 +590,15 @@ def test_list_files( session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) - files = [("files", ("Mariana_Trench.pdf", open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb")))] + files = [ + ( + "files", + ( + "Mariana_Trench.pdf", + open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"), + ), + ) + ] response = session_client.post( "/v1/conversations/batch_upload_file", @@ -904,7 +912,15 @@ def test_delete_file( mock_compass_settings, ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) - files = [("files", ("Mariana_Trench.pdf", open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb")))] + files = [ + ( + "files", + ( + "Mariana_Trench.pdf", + open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"), + ), + ) + ] response = session_client.post( "/v1/conversations/batch_upload_file", diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index e83c7e443b..9ab689c29b 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -6,18 +6,50 @@ from backend.services.compass import Compass from backend.services.file import get_compass, get_file_service from backend.tools.base import BaseTool +from backend.compass_sdk import SearchFilter def compass_file_search( - file_ids: List[str], query: str, search_limit: int = 5 + file_ids: List[str], conversation_id: str, agent_id: str, query: str, search_limit: int = 5 ) -> List[Dict[str, Any]]: results = [] - for file_id in file_ids: + + search_filters = [ + SearchFilter( + field="content.file_id", + type=SearchFilter.FilterType.EQ, + value=file_id + ) for file_id in file_ids + ] + # Search conversation ID index + hits = ( + get_compass() + .invoke( + action=Compass.ValidActions.SEARCH, + parameters={ + "index": conversation_id, + "query": query, + "top_k": search_limit, + "filters": search_filters + }, + ) + .result["hits"] + ) + print("HITS", hits) + results.extend(hits) + + # Search agent ID index + if agent_id: hits = ( get_compass() .invoke( action=Compass.ValidActions.SEARCH, - parameters={"index": file_id, "query": query, "top_k": search_limit}, + parameters={ + "index": agent_id, + "query": query, + "top_k": search_limit, + "filters": search_filters + }, ) .result["hits"] ) @@ -97,15 +129,21 @@ def is_available(cls) -> bool: async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("search_query") files = parameters.get("files") + + agent_id = kwargs.get("agent_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") user_id = kwargs.get("user_id") + print("AGENT ID DEBUG", agent_id) + print("CONVERSATION ID DEBUG", conversation_id) + print("FILES DEBUG", files) if not query or not files: return [] if Settings().feature_flags.use_compass_file_storage: file_ids = [file_id for _, file_id in files] - return compass_file_search(file_ids, query, search_limit=self.SEARCH_LIMIT) + return compass_file_search(file_ids, conversation_id, agent_id, query, search_limit=self.SEARCH_LIMIT) else: retrieved_files = get_file_service().get_files_by_ids( session, file_ids, user_id From 3421a0fd6c3a572aa9e7be3c57fb8f29aaa5bab5 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 8 Aug 2024 16:25:29 -0400 Subject: [PATCH 15/30] refactoring agent vs convo files --- src/backend/chat/custom/custom.py | 2 +- src/backend/routers/agent.py | 6 +- src/backend/routers/conversation.py | 15 +- src/backend/services/compass.py | 9 +- src/backend/services/file.py | 230 ++++++++++++++-------------- src/backend/tools/files.py | 78 ++++++---- 6 files changed, 180 insertions(+), 160 deletions(-) diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index b77a9a2370..3a0a9d98e7 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -259,7 +259,7 @@ async def call_tools( trace_id=ctx.get_trace_id(), agent_tool_metadata=ctx.get_agent_tool_metadata(), agent_id=ctx.get_agent_id(), - conversation_id=ctx.get_conversation_id() + conversation_id=ctx.get_conversation_id(), ) # If the tool returns a list of outputs, append each output to the tool_results list diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index 8413834d57..814f7ea2b7 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -45,12 +45,9 @@ ) from backend.services.context import get_context from backend.services.file import ( - attach_conversation_id_to_files, get_file_service, index_agent_files, validate_batch_file_size, - validate_file, - validate_file_size, ) from backend.services.request_validators import ( validate_create_agent_request, @@ -134,7 +131,8 @@ async def create_agent( if artifact.get("type") == "local_file" ] - await index_agent_files(file_ids, create_agent.id) + if len(file_ids) > 0: + await index_agent_files(file_ids, create_agent.id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index d25bd65184..498eebe10b 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -231,12 +231,11 @@ async def delete_conversation( HTTPException: If the conversation with the given ID is not found. """ user_id = ctx.get_user_id() - _ = validate_conversation(session, conversation_id, user_id) - conversation = conversation_crud.get_conversation(session, conversation_id, user_id) - - if conversation.file_ids: - get_file_service().bulk_delete_files(session, conversation.file_ids, user_id) + conversation = validate_conversation(session, conversation_id, user_id) + get_file_service().delete_all_conversation_files( + session, conversation.id, conversation.file_ids, user_id + ) conversation_crud.delete_conversation(session, conversation_id, user_id) return DeleteConversationResponse() @@ -525,10 +524,12 @@ async def delete_file( """ user_id = ctx.get_user_id() _ = validate_conversation(session, conversation_id, user_id) - validate_file(session, file_id, user_id) + validate_file(session, file_id, user_id, conversation_id) # Delete the File DB object - get_file_service().delete_file_by_id(session, conversation_id, file_id, user_id) + get_file_service().delete_conversation_file_by_id( + session, conversation_id, file_id, user_id + ) return DeleteFileResponse() diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index 25956a3db6..6783ee6eff 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -271,9 +271,12 @@ def _process_file(self, parameters: dict, **kwargs: Any) -> None: text=file_text, file_id=file_id, bytes_content=isinstance(file_text, bytes), + custom_context=parameters.get("custom_context", {}), ) - def _raw_parsing(self, text: str, file_id: str, bytes_content: bool): + def _raw_parsing( + self, text: str, file_id: str, bytes_content: bool, custom_context: dict + ): text_bytes = str.encode(text) if not bytes_content else text if len(text_bytes) > DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES: logger.error( @@ -300,7 +303,9 @@ def _raw_parsing(self, text: str, file_id: str, bytes_content: bool): if res.ok: docs = [CompassDocument(**doc) for doc in res.json()["docs"]] for doc in docs: - additional_metadata = CompassParserClient._get_metadata(doc=doc) + additional_metadata = CompassParserClient._get_metadata( + doc=doc, custom_context=custom_context + ) doc.content = {**doc.content, **additional_metadata} else: docs = [] diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 53b6f54bf5..92a6990996 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -48,6 +48,12 @@ def get_file_service(): + """ + Initialize a singular instance of FileService if not initialized yet + + Returns: + FileService: The singleton FileService instance + """ global file_service if file_service is None: file_service = FileService() @@ -55,6 +61,12 @@ def get_file_service(): def get_compass(): + """ + Initialize a singular instance of Compass if not initialized yet + + Returns: + Compass: The singleton Compass instance + """ global compass if compass is None: try: @@ -66,8 +78,20 @@ def get_compass(): class FileService: + """ + FileService class + + This class manages interfacing with different file storage solutions. Currently it supports storing files in the Postgres DB and or using Compass. + By default Toolkit will run with Postgres DB as the storage solution for files. + To enable Compass as the storage solution, set the `use_compass_file_storage` feature flag to `true` in the .env or .configuration file. + Also be sure to add the appropriate Compass environment variables to the .env or .configuration file. + """ + @property def is_compass_enabled(self) -> bool: + """ + Returns whether Compass is enabled as the file storage solution + """ return Settings().feature_flags.use_compass_file_storage async def create_conversation_files( @@ -79,13 +103,14 @@ async def create_conversation_files( ctx: Context = Depends(get_context), ) -> list[File]: """ - Create files and associations with conversation + Create files and associations with a conversation Args: session (DBSessionDep): The database session files (list[FastAPIUploadFile]): The files to upload user_id (str): The user ID conversation_id (str): The conversation ID + ctx (Context): Context object Returns: list[File]: The files that were created @@ -117,19 +142,21 @@ async def create_agent_files( ctx: Context = Depends(get_context), ) -> list[File]: """ - Create files and associations with conversation + Create files and associations with an agent Args: session (DBSessionDep): The database session files (list[FastAPIUploadFile]): The files to upload user_id (str): The user ID - conversation_id (str): The conversation ID - Returns: list[File]: The files that were created """ uploaded_files = [] if self.is_compass_enabled: + """ + Since agents are created after the files are upload we index files into dummy indices first + We later consolidate them in index_agent_files() to a singular index when an agent is created. + """ uploaded_files = await insert_files_in_compass(files, user_id) else: uploaded_files = await insert_files_in_db(session, files, user_id) @@ -215,7 +242,7 @@ def get_files_by_conversation_id( return files - def delete_file_by_id( + def delete_conversation_file_by_id( self, session: DBSessionDep, conversation_id: str, file_id: str, user_id: str ) -> None: """ @@ -232,44 +259,38 @@ def delete_file_by_id( ) if self.is_compass_enabled: - delete_file_in_compass(file_id, user_id) + delete_file_in_compass(conversation_id, file_id, user_id) else: file_crud.delete_file(session, file_id, user_id) return - def get_file_by_id(self, session: DBSessionDep, file_id: str, user_id: str) -> File: - """ - Get file by ID - - Args: - session (DBSessionDep): The database session - file_id (str): The file ID - user_id (str): The user ID - - Returns: - File: The file that was created - """ - if self.is_compass_enabled: - file = get_file_in_compass(file_id, user_id) - else: - file = file_crud.get_file(session, file_id, user_id) - return file - - def bulk_delete_files( - self, session: DBSessionDep, file_ids: list[str], user_id: str + def delete_all_conversation_files( + self, + session: DBSessionDep, + conversation_id: str, + file_ids: list[str], + user_id: str, ) -> None: """ - Bulk delete files + Delete all files associated with a conversation Args: session (DBSessionDep): The database session + conversation_id (str): The conversation ID file_ids (list[str]): The file IDs user_id (str): The user ID """ if self.is_compass_enabled: - for file_id in file_ids: - delete_file_in_compass(file_id, user_id) + try: + get_compass().invoke( + action=Compass.ValidActions.DELETE_INDEX, + parameters={"index": conversation_id}, + ) + except Exception as e: + logger.error( + f"Error deleting conversation {conversation_id} files from Compass: {e}" + ) else: file_crud.bulk_delete_files(session, file_ids, user_id) @@ -293,13 +314,17 @@ def get_files_by_message_id( files = [] if message.file_ids is not None: if self.is_compass_enabled: - files = get_files_in_compass(message.file_ids, user_id) + files = get_files_in_compass( + message.conversation_id, message.file_ids, user_id + ) else: files = file_crud.get_files_by_ids(session, message.file_ids, user_id) return files -def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: +def validate_file( + session: DBSessionDep, file_id: str, user_id: str, index: str = None +) -> File: """Validates if a file exists and belongs to the user Args: @@ -314,7 +339,7 @@ def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: HTTPException: If the file is not found """ if Settings().feature_flags.use_compass_file_storage: - file = get_file_in_compass(file_id, user_id) + file = get_files_in_compass(index, file_id, user_id)[0] else: file = file_crud.get_file(session, file_id, user_id) @@ -326,36 +351,11 @@ def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: # Compass Operations -def delete_file_in_compass(file_id: str, user_id: str) -> None: +def delete_file_in_compass(index: str, file_id: str, user_id: str) -> None: # todo: validate all files exists before deleting get_compass().invoke( - action=Compass.ValidActions.DELETE_INDEX, parameters={"index": file_id} - ) - - -def get_file_in_compass(file_id: str, user_id: str) -> File: - fetched_doc = None - try: - fetched_doc = ( - get_compass() - .invoke( - action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": file_id, "file_id": file_id}, - ) - .result["doc"]["content"] - ) - except Exception: - return fetched_doc - - return File( - id=file_id, - file_name=fetched_doc["file_name"], - file_size=fetched_doc["file_size"], - file_path=fetched_doc["file_path"], - file_content=fetched_doc["text"], - user_id=user_id, - created_at=datetime.fromisoformat(fetched_doc["created_at"]), - updated_at=datetime.fromisoformat(fetched_doc["updated_at"]), + action=Compass.ValidActions.DELETE, + parameters={"index": index, "file_id": file_id}, ) @@ -421,12 +421,15 @@ async def index_agent_files( file_ids, agent_id, ) -> None: - get_compass().invoke( - action=Compass.ValidActions.CREATE_INDEX, - parameters={ - "index": agent_id, - }, - ) + try: + get_compass().invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": agent_id, + }, + ) + except Exception as e: + logger.Error(f"Fail") for file_id in file_ids: try: @@ -445,16 +448,7 @@ async def index_agent_files( "file_id": file_id, "file_text": fetched_doc["text"], "custom_context": { - "file_id": file_id - } - }, - ) - get_compass().invoke( - action=Compass.ValidActions.ADD_CONTEXT, - parameters={ - "index": agent_id, - "file_id": file_id, - "context": { + "file_id": file_id, "file_name": fetched_doc["file_name"], "file_path": fetched_doc["file_path"], "file_size": fetched_doc["file_size"], @@ -482,12 +476,15 @@ async def insert_files_in_compass( index: str = None, ) -> list[File]: if index is not None: - get_compass().invoke( - action=Compass.ValidActions.CREATE_INDEX, - parameters={ - "index": index, - }, - ) + try: + get_compass().invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": index, + }, + ) + except Exception as e: + logger.error(f"[Compass File] Failed to create index: {index}, error: {e}") uploaded_files = [] for file in files: @@ -495,46 +492,47 @@ async def insert_files_in_compass( file_bytes = await file.read() new_file_id = str(uuid.uuid4()) - # Create temporary index for individual file. + # Create temporary index for individual file (files not associated with conversations) # Consolidate them under one agent index during agent creation if index is None: + try: + get_compass().invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": new_file_id, + }, + ) + except Exception as e: + logger.error( + f"[Compass File] Failed to create index: {index}, error: {e}" + ) + + try: get_compass().invoke( - action=Compass.ValidActions.CREATE_INDEX, + action=Compass.ValidActions.CREATE, parameters={ - "index": new_file_id, + "index": new_file_id if index is None else index, + "file_id": new_file_id, + "file_text": file_bytes, + "custom_context": { + "file_id": new_file_id, + "file_name": filename, + "file_path": filename, + "file_size": file.size, + "user_id": user_id, + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), + }, }, ) - - get_compass().invoke( - action=Compass.ValidActions.CREATE, - parameters={ - "index": new_file_id if index is None else index, - "file_id": new_file_id, - "file_text": file_bytes, - "custom_context": { - "file_id": new_file_id - } - }, - ) - get_compass().invoke( - action=Compass.ValidActions.ADD_CONTEXT, - parameters={ - "index": new_file_id if index is None else index, - "file_id": new_file_id, - "context": { - "file_name": filename, - "file_path": filename, - "file_size": file.size, - "user_id": user_id, - "created_at": datetime.now().isoformat(), - "updated_at": datetime.now().isoformat(), - }, - }, - ) - get_compass().invoke( - action=Compass.ValidActions.REFRESH, - parameters={"index": new_file_id if index is None else index}, - ) + get_compass().invoke( + action=Compass.ValidActions.REFRESH, + parameters={"index": new_file_id if index is None else index}, + ) + except Exception as e: + logger.error( + f"[Compass File] Failed to create document on index: {index}, error: {e}" + ) uploaded_files.append( File( diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 9ab689c29b..ff4ef3d25c 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -1,41 +1,46 @@ from typing import Any, Dict, List import backend.crud.file as file_crud +from backend.compass_sdk import SearchFilter from backend.config.settings import Settings from backend.schemas.file import File from backend.services.compass import Compass from backend.services.file import get_compass, get_file_service from backend.tools.base import BaseTool -from backend.compass_sdk import SearchFilter def compass_file_search( - file_ids: List[str], conversation_id: str, agent_id: str, query: str, search_limit: int = 5 + file_ids: List[str], + conversation_id: str, + agent_id: str, + query: str, + search_limit: int = 5, ) -> List[Dict[str, Any]]: results = [] search_filters = [ SearchFilter( - field="content.file_id", + field="content.file_id.keyword", type=SearchFilter.FilterType.EQ, - value=file_id - ) for file_id in file_ids + value=file_id, + ) + for file_id in file_ids ] + # Search conversation ID index hits = ( get_compass() .invoke( action=Compass.ValidActions.SEARCH, parameters={ - "index": conversation_id, - "query": query, + "index": conversation_id, + "query": query, "top_k": search_limit, - "filters": search_filters + "filters": search_filters, }, ) .result["hits"] ) - print("HITS", hits) results.extend(hits) # Search agent ID index @@ -45,10 +50,10 @@ def compass_file_search( .invoke( action=Compass.ValidActions.SEARCH, parameters={ - "index": agent_id, - "query": query, + "index": agent_id, + "query": query, "top_k": search_limit, - "filters": search_filters + "filters": search_filters, }, ) .result["hits"] @@ -90,24 +95,36 @@ def is_available(cls) -> bool: return True async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: - file = parameters.get("file", []) + file = parameters.get("file") + session = kwargs.get("session") user_id = kwargs.get("user_id") + agent_id = kwargs.get("agent_id") + conversation_id = kwargs.get("conversation_id") if not file: return [] _, file_id = file - retrieved_file = get_file_service().get_file_by_id(session, file_id, user_id) - if not retrieved_file: - return [] + if Settings().feature_flags.use_compass_file_storage: + return compass_file_search( + [file_id], + conversation_id, + agent_id, + "*", + search_limit=self.SEARCH_LIMIT, + ) + else: + retrieved_file = file_crud.get_file(session, file_id, user_id) + if not retrieved_file: + return [] - return [ - { - "text": retrieved_file.file_content, - "title": retrieved_file.file_name, - "url": retrieved_file.file_path, - } - ] + return [ + { + "text": retrieved_file.file_content, + "title": retrieved_file.file_name, + "url": retrieved_file.file_path, + } + ] class SearchFileTool(BaseTool): @@ -135,19 +152,20 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: session = kwargs.get("session") user_id = kwargs.get("user_id") - print("AGENT ID DEBUG", agent_id) - print("CONVERSATION ID DEBUG", conversation_id) - print("FILES DEBUG", files) if not query or not files: return [] if Settings().feature_flags.use_compass_file_storage: file_ids = [file_id for _, file_id in files] - return compass_file_search(file_ids, conversation_id, agent_id, query, search_limit=self.SEARCH_LIMIT) - else: - retrieved_files = get_file_service().get_files_by_ids( - session, file_ids, user_id + return compass_file_search( + file_ids, + conversation_id, + agent_id, + query, + search_limit=self.SEARCH_LIMIT, ) + else: + retrieved_files = file_crud.get_files_by_ids(session, file_ids, user_id) if not retrieved_files: return [] From d64d49b9592553368a5f8164a11516379750a104 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 8 Aug 2024 18:54:25 -0400 Subject: [PATCH 16/30] code complete --- docker-compose.yml | 16 +-- src/backend/routers/agent.py | 82 ++++++++----- src/backend/services/file.py | 215 +++++++++++++++++++++++------------ 3 files changed, 208 insertions(+), 105 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 10c88df36a..f2c0d75af1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -110,14 +110,14 @@ services: ignore: - node_modules/ - terrarium: - image: ghcr.io/cohere-ai/terrarium:latest - ports: - - "8080:8080" - expose: - - "8080" - networks: - - proxynet + # terrarium: + # image: ghcr.io/cohere-ai/terrarium:latest + # ports: + # - "8080:8080" + # expose: + # - "8080" + # networks: + # - proxynet volumes: db: diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index 814f7ea2b7..f1f2f42d70 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -5,6 +5,7 @@ from backend.config.routers import RouterName from backend.config.settings import Settings +from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.crud import agent_tool_metadata as agent_tool_metadata_crud from backend.database_models.agent import Agent as AgentModel @@ -25,13 +26,7 @@ UpdateAgentToolMetadataRequest, ) from backend.schemas.context import Context -from backend.schemas.file import ( - DeleteFileResponse, - FilePublic, - ListFile, - UpdateFileRequest, - UploadFileResponse, -) +from backend.schemas.file import DeleteFileResponse, UploadFileResponse from backend.schemas.metrics import ( DEFAULT_METRICS_AGENT, GenericResponseMessage, @@ -45,8 +40,8 @@ ) from backend.services.context import get_context from backend.services.file import ( + consolidate_agent_files_in_compass, get_file_service, - index_agent_files, validate_batch_file_size, ) from backend.services.request_validators import ( @@ -115,24 +110,26 @@ async def create_agent( created_agent, tool_metadata, session, ctx ) - # Consolidate agent files into one index in compass - if get_file_service().is_compass_enabled: - if agent_tool_metadata is not None: - artifacts = next( - tool_metadata.artifacts - for tool_metadata in agent_tool_metadata - if tool_metadata.tool_name == ToolName.READ_DOCUMENT - or tool_metadata.tool_name == ToolName.SEARCH_FILE - ) - - file_ids = [ - artifact.get("id") - for artifact in artifacts - if artifact.get("type") == "local_file" - ] - - if len(file_ids) > 0: - await index_agent_files(file_ids, create_agent.id) + # Consolidate agent files into one index in compass + if get_file_service().is_compass_enabled and created_agent.tools_metadata: + artifacts = next( + ( + tool_metadata.artifacts + for tool_metadata in created_agent.tools_metadata + if tool_metadata.tool_name == ToolName.Read_File + or tool_metadata.tool_name == ToolName.Search_File + ), + [], + ) + file_ids = list( + set( + artifact.get("id") + for artifact in artifacts + if artifact.get("type") == "local_file" + ) + ) + if len(file_ids) > 0: + await consolidate_agent_files_in_compass(file_ids, created_agent.id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -530,12 +527,43 @@ async def batch_upload_file( ) except Exception as e: raise HTTPException( - status_code=500, detail=f"Error while uploading file(s): {e}." + status_code=500, detail=f"Error while uploading agent file(s): {e}." ) return uploaded_files +@router.delete("/{agent_id}/files/{file_id}") +async def delete_file( + agent_id: str, + file_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), +) -> DeleteFileResponse: + """ + Delete a file by ID. + + Args: + agent_id (str): Agent ID. + file_id (str): File ID. + session (DBSessionDep): Database session. + + Returns: + DeleteFile: Empty response. + + Raises: + HTTPException: If the agent with the given ID is not found. + """ + user_id = ctx.get_user_id() + _ = validate_agent_exists(session, agent_id) + validate_file(session, file_id, user_id, agent_id) + + # Delete the File DB object + get_file_service().delete_agent_file_by_id(session, agent_id, file_id, user_id) + + return DeleteFileResponse() + + # Default Agent Router default_agent_router = APIRouter( prefix="/v1/default_agent", diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 92a6990996..b35c50b77a 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -155,7 +155,7 @@ async def create_agent_files( if self.is_compass_enabled: """ Since agents are created after the files are upload we index files into dummy indices first - We later consolidate them in index_agent_files() to a singular index when an agent is created. + We later consolidate them in consolidate_agent_files_in_compass() to a singular index when an agent is created. """ uploaded_files = await insert_files_in_compass(files, user_id) else: @@ -186,21 +186,26 @@ def get_files_by_agent_id( detail=f"Agent with ID: {agent_id} not found.", ) - agent_tool_metadata = agent.tools_metadata - if agent_tool_metadata is not None: + agent_tools_metadata = agent.tools_metadata + if agent_tools_metadata is not None: artifacts = next( - tool_metadata.artifacts - for tool_metadata in agent_tool_metadata - if tool_metadata.tool_name == ToolName.READ_DOCUMENT - or tool_metadata.tool_name == ToolName.SEARCH_FILE + ( + tool_metadata.artifacts + for tool_metadata in agent_tools_metadata + if tool_metadata.tool_name == ToolName.Read_File + or tool_metadata.tool_name == ToolName.Search_File + ), + [], # Default value if the generator is empty ) - # TODO scott: enumerate type names (?), different types for local vs. compass? - file_ids = [ - artifact.get("id") - for artifact in artifacts - if artifact.get("type") == "local_file" - ] + # Remove duplicates, since file can be associated to both read document and search file tools + file_ids = list( + set( + artifact.get("id") + for artifact in artifacts + if artifact.get("type") == "local_file" + ) + ) if self.is_compass_enabled: files = get_files_in_compass(agent_id, file_ids, user_id) @@ -246,7 +251,7 @@ def delete_conversation_file_by_id( self, session: DBSessionDep, conversation_id: str, file_id: str, user_id: str ) -> None: """ - Delete file from conversation + Delete a file asociated with a conversation Args: session (DBSessionDep): The database session @@ -265,6 +270,25 @@ def delete_conversation_file_by_id( return + def delete_agent_file_by_id( + self, session: DBSessionDep, agent_id: str, file_id: str, user_id: str + ) -> None: + """ + Delete a file asociated with an agent + + Args: + session (DBSessionDep): The database session + agent_id (str): The agent ID + file_id (str): The file ID + user_id (str): The user ID + """ + if self.is_compass_enabled: + delete_file_in_compass(agent_id, file_id, user_id) + else: + file_crud.delete_file(session, file_id, user_id) + + return + def delete_all_conversation_files( self, session: DBSessionDep, @@ -322,44 +346,42 @@ def get_files_by_message_id( return files -def validate_file( - session: DBSessionDep, file_id: str, user_id: str, index: str = None -) -> File: - """Validates if a file exists and belongs to the user +# Compass Operations +def delete_file_in_compass(index: str, file_id: str, user_id: str) -> None: + """ + Delete a file from Compass Args: - session (DBSessionDep): Database session - file_id (str): File ID - user_id (str): User ID - - Returns: - File: File object + index (str): The index + file_id (str): The file ID + user_id (str): The user ID Raises: HTTPException: If the file is not found """ - if Settings().feature_flags.use_compass_file_storage: - file = get_files_in_compass(index, file_id, user_id)[0] - else: - file = file_crud.get_file(session, file_id, user_id) - - if not file: - raise HTTPException( - status_code=404, - detail=f"File with ID: {file_id} not found.", + try: + get_compass().invoke( + action=Compass.ValidActions.DELETE, + parameters={"index": index, "file_id": file_id}, + ) + except Exception as e: + logger.error( + f"[Compass File] Error deleting file {file_id} on index {index} from Compass: {e}" ) -# Compass Operations -def delete_file_in_compass(index: str, file_id: str, user_id: str) -> None: - # todo: validate all files exists before deleting - get_compass().invoke( - action=Compass.ValidActions.DELETE, - parameters={"index": index, "file_id": file_id}, - ) +def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[File]: + """ + Get files from Compass + Args: + index (str): The index + file_ids (list[str]): The file IDs + user_id (str): The user ID -def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[File]: + Returns: + list[File]: The files that were created + """ files = [] for file_id in file_ids: try: @@ -392,35 +414,19 @@ def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[ return files -async def insert_files_in_db( - session: DBSessionDep, - files: list[FastAPIUploadFile], - user_id: str, -) -> list[File]: - files_to_upload = [] - for file in files: - content = await get_file_content(file) - cleaned_content = content.replace("\x00", "") - filename = file.filename.encode("ascii", "ignore").decode("utf-8") - - files_to_upload.append( - FileModel( - file_name=filename, - file_size=file.size, - file_path=filename, - file_content=cleaned_content, - user_id=user_id, - ) - ) - - uploaded_files = file_crud.batch_create_files(session, files_to_upload) - return uploaded_files - - -async def index_agent_files( +async def consolidate_agent_files_in_compass( file_ids, agent_id, ) -> None: + """ + Consolidate files into a single index (agent ID) in Compass. + We do this because when agents are created after a file is uploaded, the file is not associated with the agent. + We consolidate them in a single index to under one agent ID when an agent is created. + + Args: + file_ids (list[str]): The file IDs + agent_id (str): The agent ID + """ try: get_compass().invoke( action=Compass.ValidActions.CREATE_INDEX, @@ -429,7 +435,9 @@ async def index_agent_files( }, ) except Exception as e: - logger.Error(f"Fail") + logger.Error( + f"[Compass File] Error creating index for agent files: {agent_id}, error: {e}" + ) for file_id in file_ids: try: @@ -463,11 +471,13 @@ async def index_agent_files( parameters={"index": agent_id}, ) # Remove the temporary file index entry - get_compass.invoke( + get_compass().invoke( action=Compass.ValidActions.DELETE_INDEX, parameters={"index": file_id} ) except Exception as e: - print(e) + logger.error( + f"[Compass File] Error consolidating file {file_id} into agent {agent_id}, error: {e}" + ) async def insert_files_in_compass( @@ -549,6 +559,71 @@ async def insert_files_in_compass( return uploaded_files +# Misc +def validate_file( + session: DBSessionDep, file_id: str, user_id: str, index: str = None +) -> File: + """Validates if a file exists and belongs to the user + + Args: + session (DBSessionDep): Database session + file_id (str): File ID + user_id (str): User ID + + Returns: + File: File object + + Raises: + HTTPException: If the file is not found + """ + if Settings().feature_flags.use_compass_file_storage: + file = get_files_in_compass(index, file_id, user_id)[0] + else: + file = file_crud.get_file(session, file_id, user_id) + + if not file: + raise HTTPException( + status_code=404, + detail=f"File with ID: {file_id} not found.", + ) + + +async def insert_files_in_db( + session: DBSessionDep, + files: list[FastAPIUploadFile], + user_id: str, +) -> list[File]: + """ + Insert files into the database + + Args: + session (DBSessionDep): The database session + files (list[FastAPIUploadFile]): The files to upload + user_id (str): The user ID + + Returns: + list[File]: The files that were created + """ + files_to_upload = [] + for file in files: + content = await get_file_content(file) + cleaned_content = content.replace("\x00", "") + filename = file.filename.encode("ascii", "ignore").decode("utf-8") + + files_to_upload.append( + FileModel( + file_name=filename, + file_size=file.size, + file_path=filename, + file_content=cleaned_content, + user_id=user_id, + ) + ) + + uploaded_files = file_crud.batch_create_files(session, files_to_upload) + return uploaded_files + + def attach_conversation_id_to_files( conversation_id: str, files: list[FileModel] ) -> list[File]: From 523ba64c59066ff607183429a657ca2341702bb7 Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 9 Aug 2024 00:36:12 -0400 Subject: [PATCH 17/30] router test done --- src/backend/routers/agent.py | 7 +++- src/backend/routers/conversation.py | 5 ++- src/backend/services/file.py | 15 ++++++- src/backend/tests/routers/test_agent.py | 1 + .../tests/routers/test_conversation.py | 42 +++++++++---------- 5 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index f1f2f42d70..8cfefe67dd 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -111,7 +111,10 @@ async def create_agent( ) # Consolidate agent files into one index in compass - if get_file_service().is_compass_enabled and created_agent.tools_metadata: + if ( + Settings().feature_flags.use_compass_file_storage + and created_agent.tools_metadata + ): artifacts = next( ( tool_metadata.artifacts @@ -508,7 +511,7 @@ async def delete_agent_tool_metadata( return DeleteAgentToolMetadata() -@router.post("/batch_upload_file", response_model=list[UploadFileResponse]) +@router.post("/batch_upload_files", response_model=list[UploadFileResponse]) async def batch_upload_file( session: DBSessionDep, files: list[FastAPIUploadFile] = RequestFile(...), diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 498eebe10b..bee93d0ace 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -6,6 +6,7 @@ from backend.chat.custom.custom import CustomChat from backend.chat.custom.utils import get_deployment from backend.config.routers import RouterName +from backend.config.settings import Settings from backend.crud import agent as agent_crud from backend.crud import conversation as conversation_crud from backend.database_models import Conversation as ConversationModel @@ -351,7 +352,9 @@ async def upload_file( """ user_id = ctx.get_user_id() - validate_file_size(session, user_id, file) + # Currently do not limit file size for Compass + if Settings().feature_flags.use_compass_file_storage is False: + validate_file_size(session, user_id, file) # Create new conversation if not conversation_id: diff --git a/src/backend/services/file.py b/src/backend/services/file.py index b35c50b77a..cba044f76f 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -208,6 +208,9 @@ def get_files_by_agent_id( ) if self.is_compass_enabled: + print( + f"Getting files in compass for agent {agent_id}, {len(file_ids)} files" + ) files = get_files_in_compass(agent_id, file_ids, user_id) else: files = file_crud.get_files_by_ids(session, file_ids, user_id) @@ -395,7 +398,7 @@ def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[ ) except Exception as e: raise HTTPException( - status_code=404, detail=f"File: {file_id} not found in Compass." + status_code=404, detail=f"File with ID: {file_id} not found." ) files.append( @@ -438,6 +441,10 @@ async def consolidate_agent_files_in_compass( logger.Error( f"[Compass File] Error creating index for agent files: {agent_id}, error: {e}" ) + raise HTTPException( + status_code=500, + detail=f"Error creating index for agent files: {agent_id}, error: {e}", + ) for file_id in file_ids: try: @@ -478,6 +485,10 @@ async def consolidate_agent_files_in_compass( logger.error( f"[Compass File] Error consolidating file {file_id} into agent {agent_id}, error: {e}" ) + raise HTTPException( + status_code=500, + detail=f"Error consolidating file {file_id} into agent {agent_id}, error: {e}", + ) async def insert_files_in_compass( @@ -577,7 +588,7 @@ def validate_file( HTTPException: If the file is not found """ if Settings().feature_flags.use_compass_file_storage: - file = get_files_in_compass(index, file_id, user_id)[0] + file = get_files_in_compass(index, [file_id], user_id)[0] else: file = file_crud.get_file(session, file_id, user_id) diff --git a/src/backend/tests/routers/test_agent.py b/src/backend/tests/routers/test_agent.py index e612e1f2c3..f1362ef438 100644 --- a/src/backend/tests/routers/test_agent.py +++ b/src/backend/tests/routers/test_agent.py @@ -9,6 +9,7 @@ from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.schemas.metrics import MetricsData, MetricsMessageType +from backend.services.file import get_file_service from backend.services.metrics import report_metrics from backend.tests.factories import get_factory diff --git a/src/backend/tests/routers/test_conversation.py b/src/backend/tests/routers/test_conversation.py index 56262a27b7..4dfb023ab0 100644 --- a/src/backend/tests/routers/test_conversation.py +++ b/src/backend/tests/routers/test_conversation.py @@ -266,7 +266,7 @@ def test_delete_conversation( user: User, ) -> None: conversation = get_factory("Conversation", session).create( - title="test title", user_id=user.id + title="test title", user_id=user.id, id="conversation_id" ) response = session_client.delete( f"/v1/conversations/{conversation.id}", @@ -279,7 +279,7 @@ def test_delete_conversation( # Check if the conversation was deleted conversation = ( session.query(Conversation) - .filter_by(id=conversation.id, user_id=conversation.user_id) + .filter_by(id="conversation_id", user_id=user.id) .first() ) assert conversation is None @@ -343,7 +343,7 @@ def test_delete_conversation_with_messages( user: User, ) -> None: conversation = get_factory("Conversation", session).create( - title="test title", user_id=user.id + title="test title", user_id=user.id, id="conversation_id" ) _ = get_factory("Message", session).create( text="test message", @@ -362,7 +362,7 @@ def test_delete_conversation_with_messages( # Check if the conversation was deleted conversation = ( session.query(Conversation) - .filter_by(id=conversation.id, user_id=user.id) + .filter_by(id="conversation_id", user_id=user.id) .first() ) assert conversation is None @@ -939,23 +939,23 @@ def test_delete_file( assert response.status_code == 200 assert response.json() == {} - # # Check if File - # file = get_file_service().get_file_by_id( - # session, uploaded_file["id"], conversation.user_id - # ) - # assert file is None - - # conversation_file_association = ( - # session.query(ConversationFileAssociation) - # .filter(File.id == uploaded_file["id"], File.user_id == user.id) - # .first() - # ) - # assert conversation_file_association is None - - # conversation = ( - # session.query(Conversation).filter(Conversation.id == conversation.id).first() - # ) - # assert conversation.file_ids == [] + # Check if File + files = get_file_service().get_files_by_conversation_id( + session, conversation.user_id, conversation.id + ) + assert files == [] + + conversation_file_association = ( + session.query(ConversationFileAssociation) + .filter(File.id == uploaded_file["id"], File.user_id == user.id) + .first() + ) + assert conversation_file_association is None + + conversation = ( + session.query(Conversation).filter(Conversation.id == conversation.id).first() + ) + assert conversation.file_ids == [] def test_fail_delete_nonexistent_file( From c44307272f77c0dc324aa04d9b375db06f39e2b8 Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 9 Aug 2024 09:37:51 -0400 Subject: [PATCH 18/30] clean up --- docker-compose.yml | 16 ++++++++-------- src/backend/routers/agent.py | 2 +- src/backend/services/chat.py | 3 --- src/backend/services/compass.py | 15 --------------- src/backend/services/file.py | 9 +-------- src/backend/tests/routers/test_agent.py | 2 -- 6 files changed, 10 insertions(+), 37 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index f2c0d75af1..10c88df36a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -110,14 +110,14 @@ services: ignore: - node_modules/ - # terrarium: - # image: ghcr.io/cohere-ai/terrarium:latest - # ports: - # - "8080:8080" - # expose: - # - "8080" - # networks: - # - proxynet + terrarium: + image: ghcr.io/cohere-ai/terrarium:latest + ports: + - "8080:8080" + expose: + - "8080" + networks: + - proxynet volumes: db: diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index 8cfefe67dd..e16a311949 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -511,7 +511,7 @@ async def delete_agent_tool_metadata( return DeleteAgentToolMetadata() -@router.post("/batch_upload_files", response_model=list[UploadFileResponse]) +@router.post("/batch_upload_file", response_model=list[UploadFileResponse]) async def batch_upload_file( session: DBSessionDep, files: list[FastAPIUploadFile] = RequestFile(...), diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index bbf679fc0e..cad1bc38e6 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -217,10 +217,7 @@ def get_or_create_conversation( """ conversation_id = chat_request.conversation_id or "" conversation = conversation_crud.get_conversation(session, conversation_id, user_id) - print("CONVERSATION DEBUG", conversation_id) - print(conversation) if conversation is None: - print("HERE DEBUG") # Get the first 5 words of the user message as the title title = " ".join(user_message.split()[:5]) diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index 6783ee6eff..78c661b1b1 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -233,25 +233,10 @@ def _process_file(self, parameters: dict, **kwargs: Any) -> None: f"[Compass] Error processing file: No file_id specified in parameters {parameters}" ) - # Check if filename is specified for file-related actions - # if not parameters.get("filename", None) and not parameters.get( - # "file_text", None - # ): - # logger.error( - # event=f"[Compass] Error processing file: No filename or file_text specified in parameters {parameters}" - # ) - # return None - file_id = parameters["file_id"] filename = parameters.get("filename", None) file_text = parameters.get("file_text", None) - # if filename and not os.path.exists(filename): - # logger.error( - # event=f"[Compass] Error processing file: Invalid filename {filename} in parameters {parameters}" - # ) - # return None - parser_config = self.parser_config or parameters.get("parser_config", None) metadata_config = metadata_config = self.metadata_config or parameters.get( "metadata_config", None diff --git a/src/backend/services/file.py b/src/backend/services/file.py index cba044f76f..6c7a83161c 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -1,6 +1,4 @@ -import asyncio import io -import os import uuid from datetime import datetime @@ -14,15 +12,13 @@ import backend.crud.conversation as conversation_crud import backend.crud.file as file_crud from backend.config.settings import Settings - -# from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.crud import message as message_crud from backend.database_models.conversation import ConversationFileAssociation from backend.database_models.database import DBSessionDep from backend.database_models.file import File as FileModel from backend.schemas.context import Context -from backend.schemas.file import File, UpdateFileRequest +from backend.schemas.file import File from backend.services.compass import Compass from backend.services.context import get_context from backend.services.logger.utils import get_logger @@ -208,9 +204,6 @@ def get_files_by_agent_id( ) if self.is_compass_enabled: - print( - f"Getting files in compass for agent {agent_id}, {len(file_ids)} files" - ) files = get_files_in_compass(agent_id, file_ids, user_id) else: files = file_crud.get_files_by_ids(session, file_ids, user_id) diff --git a/src/backend/tests/routers/test_agent.py b/src/backend/tests/routers/test_agent.py index f1362ef438..97daec442b 100644 --- a/src/backend/tests/routers/test_agent.py +++ b/src/backend/tests/routers/test_agent.py @@ -9,8 +9,6 @@ from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.schemas.metrics import MetricsData, MetricsMessageType -from backend.services.file import get_file_service -from backend.services.metrics import report_metrics from backend.tests.factories import get_factory From 6816984c3f9b9ac93831a59e09661b9cc0c9159d Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 9 Aug 2024 09:40:02 -0400 Subject: [PATCH 19/30] clean up --- src/backend/services/compass.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index 78c661b1b1..9208ec5135 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -233,10 +233,25 @@ def _process_file(self, parameters: dict, **kwargs: Any) -> None: f"[Compass] Error processing file: No file_id specified in parameters {parameters}" ) + # Check if filename is specified for file-related actions + if not parameters.get("filename", None) and not parameters.get( + "file_text", None + ): + logger.error( + event=f"[Compass] Error processing file: No filename or file_text specified in parameters {parameters}" + ) + return None + file_id = parameters["file_id"] filename = parameters.get("filename", None) file_text = parameters.get("file_text", None) + if filename and not os.path.exists(filename): + logger.error( + event=f"[Compass] Error processing file: Invalid filename {filename} in parameters {parameters}" + ) + return None + parser_config = self.parser_config or parameters.get("parser_config", None) metadata_config = metadata_config = self.metadata_config or parameters.get( "metadata_config", None From 1f4bd63615c3a6d753ccae822b21ddc9afb3896e Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 9 Aug 2024 09:50:07 -0400 Subject: [PATCH 20/30] clean up --- src/backend/schemas/file.py | 2 +- src/backend/tools/files.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/backend/schemas/file.py b/src/backend/schemas/file.py index 08e924a3af..6621484462 100644 --- a/src/backend/schemas/file.py +++ b/src/backend/schemas/file.py @@ -11,7 +11,7 @@ class File(BaseModel): user_id: str conversation_id: Optional[str] = None - file_content: Optional[str] = None + file_content: Optional[str] = None # Used interally file_name: str file_path: str file_size: int = Field(default=0, ge=0) diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index ff4ef3d25c..3f3dc8b1b7 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -18,6 +18,8 @@ def compass_file_search( ) -> List[Dict[str, Any]]: results = [] + # Note: Compass search currently has an issue where the type of the context is not directly referenced + # Temporarily add `.keyword` to workaround this issue. search_filters = [ SearchFilter( field="content.file_id.keyword", From 4e1b60cfe7c50d157bb8901e6bdfa53711fffedc Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 9 Aug 2024 13:39:23 -0400 Subject: [PATCH 21/30] clean up --- src/backend/schemas/file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/schemas/file.py b/src/backend/schemas/file.py index 6621484462..45ab301221 100644 --- a/src/backend/schemas/file.py +++ b/src/backend/schemas/file.py @@ -11,7 +11,7 @@ class File(BaseModel): user_id: str conversation_id: Optional[str] = None - file_content: Optional[str] = None # Used interally + file_content: Optional[str] = None # Used interally file_name: str file_path: str file_size: int = Field(default=0, ge=0) From 3760af43aec3811836097775ed21456bca0a89cb Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 9 Aug 2024 14:12:22 -0400 Subject: [PATCH 22/30] bug fix for DB flow --- docker-compose.yml | 16 ++++++++-------- src/backend/config/configuration.template.yaml | 1 + src/backend/tools/files.py | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 10c88df36a..408d3ed48e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -110,14 +110,14 @@ services: ignore: - node_modules/ - terrarium: - image: ghcr.io/cohere-ai/terrarium:latest - ports: - - "8080:8080" - expose: - - "8080" - networks: - - proxynet + # terrarium: + # image: ghcr.io/cohere-ai/terrarium:latest + # ports: + # - "8080:8080" + # expose: + # - "8080" + # networks: + # - proxynet volumes: db: diff --git a/src/backend/config/configuration.template.yaml b/src/backend/config/configuration.template.yaml index 1292c52545..8ba168dc9d 100644 --- a/src/backend/config/configuration.template.yaml +++ b/src/backend/config/configuration.template.yaml @@ -44,6 +44,7 @@ feature_flags: # Experimental features use_experimental_langchain: false use_agents_view: false + use_compass_file_storage: false # Community features use_community_features: true auth: diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 3f3dc8b1b7..d1f2c28556 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -157,8 +157,8 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: if not query or not files: return [] + file_ids = [file_id for _, file_id in files] if Settings().feature_flags.use_compass_file_storage: - file_ids = [file_id for _, file_id in files] return compass_file_search( file_ids, conversation_id, @@ -172,7 +172,7 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: return [] results = [] - for file in files: + for file in retrieved_files: results.append( { "text": file.file_content, From f1a500b6bf3bcf199f855102d43f1d1b45d5d598 Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 9 Aug 2024 14:50:26 -0400 Subject: [PATCH 23/30] test bug fixed --- src/backend/config/secrets.template.yaml | 5 +++++ src/backend/tests/conftest.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/backend/config/secrets.template.yaml b/src/backend/config/secrets.template.yaml index 1a18d1be62..dc1ced0eb6 100644 --- a/src/backend/config/secrets.template.yaml +++ b/src/backend/config/secrets.template.yaml @@ -36,3 +36,8 @@ auth: client_id: client_secret: well_known_endpoint: +compass: + username: + password: + api_url: + parser_url: \ No newline at end of file diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index b232007c41..38cdc63d2f 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -212,8 +212,8 @@ def mock_compass_settings(): with patch("backend.services.file.Settings") as MockSettings: mock_settings = MockSettings.return_value mock_settings.feature_flags.use_compass_file_storage = os.getenv( - "ENABLE_COMPASS_FILE_STORAGE", False - ) + "ENABLE_COMPASS_FILE_STORAGE", "False" + ).lower() in ("true", "1") mock_settings.tools.compass.api_url = os.getenv("COHERE_COMPASS_API_URL") mock_settings.tools.compass.api_parser_url = os.getenv( "COHERE_COMPASS_API_PARSER_URL" From 68612c32a8a99e712a45913101fc85e01bbaae23 Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 12 Aug 2024 18:51:54 -0400 Subject: [PATCH 24/30] conflicts --- src/backend/chat/custom/custom.py | 4 +- src/backend/routers/agent.py | 2 +- src/backend/routers/conversation.py | 20 +-- src/backend/routers/snapshot.py | 4 +- src/backend/services/conversation.py | 7 +- src/backend/services/file.py | 136 +++++++++++------- src/backend/services/snapshot.py | 3 +- .../tests/routers/test_conversation.py | 12 +- 8 files changed, 113 insertions(+), 75 deletions(-) diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index b557071c00..071dd46bb6 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -157,13 +157,13 @@ async def call_chat( if chat_request.file_ids or chat_request.agent_id: if ToolName.Read_File in tool_names or ToolName.Search_File in tool_names: files = get_file_service().get_files_by_conversation_id( - session, user_id, ctx.get_conversation_id() + session, user_id, ctx.get_conversation_id(), ctx ) agent_files = [] if agent_id: agent_files = get_file_service().get_files_by_agent_id( - session, user_id, agent_id + session, user_id, agent_id, ctx ) all_files = files + agent_files diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index 461e95e603..991f4e5946 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -702,7 +702,7 @@ async def delete_file( validate_file(session, file_id, user_id, agent_id) # Delete the File DB object - get_file_service().delete_agent_file_by_id(session, agent_id, file_id, user_id) + get_file_service().delete_agent_file_by_id(session, agent_id, file_id, user_id, ctx) return DeleteFileResponse() diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 2ee21d92f0..e12fb9f319 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -88,10 +88,10 @@ async def get_conversation( ) files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) files_with_conversation_id = attach_conversation_id_to_files(conversation.id, files) - messages = get_messages_with_files(session, user_id, conversation.messages) + messages = get_messages_with_files(session, user_id, conversation.messages, ctx) _ = validate_conversation(session, conversation_id, user_id) conversation = ConversationPublic( @@ -143,7 +143,7 @@ async def list_conversations( results = [] for conversation in conversations: files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) files_with_conversation_id = attach_conversation_id_to_files( conversation.id, files @@ -195,9 +195,9 @@ async def update_conversation( ) files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) - messages = get_messages_with_files(session, user_id, conversation.messages) + messages = get_messages_with_files(session, user_id, conversation.messages, ctx) files_with_conversation_id = attach_conversation_id_to_files(conversation.id, files) return ConversationPublic( id=conversation.id, @@ -235,7 +235,7 @@ async def delete_conversation( conversation = validate_conversation(session, conversation_id, user_id) get_file_service().delete_all_conversation_files( - session, conversation.id, conversation.file_ids, user_id + session, conversation.id, conversation.file_ids, user_id, ctx ) conversation_crud.delete_conversation(session, conversation_id, user_id) @@ -301,7 +301,7 @@ async def search_conversations( results = [] for conversation in filtered_documents: files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) files_with_conversation_id = attach_conversation_id_to_files( conversation.id, files @@ -386,7 +386,7 @@ async def upload_file( # Handle uploading File try: upload_file = await get_file_service().create_conversation_files( - session, [file], user_id, conversation.id + session, [file], user_id, conversation.id, ctx ) except Exception as e: raise HTTPException( @@ -497,7 +497,7 @@ async def list_files( _ = validate_conversation(session, conversation_id, user_id) files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation_id + session, user_id, conversation_id, ctx ) files_with_conversation_id = attach_conversation_id_to_files(conversation_id, files) return files_with_conversation_id @@ -530,7 +530,7 @@ async def delete_file( # Delete the File DB object get_file_service().delete_conversation_file_by_id( - session, conversation_id, file_id, user_id + session, conversation_id, file_id, user_id, ctx ) return DeleteFileResponse() diff --git a/src/backend/routers/snapshot.py b/src/backend/routers/snapshot.py index 93e76080ef..9a9be80c70 100644 --- a/src/backend/routers/snapshot.py +++ b/src/backend/routers/snapshot.py @@ -60,7 +60,9 @@ async def create_snapshot( snapshot = snapshot_crud.get_snapshot_by_last_message_id(session, last_message_id) if not snapshot: - snapshot = wrap_create_snapshot(session, last_message_id, user_id, conversation) + snapshot = wrap_create_snapshot( + session, last_message_id, user_id, conversation, ctx + ) snapshot_link = wrap_create_snapshot_link(session, snapshot.id, user_id) diff --git a/src/backend/services/conversation.py b/src/backend/services/conversation.py index 385485c466..1c00ca3407 100644 --- a/src/backend/services/conversation.py +++ b/src/backend/services/conversation.py @@ -15,6 +15,7 @@ from backend.schemas.file import File from backend.schemas.message import Message from backend.services.chat import generate_chat_response +from backend.services.compass import Compass from backend.services.context import get_context from backend.services.file import attach_conversation_id_to_files, get_file_service @@ -99,7 +100,7 @@ def extract_details_from_conversation( def get_messages_with_files( - session: DBSessionDep, user_id: str, messages: list[MessageModel] + session: DBSessionDep, user_id: str, messages: list[MessageModel], ctx: Context ) -> list[Message]: """ Get messages and use the file service to get the files associated with each message @@ -115,7 +116,9 @@ def get_messages_with_files( messages_with_file = [] for message in messages: - files = get_file_service().get_files_by_message_id(session, message.id, user_id) + files = get_file_service().get_files_by_message_id( + session, message.id, user_id, ctx + ) files_with_conversation_id = attach_conversation_id_to_files( message.conversation_id, files ) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 4eb8160e32..c950ee42e5 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -22,7 +22,7 @@ from backend.services.agent import validate_agent_exists from backend.services.compass import Compass from backend.services.context import get_context -from backend.services.logger.utils import get_logger +from backend.services.logger.utils import LoggerFactory MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -39,10 +39,11 @@ # Monkey patch Pandas to use Calamine for Excel reading because Calamine is faster than Pandas pandas_monkeypatch() -logger = get_logger() file_service = None compass = None +logger = LoggerFactory().get_logger() + def get_file_service(): """ @@ -65,11 +66,14 @@ def get_compass(): Compass: The singleton Compass instance """ global compass + if compass is None: try: compass = Compass() except Exception as e: - logger.error(f"Error initializing Compass: {e}") + logger.error( + event=f"[Compass File Service] Error initializing Compass: {e}" + ) raise e return compass @@ -97,7 +101,7 @@ async def create_conversation_files( files: list[FastAPIUploadFile], user_id: str, conversation_id: str, - ctx: Context = Depends(get_context), + ctx: Context, ) -> list[File]: """ Create files and associations with a conversation @@ -114,7 +118,7 @@ async def create_conversation_files( """ if self.is_compass_enabled: uploaded_files = await insert_files_in_compass( - files, user_id, conversation_id + files, user_id, ctx, conversation_id ) else: uploaded_files = await insert_files_in_db(session, files, user_id) @@ -136,7 +140,7 @@ async def create_agent_files( session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str, - ctx: Context = Depends(get_context), + ctx: Context, ) -> list[File]: """ Create files and associations with an agent @@ -154,14 +158,14 @@ async def create_agent_files( Since agents are created after the files are upload we index files into dummy indices first We later consolidate them in consolidate_agent_files_in_compass() to a singular index when an agent is created. """ - uploaded_files = await insert_files_in_compass(files, user_id) + uploaded_files = await insert_files_in_compass(files, ctx, user_id) else: uploaded_files = await insert_files_in_db(session, files, user_id) return uploaded_files def get_files_by_agent_id( - self, session: DBSessionDep, user_id: str, agent_id: str + self, session: DBSessionDep, user_id: str, agent_id: str, ctx: Context ) -> list[File]: """ Get files by agent ID @@ -201,14 +205,14 @@ def get_files_by_agent_id( ) if self.is_compass_enabled: - files = get_files_in_compass(agent_id, file_ids, user_id) + files = get_files_in_compass(agent_id, file_ids, user_id, ctx) else: files = file_crud.get_files_by_ids(session, file_ids, user_id) return files def get_files_by_conversation_id( - self, session: DBSessionDep, user_id: str, conversation_id: str + self, session: DBSessionDep, user_id: str, conversation_id: str, ctx: Context ) -> list[FileModel]: """ Get files by conversation ID @@ -234,14 +238,19 @@ def get_files_by_conversation_id( files = [] if file_ids is not None: if self.is_compass_enabled: - files = get_files_in_compass(conversation_id, file_ids, user_id) + files = get_files_in_compass(conversation_id, file_ids, user_id, ctx) else: files = file_crud.get_files_by_ids(session, file_ids, user_id) return files def delete_conversation_file_by_id( - self, session: DBSessionDep, conversation_id: str, file_id: str, user_id: str + self, + session: DBSessionDep, + conversation_id: str, + file_id: str, + user_id: str, + ctx: Context, ) -> None: """ Delete a file asociated with a conversation @@ -257,14 +266,19 @@ def delete_conversation_file_by_id( ) if self.is_compass_enabled: - delete_file_in_compass(conversation_id, file_id, user_id) + delete_file_in_compass(conversation_id, file_id, user_id, ctx) else: file_crud.delete_file(session, file_id, user_id) return def delete_agent_file_by_id( - self, session: DBSessionDep, agent_id: str, file_id: str, user_id: str + self, + session: DBSessionDep, + agent_id: str, + file_id: str, + user_id: str, + ctx: Context, ) -> None: """ Delete a file asociated with an agent @@ -276,7 +290,7 @@ def delete_agent_file_by_id( user_id (str): The user ID """ if self.is_compass_enabled: - delete_file_in_compass(agent_id, file_id, user_id) + delete_file_in_compass(agent_id, file_id, user_id, ctx) else: file_crud.delete_file(session, file_id, user_id) @@ -288,6 +302,7 @@ def delete_all_conversation_files( conversation_id: str, file_ids: list[str], user_id: str, + ctx: Context = Depends(get_context), ) -> None: """ Delete all files associated with a conversation @@ -297,16 +312,20 @@ def delete_all_conversation_files( conversation_id (str): The conversation ID file_ids (list[str]): The file IDs user_id (str): The user ID + ctx (Context): Context object """ + logger = ctx.get_logger() + if self.is_compass_enabled: + compass = get_compass() try: - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.DELETE_INDEX, parameters={"index": conversation_id}, ) except Exception as e: logger.error( - f"Error deleting conversation {conversation_id} files from Compass: {e}" + event=f"[Compass File Service] Error deleting conversation {conversation_id} files from Compass: {e}" ) else: file_crud.bulk_delete_files(session, file_ids, user_id) @@ -314,7 +333,7 @@ def delete_all_conversation_files( return def get_files_by_message_id( - self, session: DBSessionDep, message_id: str, user_id: str + self, session: DBSessionDep, message_id: str, user_id: str, ctx: Context ) -> list[File]: """ Get message files @@ -332,7 +351,7 @@ def get_files_by_message_id( if message.file_ids is not None: if self.is_compass_enabled: files = get_files_in_compass( - message.conversation_id, message.file_ids, user_id + message.conversation_id, message.file_ids, user_id, ctx ) else: files = file_crud.get_files_by_ids(session, message.file_ids, user_id) @@ -340,7 +359,9 @@ def get_files_by_message_id( # Compass Operations -def delete_file_in_compass(index: str, file_id: str, user_id: str) -> None: +def delete_file_in_compass( + index: str, file_id: str, user_id: str, ctx: Context +) -> None: """ Delete a file from Compass @@ -348,22 +369,28 @@ def delete_file_in_compass(index: str, file_id: str, user_id: str) -> None: index (str): The index file_id (str): The file ID user_id (str): The user ID + ctx (Context): Context object Raises: HTTPException: If the file is not found """ + logger = ctx.get_logger() + compass = get_compass() + try: - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.DELETE, parameters={"index": index, "file_id": file_id}, ) except Exception as e: logger.error( - f"[Compass File] Error deleting file {file_id} on index {index} from Compass: {e}" + event=f"[Compass File Service] Error deleting file {file_id} on index {index} from Compass: {e}" ) -def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[File]: +def get_files_in_compass( + index: str, file_ids: list[str], user_id: str, ctx: Context +) -> list[File]: """ Get files from Compass @@ -375,17 +402,15 @@ def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[ Returns: list[File]: The files that were created """ + compass = get_compass() + files = [] for file_id in file_ids: try: - fetched_doc = ( - get_compass() - .invoke( - action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": index, "file_id": file_id}, - ) - .result["doc"]["content"] - ) + fetched_doc = compass.invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": index, "file_id": file_id}, + ).result["doc"]["content"] except Exception as e: raise HTTPException( status_code=404, detail=f"File with ID: {file_id} not found." @@ -410,6 +435,7 @@ def get_files_in_compass(index: str, file_ids: list[str], user_id: str) -> list[ async def consolidate_agent_files_in_compass( file_ids, agent_id, + ctx: Context, ) -> None: """ Consolidate files into a single index (agent ID) in Compass. @@ -419,9 +445,13 @@ async def consolidate_agent_files_in_compass( Args: file_ids (list[str]): The file IDs agent_id (str): The agent ID + ctx (Context): Context object """ + logger = ctx.get_logger() + compass = get_compass() + try: - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.CREATE_INDEX, parameters={ "index": agent_id, @@ -429,7 +459,7 @@ async def consolidate_agent_files_in_compass( ) except Exception as e: logger.Error( - f"[Compass File] Error creating index for agent files: {agent_id}, error: {e}" + event=f"[Compass File Service] Error creating index for agent files: {agent_id}, error: {e}" ) raise HTTPException( status_code=500, @@ -438,15 +468,11 @@ async def consolidate_agent_files_in_compass( for file_id in file_ids: try: - fetched_doc = ( - get_compass() - .invoke( - action=Compass.ValidActions.GET_DOCUMENT, - parameters={"index": file_id, "file_id": file_id}, - ) - .result["doc"]["content"] - ) - get_compass().invoke( + fetched_doc = compass.invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ).result["doc"]["content"] + compass().invoke( action=Compass.ValidActions.CREATE, parameters={ "index": agent_id, @@ -463,17 +489,17 @@ async def consolidate_agent_files_in_compass( }, }, ) - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.REFRESH, parameters={"index": agent_id}, ) # Remove the temporary file index entry - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.DELETE_INDEX, parameters={"index": file_id} ) except Exception as e: logger.error( - f"[Compass File] Error consolidating file {file_id} into agent {agent_id}, error: {e}" + event=f"[Compass File Service] Error consolidating file {file_id} into agent {agent_id}, error: {e}" ) raise HTTPException( status_code=500, @@ -484,18 +510,24 @@ async def consolidate_agent_files_in_compass( async def insert_files_in_compass( files: list[FastAPIUploadFile], user_id: str, + ctx: Context, index: str = None, ) -> list[File]: + logger = ctx.get_logger() + compass = get_compass() + if index is not None: try: - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.CREATE_INDEX, parameters={ "index": index, }, ) except Exception as e: - logger.error(f"[Compass File] Failed to create index: {index}, error: {e}") + logger.error( + event=f"[Compass File Service] Failed to create index: {index}, error: {e}" + ) uploaded_files = [] for file in files: @@ -507,7 +539,7 @@ async def insert_files_in_compass( # Consolidate them under one agent index during agent creation if index is None: try: - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.CREATE_INDEX, parameters={ "index": new_file_id, @@ -515,11 +547,11 @@ async def insert_files_in_compass( ) except Exception as e: logger.error( - f"[Compass File] Failed to create index: {index}, error: {e}" + event=f"[Compass File Service] Failed to create index: {index}, error: {e}" ) try: - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.CREATE, parameters={ "index": new_file_id if index is None else index, @@ -536,13 +568,13 @@ async def insert_files_in_compass( }, }, ) - get_compass().invoke( + compass.invoke( action=Compass.ValidActions.REFRESH, parameters={"index": new_file_id if index is None else index}, ) except Exception as e: logger.error( - f"[Compass File] Failed to create document on index: {index}, error: {e}" + event=f"[Compass File Service] Failed to create document on index: {index}, error: {e}" ) uploaded_files.append( diff --git a/src/backend/services/snapshot.py b/src/backend/services/snapshot.py index 7e953efb0c..888df2c995 100644 --- a/src/backend/services/snapshot.py +++ b/src/backend/services/snapshot.py @@ -58,6 +58,7 @@ def wrap_create_snapshot( last_message_id: str, user_id: str, conversation: Conversation, + ctx: Context, ) -> SnapshotModel: snapshot_agent = None if conversation.agent_id: @@ -72,7 +73,7 @@ def wrap_create_snapshot( tools_metadata=tools_metadata, ) - messages = get_messages_with_files(session, user_id, conversation.messages) + messages = get_messages_with_files(session, user_id, conversation.messages, ctx) snapshot_data = SnapshotData( title=conversation.title, description=conversation.description, diff --git a/src/backend/tests/routers/test_conversation.py b/src/backend/tests/routers/test_conversation.py index a993b05bfb..bd0f42c769 100644 --- a/src/backend/tests/routers/test_conversation.py +++ b/src/backend/tests/routers/test_conversation.py @@ -1,5 +1,5 @@ import os -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient @@ -663,7 +663,7 @@ def test_upload_file_existing_conversation( file = response.json() assert response.status_code == 200 files = get_file_service().get_files_by_conversation_id( - session, conversation.user_id, conversation.id + session, conversation.user_id, conversation.id, MagicMock() ) assert len(files) == 1 assert "Mariana_Trench" in file["file_name"] @@ -695,7 +695,7 @@ def test_upload_file_nonexistent_conversation_creates_new_conversation( ) files = get_file_service().get_files_by_conversation_id( - session, created_conversation.user_id, created_conversation.id + session, created_conversation.user_id, created_conversation.id, MagicMock() ) assert len(files) == 1 assert "Mariana_Trench" in file["file_name"] @@ -757,7 +757,7 @@ def test_batch_upload_file_existing_conversation( assert conversation_file_association is not None files_stored = get_file_service().get_files_by_conversation_id( - session, conversation.user_id, conversation.id + session, conversation.user_id, conversation.id, MagicMock() ) assert len(files_stored) == len(file_paths) @@ -880,7 +880,7 @@ def test_batch_upload_file_nonexistent_conversation_creates_new_conversation( assert conversation_file_association is not None files_stored = get_file_service().get_files_by_conversation_id( - session, created_conversation.user_id, created_conversation.id + session, created_conversation.user_id, created_conversation.id, MagicMock() ) assert len(files_stored) == len(file_paths) @@ -941,7 +941,7 @@ def test_delete_file( # Check if File files = get_file_service().get_files_by_conversation_id( - session, conversation.user_id, conversation.id + session, conversation.user_id, conversation.id, MagicMock() ) assert files == [] From 8b10b2d9268157f43f1da7273747b25cee31e26a Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 13 Aug 2024 13:02:33 -0400 Subject: [PATCH 25/30] fix --- src/backend/services/file.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index c950ee42e5..6292b75bfb 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -594,7 +594,7 @@ async def insert_files_in_compass( # Misc def validate_file( - session: DBSessionDep, file_id: str, user_id: str, index: str = None + session: DBSessionDep, file_id: str, user_id: str, ctx: Context, index: str = None ) -> File: """Validates if a file exists and belongs to the user @@ -610,7 +610,7 @@ def validate_file( HTTPException: If the file is not found """ if Settings().feature_flags.use_compass_file_storage: - file = get_files_in_compass(index, [file_id], user_id)[0] + file = get_files_in_compass(index, [file_id], user_id, ctx)[0] else: file = file_crud.get_file(session, file_id, user_id) From c5cc36f644f672313e418ddbbc418e985a24fad1 Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 13 Aug 2024 13:37:16 -0400 Subject: [PATCH 26/30] add chat test --- src/backend/tests/routers/test_chat.py | 37 ++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/backend/tests/routers/test_chat.py b/src/backend/tests/routers/test_chat.py index 2400dc9269..136c689411 100644 --- a/src/backend/tests/routers/test_chat.py +++ b/src/backend/tests/routers/test_chat.py @@ -976,6 +976,43 @@ def validate_chat_streaming_response( # Check if the conversation was created correctly validate_conversation(session, user, conversation_id, expected_num_messages) +@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") +def test_streaming_chat_with_files(session_client_chat: TestClient, session_chat: Session, user: User, mock_compass_settings): + # Create convo + conversation = get_factory("Conversation", session_chat).create(user_id=user.id) + + # Upload the files + files = [ + ( + "files", + ( + "Mariana_Trench.pdf", + open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"), + ), + ) + ] + + response = session_client_chat.post( + "/v1/conversations/batch_upload_file", + headers={"User-Id": conversation.user_id}, + files=files, + data={"conversation_id": conversation.id}, + ) + + assert response.status_code == 200 + file_id = response.json()[0]["id"] + + # Send the chat request + response = session_client_chat.post( + "/v1/chat", + json={"message": "Hello", "max_tokens": 10, "file_ids": [file_id], "tools": [{"name": "search_file"}]}, + headers={ + "User-Id": user.id, + "Deployment-Name": ModelDeploymentName.CoherePlatform, + }, + ) + + assert response.status_code == 200 def validate_conversation( session: Session, user: User, conversation_id: str, expected_num_messages: int From c3dc2b114c2dabb66c99c1d0634cf46f093f693f Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 13 Aug 2024 14:58:44 -0400 Subject: [PATCH 27/30] PR feedback --- docker-compose.yml | 16 ++++++++-------- src/backend/routers/agent.py | 13 +++++++------ src/backend/services/file.py | 5 ++--- src/backend/tests/routers/test_chat.py | 16 ++++++++++++++-- src/backend/tools/files.py | 5 +++++ 5 files changed, 36 insertions(+), 19 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 8b4c482315..bf065fefa8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -110,14 +110,14 @@ services: ignore: - node_modules/ - # terrarium: - # image: ghcr.io/cohere-ai/terrarium:latest - # ports: - # - "8080:8080" - # expose: - # - "8080" - # networks: - # - proxynet + terrarium: + image: ghcr.io/cohere-ai/terrarium:latest + ports: + - "8080:8080" + expose: + - "8080" + networks: + - proxynet volumes: db: diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index 991f4e5946..e07494075e 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -56,6 +56,7 @@ validate_update_agent_request, validate_user_header, ) +from backend.tools.files import FileToolsArtifactTypes router = APIRouter( prefix="/v1/agents", @@ -114,6 +115,7 @@ async def create_agent( ) # Consolidate agent files into one index in compass + file_tools = [ToolName.Read_File, ToolName.Search_File] if ( Settings().feature_flags.use_compass_file_storage and created_agent.tools_metadata @@ -122,8 +124,7 @@ async def create_agent( ( tool_metadata.artifacts for tool_metadata in created_agent.tools_metadata - if tool_metadata.tool_name == ToolName.Read_File - or tool_metadata.tool_name == ToolName.Search_File + if tool_metadata.tool_name in file_tools ), [], ) @@ -131,10 +132,10 @@ async def create_agent( set( artifact.get("id") for artifact in artifacts - if artifact.get("type") == "local_file" + if artifact.get("type") == FileToolsArtifactTypes.local_file ) ) - if len(file_ids) > 0: + if file_ids: await consolidate_agent_files_in_compass(file_ids, created_agent.id) if deployment_db and model_db: @@ -677,14 +678,14 @@ async def batch_upload_file( @router.delete("/{agent_id}/files/{file_id}") -async def delete_file( +async def delete_agent_file( agent_id: str, file_id: str, session: DBSessionDep, ctx: Context = Depends(get_context), ) -> DeleteFileResponse: """ - Delete a file by ID. + Delete an agent file by ID. Args: agent_id (str): Agent ID. diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 6292b75bfb..6a18790498 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -11,7 +11,6 @@ import backend.crud.conversation as conversation_crud import backend.crud.file as file_crud from backend.config.settings import Settings -from backend.crud import agent as agent_crud from backend.crud import message as message_crud from backend.database_models.conversation import ConversationFileAssociation from backend.database_models.database import DBSessionDep @@ -23,6 +22,7 @@ from backend.services.compass import Compass from backend.services.context import get_context from backend.services.logger.utils import LoggerFactory +from backend.tools.files import FileToolsArtifactTypes MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -195,12 +195,11 @@ def get_files_by_agent_id( [], # Default value if the generator is empty ) - # Remove duplicates, since file can be associated to both read document and search file tools file_ids = list( set( artifact.get("id") for artifact in artifacts - if artifact.get("type") == "local_file" + if artifact.get("type") == FileToolsArtifactTypes.local_file ) ) diff --git a/src/backend/tests/routers/test_chat.py b/src/backend/tests/routers/test_chat.py index 136c689411..ca0f6af92c 100644 --- a/src/backend/tests/routers/test_chat.py +++ b/src/backend/tests/routers/test_chat.py @@ -976,8 +976,14 @@ def validate_chat_streaming_response( # Check if the conversation was created correctly validate_conversation(session, user, conversation_id, expected_num_messages) + @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_streaming_chat_with_files(session_client_chat: TestClient, session_chat: Session, user: User, mock_compass_settings): +def test_streaming_chat_with_files( + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_compass_settings, +): # Create convo conversation = get_factory("Conversation", session_chat).create(user_id=user.id) @@ -1005,7 +1011,12 @@ def test_streaming_chat_with_files(session_client_chat: TestClient, session_chat # Send the chat request response = session_client_chat.post( "/v1/chat", - json={"message": "Hello", "max_tokens": 10, "file_ids": [file_id], "tools": [{"name": "search_file"}]}, + json={ + "message": "Hello", + "max_tokens": 10, + "file_ids": [file_id], + "tools": [{"name": "search_file"}], + }, headers={ "User-Id": user.id, "Deployment-Name": ModelDeploymentName.CoherePlatform, @@ -1014,6 +1025,7 @@ def test_streaming_chat_with_files(session_client_chat: TestClient, session_chat assert response.status_code == 200 + def validate_conversation( session: Session, user: User, conversation_id: str, expected_num_messages: int ) -> None: diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 2591986a22..bd70149329 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -1,3 +1,4 @@ +from enum import StrEnum from typing import Any, Dict, List import backend.crud.file as file_crud @@ -9,6 +10,10 @@ from backend.tools.base import BaseTool +class FileToolsArtifactTypes(StrEnum): + local_file = "local_file" + + def compass_file_search( file_ids: List[str], conversation_id: str, From 623f6398bfe7dcfa2b447dff97fe362826e131d0 Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 13 Aug 2024 15:16:11 -0400 Subject: [PATCH 28/30] fix import --- src/backend/tests/integration/conftest.py | 1 - src/backend/tests/unit/routers/test_agent.py | 1 - src/backend/tests/unit/routers/test_conversation.py | 1 - 3 files changed, 3 deletions(-) diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 85989c07cd..56ec693c56 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -19,7 +19,6 @@ from backend.schemas.organization import Organization from backend.schemas.user import User from backend.services.compass import Compass -from backend.tests.factories import get_factory from backend.tests.unit.factories import get_factory DATABASE_URL = os.environ["DATABASE_URL"] diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index 4293bdd382..50b28da994 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -11,7 +11,6 @@ from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.database_models.snapshot import Snapshot from backend.schemas.metrics import MetricsData, MetricsMessageType -from backend.tests.factories import get_factory from backend.tests.unit.factories import get_factory diff --git a/src/backend/tests/unit/routers/test_conversation.py b/src/backend/tests/unit/routers/test_conversation.py index 9e4c81f86a..d18ecf0b81 100644 --- a/src/backend/tests/unit/routers/test_conversation.py +++ b/src/backend/tests/unit/routers/test_conversation.py @@ -13,7 +13,6 @@ ) from backend.schemas.user import User from backend.services.file import MAX_FILE_SIZE, MAX_TOTAL_FILE_SIZE, get_file_service -from backend.tests.factories import get_factory from backend.tests.unit.factories import get_factory From 6237ad174b958996b0dbdb724a5054f3e27dfff8 Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 13 Aug 2024 15:30:29 -0400 Subject: [PATCH 29/30] circular import fix --- src/backend/services/file.py | 2 +- src/backend/tools/files.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 6a18790498..4ada3944ac 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -22,7 +22,6 @@ from backend.services.compass import Compass from backend.services.context import get_context from backend.services.logger.utils import LoggerFactory -from backend.tools.files import FileToolsArtifactTypes MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -179,6 +178,7 @@ def get_files_by_agent_id( list[File]: The files that were created """ from backend.config.tools import ToolName + from backend.tools.files import FileToolsArtifactTypes agent = validate_agent_exists(session, agent_id, user_id) diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index bd70149329..436aa5ed8e 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -6,7 +6,7 @@ from backend.config.settings import Settings from backend.schemas.file import File from backend.services.compass import Compass -from backend.services.file import get_compass, get_file_service +from backend.services.file import get_compass from backend.tools.base import BaseTool From a1246738104b0d76a3f536fc676bbaa0f7d77ce2 Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 13 Aug 2024 17:42:26 -0400 Subject: [PATCH 30/30] fix more stuff --- src/backend/routers/conversation.py | 2 +- src/backend/services/file.py | 6 +++++- src/backend/tests/integration/conftest.py | 16 ---------------- src/backend/tests/unit/conftest.py | 16 ++++++++++++++++ src/backend/tests/unit/routers/test_chat.py | 2 +- .../tests/unit/routers/test_conversation.py | 9 ++------- 6 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index e12fb9f319..37c0e094dc 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -526,7 +526,7 @@ async def delete_file( """ user_id = ctx.get_user_id() _ = validate_conversation(session, conversation_id, user_id) - validate_file(session, file_id, user_id, conversation_id) + validate_file(session, file_id, user_id, conversation_id, ctx) # Delete the File DB object get_file_service().delete_conversation_file_by_id( diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 4ada3944ac..26e5e6478c 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -402,6 +402,7 @@ def get_files_in_compass( list[File]: The files that were created """ compass = get_compass() + logger = ctx.get_logger() files = [] for file_id in file_ids: @@ -411,6 +412,9 @@ def get_files_in_compass( parameters={"index": index, "file_id": file_id}, ).result["doc"]["content"] except Exception as e: + logger.error( + event=f"[Compass File Service] Error fetching file {file_id} on index {index} from Compass: {e}" + ) raise HTTPException( status_code=404, detail=f"File with ID: {file_id} not found." ) @@ -593,7 +597,7 @@ async def insert_files_in_compass( # Misc def validate_file( - session: DBSessionDep, file_id: str, user_id: str, ctx: Context, index: str = None + session: DBSessionDep, file_id: str, user_id: str, index: str, ctx: Context ) -> File: """Validates if a file exists and belongs to the user diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 56ec693c56..96c9359a84 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -225,19 +225,3 @@ def mock_available_model_deployments(request): with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: yield mock - - -@pytest.fixture -def mock_compass_settings(): - with patch("backend.services.file.Settings") as MockSettings: - mock_settings = MockSettings.return_value - mock_settings.feature_flags.use_compass_file_storage = os.getenv( - "ENABLE_COMPASS_FILE_STORAGE", "False" - ).lower() in ("true", "1") - mock_settings.tools.compass.api_url = os.getenv("COHERE_COMPASS_API_URL") - mock_settings.tools.compass.api_parser_url = os.getenv( - "COHERE_COMPASS_API_PARSER_URL" - ) - mock_settings.tools.compass.username = os.getenv("COHERE_COMPASS_USERNAME") - mock_settings.tools.compass.password = os.getenv("COHERE_COMPASS_PASSWORD") - yield mock_settings diff --git a/src/backend/tests/unit/conftest.py b/src/backend/tests/unit/conftest.py index 3d4ccd7c0b..7486e7cd67 100644 --- a/src/backend/tests/unit/conftest.py +++ b/src/backend/tests/unit/conftest.py @@ -204,3 +204,19 @@ def mock_available_model_deployments(request): with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: yield mock + + +@pytest.fixture +def mock_compass_settings(): + with patch("backend.services.file.Settings") as MockSettings: + mock_settings = MockSettings.return_value + mock_settings.feature_flags.use_compass_file_storage = os.getenv( + "ENABLE_COMPASS_FILE_STORAGE", "False" + ).lower() in ("true", "1") + mock_settings.tools.compass.api_url = os.getenv("COHERE_COMPASS_API_URL") + mock_settings.tools.compass.api_parser_url = os.getenv( + "COHERE_COMPASS_API_PARSER_URL" + ) + mock_settings.tools.compass.username = os.getenv("COHERE_COMPASS_USERNAME") + mock_settings.tools.compass.password = os.getenv("COHERE_COMPASS_PASSWORD") + yield mock_settings diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 9b072b2168..2f0ea0f1b5 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -993,7 +993,7 @@ def test_streaming_chat_with_files( "files", ( "Mariana_Trench.pdf", - open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"), + open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"), ), ) ] diff --git a/src/backend/tests/unit/routers/test_conversation.py b/src/backend/tests/unit/routers/test_conversation.py index d18ecf0b81..dee6bd9b61 100644 --- a/src/backend/tests/unit/routers/test_conversation.py +++ b/src/backend/tests/unit/routers/test_conversation.py @@ -476,11 +476,10 @@ def test_list_files( "files", ( "Mariana_Trench.pdf", - open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"), + open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"), ), ) ] - response = session_client.post( "/v1/conversations/batch_upload_file", headers={"User-Id": conversation.user_id}, @@ -531,7 +530,6 @@ def test_upload_file_existing_conversation( session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/unit/test_data/Mariana_Trench.pdf" - saved_file_path = "src/backend/data/Mariana_Trench.pdf" conversation = get_factory("Conversation", session).create(user_id=user.id) file_doc = {"file": open(file_path, "rb")} @@ -556,7 +554,6 @@ def test_upload_file_nonexistent_conversation_creates_new_conversation( session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/unit/test_data/Mariana_Trench.pdf" - saved_file_path = "src/backend/data/Mariana_Trench.pdf" file_doc = {"file": open(file_path, "rb")} response = session_client.post( @@ -606,8 +603,6 @@ def test_batch_upload_file_existing_conversation( file_paths = { "Mariana_Trench.pdf": "src/backend/tests/unit/test_data/Mariana_Trench.pdf", "Cardistry.pdf": "src/backend/tests/unit/test_data/Cardistry.pdf", - "Tapas.pdf": "src/backend/tests/unit/test_data/Tapas.pdf", - "Mount_Everest.pdf": "src/backend/tests/unit/test_data/Mount_Everest.pdf", } files = [ ("files", (file_name, open(file_path, "rb"))) @@ -800,7 +795,7 @@ def test_delete_file( "files", ( "Mariana_Trench.pdf", - open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"), + open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"), ), ) ]