Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/token rate limits and update teamspace user roles #167

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""chat_folder__teamspace relationship table

Revision ID: 4644a2459b2b
Revises: 8487a94e1e38
Create Date: 2024-10-01 20:50:16.637771

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "4644a2459b2b"
down_revision = "8487a94e1e38"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"chat_folder__teamspace",
sa.Column("chat_folder_id", sa.Integer(), nullable=False),
sa.Column("teamspace_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["chat_folder_id"],
["chat_folder.id"],
),
sa.ForeignKeyConstraint(
["teamspace_id"],
["teamspace.id"],
),
sa.PrimaryKeyConstraint("chat_folder_id", "teamspace_id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("chat_folder__teamspace")
# ### end Alembic commands ###
6 changes: 4 additions & 2 deletions backend/ee/enmedd/db/teamspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ee.enmedd.server.teamspace.models import TeamspaceCreate
from ee.enmedd.server.teamspace.models import TeamspaceUpdate
from enmedd.auth.schemas import UserRole
from ee.enmedd.server.teamspace.models import TeamspaceUserRole
from enmedd.db.models import Assistant__Teamspace
from enmedd.db.models import ConnectorCredentialPair
from enmedd.db.models import Document
Expand Down Expand Up @@ -159,7 +159,9 @@ def _add_user__teamspace_relationships__no_commit(
User__Teamspace(
user_id=user_id,
teamspace_id=teamspace_id,
role=UserRole.ADMIN if user_id == creator_id else UserRole.BASIC,
role=TeamspaceUserRole.ADMIN
if user_id == creator_id
else TeamspaceUserRole.BASIC,
)
for user_id in user_ids
]
Expand Down
43 changes: 43 additions & 0 deletions backend/ee/enmedd/server/teamspace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
from ee.enmedd.server.teamspace.models import Teamspace
from ee.enmedd.server.teamspace.models import TeamspaceCreate
from ee.enmedd.server.teamspace.models import TeamspaceUpdate
from ee.enmedd.server.teamspace.models import TeamspaceUserRole
from ee.enmedd.server.teamspace.models import UpdateUserRoleRequest
from enmedd.auth.users import current_admin_user
from enmedd.db.engine import get_session
from enmedd.db.models import User
from enmedd.db.models import User__Teamspace
from enmedd.db.users import get_user_by_email

router = APIRouter(prefix="/manage")

Expand Down Expand Up @@ -86,3 +90,42 @@ def delete_teamspace(
prepare_teamspace_for_deletion(db_session, teamspace_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))


@router.patch("/admin/teamspace/user-role/{teamspace_id}")
def update_teamspace_user_role(
teamspace_id: int,
body: UpdateUserRoleRequest,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
user_to_update = get_user_by_email(email=body.user_email, db_session=db_session)
if not user_to_update:
raise HTTPException(status_code=404, detail="User not found")

if user_to_update.id == user.id and body.new_role != TeamspaceUserRole.ADMIN:
raise HTTPException(
status_code=400, detail="Cannot demote yourself from admin role!"
)

user_teamspace = (
db_session.query(User__Teamspace)
.filter(
User__Teamspace.user_id == user_to_update.id,
User__Teamspace.teamspace_id == teamspace_id,
)
.first()
)

if not user_teamspace:
raise HTTPException(
status_code=404, detail="User-Teamspace relationship not found"
)

user_teamspace.role = body.new_role

db_session.commit()

return {
"message": f"User role updated to {body.new_role.value} for {body.user_email}"
}
19 changes: 19 additions & 0 deletions backend/ee/enmedd/server/teamspace/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import List
from typing import Optional
from uuid import UUID
Expand All @@ -15,6 +16,7 @@
from enmedd.server.models import MinimalTeamspaceSnapshot
from enmedd.server.models import MinimalWorkspaceSnapshot
from enmedd.server.query_and_chat.models import ChatSessionDetails
from enmedd.server.token_rate_limits.models import TokenRateLimitDisplay


class Teamspace(BaseModel):
Expand All @@ -28,6 +30,7 @@ class Teamspace(BaseModel):
is_up_to_date: bool
is_up_for_deletion: bool
workspace: list[MinimalWorkspaceSnapshot]
token_rate_limit: Optional[TokenRateLimitDisplay] = None

@classmethod
def from_model(cls, teamspace_model: TeamspaceModel) -> "Teamspace":
Expand Down Expand Up @@ -110,6 +113,11 @@ def from_model(cls, teamspace_model: TeamspaceModel) -> "Teamspace":
)
for workspace in teamspace_model.workspace
],
token_rate_limit=(
TokenRateLimitDisplay.from_db(teamspace_model.token_rate_limit)
if teamspace_model.token_rate_limit is not None
else None
),
)


Expand All @@ -127,3 +135,14 @@ class TeamspaceUpdate(BaseModel):
cc_pair_ids: list[int]
document_set_ids: Optional[List[int]] = []
assistant_ids: Optional[List[int]] = []


class TeamspaceUserRole(str, Enum):
BASIC = "basic"
CREATOR = "creator"
ADMIN = "admin"


class UpdateUserRoleRequest(BaseModel):
user_email: str
new_role: TeamspaceUserRole
36 changes: 36 additions & 0 deletions backend/enmedd/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from sqlalchemy.orm import Session

from enmedd.background.indexing.checkpointing import get_time_windows_for_index_attempt
from enmedd.background.indexing.tracer import EnmeddTracer
from enmedd.configs.app_configs import INDEXING_TRACER_INTERVAL
from enmedd.configs.app_configs import POLL_CONNECTOR_OFFSET
from enmedd.connectors.factory import instantiate_connector
from enmedd.connectors.interfaces import GenerateDocumentsOutput
Expand Down Expand Up @@ -35,6 +37,8 @@

logger = setup_logger()

INDEXING_TRACER_NUM_PRINT_ENTRIES = 5


def _get_document_generator(
db_session: Session,
Expand Down Expand Up @@ -139,6 +143,12 @@ def _run_indexing(
)
)

if INDEXING_TRACER_INTERVAL > 0:
logger.info(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = EnmeddTracer()
tracer.start()
tracer.snap()

net_doc_change = 0
document_count = 0
chunk_count = 0
Expand All @@ -165,6 +175,10 @@ def _run_indexing(
)

all_connector_doc_ids: set[str] = set()

tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
Expand Down Expand Up @@ -215,6 +229,17 @@ def _run_indexing(
docs_removed_from_index=0,
)

tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.info(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)

run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
Expand Down Expand Up @@ -253,12 +278,23 @@ def _run_indexing(
credential_id=index_attempt.credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e

# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break

if INDEXING_TRACER_INTERVAL > 0:
logger.info(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.info("Memory tracer stopped.")

mark_attempt_succeeded(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(
Expand Down
77 changes: 77 additions & 0 deletions backend/enmedd/background/indexing/tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import tracemalloc

from enmedd.utils.logger import setup_logger

logger = setup_logger()

ENMEDD_TRACEMALLOC_FRAMES = 10


class EnmeddTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None

def start(self) -> None:
tracemalloc.start(ENMEDD_TRACEMALLOC_FRAMES)

def stop(self) -> None:
tracemalloc.stop()

def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)

if not self.snapshot_first:
self.snapshot_first = snapshot

if self.snapshot:
self.snapshot_prev = self.snapshot

self.snapshot = snapshot

def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return

stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.info(f"Tracer snap: {s}")
for line in s.traceback:
logger.info(f"* {line}")

@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.info(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.info(f"* {line}")

def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return

EnmeddTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)

def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return

EnmeddTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)
11 changes: 8 additions & 3 deletions backend/enmedd/background/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from enmedd.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from enmedd.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from enmedd.configs.app_configs import NUM_INDEXING_WORKERS
from enmedd.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from enmedd.db.connector import fetch_connectors
from enmedd.db.embedding_model import get_current_db_embedding_model
from enmedd.db.embedding_model import get_secondary_db_embedding_model
Expand Down Expand Up @@ -335,7 +336,11 @@ def kickoff_indexing_jobs(
return existing_jobs_copy


def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
Expand Down Expand Up @@ -364,7 +369,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_workers,
n_workers=num_secondary_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
Expand All @@ -374,7 +379,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)

existing_jobs: dict[int, Future | SimpleJob] = {}

Expand Down
7 changes: 6 additions & 1 deletion backend/enmedd/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
NUM_SECONDARY_INDEXING_WORKERS = int(
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
)
CHUNK_OVERLAP = 0
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
Expand All @@ -255,7 +258,9 @@
MINI_CHUNK_SIZE = 150
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))

# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))

#####
# Miscellaneous
Expand Down
Loading