diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 1e7d6fc9c3..071dd46bb6 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -157,13 +157,13 @@ async def call_chat( if chat_request.file_ids or chat_request.agent_id: if ToolName.Read_File in tool_names or ToolName.Search_File in tool_names: files = get_file_service().get_files_by_conversation_id( - session, user_id, ctx.get_conversation_id() + session, user_id, ctx.get_conversation_id(), ctx ) agent_files = [] if agent_id: agent_files = get_file_service().get_files_by_agent_id( - session, user_id, agent_id + session, user_id, agent_id, ctx ) all_files = files + agent_files @@ -259,7 +259,7 @@ def add_files_to_chat_history( num_words = min(25, word_count) preview = " ".join(file.file_content.split()[:num_words]) - files_message += f"Filename: {file.file_name}\nWord Count: {word_count} Preview: {preview}\n\n" + files_message += f"Filename: {file.file_name}\nFile ID: {file.id}\nWord Count: {word_count} Preview: {preview}\n\n" chat_history.append(ChatMessage(message=files_message, role=ChatRole.SYSTEM)) return chat_history diff --git a/src/backend/chat/custom/tool_calls.py b/src/backend/chat/custom/tool_calls.py index c904694438..e10fa5b713 100644 --- a/src/backend/chat/custom/tool_calls.py +++ b/src/backend/chat/custom/tool_calls.py @@ -89,6 +89,7 @@ async def _call_tool_async( user_id=ctx.get_user_id(), trace_id=ctx.get_trace_id(), agent_id=ctx.get_agent_id(), + conversation_id=ctx.get_conversation_id(), agent_tool_metadata=ctx.get_agent_tool_metadata(), ) except Exception as e: diff --git a/src/backend/config/configuration.template.yaml b/src/backend/config/configuration.template.yaml index bcce26e568..5c404d7b9a 100644 --- a/src/backend/config/configuration.template.yaml +++ b/src/backend/config/configuration.template.yaml @@ -44,6 +44,7 @@ feature_flags: # Experimental features use_experimental_langchain: false use_agents_view: false + use_compass_file_storage: false # Community features use_community_features: true auth: diff --git a/src/backend/config/secrets.template.yaml b/src/backend/config/secrets.template.yaml index 1a18d1be62..dc1ced0eb6 100644 --- a/src/backend/config/secrets.template.yaml +++ b/src/backend/config/secrets.template.yaml @@ -36,3 +36,8 @@ auth: client_id: client_secret: well_known_endpoint: +compass: + username: + password: + api_url: + parser_url: \ No newline at end of file diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 7e7dea6b4a..e823f7505e 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -101,6 +101,12 @@ class FeatureFlags(BaseSettings, BaseModel): "USE_COMMUNITY_FEATURES", "use_community_features" ), ) + use_compass_file_storage: Optional[bool] = Field( + default=False, + validation_alias=AliasChoices( + "USE_COMPASS_FILE_STORAGE", "use_compass_file_storage" + ), + ) class PythonToolSettings(BaseSettings, BaseModel): diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index e777dad1ad..967a520462 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -66,9 +66,9 @@ class ToolName(StrEnum): "type": "str", "required": True, }, - "filenames": { - "description": "A list of one or more uploaded filename strings to search over", - "type": "list", + "files": { + "description": "A list of files represented as tuples of (filename, file ID) to search over", + "type": "list[tuple[str, str]]", "required": True, }, }, @@ -82,9 +82,9 @@ class ToolName(StrEnum): display_name="Read Document", implementation=ReadFileTool, parameter_definitions={ - "filename": { - "description": "The name of the attached file to read.", - "type": "str", + "file": { + "description": "A file represented as a tuple (filename, file ID) to read over", + "type": "tuple[str, str]", "required": True, } }, diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index b994320281..e07494075e 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -1,9 +1,14 @@ import asyncio from typing import Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends +from fastapi import File as RequestFile +from fastapi import HTTPException +from fastapi import UploadFile as FastAPIUploadFile from backend.config.routers import RouterName +from backend.config.settings import Settings +from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.crud import agent_tool_metadata as agent_tool_metadata_crud from backend.crud import snapshot as snapshot_crud @@ -28,6 +33,7 @@ ) from backend.schemas.context import Context from backend.schemas.deployment import Deployment as DeploymentSchema +from backend.schemas.file import DeleteFileResponse, UploadFileResponse from backend.schemas.metrics import ( DEFAULT_METRICS_AGENT, GenericResponseMessage, @@ -40,11 +46,17 @@ validate_agent_tool_metadata_exists, ) from backend.services.context import get_context +from backend.services.file import ( + consolidate_agent_files_in_compass, + get_file_service, + validate_batch_file_size, +) from backend.services.request_validators import ( validate_create_agent_request, validate_update_agent_request, validate_user_header, ) +from backend.tools.files import FileToolsArtifactTypes router = APIRouter( prefix="/v1/agents", @@ -101,6 +113,31 @@ async def create_agent( await update_or_create_tool_metadata( created_agent, tool_metadata, session, ctx ) + + # Consolidate agent files into one index in compass + file_tools = [ToolName.Read_File, ToolName.Search_File] + if ( + Settings().feature_flags.use_compass_file_storage + and created_agent.tools_metadata + ): + artifacts = next( + ( + tool_metadata.artifacts + for tool_metadata in created_agent.tools_metadata + if tool_metadata.tool_name in file_tools + ), + [], + ) + file_ids = list( + set( + artifact.get("id") + for artifact in artifacts + if artifact.get("type") == FileToolsArtifactTypes.local_file + ) + ) + if file_ids: + await consolidate_agent_files_in_compass(file_ids, created_agent.id) + if deployment_db and model_db: deployment_config = ( agent.deployment_config @@ -615,6 +652,62 @@ async def delete_agent_tool_metadata( return DeleteAgentToolMetadata() +@router.post("/batch_upload_file", response_model=list[UploadFileResponse]) +async def batch_upload_file( + session: DBSessionDep, + files: list[FastAPIUploadFile] = RequestFile(...), + ctx: Context = Depends(get_context), +) -> UploadFileResponse: + user_id = ctx.get_user_id() + validate_batch_file_size(session, user_id, files) + + uploaded_files = [] + try: + uploaded_files = await get_file_service().create_agent_files( + session, + files, + user_id, + ctx, + ) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error while uploading agent file(s): {e}." + ) + + return uploaded_files + + +@router.delete("/{agent_id}/files/{file_id}") +async def delete_agent_file( + agent_id: str, + file_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), +) -> DeleteFileResponse: + """ + Delete an agent file by ID. + + Args: + agent_id (str): Agent ID. + file_id (str): File ID. + session (DBSessionDep): Database session. + + Returns: + DeleteFile: Empty response. + + Raises: + HTTPException: If the agent with the given ID is not found. + """ + user_id = ctx.get_user_id() + _ = validate_agent_exists(session, agent_id) + validate_file(session, file_id, user_id, agent_id) + + # Delete the File DB object + get_file_service().delete_agent_file_by_id(session, agent_id, file_id, user_id, ctx) + + return DeleteFileResponse() + + # Default Agent Router default_agent_router = APIRouter( prefix="/v1/default_agent", diff --git a/src/backend/routers/chat.py b/src/backend/routers/chat.py index 28f730f170..0d89edfe81 100644 --- a/src/backend/routers/chat.py +++ b/src/backend/routers/chat.py @@ -77,7 +77,6 @@ async def chat_stream( ( session, chat_request, - file_paths, response_message, should_store, managed_tools, @@ -91,7 +90,6 @@ async def chat_stream( CustomChat().chat( chat_request, stream=True, - file_paths=file_paths, managed_tools=managed_tools, session=session, ctx=ctx, @@ -152,7 +150,6 @@ async def chat( ( session, chat_request, - file_paths, response_message, should_store, managed_tools, @@ -165,7 +162,6 @@ async def chat( CustomChat().chat( chat_request, stream=False, - file_paths=file_paths, managed_tools=managed_tools, ctx=ctx, ), diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index dfef24b342..d9f87a56b1 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -8,6 +8,7 @@ from backend.chat.custom.custom import CustomChat from backend.chat.custom.utils import get_deployment from backend.config.routers import RouterName +from backend.config.settings import Settings from backend.crud import agent as agent_crud from backend.crud import conversation as conversation_crud from backend.database_models import Conversation as ConversationModel @@ -87,10 +88,10 @@ async def get_conversation( ) files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) files_with_conversation_id = attach_conversation_id_to_files(conversation.id, files) - messages = get_messages_with_files(session, user_id, conversation.messages) + messages = get_messages_with_files(session, user_id, conversation.messages, ctx) _ = validate_conversation(session, conversation_id, user_id) conversation = ConversationPublic( @@ -142,7 +143,7 @@ async def list_conversations( results = [] for conversation in conversations: files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) files_with_conversation_id = attach_conversation_id_to_files( conversation.id, files @@ -194,9 +195,9 @@ async def update_conversation( ) files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) - messages = get_messages_with_files(session, user_id, conversation.messages) + messages = get_messages_with_files(session, user_id, conversation.messages, ctx) files_with_conversation_id = attach_conversation_id_to_files(conversation.id, files) return ConversationPublic( id=conversation.id, @@ -231,12 +232,11 @@ async def delete_conversation( HTTPException: If the conversation with the given ID is not found. """ user_id = ctx.get_user_id() - _ = validate_conversation(session, conversation_id, user_id) - conversation = conversation_crud.get_conversation(session, conversation_id, user_id) - - if conversation.file_ids: - get_file_service().bulk_delete_files(session, conversation.file_ids, user_id) + conversation = validate_conversation(session, conversation_id, user_id) + get_file_service().delete_all_conversation_files( + session, conversation.id, conversation.file_ids, user_id, ctx + ) conversation_crud.delete_conversation(session, conversation_id, user_id) return DeleteConversationResponse() @@ -301,7 +301,7 @@ async def search_conversations( results = [] for conversation in filtered_documents: files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation.id + session, user_id, conversation.id, ctx ) files_with_conversation_id = attach_conversation_id_to_files( conversation.id, files @@ -351,7 +351,9 @@ async def upload_file( """ user_id = ctx.get_user_id() - validate_file_size(session, user_id, file) + # Currently do not limit file size for Compass + if Settings().feature_flags.use_compass_file_storage is False: + validate_file_size(session, user_id, file) # Create new conversation if not conversation_id: @@ -384,7 +386,7 @@ async def upload_file( # Handle uploading File try: upload_file = await get_file_service().create_conversation_files( - session, [file], user_id, conversation.id + session, [file], user_id, conversation.id, ctx ) except Exception as e: raise HTTPException( @@ -453,9 +455,14 @@ async def batch_upload_file( ) # TODO: check if file already exists in DB once we have files per agents + try: uploaded_files = await get_file_service().create_conversation_files( - session, files, user_id, conversation.id + session, + files, + user_id, + conversation.id, + ctx, ) except Exception as e: raise HTTPException( @@ -490,48 +497,12 @@ async def list_files( _ = validate_conversation(session, conversation_id, user_id) files = get_file_service().get_files_by_conversation_id( - session, user_id, conversation_id + session, user_id, conversation_id, ctx ) files_with_conversation_id = attach_conversation_id_to_files(conversation_id, files) return files_with_conversation_id -@router.put("/{conversation_id}/files/{file_id}", response_model=FilePublic) -async def update_file( - conversation_id: str, - file_id: str, - new_file: UpdateFileRequest, - session: DBSessionDep, - ctx: Context = Depends(get_context), -) -> FilePublic: - """ - Update a file by ID. - - Args: - conversation_id (str): Conversation ID. - file_id (str): File ID. - new_file (UpdateFileRequest): New file data. - session (DBSessionDep): Database session. - ctx (Context): Context object. - - Returns: - FilePublic: Updated file. - - Raises: - HTTPException: If the conversation with the given ID is not found. - """ - user_id = ctx.get_user_id() - _ = validate_conversation(session, conversation_id, user_id) - _ = validate_file(session, file_id, user_id) - - file = get_file_service().get_file_by_id(session, file_id, user_id) - file = get_file_service().update_file(session, file, new_file) - files_with_conversation_id = attach_conversation_id_to_files( - conversation_id, [file] - ) - return files_with_conversation_id[0] - - @router.delete("/{conversation_id}/files/{file_id}") async def delete_file( conversation_id: str, @@ -555,13 +526,11 @@ async def delete_file( """ user_id = ctx.get_user_id() _ = validate_conversation(session, conversation_id, user_id) - _ = validate_file(session, file_id, user_id) - - file = get_file_service().get_file_by_id(session, file_id, user_id) + validate_file(session, file_id, user_id, conversation_id, ctx) # Delete the File DB object - get_file_service().delete_file_from_conversation( - session, conversation_id, file_id, user_id + get_file_service().delete_conversation_file_by_id( + session, conversation_id, file_id, user_id, ctx ) return DeleteFileResponse() diff --git a/src/backend/routers/snapshot.py b/src/backend/routers/snapshot.py index 93e76080ef..9a9be80c70 100644 --- a/src/backend/routers/snapshot.py +++ b/src/backend/routers/snapshot.py @@ -60,7 +60,9 @@ async def create_snapshot( snapshot = snapshot_crud.get_snapshot_by_last_message_id(session, last_message_id) if not snapshot: - snapshot = wrap_create_snapshot(session, last_message_id, user_id, conversation) + snapshot = wrap_create_snapshot( + session, last_message_id, user_id, conversation, ctx + ) snapshot_link = wrap_create_snapshot_link(session, snapshot.id, user_id) diff --git a/src/backend/schemas/file.py b/src/backend/schemas/file.py index acf084767b..45ab301221 100644 --- a/src/backend/schemas/file.py +++ b/src/backend/schemas/file.py @@ -10,7 +10,8 @@ class File(BaseModel): updated_at: datetime.datetime user_id: str - conversation_id: str + conversation_id: Optional[str] = None + file_content: Optional[str] = None # Used interally file_name: str file_path: str file_size: int = Field(default=0, ge=0) diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index 66984e170c..12df1828d1 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -54,7 +54,6 @@ from backend.schemas.search_query import SearchQuery from backend.schemas.tool import Tool, ToolCall, ToolCallDelta from backend.services.agent import validate_agent_exists -from backend.services.file import get_file_service from backend.services.generators import AsyncGeneratorContextManager @@ -138,9 +137,7 @@ def process_chat( id=str(uuid4()), ) - file_paths = None if isinstance(chat_request, CohereChatRequest): - file_paths = handle_file_retrieval(session, user_id, chat_request.file_ids) if should_store: attach_files_to_messages( session, @@ -165,7 +162,6 @@ def process_chat( return ( session, chat_request, - file_paths, chatbot_message, should_store, managed_tools, @@ -217,7 +213,6 @@ def get_or_create_conversation( """ conversation_id = chat_request.conversation_id or "" conversation = conversation_crud.get_conversation(session, conversation_id, user_id) - if conversation is None: # Get the first 5 words of the user message as the title title = " ".join(user_message.split()[:5]) @@ -306,29 +301,6 @@ def create_message( return message -def handle_file_retrieval( - session: DBSessionDep, user_id: str, file_ids: List[str] | None = None -) -> list[str] | None: - """ - Retrieve file paths from the database. - - Args: - session (DBSessionDep): Database session. - user_id (str): User ID. - file_ids (List): List of File IDs. - - Returns: - list[str] | None: List of file paths or None. - """ - file_paths = None - # Use file_ids if provided - if file_ids is not None: - files = get_file_service().get_files_by_ids(session, file_ids, user_id) - file_paths = [file.file_path for file in files] - - return file_paths - - def attach_files_to_messages( session: DBSessionDep, user_id: str, diff --git a/src/backend/services/compass.py b/src/backend/services/compass.py index 66faa39745..5af71e59cc 100644 --- a/src/backend/services/compass.py +++ b/src/backend/services/compass.py @@ -111,7 +111,7 @@ def invoke( return self.compass_client.create_index( index_name=parameters["index"] ) - case self.ValidActions.CREATE_INDEX.value: + case self.ValidActions.DELETE_INDEX.value: return self.compass_client.delete_index( index_name=parameters["index"] ) @@ -157,7 +157,7 @@ def _create(self, parameters: dict, **kwargs: Any) -> Dict[str, str]: docs=compass_docs, ) if error is not None: - message = ("[Compass] Error inserting document: {error}",) + message = (f"[Compass] Error inserting document: {error}",) logger.error(event=message) raise Exception(message) @@ -271,9 +271,12 @@ def _process_file(self, parameters: dict, **kwargs: Any) -> None: text=file_text, file_id=file_id, bytes_content=isinstance(file_text, bytes), + custom_context=parameters.get("custom_context", {}), ) - def _raw_parsing(self, text: str, file_id: str, bytes_content: bool): + def _raw_parsing( + self, text: str, file_id: str, bytes_content: bool, custom_context: dict + ): text_bytes = str.encode(text) if not bytes_content else text if len(text_bytes) > DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES: logger.error( @@ -300,7 +303,9 @@ def _raw_parsing(self, text: str, file_id: str, bytes_content: bool): if res.ok: docs = [CompassDocument(**doc) for doc in res.json()["docs"]] for doc in docs: - additional_metadata = CompassParserClient._get_metadata(doc=doc) + additional_metadata = CompassParserClient._get_metadata( + doc=doc, custom_context=custom_context + ) doc.content = {**doc.content, **additional_metadata} else: docs = [] diff --git a/src/backend/services/conversation.py b/src/backend/services/conversation.py index 385485c466..1c00ca3407 100644 --- a/src/backend/services/conversation.py +++ b/src/backend/services/conversation.py @@ -15,6 +15,7 @@ from backend.schemas.file import File from backend.schemas.message import Message from backend.services.chat import generate_chat_response +from backend.services.compass import Compass from backend.services.context import get_context from backend.services.file import attach_conversation_id_to_files, get_file_service @@ -99,7 +100,7 @@ def extract_details_from_conversation( def get_messages_with_files( - session: DBSessionDep, user_id: str, messages: list[MessageModel] + session: DBSessionDep, user_id: str, messages: list[MessageModel], ctx: Context ) -> list[Message]: """ Get messages and use the file service to get the files associated with each message @@ -115,7 +116,9 @@ def get_messages_with_files( messages_with_file = [] for message in messages: - files = get_file_service().get_files_by_message_id(session, message.id, user_id) + files = get_file_service().get_files_by_message_id( + session, message.id, user_id, ctx + ) files_with_conversation_id = attach_conversation_id_to_files( message.conversation_id, files ) diff --git a/src/backend/services/file.py b/src/backend/services/file.py index 4d660a492f..26e5e6478c 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -1,22 +1,27 @@ import io +import uuid +from datetime import datetime import pandas as pd from docx import Document -from fastapi import HTTPException +from fastapi import Depends, HTTPException from fastapi import UploadFile as FastAPIUploadFile from python_calamine.pandas import pandas_monkeypatch import backend.crud.conversation as conversation_crud import backend.crud.file as file_crud -from backend.config.tools import ToolName -from backend.crud import agent as agent_crud +from backend.config.settings import Settings from backend.crud import message as message_crud from backend.database_models.conversation import ConversationFileAssociation from backend.database_models.database import DBSessionDep from backend.database_models.file import File as FileModel -from backend.schemas.file import File, UpdateFileRequest +from backend.schemas.context import Context +from backend.schemas.file import File from backend.services import utils from backend.services.agent import validate_agent_exists +from backend.services.compass import Compass +from backend.services.context import get_context +from backend.services.logger.utils import LoggerFactory MAX_FILE_SIZE = 20_000_000 # 20MB MAX_TOTAL_FILE_SIZE = 1_000_000_000 # 1GB @@ -34,81 +39,132 @@ pandas_monkeypatch() file_service = None +compass = None + +logger = LoggerFactory().get_logger() def get_file_service(): + """ + Initialize a singular instance of FileService if not initialized yet + + Returns: + FileService: The singleton FileService instance + """ global file_service if file_service is None: file_service = FileService() return file_service +def get_compass(): + """ + Initialize a singular instance of Compass if not initialized yet + + Returns: + Compass: The singleton Compass instance + """ + global compass + + if compass is None: + try: + compass = Compass() + except Exception as e: + logger.error( + event=f"[Compass File Service] Error initializing Compass: {e}" + ) + raise e + return compass + + class FileService: + """ + FileService class + + This class manages interfacing with different file storage solutions. Currently it supports storing files in the Postgres DB and or using Compass. + By default Toolkit will run with Postgres DB as the storage solution for files. + To enable Compass as the storage solution, set the `use_compass_file_storage` feature flag to `true` in the .env or .configuration file. + Also be sure to add the appropriate Compass environment variables to the .env or .configuration file. + """ + @property def is_compass_enabled(self) -> bool: - # TODO Scott: add compass env variable anc check here - return False + """ + Returns whether Compass is enabled as the file storage solution + """ + return Settings().feature_flags.use_compass_file_storage - # All these functions will eventually support file operations on Compass async def create_conversation_files( self, session: DBSessionDep, files: list[FastAPIUploadFile], user_id: str, conversation_id: str, + ctx: Context, ) -> list[File]: """ - Create files and associations with conversation + Create files and associations with a conversation Args: session (DBSessionDep): The database session files (list[FastAPIUploadFile]): The files to upload user_id (str): The user ID conversation_id (str): The conversation ID + ctx (Context): Context object Returns: list[File]: The files that were created """ - files_to_upload = [] - for file in files: - content = await get_file_content(file) - cleaned_content = content.replace("\x00", "") - filename = file.filename.encode("ascii", "ignore").decode("utf-8") - conversation = conversation_crud.get_conversation( - session, conversation_id, user_id - ) - if not conversation: - raise HTTPException( - status_code=404, - detail=f"Conversation with ID: {conversation_id} not found.", - ) - - files_to_upload.append( - FileModel( - file_name=filename, - file_size=file.size, - file_path=filename, - file_content=cleaned_content, - user_id=conversation.user_id, - ) + if self.is_compass_enabled: + uploaded_files = await insert_files_in_compass( + files, user_id, ctx, conversation_id ) + else: + uploaded_files = await insert_files_in_db(session, files, user_id) - uploaded_files = file_crud.batch_create_files(session, files_to_upload) - uploaded_file_ids = [file.id for file in uploaded_files] - for file_id in uploaded_file_ids: + for file in uploaded_files: conversation_crud.create_conversation_file_association( session, ConversationFileAssociation( conversation_id=conversation_id, user_id=user_id, - file_id=file_id, + file_id=file.id, ), ) return uploaded_files + async def create_agent_files( + self, + session: DBSessionDep, + files: list[FastAPIUploadFile], + user_id: str, + ctx: Context, + ) -> list[File]: + """ + Create files and associations with an agent + + Args: + session (DBSessionDep): The database session + files (list[FastAPIUploadFile]): The files to upload + user_id (str): The user ID + Returns: + list[File]: The files that were created + """ + uploaded_files = [] + if self.is_compass_enabled: + """ + Since agents are created after the files are upload we index files into dummy indices first + We later consolidate them in consolidate_agent_files_in_compass() to a singular index when an agent is created. + """ + uploaded_files = await insert_files_in_compass(files, ctx, user_id) + else: + uploaded_files = await insert_files_in_db(session, files, user_id) + + return uploaded_files + def get_files_by_agent_id( - self, session: DBSessionDep, user_id: str, agent_id: str + self, session: DBSessionDep, user_id: str, agent_id: str, ctx: Context ) -> list[File]: """ Get files by agent ID @@ -121,6 +177,9 @@ def get_files_by_agent_id( Returns: list[File]: The files that were created """ + from backend.config.tools import ToolName + from backend.tools.files import FileToolsArtifactTypes + agent = validate_agent_exists(session, agent_id, user_id) files = [] @@ -136,19 +195,23 @@ def get_files_by_agent_id( [], # Default value if the generator is empty ) - # TODO scott: enumerate type names (?), different types for local vs. compass? - file_ids = [ - artifact.get("id") - for artifact in artifacts - if artifact.get("type") == "local_file" - ] + file_ids = list( + set( + artifact.get("id") + for artifact in artifacts + if artifact.get("type") == FileToolsArtifactTypes.local_file + ) + ) - files = file_crud.get_files_by_ids(session, file_ids, user_id) + if self.is_compass_enabled: + files = get_files_in_compass(agent_id, file_ids, user_id, ctx) + else: + files = file_crud.get_files_by_ids(session, file_ids, user_id) return files def get_files_by_conversation_id( - self, session: DBSessionDep, user_id: str, conversation_id: str + self, session: DBSessionDep, user_id: str, conversation_id: str, ctx: Context ) -> list[FileModel]: """ Get files by conversation ID @@ -173,15 +236,23 @@ def get_files_by_conversation_id( files = [] if file_ids is not None: - files = file_crud.get_files_by_ids(session, file_ids, user_id) + if self.is_compass_enabled: + files = get_files_in_compass(conversation_id, file_ids, user_id, ctx) + else: + files = file_crud.get_files_by_ids(session, file_ids, user_id) return files - def delete_file_from_conversation( - self, session: DBSessionDep, conversation_id: str, file_id: str, user_id: str + def delete_conversation_file_by_id( + self, + session: DBSessionDep, + conversation_id: str, + file_id: str, + user_id: str, + ctx: Context, ) -> None: """ - Delete file from conversation + Delete a file asociated with a conversation Args: session (DBSessionDep): The database session @@ -192,73 +263,76 @@ def delete_file_from_conversation( conversation_crud.delete_conversation_file_association( session, conversation_id, file_id, user_id ) - file_crud.delete_file(session, file_id, user_id) - return - - def get_file_by_id(self, session: DBSessionDep, file_id: str, user_id: str) -> File: - """ - Get file by ID - Args: - session (DBSessionDep): The database session - file_id (str): The file ID - user_id (str): The user ID + if self.is_compass_enabled: + delete_file_in_compass(conversation_id, file_id, user_id, ctx) + else: + file_crud.delete_file(session, file_id, user_id) - Returns: - File: The file that was created - """ - file = file_crud.get_file(session, file_id, user_id) - return file + return - def get_files_by_ids( - self, session: DBSessionDep, file_ids: list[str], user_id: str - ) -> list[FileModel]: + def delete_agent_file_by_id( + self, + session: DBSessionDep, + agent_id: str, + file_id: str, + user_id: str, + ctx: Context, + ) -> None: """ - Get files by IDs + Delete a file asociated with an agent Args: session (DBSessionDep): The database session - file_ids (list[str]): The file IDs + agent_id (str): The agent ID + file_id (str): The file ID user_id (str): The user ID - - Returns: - list[File]: The files that were created """ - files = file_crud.get_files_by_ids(session, file_ids, user_id) - return files - - def update_file( - self, session: DBSessionDep, file: File, new_file: UpdateFileRequest - ) -> File: - """ - Update file - - Args: - session (DBSessionDep): The database session - file (File): The file to update - new_file (UpdateFileRequest): The new file data + if self.is_compass_enabled: + delete_file_in_compass(agent_id, file_id, user_id, ctx) + else: + file_crud.delete_file(session, file_id, user_id) - Returns: - File: The updated file - """ - updated_file = file_crud.update_file(session, file, new_file) - return updated_file + return - def bulk_delete_files( - self, session: DBSessionDep, file_ids: list[str], user_id: str + def delete_all_conversation_files( + self, + session: DBSessionDep, + conversation_id: str, + file_ids: list[str], + user_id: str, + ctx: Context = Depends(get_context), ) -> None: """ - Bulk delete files + Delete all files associated with a conversation Args: session (DBSessionDep): The database session + conversation_id (str): The conversation ID file_ids (list[str]): The file IDs user_id (str): The user ID + ctx (Context): Context object """ - file_crud.bulk_delete_files(session, file_ids, user_id) + logger = ctx.get_logger() + + if self.is_compass_enabled: + compass = get_compass() + try: + compass.invoke( + action=Compass.ValidActions.DELETE_INDEX, + parameters={"index": conversation_id}, + ) + except Exception as e: + logger.error( + event=f"[Compass File Service] Error deleting conversation {conversation_id} files from Compass: {e}" + ) + else: + file_crud.bulk_delete_files(session, file_ids, user_id) + + return def get_files_by_message_id( - self, session: DBSessionDep, message_id: str, user_id: str + self, session: DBSessionDep, message_id: str, user_id: str, ctx: Context ) -> list[File]: """ Get message files @@ -274,31 +348,257 @@ def get_files_by_message_id( message = message_crud.get_message(session, message_id, user_id) files = [] if message.file_ids is not None: - files = file_crud.get_files_by_ids(session, message.file_ids, user_id) + if self.is_compass_enabled: + files = get_files_in_compass( + message.conversation_id, message.file_ids, user_id, ctx + ) + else: + files = file_crud.get_files_by_ids(session, message.file_ids, user_id) return files -def attach_conversation_id_to_files( - conversation_id: str, files: list[FileModel] +# Compass Operations +def delete_file_in_compass( + index: str, file_id: str, user_id: str, ctx: Context +) -> None: + """ + Delete a file from Compass + + Args: + index (str): The index + file_id (str): The file ID + user_id (str): The user ID + ctx (Context): Context object + + Raises: + HTTPException: If the file is not found + """ + logger = ctx.get_logger() + compass = get_compass() + + try: + compass.invoke( + action=Compass.ValidActions.DELETE, + parameters={"index": index, "file_id": file_id}, + ) + except Exception as e: + logger.error( + event=f"[Compass File Service] Error deleting file {file_id} on index {index} from Compass: {e}" + ) + + +def get_files_in_compass( + index: str, file_ids: list[str], user_id: str, ctx: Context ) -> list[File]: - results = [] + """ + Get files from Compass + + Args: + index (str): The index + file_ids (list[str]): The file IDs + user_id (str): The user ID + + Returns: + list[File]: The files that were created + """ + compass = get_compass() + logger = ctx.get_logger() + + files = [] + for file_id in file_ids: + try: + fetched_doc = compass.invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": index, "file_id": file_id}, + ).result["doc"]["content"] + except Exception as e: + logger.error( + event=f"[Compass File Service] Error fetching file {file_id} on index {index} from Compass: {e}" + ) + raise HTTPException( + status_code=404, detail=f"File with ID: {file_id} not found." + ) + + files.append( + File( + id=file_id, + file_name=fetched_doc["file_name"], + file_size=fetched_doc["file_size"], + file_path=fetched_doc["file_path"], + file_content=fetched_doc["text"], + user_id=user_id, + created_at=datetime.fromisoformat(fetched_doc["created_at"]), + updated_at=datetime.fromisoformat(fetched_doc["updated_at"]), + ) + ) + + return files + + +async def consolidate_agent_files_in_compass( + file_ids, + agent_id, + ctx: Context, +) -> None: + """ + Consolidate files into a single index (agent ID) in Compass. + We do this because when agents are created after a file is uploaded, the file is not associated with the agent. + We consolidate them in a single index to under one agent ID when an agent is created. + + Args: + file_ids (list[str]): The file IDs + agent_id (str): The agent ID + ctx (Context): Context object + """ + logger = ctx.get_logger() + compass = get_compass() + + try: + compass.invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": agent_id, + }, + ) + except Exception as e: + logger.Error( + event=f"[Compass File Service] Error creating index for agent files: {agent_id}, error: {e}" + ) + raise HTTPException( + status_code=500, + detail=f"Error creating index for agent files: {agent_id}, error: {e}", + ) + + for file_id in file_ids: + try: + fetched_doc = compass.invoke( + action=Compass.ValidActions.GET_DOCUMENT, + parameters={"index": file_id, "file_id": file_id}, + ).result["doc"]["content"] + compass().invoke( + action=Compass.ValidActions.CREATE, + parameters={ + "index": agent_id, + "file_id": file_id, + "file_text": fetched_doc["text"], + "custom_context": { + "file_id": file_id, + "file_name": fetched_doc["file_name"], + "file_path": fetched_doc["file_path"], + "file_size": fetched_doc["file_size"], + "user_id": fetched_doc["user_id"], + "created_at": fetched_doc["created_at"], + "updated_at": fetched_doc["updated_at"], + }, + }, + ) + compass.invoke( + action=Compass.ValidActions.REFRESH, + parameters={"index": agent_id}, + ) + # Remove the temporary file index entry + compass.invoke( + action=Compass.ValidActions.DELETE_INDEX, parameters={"index": file_id} + ) + except Exception as e: + logger.error( + event=f"[Compass File Service] Error consolidating file {file_id} into agent {agent_id}, error: {e}" + ) + raise HTTPException( + status_code=500, + detail=f"Error consolidating file {file_id} into agent {agent_id}, error: {e}", + ) + + +async def insert_files_in_compass( + files: list[FastAPIUploadFile], + user_id: str, + ctx: Context, + index: str = None, +) -> list[File]: + logger = ctx.get_logger() + compass = get_compass() + + if index is not None: + try: + compass.invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": index, + }, + ) + except Exception as e: + logger.error( + event=f"[Compass File Service] Failed to create index: {index}, error: {e}" + ) + + uploaded_files = [] for file in files: - results.append( + filename = file.filename.encode("ascii", "ignore").decode("utf-8") + file_bytes = await file.read() + new_file_id = str(uuid.uuid4()) + + # Create temporary index for individual file (files not associated with conversations) + # Consolidate them under one agent index during agent creation + if index is None: + try: + compass.invoke( + action=Compass.ValidActions.CREATE_INDEX, + parameters={ + "index": new_file_id, + }, + ) + except Exception as e: + logger.error( + event=f"[Compass File Service] Failed to create index: {index}, error: {e}" + ) + + try: + compass.invoke( + action=Compass.ValidActions.CREATE, + parameters={ + "index": new_file_id if index is None else index, + "file_id": new_file_id, + "file_text": file_bytes, + "custom_context": { + "file_id": new_file_id, + "file_name": filename, + "file_path": filename, + "file_size": file.size, + "user_id": user_id, + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), + }, + }, + ) + compass.invoke( + action=Compass.ValidActions.REFRESH, + parameters={"index": new_file_id if index is None else index}, + ) + except Exception as e: + logger.error( + event=f"[Compass File Service] Failed to create document on index: {index}, error: {e}" + ) + + uploaded_files.append( File( - id=file.id, - conversation_id=conversation_id, - file_name=file.file_name, - file_size=file.file_size, - file_path=file.file_path, - user_id=file.user_id, - created_at=file.created_at, - updated_at=file.updated_at, + file_name=filename, + id=new_file_id, + file_size=file.size, + file_path=filename, + user_id=user_id, + created_at=datetime.now(), + updated_at=datetime.now(), ) ) - return results + return uploaded_files -def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: + +# Misc +def validate_file( + session: DBSessionDep, file_id: str, user_id: str, index: str, ctx: Context +) -> File: """Validates if a file exists and belongs to the user Args: @@ -312,7 +612,10 @@ def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: Raises: HTTPException: If the file is not found """ - file = file_crud.get_file(session, file_id, user_id) + if Settings().feature_flags.use_compass_file_storage: + file = get_files_in_compass(index, [file_id], user_id, ctx)[0] + else: + file = file_crud.get_file(session, file_id, user_id) if not file: raise HTTPException( @@ -320,7 +623,61 @@ def validate_file(session: DBSessionDep, file_id: str, user_id: str) -> File: detail=f"File with ID: {file_id} not found.", ) - return file + +async def insert_files_in_db( + session: DBSessionDep, + files: list[FastAPIUploadFile], + user_id: str, +) -> list[File]: + """ + Insert files into the database + + Args: + session (DBSessionDep): The database session + files (list[FastAPIUploadFile]): The files to upload + user_id (str): The user ID + + Returns: + list[File]: The files that were created + """ + files_to_upload = [] + for file in files: + content = await get_file_content(file) + cleaned_content = content.replace("\x00", "") + filename = file.filename.encode("ascii", "ignore").decode("utf-8") + + files_to_upload.append( + FileModel( + file_name=filename, + file_size=file.size, + file_path=filename, + file_content=cleaned_content, + user_id=user_id, + ) + ) + + uploaded_files = file_crud.batch_create_files(session, files_to_upload) + return uploaded_files + + +def attach_conversation_id_to_files( + conversation_id: str, files: list[FileModel] +) -> list[File]: + results = [] + for file in files: + results.append( + File( + id=file.id, + conversation_id=conversation_id, + file_name=file.file_name, + file_size=file.file_size, + file_path=file.file_path, + user_id=file.user_id, + created_at=file.created_at, + updated_at=file.updated_at, + ) + ) + return results def get_file_extension(file_name: str) -> str: @@ -427,7 +784,7 @@ def validate_batch_file_size( user_id (str): The user ID files (list[FastAPIUploadFile]): The files to validate - Raises: + Raises:p HTTPException: If the file size is too large """ total_batch_size = 0 diff --git a/src/backend/services/snapshot.py b/src/backend/services/snapshot.py index 7e953efb0c..888df2c995 100644 --- a/src/backend/services/snapshot.py +++ b/src/backend/services/snapshot.py @@ -58,6 +58,7 @@ def wrap_create_snapshot( last_message_id: str, user_id: str, conversation: Conversation, + ctx: Context, ) -> SnapshotModel: snapshot_agent = None if conversation.agent_id: @@ -72,7 +73,7 @@ def wrap_create_snapshot( tools_metadata=tools_metadata, ) - messages = get_messages_with_files(session, user_id, conversation.messages) + messages = get_messages_with_files(session, user_id, conversation.messages, ctx) snapshot_data = SnapshotData( title=conversation.title, description=conversation.description, diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 5932ab18e4..96c9359a84 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -18,6 +18,7 @@ from backend.schemas.deployment import Deployment as DeploymentSchema from backend.schemas.organization import Organization from backend.schemas.user import User +from backend.services.compass import Compass from backend.tests.unit.factories import get_factory DATABASE_URL = os.environ["DATABASE_URL"] diff --git a/src/backend/tests/unit/conftest.py b/src/backend/tests/unit/conftest.py index 3d4ccd7c0b..7486e7cd67 100644 --- a/src/backend/tests/unit/conftest.py +++ b/src/backend/tests/unit/conftest.py @@ -204,3 +204,19 @@ def mock_available_model_deployments(request): with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: yield mock + + +@pytest.fixture +def mock_compass_settings(): + with patch("backend.services.file.Settings") as MockSettings: + mock_settings = MockSettings.return_value + mock_settings.feature_flags.use_compass_file_storage = os.getenv( + "ENABLE_COMPASS_FILE_STORAGE", "False" + ).lower() in ("true", "1") + mock_settings.tools.compass.api_url = os.getenv("COHERE_COMPASS_API_URL") + mock_settings.tools.compass.api_parser_url = os.getenv( + "COHERE_COMPASS_API_PARSER_URL" + ) + mock_settings.tools.compass.username = os.getenv("COHERE_COMPASS_USERNAME") + mock_settings.tools.compass.password = os.getenv("COHERE_COMPASS_PASSWORD") + yield mock_settings diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 63aea52ad3..2f0ea0f1b5 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -977,6 +977,55 @@ def validate_chat_streaming_response( validate_conversation(session, user, conversation_id, expected_num_messages) +@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") +def test_streaming_chat_with_files( + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_compass_settings, +): + # Create convo + conversation = get_factory("Conversation", session_chat).create(user_id=user.id) + + # Upload the files + files = [ + ( + "files", + ( + "Mariana_Trench.pdf", + open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"), + ), + ) + ] + + response = session_client_chat.post( + "/v1/conversations/batch_upload_file", + headers={"User-Id": conversation.user_id}, + files=files, + data={"conversation_id": conversation.id}, + ) + + assert response.status_code == 200 + file_id = response.json()[0]["id"] + + # Send the chat request + response = session_client_chat.post( + "/v1/chat", + json={ + "message": "Hello", + "max_tokens": 10, + "file_ids": [file_id], + "tools": [{"name": "search_file"}], + }, + headers={ + "User-Id": user.id, + "Deployment-Name": ModelDeploymentName.CoherePlatform, + }, + ) + + assert response.status_code == 200 + + def validate_conversation( session: Session, user: User, conversation_id: str, expected_num_messages: int ) -> None: diff --git a/src/backend/tests/unit/routers/test_conversation.py b/src/backend/tests/unit/routers/test_conversation.py index aba0ca6464..dee6bd9b61 100644 --- a/src/backend/tests/unit/routers/test_conversation.py +++ b/src/backend/tests/unit/routers/test_conversation.py @@ -1,4 +1,4 @@ -import os +from unittest.mock import MagicMock from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -12,7 +12,7 @@ Message, ) from backend.schemas.user import User -from backend.services.file import MAX_FILE_SIZE, MAX_TOTAL_FILE_SIZE +from backend.services.file import MAX_FILE_SIZE, MAX_TOTAL_FILE_SIZE, get_file_service from backend.tests.unit.factories import get_factory @@ -262,7 +262,7 @@ def test_delete_conversation( user: User, ) -> None: conversation = get_factory("Conversation", session).create( - title="test title", user_id=user.id + title="test title", user_id=user.id, id="conversation_id" ) response = session_client.delete( f"/v1/conversations/{conversation.id}", @@ -275,7 +275,7 @@ def test_delete_conversation( # Check if the conversation was deleted conversation = ( session.query(Conversation) - .filter_by(id=conversation.id, user_id=conversation.user_id) + .filter_by(id="conversation_id", user_id=user.id) .first() ) assert conversation is None @@ -339,7 +339,7 @@ def test_delete_conversation_with_messages( user: User, ) -> None: conversation = get_factory("Conversation", session).create( - title="test title", user_id=user.id + title="test title", user_id=user.id, id="conversation_id" ) _ = get_factory("Message", session).create( text="test message", @@ -358,7 +358,7 @@ def test_delete_conversation_with_messages( # Check if the conversation was deleted conversation = ( session.query(Conversation) - .filter_by(id=conversation.id, user_id=user.id) + .filter_by(id="conversation_id", user_id=user.id) .first() ) assert conversation is None @@ -468,20 +468,27 @@ def test_search_conversations_missing_user_id( # FILES def test_list_files( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: - file = get_factory("File", session).create( - file_name="test_file.txt", - user_id=user.id, - ) - conversation = get_factory("Conversation", session).create( - user_id=user.id, - ) - _ = get_factory("ConversationFileAssociation", session).create( - conversation_id=conversation.id, user_id=user.id, file_id=file.id + conversation = get_factory("Conversation", session).create(user_id=user.id) + files = [ + ( + "files", + ( + "Mariana_Trench.pdf", + open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"), + ), + ) + ] + response = session_client.post( + "/v1/conversations/batch_upload_file", + headers={"User-Id": conversation.user_id}, + files=files, + data={"conversation_id": conversation.id}, ) + assert response.status_code == 200 + files = response.json() + uploaded_file = files[0] response = session_client.get( f"/v1/conversations/{conversation.id}/files", @@ -492,14 +499,12 @@ def test_list_files( response = response.json() assert len(response) == 1 response_file = response[0] - assert response_file["id"] == file.id - assert response_file["file_name"] == "test_file.txt" + assert response_file["id"] == uploaded_file["id"] + assert response_file["file_name"] == uploaded_file["file_name"] def test_list_files_no_files( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) response = session_client.get( @@ -512,9 +517,7 @@ def test_list_files_no_files( def test_list_files_missing_user_id( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) response = session_client.get(f"/v1/conversations/{conversation.id}/files") @@ -524,12 +527,9 @@ def test_list_files_missing_user_id( def test_upload_file_existing_conversation( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/unit/test_data/Mariana_Trench.pdf" - saved_file_path = "src/backend/data/Mariana_Trench.pdf" conversation = get_factory("Conversation", session).create(user_id=user.id) file_doc = {"file": open(file_path, "rb")} @@ -541,29 +541,25 @@ def test_upload_file_existing_conversation( ) file = response.json() - - file_in_db = session.get(File, file.get("id")) - assert file_in_db is not None assert response.status_code == 200 + files = get_file_service().get_files_by_conversation_id( + session, conversation.user_id, conversation.id, MagicMock() + ) + assert len(files) == 1 assert "Mariana_Trench" in file["file_name"] - assert conversation.file_ids == [file_in_db.id] - - # File should not exist in the directory - assert not os.path.exists(saved_file_path) + assert conversation.file_ids == [file["id"]] def test_upload_file_nonexistent_conversation_creates_new_conversation( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/unit/test_data/Mariana_Trench.pdf" - saved_file_path = "src/backend/data/Mariana_Trench.pdf" file_doc = {"file": open(file_path, "rb")} response = session_client.post( "/v1/conversations/upload_file", files=file_doc, headers={"User-Id": user.id} ) + assert response.status_code == 200 file = response.json() @@ -577,22 +573,20 @@ def test_upload_file_nonexistent_conversation_creates_new_conversation( .filter_by(id=conversation_file_association.conversation_id) .first() ) - file_in_db = session.get(File, file.get("id")) - assert file_in_db is not None - assert response.status_code == 200 + + files = get_file_service().get_files_by_conversation_id( + session, created_conversation.user_id, created_conversation.id, MagicMock() + ) + assert len(files) == 1 + assert "Mariana_Trench" in file["file_name"] + assert created_conversation.file_ids == [file["id"]] assert created_conversation is not None assert created_conversation.user_id == user.id assert conversation_file_association is not None - assert "Mariana_Trench" in file.get("file_name") - - # File should not exist in the directory - assert not os.path.exists(saved_file_path) def test_upload_file_nonexistent_conversation_fails_if_user_id_not_provided( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_path = "src/backend/tests/unit/test_data/Mariana_Trench.pdf" file_doc = {"file": open(file_path, "rb")} @@ -604,20 +598,12 @@ def test_upload_file_nonexistent_conversation_fails_if_user_id_not_provided( def test_batch_upload_file_existing_conversation( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: file_paths = { "Mariana_Trench.pdf": "src/backend/tests/unit/test_data/Mariana_Trench.pdf", "Cardistry.pdf": "src/backend/tests/unit/test_data/Cardistry.pdf", - "Tapas.pdf": "src/backend/tests/unit/test_data/Tapas.pdf", - "Mount_Everest.pdf": "src/backend/tests/unit/test_data/Mount_Everest.pdf", } - saved_file_paths = [ - "src/backend/data/Mariana_Trench.pdf", - "src/backend/data/Cardistry.pdf", - "src/backend/data/Tapas.pdf", - "src/backend/data/Mount_Everest.pdf", - ] files = [ ("files", (file_name, open(file_path, "rb"))) for file_name, file_path in file_paths.items() @@ -648,13 +634,14 @@ def test_batch_upload_file_existing_conversation( ) assert conversation_file_association is not None - # File should not exist in the directory - for saved_file_path in saved_file_paths: - assert not os.path.exists(saved_file_path) + files_stored = get_file_service().get_files_by_conversation_id( + session, conversation.user_id, conversation.id, MagicMock() + ) + assert len(files_stored) == len(file_paths) def test_batch_upload_total_files_exceeds_limit( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: _ = get_factory("Conversation", session).create(user_id=user.id) file_paths = { @@ -691,7 +678,7 @@ def test_batch_upload_total_files_exceeds_limit( def test_batch_upload_single_file_exceeds_limit( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: _ = get_factory("Conversation", session).create(user_id=user.id) file_paths = { @@ -721,7 +708,7 @@ def test_batch_upload_single_file_exceeds_limit( def test_batch_upload_file_nonexistent_conversation_creates_new_conversation( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user, mock_compass_settings ) -> None: file_paths = { "Mariana_Trench.pdf": "src/backend/tests/unit/test_data/Mariana_Trench.pdf", @@ -729,12 +716,6 @@ def test_batch_upload_file_nonexistent_conversation_creates_new_conversation( "Tapas.pdf": "src/backend/tests/unit/test_data/Tapas.pdf", "Mount_Everest.pdf": "src/backend/tests/unit/test_data/Mount_Everest.pdf", } - saved_file_paths = [ - "src/backend/data/Mariana_Trench.pdf", - "src/backend/data/Cardistry.pdf", - "src/backend/data/Tapas.pdf", - "src/backend/data/Mount_Everest.pdf", - ] files = [ ("files", (file_name, open(file_path, "rb"))) for file_name, file_path in file_paths.items() @@ -776,13 +757,14 @@ def test_batch_upload_file_nonexistent_conversation_creates_new_conversation( ) assert conversation_file_association is not None - # File should not exist in the directory - for saved_file_path in saved_file_paths: - assert not os.path.exists(saved_file_path) + files_stored = get_file_service().get_files_by_conversation_id( + session, created_conversation.user_id, created_conversation.id, MagicMock() + ) + assert len(files_stored) == len(file_paths) def test_batch_upload_file_nonexistent_conversation_fails_if_user_id_not_provided( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: file_paths = { "Mariana_Trench.pdf": "src/backend/tests/unit/test_data/Mariana_Trench.pdf", @@ -790,12 +772,6 @@ def test_batch_upload_file_nonexistent_conversation_fails_if_user_id_not_provide "Tapas.pdf": "src/backend/tests/unit/test_data/Tapas.pdf", "Mount_Everest.pdf": "src/backend/tests/unit/test_data/Mount_Everest.pdf", } - saved_file_paths = [ - "src/backend/data/Mariana_Trench.pdf", - "src/backend/data/Cardistry.pdf", - "src/backend/data/Tapas.pdf", - "src/backend/data/Mount_Everest.pdf", - ] files = [ ("files", (file_name, open(file_path, "rb"))) for file_name, file_path in file_paths.items() @@ -807,119 +783,49 @@ def test_batch_upload_file_nonexistent_conversation_fails_if_user_id_not_provide assert response.json() == {"detail": "User-Id required in request headers."} -def test_update_file_name( +def test_delete_file( session_client: TestClient, session: Session, user: User, + mock_compass_settings, ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) - file = get_factory("File", session).create( - file_name="test_file.txt", - user_id=conversation.user_id, - ) - response = session_client.put( - f"/v1/conversations/{conversation.id}/files/{file.id}", - headers={"User-Id": conversation.user_id}, - json={"file_name": "new name"}, - ) - response_file = response.json() - - assert response.status_code == 200 - assert response_file["file_name"] == "new name" - - # Check if the file was updated - file = ( - session.query(File).filter_by(id=file.id, user_id=conversation.user_id).first() - ) - assert file is not None - assert file.file_name == "new name" - + files = [ + ( + "files", + ( + "Mariana_Trench.pdf", + open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"), + ), + ) + ] -def test_fail_update_nonexistent_file( - session_client: TestClient, - session: Session, - user: User, -) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) - response = session_client.put( - f"/v1/conversations/{conversation.id}/files/123", - json={"file_name": "new name"}, + response = session_client.post( + "/v1/conversations/batch_upload_file", headers={"User-Id": conversation.user_id}, + files=files, + data={"conversation_id": conversation.id}, ) - - assert response.status_code == 404 - assert response.json() == {"detail": f"File with ID: 123 not found."} - - -def test_fail_update_nonexistent_file( - session_client: TestClient, - session: Session, - user: User, -) -> None: - response = session_client.put( - f"/v1/conversations/123/files/123", - json={"file_name": "new name"}, - headers={"User-Id": user.id}, - ) - - assert response.status_code == 404 - assert response.json() == {"detail": f"Conversation with ID: 123 not found."} - - -def test_fail_update_file_missing_user_id( - session_client: TestClient, - session: Session, - user: User, -) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) - file = get_factory("File", session).create( - file_name="test_file.txt", - user_id=conversation.user_id, - ) - - response = session_client.put( - f"/v1/conversations/{conversation.id}/files/{file.id}", - json={"file_name": "new name"}, - ) - - assert response.status_code == 401 - assert response.json() == {"detail": "User-Id required in request headers."} - - -def test_delete_file( - session_client: TestClient, - session: Session, - user: User, -) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) - file = get_factory("File", session).create( - id="file_id", - file_name="test_file.txt", - user_id=conversation.user_id, - ) - _ = get_factory("ConversationFileAssociation", session).create( - conversation_id=conversation.id, user_id=user.id, file_id=file.id - ) + assert response.status_code == 200 + files = response.json() + uploaded_file = files[0] response = session_client.delete( - f"/v1/conversations/{conversation.id}/files/{file.id}", + f"/v1/conversations/{conversation.id}/files/{uploaded_file['id']}", headers={"User-Id": conversation.user_id}, ) - assert response.status_code == 200 assert response.json() == {} # Check if File - db_file = ( - session.query(File) - .filter(File.id == "file_id", File.user_id == user.id) - .first() + files = get_file_service().get_files_by_conversation_id( + session, conversation.user_id, conversation.id, MagicMock() ) - assert db_file is None + assert files == [] conversation_file_association = ( session.query(ConversationFileAssociation) - .filter(File.id == "file_id", File.user_id == user.id) + .filter(File.id == uploaded_file["id"], File.user_id == user.id) .first() ) assert conversation_file_association is None @@ -931,9 +837,7 @@ def test_delete_file( def test_fail_delete_nonexistent_file( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) response = session_client.delete( @@ -946,9 +850,7 @@ def test_fail_delete_nonexistent_file( def test_fail_delete_file_missing_user_id( - session_client: TestClient, - session: Session, - user: User, + session_client: TestClient, session: Session, user: User, mock_compass_settings ) -> None: conversation = get_factory("Conversation", session).create(user_id=user.id) file = get_factory("File", session).create( diff --git a/src/backend/tests/unit/services/test_file.py b/src/backend/tests/unit/services/test_file.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 0b4cfe2f00..436aa5ed8e 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -1,9 +1,90 @@ +from enum import StrEnum from typing import Any, Dict, List import backend.crud.file as file_crud +from backend.compass_sdk import SearchFilter +from backend.config.settings import Settings +from backend.schemas.file import File +from backend.services.compass import Compass +from backend.services.file import get_compass from backend.tools.base import BaseTool +class FileToolsArtifactTypes(StrEnum): + local_file = "local_file" + + +def compass_file_search( + file_ids: List[str], + conversation_id: str, + agent_id: str, + query: str, + search_limit: int = 5, +) -> List[Dict[str, Any]]: + results = [] + + # Note: Compass search currently has an issue where the type of the context is not directly referenced + # Temporarily add `.keyword` to workaround this issue. + search_filters = [ + SearchFilter( + field="content.file_id.keyword", + type=SearchFilter.FilterType.EQ, + value=file_id, + ) + for file_id in file_ids + ] + + # Search conversation ID index + hits = ( + get_compass() + .invoke( + action=Compass.ValidActions.SEARCH, + parameters={ + "index": conversation_id, + "query": query, + "top_k": search_limit, + "filters": search_filters, + }, + ) + .result["hits"] + ) + results.extend(hits) + + # Search agent ID index + if agent_id: + hits = ( + get_compass() + .invoke( + action=Compass.ValidActions.SEARCH, + parameters={ + "index": agent_id, + "query": query, + "top_k": search_limit, + "filters": search_filters, + }, + ) + .result["hits"] + ) + results.extend(hits) + + chunks = sorted( + [ + { + "text": chunk["content"]["text"], + "score": chunk["score"], + "url": result["content"].get("title", ""), + "title": result["content"].get("title", ""), + } + for result in results + for chunk in result["chunks"] + ], + key=lambda x: x["score"], + reverse=True, + )[:search_limit] + + return chunks + + class ReadFileTool(BaseTool): """ This class reads a file from the file system. @@ -11,6 +92,7 @@ class ReadFileTool(BaseTool): NAME = "read_document" MAX_NUM_CHUNKS = 10 + SEARCH_LIMIT = 5 def __init__(self): pass @@ -19,29 +101,37 @@ def __init__(self): def is_available(cls) -> bool: return True - async def call( - self, parameters: dict, ctx: Any, **kwargs: Any - ) -> List[Dict[str, Any]]: - file_name = parameters.get("filename", "") + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + file = parameters.get("file") + session = kwargs.get("session") user_id = kwargs.get("user_id") - - if not file_name: + agent_id = kwargs.get("agent_id") + conversation_id = kwargs.get("conversation_id") + if not file: return [] - files = file_crud.get_files_by_file_names(session, [file_name], user_id) - - if not files: - return [] + _, file_id = file + if Settings().feature_flags.use_compass_file_storage: + return compass_file_search( + [file_id], + conversation_id, + agent_id, + "*", + search_limit=self.SEARCH_LIMIT, + ) + else: + retrieved_file = file_crud.get_file(session, file_id, user_id) + if not retrieved_file: + return [] - file = files[0] - return [ - { - "text": file.file_content, - "title": file.file_name, - "url": file.file_path, - } - ] + return [ + { + "text": retrieved_file.file_content, + "title": retrieved_file.file_name, + "url": retrieved_file.file_path, + } + ] class SearchFileTool(BaseTool): @@ -51,6 +141,7 @@ class SearchFileTool(BaseTool): NAME = "search_file" MAX_NUM_CHUNKS = 10 + SEARCH_LIMIT = 5 def __init__(self): pass @@ -63,31 +154,37 @@ async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: query = parameters.get("search_query") - file_names = parameters.get("filenames") + files = parameters.get("files") + + agent_id = kwargs.get("agent_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") user_id = kwargs.get("user_id") - if not query or not file_names: + if not query or not files: return [] - file_names = [ - file_name.encode("ascii", "ignore").decode("utf-8") - for file_name in file_names - ] - - files = file_crud.get_files_by_file_names(session, file_names, user_id) - - if not files: - return [] - - results = [] - for file in files: - results.append( - { - "text": file.file_content, - "title": file.file_name, - "url": file.file_path, - } + file_ids = [file_id for _, file_id in files] + if Settings().feature_flags.use_compass_file_storage: + return compass_file_search( + file_ids, + conversation_id, + agent_id, + query, + search_limit=self.SEARCH_LIMIT, ) - - return results + else: + retrieved_files = file_crud.get_files_by_ids(session, file_ids, user_id) + if not retrieved_files: + return [] + + results = [] + for file in retrieved_files: + results.append( + { + "text": file.file_content, + "title": file.file_name, + "url": file.file_path, + } + ) + return results