Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: assign user roles on teamspace_user relationship table; assign chatsession to its teamspace #138

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""user-teamspace relationship table update

Revision ID: a912ff7ae8c5
Revises: ea802eba7d25
Create Date: 2024-09-23 12:50:33.508639

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "a912ff7ae8c5"
down_revision = "ea802eba7d25"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user__teamspace",
sa.Column(
"role",
sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
nullable=False,
),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user__teamspace", "role")
# ### end Alembic commands ###
39 changes: 36 additions & 3 deletions backend/ee/enmedd/db/teamspace.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence
from operator import and_
from typing import Optional
from uuid import UUID

from sqlalchemy import func
Expand All @@ -8,6 +9,7 @@

from ee.enmedd.server.teamspace.models import TeamspaceCreate
from ee.enmedd.server.teamspace.models import TeamspaceUpdate
from enmedd.auth.schemas import UserRole
from enmedd.db.models import Assistant__Teamspace
from enmedd.db.models import ConnectorCredentialPair
from enmedd.db.models import Document
Expand All @@ -18,6 +20,7 @@
from enmedd.db.models import TokenRateLimit__Teamspace
from enmedd.db.models import User
from enmedd.db.models import User__Teamspace
from enmedd.db.models import Workspace__Teamspace
from enmedd.server.documents.models import ConnectorCredentialPairIdentifier


Expand Down Expand Up @@ -140,11 +143,24 @@ def _check_teamspace_is_modifiable(teamspace: Teamspace) -> None:


def _add_user__teamspace_relationships__no_commit(
db_session: Session, teamspace_id: int, user_ids: list[UUID]
db_session: Session,
teamspace_id: int,
user_ids: list[UUID],
creator_id: Optional[UUID],
) -> list[User__Teamspace]:
"""NOTE: does not commit the transaction."""
if creator_id is None:
return []

if user_ids and creator_id not in user_ids:
user_ids.append(creator_id)

relationships = [
User__Teamspace(user_id=user_id, teamspace_id=teamspace_id)
User__Teamspace(
user_id=user_id,
teamspace_id=teamspace_id,
role=UserRole.ADMIN if user_id == creator_id else UserRole.BASIC,
)
for user_id in user_ids
]
db_session.add_all(relationships)
Expand Down Expand Up @@ -191,7 +207,20 @@ def _add_teamspace__assistant_relationships__no_commit(
return relationships


def insert_teamspace(db_session: Session, teamspace: TeamspaceCreate) -> Teamspace:
def _add_workspace__teamspace_relationship(
db_session: Session, workspace_id: int, teamspace_id: int
) -> Workspace__Teamspace:
relationship = Workspace__Teamspace(
workspace_id=workspace_id,
teamspace_id=teamspace_id,
)
db_session.add(relationship)
return relationship


def insert_teamspace(
db_session: Session, teamspace: TeamspaceCreate, creator_id: UUID
) -> Teamspace:
db_teamspace = Teamspace(name=teamspace.name)
db_session.add(db_teamspace)
db_session.flush() # give the group an ID
Expand All @@ -200,6 +229,7 @@ def insert_teamspace(db_session: Session, teamspace: TeamspaceCreate) -> Teamspa
db_session=db_session,
teamspace_id=db_teamspace.id,
user_ids=teamspace.user_ids,
creator_id=creator_id,
)
_add_teamspace__document_set_relationships__no_commit(
db_session=db_session,
Expand All @@ -216,6 +246,9 @@ def insert_teamspace(db_session: Session, teamspace: TeamspaceCreate) -> Teamspa
teamspace_id=db_teamspace.id,
cc_pair_ids=teamspace.cc_pair_ids,
)
_add_workspace__teamspace_relationship(
db_session, teamspace.workspace_id, db_teamspace.id
)

db_session.commit()
return db_teamspace
Expand Down
6 changes: 4 additions & 2 deletions backend/ee/enmedd/server/teamspace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def list_teamspaces(
@router.post("/admin/teamspace")
def create_teamspace(
teamspace: TeamspaceCreate,
_: User = Depends(current_admin_user),
current_user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Teamspace:
try:
db_teamspace = insert_teamspace(db_session, teamspace)
db_teamspace = insert_teamspace(
db_session, teamspace, creator_id=current_user.id
)
except IntegrityError:
raise HTTPException(
400,
Expand Down
19 changes: 18 additions & 1 deletion backend/ee/enmedd/server/teamspace/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from enmedd.server.features.document_set.models import DocumentSet
from enmedd.server.manage.models import UserInfo
from enmedd.server.manage.models import UserPreferences
from enmedd.server.models import MinimalTeamspaceSnapshot
from enmedd.server.models import MinimalWorkspaceSnapshot
from enmedd.server.query_and_chat.models import ChatSessionDetails

Expand Down Expand Up @@ -63,6 +64,20 @@ def from_model(cls, teamspace_model: TeamspaceModel) -> "Teamspace":
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
groups=[
MinimalTeamspaceSnapshot(
id=group.id,
name=group.name,
workspace=[
MinimalWorkspaceSnapshot(
id=workspace.id,
workspace_name=workspace.workspace_name,
)
for workspace in group.workspace
],
)
for group in cc_pair_relationship.cc_pair.groups
],
)
for cc_pair_relationship in teamspace_model.cc_pair_relationships
if cc_pair_relationship.is_current
Expand All @@ -77,9 +92,10 @@ def from_model(cls, teamspace_model: TeamspaceModel) -> "Teamspace":
chat_sessions=[
ChatSessionDetails(
id=chat_session.id,
name=chat_session.description,
description=chat_session.description,
assistant_id=chat_session.assistant_id,
time_created=chat_session.time_created,
time_created=chat_session.time_created.isoformat(),
shared_status=chat_session.shared_status,
folder_id=chat_session.folder_id,
current_alternate_model=chat_session.current_alternate_model,
Expand All @@ -103,6 +119,7 @@ class TeamspaceCreate(BaseModel):
cc_pair_ids: list[int]
document_set_ids: Optional[List[int]] = []
assistant_ids: Optional[List[int]] = []
workspace_id: Optional[int] = 0


class TeamspaceUpdate(BaseModel):
Expand Down
12 changes: 12 additions & 0 deletions backend/enmedd/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from enmedd.db.models import ChatMessage
from enmedd.db.models import ChatMessage__SearchDoc
from enmedd.db.models import ChatSession
from enmedd.db.models import ChatSession__Teamspace
from enmedd.db.models import ChatSessionSharedStatus
from enmedd.db.models import Prompt
from enmedd.db.models import SearchDoc
Expand Down Expand Up @@ -141,10 +142,12 @@ def create_chat_session(
description: str,
user_id: UUID | None,
assistant_id: int | None = None,
teamspace_id: int | None = None, # Include teamspace_id as a parameter
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
) -> ChatSession:
# Create the chat session
chat_session = ChatSession(
user_id=user_id,
assistant_id=assistant_id,
Expand All @@ -154,9 +157,18 @@ def create_chat_session(
one_shot=one_shot,
)

# Add the chat session to the session and commit it to generate an ID
db_session.add(chat_session)
db_session.commit()

# Add the relationship to the ChatSession__Teamspace table if teamspace_id is provided
if teamspace_id:
chat_session_teamspace_relationship = ChatSession__Teamspace(
chat_session_id=chat_session.id, teamspace_id=teamspace_id
)
db_session.add(chat_session_teamspace_relationship)
db_session.commit()

return chat_session


Expand Down
4 changes: 4 additions & 0 deletions backend/enmedd/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,10 @@ class User__Teamspace(Base):
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)

role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)


class Teamspace__ConnectorCredentialPair(Base):
__tablename__ = "teamspace__connector_credential_pair"
Expand Down
17 changes: 17 additions & 0 deletions backend/enmedd/server/documents/cc_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from enmedd.server.documents.models import CCPairFullInfo
from enmedd.server.documents.models import ConnectorCredentialPairIdentifier
from enmedd.server.documents.models import ConnectorCredentialPairMetadata
from enmedd.server.models import MinimalTeamspaceSnapshot
from enmedd.server.models import MinimalWorkspaceSnapshot
from enmedd.server.models import StatusResponse

router = APIRouter(prefix="/manage")
Expand Down Expand Up @@ -64,11 +66,26 @@ def get_cc_pair_full_info(
db_session=db_session,
)

groups = [
MinimalTeamspaceSnapshot(
id=group.id,
name=group.name,
workspace=[
MinimalWorkspaceSnapshot(
id=workspace.id, workspace_name=workspace.workspace_name
)
for workspace in group.workspace
],
)
for group in cc_pair.groups
]

return CCPairFullInfo.from_models(
cc_pair_model=cc_pair,
index_attempt_models=list(index_attempts),
latest_deletion_attempt=latest_deletion_attempt,
num_docs_indexed=documents_indexed,
groups=groups,
)


Expand Down
7 changes: 5 additions & 2 deletions backend/enmedd/server/documents/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import Any
from typing import Optional
from uuid import UUID

from pydantic import BaseModel
Expand Down Expand Up @@ -132,7 +133,7 @@ class CCPairFullInfo(BaseModel):
credential: CredentialSnapshot
index_attempts: list[IndexAttemptSnapshot]
latest_deletion_attempt: DeletionAttemptSnapshot | None
groups: list[MinimalTeamspaceSnapshot] | None
groups: list[MinimalTeamspaceSnapshot] = []

@classmethod
def from_models(
Expand All @@ -141,6 +142,7 @@ def from_models(
index_attempt_models: list[IndexAttempt],
latest_deletion_attempt: DeletionAttemptSnapshot | None,
num_docs_indexed: int, # not ideal, but this must be computed separately
groups: list[MinimalTeamspaceSnapshot] = [],
) -> "CCPairFullInfo":
return cls(
id=cc_pair_model.id,
Expand All @@ -157,6 +159,7 @@ def from_models(
for index_attempt_model in index_attempt_models
],
latest_deletion_attempt=latest_deletion_attempt,
groups=groups,
)


Expand Down Expand Up @@ -193,7 +196,7 @@ class ConnectorCredentialPairDescriptor(BaseModel):
name: str | None
connector: ConnectorSnapshot
credential: CredentialSnapshot
groups: list[MinimalTeamspaceSnapshot] | None
groups: Optional[list[MinimalTeamspaceSnapshot]] = []


class RunConnectorRequest(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions backend/enmedd/server/feature_flags/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ class FeatureFlags(BaseModel):
"""Features Control"""

profile_page: bool = False
multi_teamspace: bool = True
multi_teamspace: bool = False
multi_workspace: bool = False
query_history: bool = False
whitelabelling: bool = True
whitelabelling: bool = False
share_chat: bool = False
explore_assistants: bool = False

Expand Down
16 changes: 15 additions & 1 deletion backend/enmedd/server/features/document_set/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from enmedd.server.features.document_set.models import DocumentSet
from enmedd.server.features.document_set.models import DocumentSetCreationRequest
from enmedd.server.features.document_set.models import DocumentSetUpdateRequest
from enmedd.server.models import MinimalTeamspaceSnapshot
from enmedd.server.models import MinimalWorkspaceSnapshot


router = APIRouter(prefix="/manage")
Expand Down Expand Up @@ -116,7 +118,19 @@ def list_document_sets(
is_up_to_date=document_set_db_model.is_up_to_date,
is_public=document_set_db_model.is_public,
users=[user.id for user in document_set_db_model.users],
groups=[group.id for group in document_set_db_model.groups],
groups=[
MinimalTeamspaceSnapshot(
id=group.id,
name=group.name,
workspace=[
MinimalWorkspaceSnapshot(
id=workspace.id, workspace_name=workspace.workspace_name
)
for workspace in group.workspace
],
)
for group in document_set_db_model.groups
],
)
for document_set_db_model, cc_pairs in document_set_info
]
Expand Down
17 changes: 16 additions & 1 deletion backend/enmedd/server/features/document_set/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from uuid import UUID

from pydantic import BaseModel
Expand Down Expand Up @@ -52,7 +53,7 @@ class DocumentSet(BaseModel):
is_public: bool
# For Private Document Sets, who should be able to access these
users: list[UUID]
groups: list[MinimalTeamspaceSnapshot]
groups: Optional[list[MinimalTeamspaceSnapshot]]

@classmethod
def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
Expand All @@ -76,6 +77,20 @@ def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential
),
groups=[
MinimalTeamspaceSnapshot(
id=group.id,
name=group.name,
workspace=[
MinimalWorkspaceSnapshot(
id=workspace.id,
workspace_name=workspace.workspace_name,
)
for workspace in group.workspace
],
)
for group in cc_pair.groups
],
)
for cc_pair in document_set_model.connector_credential_pairs
],
Expand Down
Loading
Loading