Skip to content

Commit

Permalink
Feature/token rate limits and update teamspace user roles (#167)
Browse files Browse the repository at this point in the history
* feat: teamspace_chatfolder relationship table; get assistant by teamspace_id

* feat: include token rate limit on fetch teamspace; update teamspace user role

---------

Co-authored-by: Kai Tecson <111247289+SchadenKai@users.noreply.github.com>
  • Loading branch information
Amboyandrey and SchadenKai authored Oct 5, 2024
1 parent add1c85 commit 8beda16
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 17 deletions.
6 changes: 4 additions & 2 deletions backend/ee/enmedd/db/teamspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ee.enmedd.server.teamspace.models import TeamspaceCreate
from ee.enmedd.server.teamspace.models import TeamspaceUpdate
from enmedd.auth.schemas import UserRole
from ee.enmedd.server.teamspace.models import TeamspaceUserRole
from enmedd.db.models import Assistant__Teamspace
from enmedd.db.models import ConnectorCredentialPair
from enmedd.db.models import Document
Expand Down Expand Up @@ -159,7 +159,9 @@ def _add_user__teamspace_relationships__no_commit(
User__Teamspace(
user_id=user_id,
teamspace_id=teamspace_id,
role=UserRole.ADMIN if user_id == creator_id else UserRole.BASIC,
role=TeamspaceUserRole.ADMIN
if user_id == creator_id
else TeamspaceUserRole.BASIC,
)
for user_id in user_ids
]
Expand Down
43 changes: 43 additions & 0 deletions backend/ee/enmedd/server/teamspace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
from ee.enmedd.server.teamspace.models import Teamspace
from ee.enmedd.server.teamspace.models import TeamspaceCreate
from ee.enmedd.server.teamspace.models import TeamspaceUpdate
from ee.enmedd.server.teamspace.models import TeamspaceUserRole
from ee.enmedd.server.teamspace.models import UpdateUserRoleRequest
from enmedd.auth.users import current_admin_user
from enmedd.db.engine import get_session
from enmedd.db.models import User
from enmedd.db.models import User__Teamspace
from enmedd.db.users import get_user_by_email

router = APIRouter(prefix="/manage")

Expand Down Expand Up @@ -86,3 +90,42 @@ def delete_teamspace(
prepare_teamspace_for_deletion(db_session, teamspace_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))


@router.patch("/admin/teamspace/user-role/{teamspace_id}")
def update_teamspace_user_role(
teamspace_id: int,
body: UpdateUserRoleRequest,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
user_to_update = get_user_by_email(email=body.user_email, db_session=db_session)
if not user_to_update:
raise HTTPException(status_code=404, detail="User not found")

if user_to_update.id == user.id and body.new_role != TeamspaceUserRole.ADMIN:
raise HTTPException(
status_code=400, detail="Cannot demote yourself from admin role!"
)

user_teamspace = (
db_session.query(User__Teamspace)
.filter(
User__Teamspace.user_id == user_to_update.id,
User__Teamspace.teamspace_id == teamspace_id,
)
.first()
)

if not user_teamspace:
raise HTTPException(
status_code=404, detail="User-Teamspace relationship not found"
)

user_teamspace.role = body.new_role

db_session.commit()

return {
"message": f"User role updated to {body.new_role.value} for {body.user_email}"
}
19 changes: 19 additions & 0 deletions backend/ee/enmedd/server/teamspace/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import List
from typing import Optional
from uuid import UUID
Expand All @@ -15,6 +16,7 @@
from enmedd.server.models import MinimalTeamspaceSnapshot
from enmedd.server.models import MinimalWorkspaceSnapshot
from enmedd.server.query_and_chat.models import ChatSessionDetails
from enmedd.server.token_rate_limits.models import TokenRateLimitDisplay


class Teamspace(BaseModel):
Expand All @@ -28,6 +30,7 @@ class Teamspace(BaseModel):
is_up_to_date: bool
is_up_for_deletion: bool
workspace: list[MinimalWorkspaceSnapshot]
token_rate_limit: Optional[TokenRateLimitDisplay] = None

@classmethod
def from_model(cls, teamspace_model: TeamspaceModel) -> "Teamspace":
Expand Down Expand Up @@ -110,6 +113,11 @@ def from_model(cls, teamspace_model: TeamspaceModel) -> "Teamspace":
)
for workspace in teamspace_model.workspace
],
token_rate_limit=(
TokenRateLimitDisplay.from_db(teamspace_model.token_rate_limit)
if teamspace_model.token_rate_limit is not None
else None
),
)


Expand All @@ -127,3 +135,14 @@ class TeamspaceUpdate(BaseModel):
cc_pair_ids: list[int]
document_set_ids: Optional[List[int]] = []
assistant_ids: Optional[List[int]] = []


class TeamspaceUserRole(str, Enum):
BASIC = "basic"
CREATOR = "creator"
ADMIN = "admin"


class UpdateUserRoleRequest(BaseModel):
user_email: str
new_role: TeamspaceUserRole
36 changes: 36 additions & 0 deletions backend/enmedd/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from sqlalchemy.orm import Session

from enmedd.background.indexing.checkpointing import get_time_windows_for_index_attempt
from enmedd.background.indexing.tracer import EnmeddTracer
from enmedd.configs.app_configs import INDEXING_TRACER_INTERVAL
from enmedd.configs.app_configs import POLL_CONNECTOR_OFFSET
from enmedd.connectors.factory import instantiate_connector
from enmedd.connectors.interfaces import GenerateDocumentsOutput
Expand Down Expand Up @@ -35,6 +37,8 @@

logger = setup_logger()

INDEXING_TRACER_NUM_PRINT_ENTRIES = 5


def _get_document_generator(
db_session: Session,
Expand Down Expand Up @@ -139,6 +143,12 @@ def _run_indexing(
)
)

if INDEXING_TRACER_INTERVAL > 0:
logger.info(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = EnmeddTracer()
tracer.start()
tracer.snap()

net_doc_change = 0
document_count = 0
chunk_count = 0
Expand All @@ -165,6 +175,10 @@ def _run_indexing(
)

all_connector_doc_ids: set[str] = set()

tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
Expand Down Expand Up @@ -215,6 +229,17 @@ def _run_indexing(
docs_removed_from_index=0,
)

tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.info(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)

run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
Expand Down Expand Up @@ -253,12 +278,23 @@ def _run_indexing(
credential_id=index_attempt.credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e

# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break

if INDEXING_TRACER_INTERVAL > 0:
logger.info(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.info("Memory tracer stopped.")

mark_attempt_succeeded(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(
Expand Down
77 changes: 77 additions & 0 deletions backend/enmedd/background/indexing/tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import tracemalloc

from enmedd.utils.logger import setup_logger

logger = setup_logger()

ENMEDD_TRACEMALLOC_FRAMES = 10


class EnmeddTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None

def start(self) -> None:
tracemalloc.start(ENMEDD_TRACEMALLOC_FRAMES)

def stop(self) -> None:
tracemalloc.stop()

def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)

if not self.snapshot_first:
self.snapshot_first = snapshot

if self.snapshot:
self.snapshot_prev = self.snapshot

self.snapshot = snapshot

def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return

stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.info(f"Tracer snap: {s}")
for line in s.traceback:
logger.info(f"* {line}")

@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.info(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.info(f"* {line}")

def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return

EnmeddTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)

def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return

EnmeddTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)
11 changes: 8 additions & 3 deletions backend/enmedd/background/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from enmedd.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from enmedd.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from enmedd.configs.app_configs import NUM_INDEXING_WORKERS
from enmedd.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from enmedd.db.connector import fetch_connectors
from enmedd.db.embedding_model import get_current_db_embedding_model
from enmedd.db.embedding_model import get_secondary_db_embedding_model
Expand Down Expand Up @@ -335,7 +336,11 @@ def kickoff_indexing_jobs(
return existing_jobs_copy


def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
Expand Down Expand Up @@ -364,7 +369,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_workers,
n_workers=num_secondary_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
Expand All @@ -374,7 +379,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)

existing_jobs: dict[int, Future | SimpleJob] = {}

Expand Down
7 changes: 6 additions & 1 deletion backend/enmedd/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
NUM_SECONDARY_INDEXING_WORKERS = int(
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
)
CHUNK_OVERLAP = 0
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
Expand All @@ -255,7 +258,9 @@
MINI_CHUNK_SIZE = 150
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))

# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))

#####
# Miscellaneous
Expand Down
11 changes: 11 additions & 0 deletions backend/enmedd/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,12 @@ class Teamspace(Base):
secondary=ChatSession__Teamspace.__table__,
viewonly=True,
)
token_rate_limit: Mapped["TokenRateLimit"] = relationship(
"TokenRateLimit",
secondary="token_rate_limit__teamspace",
viewonly=True,
)

chat_folders: Mapped[list[ChatFolder]] = relationship(
"ChatFolder",
secondary=ChatFolder__Teamspace.__table__,
Expand All @@ -1389,6 +1395,11 @@ class TokenRateLimit(Base):
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
groups: Mapped["Teamspace"] = relationship(
"Teamspace",
secondary="token_rate_limit__teamspace",
viewonly=True,
)


class TokenRateLimit__Teamspace(Base):
Expand Down
10 changes: 6 additions & 4 deletions backend/enmedd/server/token_rate_limits/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pydantic import BaseModel

from enmedd.db.models import TokenRateLimit
Expand All @@ -10,10 +12,10 @@ class TokenRateLimitArgs(BaseModel):


class TokenRateLimitDisplay(BaseModel):
token_id: int
enabled: bool
token_budget: int
period_hours: int
token_id: Optional[str] = None
enabled: Optional[bool] = None
token_budget: Optional[int] = None
period_hours: Optional[int] = None

@classmethod
def from_db(cls, token_rate_limit: TokenRateLimit) -> "TokenRateLimitDisplay":
Expand Down
Loading

0 comments on commit 8beda16

Please sign in to comment.