Skip to content

Commit

Permalink
backend: upload + chat with compass (#569)
Browse files Browse the repository at this point in the history
* changes

* debug

* upload works, fetching docs work

* getting chat to work

* cleaing up

* fixed

* fixed both search and read docs tool

* saving

* remove update, add delete

* fixes

* Fix tests

* fixing tests

* saving

* refactor indexing

* refactoring agent vs convo files

* code complete

* router test done

* clean up

* clean up

* clean up

* clean up

* bug fix for DB flow

* test bug fixed

* conflicts

* fix

* add chat test

* PR feedback

* fix import

* circular import fix

* fix more stuff
  • Loading branch information
scott-cohere authored Aug 14, 2024
1 parent 74f1394 commit 0dbb283
Show file tree
Hide file tree
Showing 22 changed files with 916 additions and 439 deletions.
6 changes: 3 additions & 3 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ async def call_chat(
if chat_request.file_ids or chat_request.agent_id:
if ToolName.Read_File in tool_names or ToolName.Search_File in tool_names:
files = get_file_service().get_files_by_conversation_id(
session, user_id, ctx.get_conversation_id()
session, user_id, ctx.get_conversation_id(), ctx
)

agent_files = []
if agent_id:
agent_files = get_file_service().get_files_by_agent_id(
session, user_id, agent_id
session, user_id, agent_id, ctx
)

all_files = files + agent_files
Expand Down Expand Up @@ -259,7 +259,7 @@ def add_files_to_chat_history(
num_words = min(25, word_count)
preview = " ".join(file.file_content.split()[:num_words])

files_message += f"Filename: {file.file_name}\nWord Count: {word_count} Preview: {preview}\n\n"
files_message += f"Filename: {file.file_name}\nFile ID: {file.id}\nWord Count: {word_count} Preview: {preview}\n\n"

chat_history.append(ChatMessage(message=files_message, role=ChatRole.SYSTEM))
return chat_history
1 change: 1 addition & 0 deletions src/backend/chat/custom/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def _call_tool_async(
user_id=ctx.get_user_id(),
trace_id=ctx.get_trace_id(),
agent_id=ctx.get_agent_id(),
conversation_id=ctx.get_conversation_id(),
agent_tool_metadata=ctx.get_agent_tool_metadata(),
)
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions src/backend/config/configuration.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ feature_flags:
# Experimental features
use_experimental_langchain: false
use_agents_view: false
use_compass_file_storage: false
# Community features
use_community_features: true
auth:
Expand Down
5 changes: 5 additions & 0 deletions src/backend/config/secrets.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ auth:
client_id:
client_secret:
well_known_endpoint:
compass:
username:
password:
api_url:
parser_url:
6 changes: 6 additions & 0 deletions src/backend/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class FeatureFlags(BaseSettings, BaseModel):
"USE_COMMUNITY_FEATURES", "use_community_features"
),
)
use_compass_file_storage: Optional[bool] = Field(
default=False,
validation_alias=AliasChoices(
"USE_COMPASS_FILE_STORAGE", "use_compass_file_storage"
),
)


class PythonToolSettings(BaseSettings, BaseModel):
Expand Down
12 changes: 6 additions & 6 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ class ToolName(StrEnum):
"type": "str",
"required": True,
},
"filenames": {
"description": "A list of one or more uploaded filename strings to search over",
"type": "list",
"files": {
"description": "A list of files represented as tuples of (filename, file ID) to search over",
"type": "list[tuple[str, str]]",
"required": True,
},
},
Expand All @@ -82,9 +82,9 @@ class ToolName(StrEnum):
display_name="Read Document",
implementation=ReadFileTool,
parameter_definitions={
"filename": {
"description": "The name of the attached file to read.",
"type": "str",
"file": {
"description": "A file represented as a tuple (filename, file ID) to read over",
"type": "tuple[str, str]",
"required": True,
}
},
Expand Down
95 changes: 94 additions & 1 deletion src/backend/routers/agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import asyncio
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from fastapi import File as RequestFile
from fastapi import HTTPException
from fastapi import UploadFile as FastAPIUploadFile

from backend.config.routers import RouterName
from backend.config.settings import Settings
from backend.config.tools import ToolName
from backend.crud import agent as agent_crud
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
from backend.crud import snapshot as snapshot_crud
Expand All @@ -28,6 +33,7 @@
)
from backend.schemas.context import Context
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.file import DeleteFileResponse, UploadFileResponse
from backend.schemas.metrics import (
DEFAULT_METRICS_AGENT,
GenericResponseMessage,
Expand All @@ -40,11 +46,17 @@
validate_agent_tool_metadata_exists,
)
from backend.services.context import get_context
from backend.services.file import (
consolidate_agent_files_in_compass,
get_file_service,
validate_batch_file_size,
)
from backend.services.request_validators import (
validate_create_agent_request,
validate_update_agent_request,
validate_user_header,
)
from backend.tools.files import FileToolsArtifactTypes

router = APIRouter(
prefix="/v1/agents",
Expand Down Expand Up @@ -101,6 +113,31 @@ async def create_agent(
await update_or_create_tool_metadata(
created_agent, tool_metadata, session, ctx
)

# Consolidate agent files into one index in compass
file_tools = [ToolName.Read_File, ToolName.Search_File]
if (
Settings().feature_flags.use_compass_file_storage
and created_agent.tools_metadata
):
artifacts = next(
(
tool_metadata.artifacts
for tool_metadata in created_agent.tools_metadata
if tool_metadata.tool_name in file_tools
),
[],
)
file_ids = list(
set(
artifact.get("id")
for artifact in artifacts
if artifact.get("type") == FileToolsArtifactTypes.local_file
)
)
if file_ids:
await consolidate_agent_files_in_compass(file_ids, created_agent.id)

if deployment_db and model_db:
deployment_config = (
agent.deployment_config
Expand Down Expand Up @@ -615,6 +652,62 @@ async def delete_agent_tool_metadata(
return DeleteAgentToolMetadata()


@router.post("/batch_upload_file", response_model=list[UploadFileResponse])
async def batch_upload_file(
session: DBSessionDep,
files: list[FastAPIUploadFile] = RequestFile(...),
ctx: Context = Depends(get_context),
) -> UploadFileResponse:
user_id = ctx.get_user_id()
validate_batch_file_size(session, user_id, files)

uploaded_files = []
try:
uploaded_files = await get_file_service().create_agent_files(
session,
files,
user_id,
ctx,
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error while uploading agent file(s): {e}."
)

return uploaded_files


@router.delete("/{agent_id}/files/{file_id}")
async def delete_agent_file(
agent_id: str,
file_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> DeleteFileResponse:
"""
Delete an agent file by ID.
Args:
agent_id (str): Agent ID.
file_id (str): File ID.
session (DBSessionDep): Database session.
Returns:
DeleteFile: Empty response.
Raises:
HTTPException: If the agent with the given ID is not found.
"""
user_id = ctx.get_user_id()
_ = validate_agent_exists(session, agent_id)
validate_file(session, file_id, user_id, agent_id)

# Delete the File DB object
get_file_service().delete_agent_file_by_id(session, agent_id, file_id, user_id, ctx)

return DeleteFileResponse()


# Default Agent Router
default_agent_router = APIRouter(
prefix="/v1/default_agent",
Expand Down
4 changes: 0 additions & 4 deletions src/backend/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ async def chat_stream(
(
session,
chat_request,
file_paths,
response_message,
should_store,
managed_tools,
Expand All @@ -91,7 +90,6 @@ async def chat_stream(
CustomChat().chat(
chat_request,
stream=True,
file_paths=file_paths,
managed_tools=managed_tools,
session=session,
ctx=ctx,
Expand Down Expand Up @@ -152,7 +150,6 @@ async def chat(
(
session,
chat_request,
file_paths,
response_message,
should_store,
managed_tools,
Expand All @@ -165,7 +162,6 @@ async def chat(
CustomChat().chat(
chat_request,
stream=False,
file_paths=file_paths,
managed_tools=managed_tools,
ctx=ctx,
),
Expand Down
Loading

0 comments on commit 0dbb283

Please sign in to comment.