Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/multi teamspace/workspace settings #176

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions backend/alembic/versions/24cad828e24c_settings_database_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""settings database table

Revision ID: 24cad828e24c
Revises: 4644a2459b2b
Create Date: 2024-10-10 19:55:14.464493

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "24cad828e24c"
down_revision = "4644a2459b2b"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"teamspace_settings",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("chat_page_enabled", sa.Boolean(), nullable=False),
sa.Column("search_page_enabled", sa.Boolean(), nullable=False),
sa.Column(
"default_page",
sa.Enum("CHAT", "SEARCH", name="pagetype", native_enum=False),
nullable=False,
),
sa.Column("maximum_chat_retention_days", sa.Integer(), nullable=True),
sa.Column("teamspace_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["teamspace_id"],
["teamspace.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"workspace_settings",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("chat_page_enabled", sa.Boolean(), nullable=False),
sa.Column("search_page_enabled", sa.Boolean(), nullable=False),
sa.Column(
"default_page",
sa.Enum("CHAT", "SEARCH", name="pagetype", native_enum=False),
nullable=False,
),
sa.Column("maximum_chat_retention_days", sa.Integer(), nullable=True),
sa.Column("workspace_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["workspace_id"],
["workspace.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("workspace_settings")
op.drop_table("teamspace_settings")
# ### end Alembic commands ###
169 changes: 122 additions & 47 deletions backend/ee/enmedd/db/teamspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional
from uuid import UUID

from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -146,11 +147,9 @@ def _add_user__teamspace_relationships__no_commit(
db_session: Session,
teamspace_id: int,
user_ids: list[UUID],
creator_id: Optional[UUID],
creator_id: Optional[UUID] = None,
) -> list[User__Teamspace]:
"""NOTE: does not commit the transaction."""
if creator_id is None:
return []

# if creator_id not in user_ids:
# user_ids.append(creator_id)
Expand Down Expand Up @@ -306,68 +305,138 @@ def _mark_teamspace__assistant_relationships_outdated__no_commit(
teamspace__assistant_relationship.is_current = False


def update_teamspace(
db_session: Session, teamspace_id: int, teamspace: TeamspaceUpdate
) -> Teamspace:
stmt = select(Teamspace).where(Teamspace.id == teamspace_id)
db_teamspace = db_session.scalar(stmt)
if db_teamspace is None:
raise ValueError(f"Teamspace with id '{teamspace_id}' not found")

_check_teamspace_is_modifiable(db_teamspace)

cc_pairs_updated = set([cc_pair.id for cc_pair in db_teamspace.cc_pairs]) != set(
teamspace.cc_pair_ids
def _sync_relationships(
db_session: Session,
teamspace_id: int,
updated_user_ids: list[UUID],
updated_cc_pair_ids: list[int],
updated_document_set_ids: list[int],
updated_assistant_ids: list[int],
) -> None:
current_user_ids = set(
db_session.scalars(
select(User__Teamspace.user_id).where(
User__Teamspace.teamspace_id == teamspace_id
)
).all()
)
current_cc_pair_ids = set(
db_session.scalars(
select(Teamspace__ConnectorCredentialPair.cc_pair_id).where(
Teamspace__ConnectorCredentialPair.teamspace_id == teamspace_id
)
).all()
)
current_document_set_ids = set(
db_session.scalars(
select(DocumentSet__Teamspace.document_set_id).where(
DocumentSet__Teamspace.teamspace_id == teamspace_id
)
).all()
)
document_set_updated = set(
[document_set.id for document_set in db_teamspace.document_sets]
) != set(teamspace.document_set_ids)
assistant_updated = set(
[assistant.id for assistant in db_teamspace.assistants]
) != set(teamspace.assistant_ids)
users_updated = set([user.id for user in db_teamspace.users]) != set(
teamspace.user_ids
current_assistant_ids = set(
db_session.scalars(
select(Assistant__Teamspace.assistant_id).where(
Assistant__Teamspace.teamspace_id == teamspace_id
)
).all()
)

if users_updated:
_cleanup_user__teamspace_relationships__no_commit(
db_session=db_session, teamspace_id=teamspace_id
users_to_delete = current_user_ids - set(updated_user_ids)
users_to_add = set(updated_user_ids) - current_user_ids

cc_pairs_to_delete = current_cc_pair_ids - set(updated_cc_pair_ids)
cc_pairs_to_add = set(updated_cc_pair_ids) - current_cc_pair_ids

document_sets_to_delete = current_document_set_ids - set(updated_document_set_ids)
document_sets_to_add = set(updated_document_set_ids) - current_document_set_ids

assistants_to_delete = current_assistant_ids - set(updated_assistant_ids)
assistants_to_add = set(updated_assistant_ids) - current_assistant_ids

if users_to_delete:
db_session.execute(
delete(User__Teamspace)
.where(User__Teamspace.teamspace_id == teamspace_id)
.where(User__Teamspace.user_id.in_(users_to_delete))
)

if cc_pairs_to_delete:
db_session.execute(
delete(Teamspace__ConnectorCredentialPair)
.where(Teamspace__ConnectorCredentialPair.teamspace_id == teamspace_id)
.where(
Teamspace__ConnectorCredentialPair.cc_pair_id.in_(cc_pairs_to_delete)
)
)

if document_sets_to_delete:
db_session.execute(
delete(DocumentSet__Teamspace)
.where(DocumentSet__Teamspace.teamspace_id == teamspace_id)
.where(DocumentSet__Teamspace.document_set_id.in_(document_sets_to_delete))
)

if assistants_to_delete:
db_session.execute(
delete(Assistant__Teamspace)
.where(Assistant__Teamspace.teamspace_id == teamspace_id)
.where(Assistant__Teamspace.assistant_id.in_(assistants_to_delete))
)

if users_to_add:
_add_user__teamspace_relationships__no_commit(
db_session=db_session,
teamspace_id=teamspace_id,
user_ids=teamspace.user_ids,
)
if cc_pairs_updated:
_mark_teamspace__cc_pair_relationships_outdated__no_commit(
db_session=db_session, teamspace_id=teamspace_id
user_ids=list(users_to_add),
)

if cc_pairs_to_add:
_add_teamspace__cc_pair_relationships__no_commit(
db_session=db_session,
teamspace_id=db_teamspace.id,
cc_pair_ids=teamspace.cc_pair_ids,
)
if document_set_updated:
_mark_teamspace__document_set_relationships_outdated__no_commit(
db_session=db_session, teamspace_id=teamspace_id
teamspace_id=teamspace_id,
cc_pair_ids=list(cc_pairs_to_add),
)

if document_sets_to_add:
_add_teamspace__document_set_relationships__no_commit(
db_session=db_session,
teamspace_id=db_teamspace.id,
document_set_id=teamspace.document_set_ids,
)
if assistant_updated:
_mark_teamspace__assistant_relationships_outdated__no_commit(
db_session=db_session, teamspace_id=teamspace_id
teamspace_id=teamspace_id,
document_set_ids=list(document_sets_to_add),
)

if assistants_to_add:
_add_teamspace__assistant_relationships__no_commit(
db_session=db_session,
teamspace_id=db_teamspace.id,
assistant_id=teamspace.assistant_ids,
teamspace_id=teamspace_id,
assistant_ids=list(assistants_to_add),
)

# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_session.commit()


def update_teamspace(
db_session: Session, teamspace_id: int, teamspace: TeamspaceUpdate
) -> Teamspace:
stmt = select(Teamspace).where(Teamspace.id == teamspace_id)
db_teamspace = db_session.scalar(stmt)
if db_teamspace is None:
raise ValueError(f"Teamspace with id '{teamspace_id}' not found")

_check_teamspace_is_modifiable(db_teamspace)

_sync_relationships(
db_session=db_session,
teamspace_id=teamspace_id,
updated_user_ids=teamspace.user_ids,
updated_cc_pair_ids=teamspace.cc_pair_ids,
updated_document_set_ids=teamspace.document_set_ids,
updated_assistant_ids=teamspace.assistant_ids,
)

if set([cc_pair.id for cc_pair in db_teamspace.cc_pairs]) != set(
teamspace.cc_pair_ids
):
db_teamspace.is_up_to_date = False

db_session.commit()
Expand Down Expand Up @@ -403,6 +472,12 @@ def prepare_teamspace_for_deletion(db_session: Session, teamspace_id: int) -> No
_mark_teamspace__cc_pair_relationships_outdated__no_commit(
db_session=db_session, teamspace_id=teamspace_id
)
_mark_teamspace__document_set_relationships_outdated__no_commit(
db_session=db_session, teamspace_id=teamspace_id
)
_mark_teamspace__assistant_relationships_outdated__no_commit(
db_session=db_session, teamspace_id=teamspace_id
)
_cleanup_token_rate_limit__teamspace_relationships__no_commit(
db_session=db_session, teamspace_id=teamspace_id
)
Expand Down
4 changes: 3 additions & 1 deletion backend/ee/enmedd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from ee.enmedd.server.reporting.usage_export_api import router as usage_export_router
from ee.enmedd.server.saml import router as saml_router
from ee.enmedd.server.seeding import seed_db
from ee.enmedd.server.teamspace.api import router as teamspace_router
from ee.enmedd.server.teamspace.api import admin_router as teamspace_admin_router
from ee.enmedd.server.teamspace.api import basic_router as teamspace_router
from ee.enmedd.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
Expand Down Expand Up @@ -88,6 +89,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, chat_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(application, workspaces_admin_router)
include_router_with_global_prefix_prepended(application, teamspace_admin_router)
# Token rate limit settings
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
Expand Down
Loading
Loading