From eede835f1b710eb90fd8bccc32da286c4670a836 Mon Sep 17 00:00:00 2001 From: Amboyandrey Date: Mon, 7 Oct 2024 08:57:57 +0800 Subject: [PATCH 1/6] feat: fetch, update user profile; removing user profile --- backend/ee/enmedd/server/workspace/store.py | 18 +++++------ backend/enmedd/server/manage/users.py | 31 ++++++++++++++++--- web/src/app/profile/tabContent/profileTab.tsx | 2 +- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/backend/ee/enmedd/server/workspace/store.py b/backend/ee/enmedd/server/workspace/store.py index efab2776d82..a99850017ba 100644 --- a/backend/ee/enmedd/server/workspace/store.py +++ b/backend/ee/enmedd/server/workspace/store.py @@ -11,6 +11,7 @@ from ee.enmedd.server.workspace.models import AnalyticsScriptUpload from enmedd.configs.constants import FileOrigin +from enmedd.db.models import User from enmedd.dynamic_configs.factory import get_dynamic_config_store from enmedd.dynamic_configs.interface import ConfigNotFoundError from enmedd.file_store.file_store import get_default_file_store @@ -103,17 +104,14 @@ def upload_logo( return True -def upload_profile( - db_session: Session, - file: UploadFile | str, -) -> bool: +def upload_profile(db_session: Session, file: UploadFile | str, user: User) -> bool: content: IO[Any] if isinstance(file, str): - logger.info(f"Uploading logo from local path {file}") + logger.info(f"Uploading profile from local path {file}") if not os.path.isfile(file) or not is_valid_file_type(file): logger.error( - "Invalid file type- only .png, .jpg, and .jpeg files are allowed" + "Invalid file type - only .png, .jpg, and .jpeg files are allowed" ) return False @@ -124,19 +122,21 @@ def upload_profile( file_type = guess_file_type(file) else: - logger.info("Uploading logo from uploaded file") + logger.info("Uploading profile from uploaded file") if not file.filename or not is_valid_file_type(file.filename): raise HTTPException( status_code=400, - detail="Invalid file type- only .png, .jpg, and .jpeg files are allowed", + detail="Invalid file type - only .png, .jpg, and .jpeg files are allowed", ) content = file.file display_name = file.filename file_type = file.content_type or "image/jpeg" + file_name = f"{user.id}/{_PROFILE_FILENAME}" + file_store = get_default_file_store(db_session) file_store.save_file( - file_name=_PROFILE_FILENAME, + file_name=file_name, content=content, display_name=display_name, file_origin=FileOrigin.OTHER, diff --git a/backend/enmedd/server/manage/users.py b/backend/enmedd/server/manage/users.py index ee4838b618d..a88ba1e5177 100644 --- a/backend/enmedd/server/manage/users.py +++ b/backend/enmedd/server/manage/users.py @@ -381,22 +381,43 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse: def put_profile( file: UploadFile, db_session: Session = Depends(get_session), - _: User | None = Depends(current_user), + current_user: User = Depends(current_user), ) -> None: - upload_profile(file=file, db_session=db_session) + upload_profile(file=file, db_session=db_session, user=current_user) @router.get("/me/profile") def fetch_profile( db_session: Session = Depends(get_session), - _: User | None = Depends(current_user), # Ensure that the user is authenticated + current_user: User = Depends(current_user), ) -> Response: try: + file_path = f"{current_user.id}/{_PROFILE_FILENAME}" + file_store = get_default_file_store(db_session) - file_io = file_store.read_file(_PROFILE_FILENAME, mode="b") + file_io = file_store.read_file(file_path, mode="b") + return Response(content=file_io.read(), media_type="image/jpeg") except Exception: - raise HTTPException(status_code=404, detail="No logo file found") + raise HTTPException(status_code=404, detail="No profile file found") + + +@router.delete("/me/profile") +def remove_profile( + db_session: Session = Depends(get_session), + current_user: User = Depends(current_user), # Get the current user +) -> None: + try: + file_name = f"{current_user.id}/{_PROFILE_FILENAME}" + + file_store = get_default_file_store(db_session) + + file_store.delete_file(file_name) + + return {"detail": "Profile picture removed successfully."} + except Exception as e: + logger.error(f"Error removing profile picture: {str(e)}") + raise HTTPException(status_code=404, detail="Profile picture not found.") @router.get("/me") diff --git a/web/src/app/profile/tabContent/profileTab.tsx b/web/src/app/profile/tabContent/profileTab.tsx index 0759e638be4..8ce65c69a2e 100644 --- a/web/src/app/profile/tabContent/profileTab.tsx +++ b/web/src/app/profile/tabContent/profileTab.tsx @@ -26,7 +26,7 @@ export default function ProfileTab({ const [selectedFile, setSelectedFile] = useState(null); const [profileImageUrl, setProfileImageUrl] = useState(null); - // Fetch profile image on mount + // TODO: Update instantly when user profile changes useEffect(() => { const fetchProfileImage = async () => { try { From 06aa4284f27c235dde25c1948806ff0be8b13cce Mon Sep 17 00:00:00 2001 From: Amboyandrey Date: Mon, 7 Oct 2024 13:24:12 +0800 Subject: [PATCH 2/6] fix: get assistant by teamspace_id param --- backend/enmedd/db/assistant.py | 53 +++++-------------- .../enmedd/server/features/assistant/api.py | 23 ++------ 2 files changed, 19 insertions(+), 57 deletions(-) diff --git a/backend/enmedd/db/assistant.py b/backend/enmedd/db/assistant.py index 8a89545a1dd..91e6880523b 100644 --- a/backend/enmedd/db/assistant.py +++ b/backend/enmedd/db/assistant.py @@ -169,37 +169,45 @@ def get_prompts( def get_assistants( - # if user_id is `None` assume the user is an admin or auth is disabled user_id: UUID | None, db_session: Session, + teamspace_id: int | None = None, include_default: bool = True, include_deleted: bool = False, ) -> Sequence[Assistant]: stmt = select(Assistant).distinct() + if user_id is not None: - # Subquery to find all teams the user belongs to teamspaces_subquery = ( select(User__Teamspace.teamspace_id) .where(User__Teamspace.user_id == user_id) .subquery() ) - # Include assistants where the user is directly related or part of a teamspace that has access access_conditions = or_( Assistant.is_public == True, # noqa: E712 - Assistant.id.in_( # User has access through list of users with access + Assistant.id.in_( select(Assistant__User.assistant_id).where( Assistant__User.user_id == user_id ) ), - Assistant.id.in_( # User is part of a group that has access + Assistant.id.in_( select(Assistant__Teamspace.assistant_id).where( - Assistant__Teamspace.teamspace_id.in_(teamspaces_subquery) # type: ignore + Assistant__Teamspace.teamspace_id.in_(teamspaces_subquery) ) ), ) stmt = stmt.where(access_conditions) + if teamspace_id is not None: + stmt = stmt.where( + Assistant.id.in_( + select(Assistant__Teamspace.assistant_id).where( + Assistant__Teamspace.teamspace_id == teamspace_id + ) + ) + ) + if not include_default: stmt = stmt.where(Assistant.default_assistant.is_(False)) if not include_deleted: @@ -591,39 +599,6 @@ def get_assistants_by_ids( return assistants -def get_assistants_by_teamspace_id( - teamspace_id: int, - user: User | None, - db_session: Session, - include_deleted: bool = False, -) -> list[Assistant]: - stmt = ( - select(Assistant) - .join(Assistant__Teamspace) - .where(Assistant__Teamspace.teamspace_id == teamspace_id) - ) - - or_conditions = [] - - if user is not None and user.role != UserRole.ADMIN: - or_conditions.extend( - [Assistant.user_id == user.id, Assistant.user_id.is_(None)] - ) - - or_conditions.append(Assistant.is_public.is_(True)) - - if or_conditions: - stmt = stmt.where(or_(*or_conditions)) - - if not include_deleted: - stmt = stmt.where(Assistant.deleted.is_(False)) - - result = db_session.execute(stmt) - assistants = result.scalars().all() - - return assistants - - def get_prompt_by_name( prompt_name: str, user: User | None, db_session: Session ) -> Prompt | None: diff --git a/backend/enmedd/server/features/assistant/api.py b/backend/enmedd/server/features/assistant/api.py index fd78f345cd4..15523d9a463 100644 --- a/backend/enmedd/server/features/assistant/api.py +++ b/backend/enmedd/server/features/assistant/api.py @@ -10,7 +10,6 @@ from enmedd.db.assistant import create_update_assistant from enmedd.db.assistant import get_assistant_by_id from enmedd.db.assistant import get_assistants -from enmedd.db.assistant import get_assistants_by_teamspace_id from enmedd.db.assistant import mark_assistant_as_deleted from enmedd.db.assistant import mark_assistant_as_not_deleted from enmedd.db.assistant import update_all_assistants_display_priority @@ -157,6 +156,7 @@ def delete_assistant( @basic_router.get("") def list_assistants( + teamspace_id: int | None = None, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), include_deleted: bool = False, @@ -165,7 +165,10 @@ def list_assistants( return [ AssistantSnapshot.from_model(assistant) for assistant in get_assistants( - user_id=user_id, include_deleted=include_deleted, db_session=db_session + user_id=user_id, + teamspace_id=teamspace_id, + include_deleted=include_deleted, + db_session=db_session, ) ] @@ -186,22 +189,6 @@ def get_assistant( ) -@basic_router.get("/teamspace/{teamspace_id}") -def get_assistants_by_id( - teamspace_id: int, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> list[AssistantSnapshot]: - assistants = get_assistants_by_teamspace_id( - teamspace_id=teamspace_id, - user=user, - db_session=db_session, - include_deleted=False, - ) - - return [AssistantSnapshot.from_model(assistant) for assistant in assistants] - - @basic_router.get("/utils/prompt-explorer") def build_final_template_prompt( system_prompt: str, From c6c70ed0f3048dc4b49c7c6db0d36d2029b2363c Mon Sep 17 00:00:00 2001 From: Amboyandrey Date: Tue, 8 Oct 2024 09:40:52 +0800 Subject: [PATCH 3/6] feat: adding teamspace_id as parameters on selected endpoints; update chat-folder/chat-session --- backend/enmedd/db/chat.py | 15 ++++++++---- backend/enmedd/db/document_set.py | 16 +++++++++++++ backend/enmedd/db/folder.py | 19 +++++++++++++++ backend/enmedd/server/documents/connector.py | 19 +++++++++++++-- .../server/features/document_set/api.py | 18 +++++++++++---- backend/enmedd/server/features/folder/api.py | 23 +++++++++++++++---- backend/enmedd/server/manage/users.py | 20 ++++++++++++---- .../server/query_and_chat/chat_backend.py | 14 ++++++++--- 8 files changed, 121 insertions(+), 23 deletions(-) diff --git a/backend/enmedd/db/chat.py b/backend/enmedd/db/chat.py index 196ace2dd3a..3a0af55c347 100644 --- a/backend/enmedd/db/chat.py +++ b/backend/enmedd/db/chat.py @@ -1,5 +1,6 @@ from datetime import datetime from datetime import timedelta +from typing import Optional from uuid import UUID from sqlalchemy import delete @@ -70,11 +71,20 @@ def get_chat_session_by_id( def get_chat_sessions_by_user( user_id: UUID | None, + teamspace_id: Optional[int] | None, deleted: bool | None, db_session: Session, include_one_shot: bool = False, ) -> list[ChatSession]: - stmt = select(ChatSession).where(ChatSession.user_id == user_id) + stmt = select(ChatSession) + + if user_id is not None: + stmt = stmt.where(ChatSession.user_id == user_id) + + if teamspace_id is not None: + stmt = stmt.join(ChatSession__Teamspace).where( + ChatSession__Teamspace.teamspace_id == teamspace_id + ) if not include_one_shot: stmt = stmt.where(ChatSession.one_shot.is_(False)) @@ -147,7 +157,6 @@ def create_chat_session( prompt_override: PromptOverride | None = None, one_shot: bool = False, ) -> ChatSession: - # Create the chat session chat_session = ChatSession( user_id=user_id, assistant_id=assistant_id, @@ -157,11 +166,9 @@ def create_chat_session( one_shot=one_shot, ) - # Add the chat session to the session and commit it to generate an ID db_session.add(chat_session) db_session.commit() - # Add the relationship to the ChatSession__Teamspace table if teamspace_id is provided if teamspace_id: chat_session_teamspace_relationship = ChatSession__Teamspace( chat_session_id=chat_session.id, teamspace_id=teamspace_id diff --git a/backend/enmedd/db/document_set.py b/backend/enmedd/db/document_set.py index a0dae50e53c..83feaf60d68 100644 --- a/backend/enmedd/db/document_set.py +++ b/backend/enmedd/db/document_set.py @@ -14,6 +14,7 @@ from enmedd.db.models import DocumentByConnectorCredentialPair from enmedd.db.models import DocumentSet as DocumentSetDBModel from enmedd.db.models import DocumentSet__ConnectorCredentialPair +from enmedd.db.models import DocumentSet__Teamspace from enmedd.server.features.document_set.models import DocumentSetCreationRequest from enmedd.server.features.document_set.models import DocumentSetUpdateRequest from enmedd.utils.variable_functionality import fetch_versioned_implementation @@ -362,6 +363,21 @@ def fetch_all_document_sets(db_session: Session) -> Sequence[DocumentSetDBModel] return db_session.scalars(select(DocumentSetDBModel)).all() +def fetch_document_sets_by_teamspace( + teamspace_id: int, db_session: Session +) -> Sequence[DocumentSetDBModel]: + """Fetch document sets for a specific teamspace.""" + return ( + db_session.query(DocumentSetDBModel) + .join( + DocumentSet__Teamspace, + DocumentSetDBModel.id == DocumentSet__Teamspace.document_set_id, + ) + .filter(DocumentSet__Teamspace.teamspace_id == teamspace_id) + .all() + ) + + def fetch_user_document_sets( user_id: UUID | None, db_session: Session ) -> list[tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]]: diff --git a/backend/enmedd/db/folder.py b/backend/enmedd/db/folder.py index df0adf9f254..61d1ab341e5 100644 --- a/backend/enmedd/db/folder.py +++ b/backend/enmedd/db/folder.py @@ -4,6 +4,7 @@ from enmedd.db.chat import delete_chat_session from enmedd.db.models import ChatFolder +from enmedd.db.models import ChatFolder__Teamspace from enmedd.db.models import ChatSession from enmedd.utils.logger import setup_logger @@ -17,6 +18,24 @@ def get_user_folders( return db_session.query(ChatFolder).filter(ChatFolder.user_id == user_id).all() +def get_user_folders_in_teamspace( + user_id: UUID | None, + teamspace_id: int, + db_session: Session, +) -> list[ChatFolder]: + return ( + db_session.query(ChatFolder) + .join( + ChatFolder__Teamspace, ChatFolder.id == ChatFolder__Teamspace.chat_folder_id + ) + .filter( + ChatFolder.user_id == user_id, + ChatFolder__Teamspace.teamspace_id == teamspace_id, + ) + .all() + ) + + def update_folder_display_priority( user_id: UUID | None, display_priority_map: dict[int, int], diff --git a/backend/enmedd/server/documents/connector.py b/backend/enmedd/server/documents/connector.py index 083005fc1c0..7571ae596c2 100644 --- a/backend/enmedd/server/documents/connector.py +++ b/backend/enmedd/server/documents/connector.py @@ -1,6 +1,7 @@ import os import uuid from typing import cast +from typing import Optional from fastapi import APIRouter from fastapi import Depends @@ -65,6 +66,8 @@ from enmedd.db.index_attempt import create_index_attempt from enmedd.db.index_attempt import get_index_attempts_for_cc_pair from enmedd.db.index_attempt import get_latest_index_attempts +from enmedd.db.models import ConnectorCredentialPair +from enmedd.db.models import Teamspace__ConnectorCredentialPair from enmedd.db.models import User from enmedd.dynamic_configs.interface import ConfigNotFoundError from enmedd.file_store.file_store import get_default_file_store @@ -366,14 +369,26 @@ def upload_files( @router.get("/admin/connector/indexing-status") def get_connector_indexing_status( + teamspace_id: Optional[int] = None, secondary_index: bool = False, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[ConnectorIndexingStatus]: indexing_statuses: list[ConnectorIndexingStatus] = [] - # TODO: make this one query - cc_pairs = get_connector_credential_pairs(db_session) + if teamspace_id: + cc_pairs = ( + db_session.query(ConnectorCredentialPair) + .join(Teamspace__ConnectorCredentialPair) + .filter( + Teamspace__ConnectorCredentialPair.teamspace_id == teamspace_id, + Teamspace__ConnectorCredentialPair.is_current == True, # noqa: E712 + ) + .all() + ) + else: + cc_pairs = get_connector_credential_pairs(db_session) + cc_pair_identifiers = [ ConnectorCredentialPairIdentifier( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id diff --git a/backend/enmedd/server/features/document_set/api.py b/backend/enmedd/server/features/document_set/api.py index bf72904998a..cb8e5f9635c 100644 --- a/backend/enmedd/server/features/document_set/api.py +++ b/backend/enmedd/server/features/document_set/api.py @@ -1,3 +1,5 @@ +from typing import Optional + from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException @@ -7,6 +9,7 @@ from enmedd.auth.users import current_user from enmedd.db.document_set import check_document_sets_are_public from enmedd.db.document_set import fetch_all_document_sets +from enmedd.db.document_set import fetch_document_sets_by_teamspace from enmedd.db.document_set import fetch_user_document_sets from enmedd.db.document_set import insert_document_set from enmedd.db.document_set import mark_document_set_as_to_be_deleted @@ -76,13 +79,20 @@ def delete_document_set( @router.get("/admin/document-set") def list_document_sets_admin( + teamspace_id: Optional[int] = None, _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[DocumentSet]: - return [ - DocumentSet.from_model(ds) - for ds in fetch_all_document_sets(db_session=db_session) - ] + if teamspace_id: + return [ + DocumentSet.from_model(ds) + for ds in fetch_document_sets_by_teamspace(teamspace_id, db_session) + ] + else: + return [ + DocumentSet.from_model(ds) + for ds in fetch_all_document_sets(db_session=db_session) + ] """Endpoints for non-admins""" diff --git a/backend/enmedd/server/features/folder/api.py b/backend/enmedd/server/features/folder/api.py index 70e69e2fba2..66f473a904f 100644 --- a/backend/enmedd/server/features/folder/api.py +++ b/backend/enmedd/server/features/folder/api.py @@ -1,3 +1,5 @@ +from typing import Optional + from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException @@ -11,6 +13,7 @@ from enmedd.db.folder import create_folder from enmedd.db.folder import delete_folder from enmedd.db.folder import get_user_folders +from enmedd.db.folder import get_user_folders_in_teamspace from enmedd.db.folder import remove_chat_from_folder from enmedd.db.folder import rename_folder from enmedd.db.folder import update_folder_display_priority @@ -31,13 +34,23 @@ @router.get("") def get_folders( user: User = Depends(current_user), + teamspace_id: Optional[int] = None, db_session: Session = Depends(get_session), ) -> GetUserFoldersResponse: - folders = get_user_folders( - user_id=user.id if user else None, - db_session=db_session, - ) + if teamspace_id: + folders = get_user_folders_in_teamspace( + user_id=user.id if user else None, + teamspace_id=teamspace_id, + db_session=db_session, + ) + else: + folders = get_user_folders( + user_id=user.id if user else None, + db_session=db_session, + ) + folders.sort() + return GetUserFoldersResponse( folders=[ FolderResponse( @@ -79,7 +92,7 @@ def put_folder_display_priority( def create_folder_endpoint( request: FolderCreationRequest, user: User = Depends(current_user), - teamspace_id: int | None = None, + teamspace_id: Optional[int] = None, db_session: Session = Depends(get_session), ) -> int: chat_folder = create_folder( diff --git a/backend/enmedd/server/manage/users.py b/backend/enmedd/server/manage/users.py index a88ba1e5177..db3aa1ca345 100644 --- a/backend/enmedd/server/manage/users.py +++ b/backend/enmedd/server/manage/users.py @@ -46,6 +46,7 @@ from enmedd.db.models import AccessToken from enmedd.db.models import TwofactorAuth from enmedd.db.models import User +from enmedd.db.models import User__Teamspace from enmedd.db.users import change_user_password from enmedd.db.users import get_user_by_email from enmedd.db.users import list_users @@ -197,19 +198,28 @@ def list_all_users( q: str | None = None, accepted_page: int | None = None, invited_page: int | None = None, + teamspace_id: int | None = None, _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> AllUsersResponse: if not q: q = "" - users = [ - user - for user in list_users(db_session, q=q) - if not is_api_key_email_address(user.email) - ] + if teamspace_id is not None: + users = ( + db_session.query(User) + .join(User__Teamspace) + .filter(User__Teamspace.teamspace_id == teamspace_id) + .all() + ) + else: + users = list_users(db_session, q=q) + + users = [user for user in users if not is_api_key_email_address(user.email)] + accepted_emails = {user.email for user in users} invited_emails = get_invited_users() + if q: invited_emails = [ email for email in invited_emails if re.search(r"{}".format(q), email, re.I) diff --git a/backend/enmedd/server/query_and_chat/chat_backend.py b/backend/enmedd/server/query_and_chat/chat_backend.py index 90d785eb38b..46460c46dfd 100644 --- a/backend/enmedd/server/query_and_chat/chat_backend.py +++ b/backend/enmedd/server/query_and_chat/chat_backend.py @@ -1,5 +1,6 @@ import io import uuid +from typing import Optional from fastapi import APIRouter from fastapi import Depends @@ -76,6 +77,7 @@ @router.get("/get-user-chat-sessions") def get_user_chat_sessions( + teamspace_id: Optional[int] = None, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> ChatSessionsResponse: @@ -83,11 +85,16 @@ def get_user_chat_sessions( try: chat_sessions = get_chat_sessions_by_user( - user_id=user_id, deleted=False, db_session=db_session + user_id=user_id, + teamspace_id=teamspace_id, + deleted=False, + db_session=db_session, ) except ValueError: - raise ValueError("Chat session does not exist or has been deleted") + raise HTTPException( + status_code=404, detail="Chat session does not exist or has been deleted" + ) return ChatSessionsResponse( sessions=[ @@ -207,6 +214,7 @@ def create_new_chat_session( chat_session_creation_request: ChatSessionCreationRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), + teamspace_id: Optional[int] = None, ) -> CreateChatSessionID: user_id = user.id if user is not None else None try: @@ -216,7 +224,7 @@ def create_new_chat_session( or "", # Leave the naming till later to prevent delay user_id=user_id, assistant_id=chat_session_creation_request.assistant_id, - teamspace_id=chat_session_creation_request.teamspace_id, + teamspace_id=teamspace_id, ) except Exception as e: logger.exception(e) From 888f0bd5128608983aa7380bcd5d735a414f5e96 Mon Sep 17 00:00:00 2001 From: Amboyandrey Date: Thu, 10 Oct 2024 13:22:29 +0800 Subject: [PATCH 4/6] feature: update patch endpoint for teamspace --- backend/ee/enmedd/db/teamspace.py | 169 +++++++++++++------ backend/ee/enmedd/server/teamspace/models.py | 12 +- 2 files changed, 128 insertions(+), 53 deletions(-) diff --git a/backend/ee/enmedd/db/teamspace.py b/backend/ee/enmedd/db/teamspace.py index 9521c3eeaf0..4fe1da51b71 100644 --- a/backend/ee/enmedd/db/teamspace.py +++ b/backend/ee/enmedd/db/teamspace.py @@ -3,6 +3,7 @@ from typing import Optional from uuid import UUID +from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import Session @@ -146,11 +147,9 @@ def _add_user__teamspace_relationships__no_commit( db_session: Session, teamspace_id: int, user_ids: list[UUID], - creator_id: Optional[UUID], + creator_id: Optional[UUID] = None, ) -> list[User__Teamspace]: """NOTE: does not commit the transaction.""" - if creator_id is None: - return [] # if creator_id not in user_ids: # user_ids.append(creator_id) @@ -306,68 +305,138 @@ def _mark_teamspace__assistant_relationships_outdated__no_commit( teamspace__assistant_relationship.is_current = False -def update_teamspace( - db_session: Session, teamspace_id: int, teamspace: TeamspaceUpdate -) -> Teamspace: - stmt = select(Teamspace).where(Teamspace.id == teamspace_id) - db_teamspace = db_session.scalar(stmt) - if db_teamspace is None: - raise ValueError(f"Teamspace with id '{teamspace_id}' not found") - - _check_teamspace_is_modifiable(db_teamspace) - - cc_pairs_updated = set([cc_pair.id for cc_pair in db_teamspace.cc_pairs]) != set( - teamspace.cc_pair_ids +def _sync_relationships( + db_session: Session, + teamspace_id: int, + updated_user_ids: list[UUID], + updated_cc_pair_ids: list[int], + updated_document_set_ids: list[int], + updated_assistant_ids: list[int], +) -> None: + current_user_ids = set( + db_session.scalars( + select(User__Teamspace.user_id).where( + User__Teamspace.teamspace_id == teamspace_id + ) + ).all() + ) + current_cc_pair_ids = set( + db_session.scalars( + select(Teamspace__ConnectorCredentialPair.cc_pair_id).where( + Teamspace__ConnectorCredentialPair.teamspace_id == teamspace_id + ) + ).all() + ) + current_document_set_ids = set( + db_session.scalars( + select(DocumentSet__Teamspace.document_set_id).where( + DocumentSet__Teamspace.teamspace_id == teamspace_id + ) + ).all() ) - document_set_updated = set( - [document_set.id for document_set in db_teamspace.document_sets] - ) != set(teamspace.document_set_ids) - assistant_updated = set( - [assistant.id for assistant in db_teamspace.assistants] - ) != set(teamspace.assistant_ids) - users_updated = set([user.id for user in db_teamspace.users]) != set( - teamspace.user_ids + current_assistant_ids = set( + db_session.scalars( + select(Assistant__Teamspace.assistant_id).where( + Assistant__Teamspace.teamspace_id == teamspace_id + ) + ).all() ) - if users_updated: - _cleanup_user__teamspace_relationships__no_commit( - db_session=db_session, teamspace_id=teamspace_id + users_to_delete = current_user_ids - set(updated_user_ids) + users_to_add = set(updated_user_ids) - current_user_ids + + cc_pairs_to_delete = current_cc_pair_ids - set(updated_cc_pair_ids) + cc_pairs_to_add = set(updated_cc_pair_ids) - current_cc_pair_ids + + document_sets_to_delete = current_document_set_ids - set(updated_document_set_ids) + document_sets_to_add = set(updated_document_set_ids) - current_document_set_ids + + assistants_to_delete = current_assistant_ids - set(updated_assistant_ids) + assistants_to_add = set(updated_assistant_ids) - current_assistant_ids + + if users_to_delete: + db_session.execute( + delete(User__Teamspace) + .where(User__Teamspace.teamspace_id == teamspace_id) + .where(User__Teamspace.user_id.in_(users_to_delete)) ) + + if cc_pairs_to_delete: + db_session.execute( + delete(Teamspace__ConnectorCredentialPair) + .where(Teamspace__ConnectorCredentialPair.teamspace_id == teamspace_id) + .where( + Teamspace__ConnectorCredentialPair.cc_pair_id.in_(cc_pairs_to_delete) + ) + ) + + if document_sets_to_delete: + db_session.execute( + delete(DocumentSet__Teamspace) + .where(DocumentSet__Teamspace.teamspace_id == teamspace_id) + .where(DocumentSet__Teamspace.document_set_id.in_(document_sets_to_delete)) + ) + + if assistants_to_delete: + db_session.execute( + delete(Assistant__Teamspace) + .where(Assistant__Teamspace.teamspace_id == teamspace_id) + .where(Assistant__Teamspace.assistant_id.in_(assistants_to_delete)) + ) + + if users_to_add: _add_user__teamspace_relationships__no_commit( db_session=db_session, teamspace_id=teamspace_id, - user_ids=teamspace.user_ids, - ) - if cc_pairs_updated: - _mark_teamspace__cc_pair_relationships_outdated__no_commit( - db_session=db_session, teamspace_id=teamspace_id + user_ids=list(users_to_add), ) + + if cc_pairs_to_add: _add_teamspace__cc_pair_relationships__no_commit( db_session=db_session, - teamspace_id=db_teamspace.id, - cc_pair_ids=teamspace.cc_pair_ids, - ) - if document_set_updated: - _mark_teamspace__document_set_relationships_outdated__no_commit( - db_session=db_session, teamspace_id=teamspace_id + teamspace_id=teamspace_id, + cc_pair_ids=list(cc_pairs_to_add), ) + + if document_sets_to_add: _add_teamspace__document_set_relationships__no_commit( db_session=db_session, - teamspace_id=db_teamspace.id, - document_set_id=teamspace.document_set_ids, - ) - if assistant_updated: - _mark_teamspace__assistant_relationships_outdated__no_commit( - db_session=db_session, teamspace_id=teamspace_id + teamspace_id=teamspace_id, + document_set_ids=list(document_sets_to_add), ) + + if assistants_to_add: _add_teamspace__assistant_relationships__no_commit( db_session=db_session, - teamspace_id=db_teamspace.id, - assistant_id=teamspace.assistant_ids, + teamspace_id=teamspace_id, + assistant_ids=list(assistants_to_add), ) - # only needs to sync with Vespa if the cc_pairs have been updated - if cc_pairs_updated: + db_session.commit() + + +def update_teamspace( + db_session: Session, teamspace_id: int, teamspace: TeamspaceUpdate +) -> Teamspace: + stmt = select(Teamspace).where(Teamspace.id == teamspace_id) + db_teamspace = db_session.scalar(stmt) + if db_teamspace is None: + raise ValueError(f"Teamspace with id '{teamspace_id}' not found") + + _check_teamspace_is_modifiable(db_teamspace) + + _sync_relationships( + db_session=db_session, + teamspace_id=teamspace_id, + updated_user_ids=teamspace.user_ids, + updated_cc_pair_ids=teamspace.cc_pair_ids, + updated_document_set_ids=teamspace.document_set_ids, + updated_assistant_ids=teamspace.assistant_ids, + ) + + if set([cc_pair.id for cc_pair in db_teamspace.cc_pairs]) != set( + teamspace.cc_pair_ids + ): db_teamspace.is_up_to_date = False db_session.commit() @@ -403,6 +472,12 @@ def prepare_teamspace_for_deletion(db_session: Session, teamspace_id: int) -> No _mark_teamspace__cc_pair_relationships_outdated__no_commit( db_session=db_session, teamspace_id=teamspace_id ) + _mark_teamspace__document_set_relationships_outdated__no_commit( + db_session=db_session, teamspace_id=teamspace_id + ) + _mark_teamspace__assistant_relationships_outdated__no_commit( + db_session=db_session, teamspace_id=teamspace_id + ) _cleanup_token_rate_limit__teamspace_relationships__no_commit( db_session=db_session, teamspace_id=teamspace_id ) diff --git a/backend/ee/enmedd/server/teamspace/models.py b/backend/ee/enmedd/server/teamspace/models.py index b4f49564f32..d6cd8e7ab1c 100644 --- a/backend/ee/enmedd/server/teamspace/models.py +++ b/backend/ee/enmedd/server/teamspace/models.py @@ -124,17 +124,17 @@ def from_model(cls, teamspace_model: TeamspaceModel) -> "Teamspace": class TeamspaceCreate(BaseModel): name: str user_ids: list[UUID] - cc_pair_ids: list[int] - document_set_ids: Optional[List[int]] = [] - assistant_ids: Optional[List[int]] = [] + cc_pair_ids: Optional[List[int]] = None + document_set_ids: Optional[List[int]] = None + assistant_ids: Optional[List[int]] = None workspace_id: Optional[int] = 0 class TeamspaceUpdate(BaseModel): user_ids: list[UUID] - cc_pair_ids: list[int] - document_set_ids: Optional[List[int]] = [] - assistant_ids: Optional[List[int]] = [] + cc_pair_ids: Optional[List[int]] = None + document_set_ids: Optional[List[int]] = None + assistant_ids: Optional[List[int]] = None class TeamspaceUserRole(str, Enum): From 9ac3423afb5d7c310633b76026db69efe20f4d4a Mon Sep 17 00:00:00 2001 From: Amboyandrey Date: Thu, 10 Oct 2024 20:08:33 +0800 Subject: [PATCH 5/6] feature: moving workspace_settings out of the KVstores; implement multi-teamspace feature settings --- .../24cad828e24c_settings_database_table.py | 64 +++++++++++ backend/enmedd/db/enums.py | 5 + backend/enmedd/db/models.py | 52 ++++++++- backend/enmedd/server/settings/api.py | 25 ++++- backend/enmedd/server/settings/models.py | 2 +- backend/enmedd/server/settings/store.py | 104 +++++++++++++++--- 6 files changed, 228 insertions(+), 24 deletions(-) create mode 100644 backend/alembic/versions/24cad828e24c_settings_database_table.py diff --git a/backend/alembic/versions/24cad828e24c_settings_database_table.py b/backend/alembic/versions/24cad828e24c_settings_database_table.py new file mode 100644 index 00000000000..664e32c3256 --- /dev/null +++ b/backend/alembic/versions/24cad828e24c_settings_database_table.py @@ -0,0 +1,64 @@ +"""settings database table + +Revision ID: 24cad828e24c +Revises: 4644a2459b2b +Create Date: 2024-10-10 19:55:14.464493 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "24cad828e24c" +down_revision = "4644a2459b2b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "teamspace_settings", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("chat_page_enabled", sa.Boolean(), nullable=False), + sa.Column("search_page_enabled", sa.Boolean(), nullable=False), + sa.Column( + "default_page", + sa.Enum("CHAT", "SEARCH", name="pagetype", native_enum=False), + nullable=False, + ), + sa.Column("maximum_chat_retention_days", sa.Integer(), nullable=True), + sa.Column("teamspace_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["teamspace_id"], + ["teamspace.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "workspace_settings", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("chat_page_enabled", sa.Boolean(), nullable=False), + sa.Column("search_page_enabled", sa.Boolean(), nullable=False), + sa.Column( + "default_page", + sa.Enum("CHAT", "SEARCH", name="pagetype", native_enum=False), + nullable=False, + ), + sa.Column("maximum_chat_retention_days", sa.Integer(), nullable=True), + sa.Column("workspace_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("workspace_settings") + op.drop_table("teamspace_settings") + # ### end Alembic commands ### diff --git a/backend/enmedd/db/enums.py b/backend/enmedd/db/enums.py index 1078ffaac68..6ac7840f2c1 100644 --- a/backend/enmedd/db/enums.py +++ b/backend/enmedd/db/enums.py @@ -49,3 +49,8 @@ class WorkspaceSubscriptionPlan(str, PyEnum): class InstanceSubscriptionPlan(str, PyEnum): ENTERPRISE = "enterprise" PARTNER = "partner" + + +class PageType(str, PyEnum): + CHAT = "chat" + SEARCH = "search" diff --git a/backend/enmedd/db/models.py b/backend/enmedd/db/models.py index 15f20e4f62a..41a902ee900 100644 --- a/backend/enmedd/db/models.py +++ b/backend/enmedd/db/models.py @@ -43,6 +43,7 @@ from enmedd.db.enums import IndexingStatus from enmedd.db.enums import IndexModelStatus from enmedd.db.enums import InstanceSubscriptionPlan +from enmedd.db.enums import PageType from enmedd.db.enums import TaskStatus from enmedd.db.pydantic_type import PydanticType from enmedd.dynamic_configs.interface import JSON_ro @@ -1369,12 +1370,15 @@ class Teamspace(Base): secondary="token_rate_limit__teamspace", viewonly=True, ) - + chat_folders: Mapped[list[ChatFolder]] = relationship( "ChatFolder", secondary=ChatFolder__Teamspace.__table__, viewonly=True, ) + settings: Mapped["TeamspaceSettings"] = relationship( + "TeamspaceSettings", back_populates="teamspace", viewonly=False + ) """Tables related to Token Rate Limiting @@ -1556,7 +1560,9 @@ class Workspace(Base): back_populates="workspace", viewonly=True, ) - + settings: Mapped["WorkspaceSettings"] = relationship( + "WorkspaceSettings", back_populates="workspace", viewonly=False + ) instance: Mapped["Instance"] = relationship("Instance", back_populates="workspaces") @@ -1574,3 +1580,45 @@ class Instance(Base): workspaces: Mapped[list[Workspace] | None] = relationship( "Workspace", back_populates="instance" ) + + +class WorkspaceSettings(Base): + __tablename__ = "workspace_settings" + + id: Mapped[int] = mapped_column(primary_key=True) + chat_page_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + search_page_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + default_page: Mapped[PageType] = mapped_column( + Enum(PageType, native_enum=False), default=PageType.CHAT + ) + maximum_chat_retention_days: Mapped[int | None] = mapped_column( + Integer, nullable=True + ) + + workspace_id: Mapped[int | None] = mapped_column( + ForeignKey("workspace.id"), nullable=True + ) + workspace: Mapped["Workspace"] = relationship( + "Workspace", back_populates="settings" + ) + + +class TeamspaceSettings(Base): + __tablename__ = "teamspace_settings" + + id: Mapped[int] = mapped_column(primary_key=True) + chat_page_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + search_page_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + default_page: Mapped[PageType] = mapped_column( + Enum(PageType, native_enum=False), default=PageType.CHAT + ) + maximum_chat_retention_days: Mapped[int | None] = mapped_column( + Integer, nullable=True + ) + + teamspace_id: Mapped[int | None] = mapped_column( + ForeignKey("teamspace.id"), nullable=True + ) + teamspace: Mapped["Teamspace"] = relationship( + "Teamspace", back_populates="settings" + ) diff --git a/backend/enmedd/server/settings/api.py b/backend/enmedd/server/settings/api.py index 4c0fbde59a1..5f966dd4ce4 100644 --- a/backend/enmedd/server/settings/api.py +++ b/backend/enmedd/server/settings/api.py @@ -1,11 +1,15 @@ +from typing import Optional + from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from sqlalchemy.orm import Session from enmedd.auth.users import current_admin_user from enmedd.auth.users import current_user +from enmedd.db.engine import get_session from enmedd.db.models import User -from enmedd.server.settings.models import Settings +from enmedd.server.settings.models import Setting from enmedd.server.settings.store import load_settings from enmedd.server.settings.store import store_settings @@ -16,15 +20,26 @@ @admin_router.put("") def put_settings( - settings: Settings, _: User | None = Depends(current_admin_user) + settings: Setting, + db: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), + teamspace_id: Optional[int] = None, ) -> None: try: settings.check_validity() except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - store_settings(settings) + + try: + store_settings(settings, db, teamspace_id) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @basic_router.get("") -def fetch_settings(_: User | None = Depends(current_user)) -> Settings: - return load_settings() +def fetch_settings( + db: Session = Depends(get_session), + _: User | None = Depends(current_user), + teamspace_id: Optional[int] = None, +) -> Setting: + return load_settings(db, teamspace_id) diff --git a/backend/enmedd/server/settings/models.py b/backend/enmedd/server/settings/models.py index b41bbb4936c..78b0ac9ee53 100644 --- a/backend/enmedd/server/settings/models.py +++ b/backend/enmedd/server/settings/models.py @@ -8,7 +8,7 @@ class PageType(str, Enum): SEARCH = "search" -class Settings(BaseModel): +class Setting(BaseModel): """General settings""" chat_page_enabled: bool = True diff --git a/backend/enmedd/server/settings/store.py b/backend/enmedd/server/settings/store.py index 7a3dffe142a..8fa87d9047c 100644 --- a/backend/enmedd/server/settings/store.py +++ b/backend/enmedd/server/settings/store.py @@ -1,23 +1,95 @@ -from typing import cast +from typing import Optional -from enmedd.dynamic_configs.factory import get_dynamic_config_store -from enmedd.dynamic_configs.interface import ConfigNotFoundError -from enmedd.server.settings.models import Settings +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.orm import Session -# TODO: replace the name here -_SETTINGS_KEY = "enmedd_settings" +from enmedd.db.engine import get_session +from enmedd.db.models import TeamspaceSettings +from enmedd.db.models import WorkspaceSettings +from enmedd.server.settings.models import PageType +from enmedd.server.settings.models import Setting -def load_settings() -> Settings: - dynamic_config_store = get_dynamic_config_store() - try: - settings = Settings(**cast(dict, dynamic_config_store.load(_SETTINGS_KEY))) - except ConfigNotFoundError: - settings = Settings() - dynamic_config_store.store(_SETTINGS_KEY, settings.dict()) +def load_settings( + db: Session = Depends(get_session), teamspace_id: Optional[int] = None +) -> Setting: + if teamspace_id: + settings_record = ( + db.query(TeamspaceSettings).filter_by(teamspace_id=teamspace_id).first() + ) + else: + settings_record = db.query(WorkspaceSettings).first() + + if not settings_record: + settings_record = ( + TeamspaceSettings( + chat_page_enabled=True, + search_page_enabled=True, + default_page=PageType.CHAT, + maximum_chat_retention_days=None, + teamspace_id=teamspace_id, + ) + if teamspace_id + else WorkspaceSettings( + chat_page_enabled=True, + search_page_enabled=True, + default_page=PageType.CHAT, + maximum_chat_retention_days=None, + ) + ) + db.add(settings_record) + db.commit() + db.refresh(settings_record) + + return Setting( + chat_page_enabled=settings_record.chat_page_enabled, + search_page_enabled=settings_record.search_page_enabled, + default_page=settings_record.default_page, + maximum_chat_retention_days=settings_record.maximum_chat_retention_days, + ) - return settings +def store_settings( + settings: Setting, + db: Session = Depends(get_session), + teamspace_id: Optional[int] = None, +) -> None: + if teamspace_id: + settings_record = ( + db.query(TeamspaceSettings).filter_by(teamspace_id=teamspace_id).first() + ) + else: + settings_record = db.query(WorkspaceSettings).first() -def store_settings(settings: Settings) -> None: - get_dynamic_config_store().store(_SETTINGS_KEY, settings.dict()) + if settings_record: + settings_record.chat_page_enabled = settings.chat_page_enabled + settings_record.search_page_enabled = settings.search_page_enabled + settings_record.default_page = settings.default_page + settings_record.maximum_chat_retention_days = ( + settings.maximum_chat_retention_days + ) + else: + new_record = ( + TeamspaceSettings( + chat_page_enabled=settings.chat_page_enabled, + search_page_enabled=settings.search_page_enabled, + default_page=settings.default_page, + maximum_chat_retention_days=settings.maximum_chat_retention_days, + teamspace_id=teamspace_id, + ) + if teamspace_id + else WorkspaceSettings( + chat_page_enabled=settings.chat_page_enabled, + search_page_enabled=settings.search_page_enabled, + default_page=settings.default_page, + maximum_chat_retention_days=settings.maximum_chat_retention_days, + ) + ) + db.add(new_record) + + try: + db.commit() + except Exception: + db.rollback() + raise HTTPException(status_code=500, detail="Failed to store settings.") From b21f4760deea12ad271cc740fe241cd285947563 Mon Sep 17 00:00:00 2001 From: Amboyandrey Date: Fri, 11 Oct 2024 10:55:30 +0800 Subject: [PATCH 6/6] feature: implement teamspace profile on the backend --- backend/ee/enmedd/main.py | 4 +- backend/ee/enmedd/server/teamspace/api.py | 71 +++++++++++++++++++-- backend/ee/enmedd/server/workspace/store.py | 45 +++++++++++++ 3 files changed, 112 insertions(+), 8 deletions(-) diff --git a/backend/ee/enmedd/main.py b/backend/ee/enmedd/main.py index 09c6e2367d3..acb7c6787e3 100644 --- a/backend/ee/enmedd/main.py +++ b/backend/ee/enmedd/main.py @@ -15,7 +15,8 @@ from ee.enmedd.server.reporting.usage_export_api import router as usage_export_router from ee.enmedd.server.saml import router as saml_router from ee.enmedd.server.seeding import seed_db -from ee.enmedd.server.teamspace.api import router as teamspace_router +from ee.enmedd.server.teamspace.api import admin_router as teamspace_admin_router +from ee.enmedd.server.teamspace.api import basic_router as teamspace_router from ee.enmedd.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) @@ -88,6 +89,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, chat_router) # Enterprise-only global settings include_router_with_global_prefix_prepended(application, workspaces_admin_router) + include_router_with_global_prefix_prepended(application, teamspace_admin_router) # Token rate limit settings include_router_with_global_prefix_prepended( application, token_rate_limit_settings_router diff --git a/backend/ee/enmedd/server/teamspace/api.py b/backend/ee/enmedd/server/teamspace/api.py index bd60bcc77fa..7c46a1fc762 100644 --- a/backend/ee/enmedd/server/teamspace/api.py +++ b/backend/ee/enmedd/server/teamspace/api.py @@ -1,6 +1,8 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Response +from fastapi import UploadFile from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -14,16 +16,24 @@ from ee.enmedd.server.teamspace.models import TeamspaceUpdate from ee.enmedd.server.teamspace.models import TeamspaceUserRole from ee.enmedd.server.teamspace.models import UpdateUserRoleRequest +from ee.enmedd.server.workspace.store import _LOGO_FILENAME +from ee.enmedd.server.workspace.store import upload_teamspace_logo from enmedd.auth.users import current_admin_user +from enmedd.auth.users import current_user from enmedd.db.engine import get_session from enmedd.db.models import User from enmedd.db.models import User__Teamspace from enmedd.db.users import get_user_by_email +from enmedd.file_store.file_store import get_default_file_store +from enmedd.utils.logger import setup_logger -router = APIRouter(prefix="/manage") +logger = setup_logger() +admin_router = APIRouter(prefix="/manage") +basic_router = APIRouter(prefix="/teamspace") -@router.get("/admin/teamspace/{teamspace_id}") + +@admin_router.get("/admin/teamspace/{teamspace_id}") def get_teamspace_by_id( teamspace_id: int, _: User = Depends(current_admin_user), @@ -37,7 +47,7 @@ def get_teamspace_by_id( return Teamspace.from_model(db_teamspace) -@router.get("/admin/teamspace") +@admin_router.get("/admin/teamspace") def list_teamspaces( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), @@ -46,7 +56,7 @@ def list_teamspaces( return [Teamspace.from_model(teamspace) for teamspace in teamspaces] -@router.post("/admin/teamspace") +@admin_router.post("/admin/teamspace") def create_teamspace( teamspace: TeamspaceCreate, current_user: User = Depends(current_admin_user), @@ -65,7 +75,7 @@ def create_teamspace( return Teamspace.from_model(db_teamspace) -@router.patch("/admin/teamspace/{teamspace_id}") +@admin_router.patch("/admin/teamspace/{teamspace_id}") def patch_teamspace( teamspace_id: int, teamspace: TeamspaceUpdate, @@ -80,7 +90,7 @@ def patch_teamspace( raise HTTPException(status_code=404, detail=str(e)) -@router.delete("/admin/teamspace/{teamspace_id}") +@admin_router.delete("/admin/teamspace/{teamspace_id}") def delete_teamspace( teamspace_id: int, _: User = Depends(current_admin_user), @@ -92,7 +102,7 @@ def delete_teamspace( raise HTTPException(status_code=404, detail=str(e)) -@router.patch("/admin/teamspace/user-role/{teamspace_id}") +@admin_router.patch("/admin/teamspace/user-role/{teamspace_id}") def update_teamspace_user_role( teamspace_id: int, body: UpdateUserRoleRequest, @@ -129,3 +139,50 @@ def update_teamspace_user_role( return { "message": f"User role updated to {body.new_role.value} for {body.user_email}" } + + +@admin_router.put("/admin/teamspace/logo") +def put_teamspace_logo( + teamspace_id: int, + file: UploadFile, + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + upload_teamspace_logo(teamspace_id=teamspace_id, file=file, db_session=db_session) + + +@admin_router.delete("/admin/teamspace/logo") +def remove_teamspace_logo( + teamspace_id: int, + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + try: + file_name = f"{teamspace_id}/{_LOGO_FILENAME}" + + file_store = get_default_file_store(db_session) + file_store.delete_file(file_name) + + return {"detail": "Teamspace logo removed successfully."} + except Exception as e: + logger.error(f"Error removing teamspace logo: {str(e)}") + raise HTTPException(status_code=404, detail="Teamspace logo not found.") + + +@basic_router.get("/logo") +def fetch_teamspace_logo( + teamspace_id: int, + _: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> Response: + try: + file_path = f"{teamspace_id}/{_LOGO_FILENAME}" + + file_store = get_default_file_store(db_session) + file_io = file_store.read_file(file_path, mode="b") + + return Response(content=file_io.read(), media_type="image/jpeg") + except Exception: + raise HTTPException( + status_code=404, detail="No logo file found for the teamspace" + ) diff --git a/backend/ee/enmedd/server/workspace/store.py b/backend/ee/enmedd/server/workspace/store.py index a99850017ba..6cc45e675bc 100644 --- a/backend/ee/enmedd/server/workspace/store.py +++ b/backend/ee/enmedd/server/workspace/store.py @@ -143,3 +143,48 @@ def upload_profile(db_session: Session, file: UploadFile | str, user: User) -> b file_type=file_type, ) return True + + +def upload_teamspace_logo( + db_session: Session, + teamspace_id: int, + file: UploadFile | str, +) -> bool: + content: IO[Any] + + if isinstance(file, str): + logger.info(f"Uploading teamspace logo from local path {file}") + if not os.path.isfile(file) or not is_valid_file_type(file): + logger.error( + "Invalid file type - only .png, .jpg, and .jpeg files are allowed" + ) + return False + + with open(file, "rb") as file_handle: + file_content = file_handle.read() + content = BytesIO(file_content) + display_name = file + file_type = guess_file_type(file) + + else: + logger.info("Uploading teamspace logo from uploaded file") + if not file.filename or not is_valid_file_type(file.filename): + raise HTTPException( + status_code=400, + detail="Invalid file type - only .png, .jpg, and .jpeg files are allowed", + ) + content = file.file + display_name = file.filename + file_type = file.content_type or "image/jpeg" + + file_name = f"{teamspace_id}/{_LOGO_FILENAME}" + + file_store = get_default_file_store(db_session) + file_store.save_file( + file_name=file_name, + content=content, + display_name=display_name, + file_origin=FileOrigin.OTHER, + file_type=file_type, + ) + return True