diff --git a/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py new file mode 100644 index 00000000000..f284c7b4bf1 --- /dev/null +++ b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py @@ -0,0 +1,66 @@ +"""Add last synced and last modified to document table + +Revision ID: 52a219fb5233 +Revises: f17bf3b0d9f1 +Create Date: 2024-08-28 17:40:46.077470 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import func + +# revision identifiers, used by Alembic. +revision = "52a219fb5233" +down_revision = "f7e58d357687" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # last modified represents the last time anything needing syncing to vespa changed + # including row metadata and the document itself. This obviously does not include + # the last_synced column. + op.add_column( + "document", + sa.Column( + "last_modified", + sa.DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + + # last synced represents the last time this document was synced to Vespa + op.add_column( + "document", + sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True), + ) + + # Set last_synced to the same value as last_modified for existing rows + op.execute( + """ + UPDATE document + SET last_synced = last_modified + """ + ) + + op.create_index( + op.f("ix_document_last_modified"), + "document", + ["last_modified"], + unique=False, + ) + + op.create_index( + op.f("ix_document_last_synced"), + "document", + ["last_synced"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_document_last_synced"), table_name="document") + op.drop_index(op.f("ix_document_last_modified"), table_name="document") + op.drop_column("document", "last_synced") + op.drop_column("document", "last_modified") diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 5501980ab48..9088ddf8425 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -3,21 +3,49 @@ from danswer.access.models import DocumentAccess from danswer.access.utils import prefix_user from danswer.configs.constants import PUBLIC_DOC_PAT -from danswer.db.document import get_acccess_info_for_documents +from danswer.db.document import get_access_info_for_document +from danswer.db.document import get_access_info_for_documents from danswer.db.models import User from danswer.utils.variable_functionality import fetch_versioned_implementation +def _get_access_for_document( + document_id: str, + db_session: Session, +) -> DocumentAccess: + info = get_access_info_for_document( + db_session=db_session, + document_id=document_id, + ) + + if not info: + return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + + return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2]) + + +def get_access_for_document( + document_id: str, + db_session: Session, +) -> DocumentAccess: + versioned_get_access_for_document_fn = fetch_versioned_implementation( + "danswer.access.access", "_get_access_for_document" + ) + return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore + + def _get_access_for_documents( document_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: - document_access_info = get_acccess_info_for_documents( + document_access_info = get_access_info_for_documents( db_session=db_session, document_ids=document_ids, ) return { - document_id: DocumentAccess.build(user_ids, [], is_public) + document_id: DocumentAccess.build( + user_ids=user_ids, user_groups=[], is_public=is_public + ) for document_id, user_ids, is_public in document_access_info } diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index c401dde83ca..5029520dcb6 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -3,66 +3,83 @@ from typing import Any from typing import cast -from celery import Celery # type: ignore +import redis +from celery import Celery +from celery import signals +from celery import Task from celery.contrib.abortable import AbortableTask # type: ignore +from celery.exceptions import SoftTimeLimitExceeded from celery.exceptions import TaskRevokedError +from celery.signals import beat_init +from celery.signals import worker_init +from celery.states import READY_STATES +from celery.utils.log import get_task_logger +from redis import Redis from sqlalchemy import inspect from sqlalchemy import text from sqlalchemy.orm import Session +from danswer.access.access import get_access_for_document +from danswer.background.celery.celery_redis import RedisConnectorCredentialPair +from danswer.background.celery.celery_redis import RedisDocumentSet +from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair from danswer.background.celery.celery_utils import should_prune_cc_pair -from danswer.background.celery.celery_utils import should_sync_doc_set from danswer.background.connector_deletion import delete_connector_credential_pair from danswer.background.connector_deletion import delete_connector_credential_pair_batch from danswer.background.task_utils import build_celery_task_wrapper from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_cc_prune_task -from danswer.background.task_utils import name_document_set_sync_task from danswer.configs.app_configs import JOB_TIMEOUT -from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY -from danswer.configs.app_configs import REDIS_HOST -from danswer.configs.app_configs import REDIS_PASSWORD -from danswer.configs.app_configs import REDIS_PORT +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerRedisLocks +from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME +from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME from danswer.configs.constants import PostgresAdvisoryLocks from danswer.connectors.factory import instantiate_connector from danswer.connectors.models import InputType -from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.connector_credential_pair import ( + get_connector_credential_pair, +) from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed +from danswer.db.document import count_documents_by_needs_sync +from danswer.db.document import get_document from danswer.db.document import get_documents_for_connector_credential_pair -from danswer.db.document import prepare_to_modify_documents +from danswer.db.document import mark_document_as_synced from danswer.db.document_set import delete_document_set +from danswer.db.document_set import fetch_document_set_for_document from danswer.db.document_set import fetch_document_sets -from danswer.db.document_set import fetch_document_sets_for_documents -from danswer.db.document_set import fetch_documents_for_document_set_paginated from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import init_sqlalchemy_engine from danswer.db.models import DocumentSet +from danswer.db.models import UserGroup from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest +from danswer.redis.redis_pool import RedisPool from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation_with_fallback, +) +from danswer.utils.variable_functionality import noop_fallback logger = setup_logger() -CELERY_PASSWORD_PART = "" -if REDIS_PASSWORD: - CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@" - -# example celery_broker_url: "redis://:password@localhost:6379/15" -celery_broker_url = ( - f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}" -) -celery_backend_url = ( - f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}" -) -celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url) +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) +redis_pool = RedisPool() -_SYNC_BATCH_SIZE = 100 +celery_app = Celery(__name__) +celery_app.config_from_object( + "danswer.background.celery.celeryconfig" +) # Load configuration from 'celeryconfig.py' ##### @@ -111,7 +128,10 @@ def cleanup_connector_credential_pair_task( cc_pair=cc_pair, ) except Exception as e: - logger.exception(f"Failed to run connector_deletion due to {e}") + task_logger.exception( + f"Failed to run connector_deletion. " + f"connector_id={connector_id} credential_id={credential_id}" + ) raise e @@ -130,7 +150,9 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: ) if not cc_pair: - logger.warning(f"ccpair not found for {connector_id} {credential_id}") + task_logger.warning( + f"ccpair not found for {connector_id} {credential_id}" + ) return runnable_connector = instantiate_connector( @@ -162,12 +184,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: ) if len(doc_ids_to_remove) == 0: - logger.info( + task_logger.info( f"No docs to prune from {cc_pair.connector.source} connector" ) return - logger.info( + task_logger.info( f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector" ) delete_connector_credential_pair_batch( @@ -177,113 +199,202 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: document_index=document_index, ) except Exception as e: - logger.exception( - f"Failed to run pruning for connector id {connector_id} due to {e}" + task_logger.exception( + f"Failed to run pruning for connector id {connector_id}." ) raise e -@build_celery_task_wrapper(name_document_set_sync_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_document_set_task(document_set_id: int) -> None: - """For document sets marked as not up to date, sync the state from postgres - into the datastore. Also handles deletions.""" - - def _sync_document_batch(document_ids: list[str], db_session: Session) -> None: - logger.debug(f"Syncing document sets for: {document_ids}") - - # Acquires a lock on the documents so that no other process can modify them - with prepare_to_modify_documents( - db_session=db_session, document_ids=document_ids - ): - # get current state of document sets for these documents - document_set_map = { - document_id: document_sets - for document_id, document_sets in fetch_document_sets_for_documents( - document_ids=document_ids, db_session=db_session - ) - } +def try_generate_stale_document_sync_tasks( + db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + # the fence is up, do nothing + if r.exists(RedisConnectorCredentialPair.get_fence_key()): + return None - # update Vespa - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - update_requests = [ - UpdateRequest( - document_ids=[document_id], - document_sets=set(document_set_map.get(document_id, [])), - ) - for document_id in document_ids - ] - document_index.update(update_requests=update_requests) + r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset - with Session(get_sqlalchemy_engine()) as db_session: - try: - cursor = None - while True: - document_batch, cursor = fetch_documents_for_document_set_paginated( - document_set_id=document_set_id, - db_session=db_session, - current_only=False, - last_document_id=cursor, - limit=_SYNC_BATCH_SIZE, - ) - _sync_document_batch( - document_ids=[document.id for document in document_batch], - db_session=db_session, - ) - if cursor is None: - break - - # if there are no connectors, then delete the document set. Otherwise, just - # mark it as successfully synced. - document_set = cast( - DocumentSet, - get_document_set_by_id( - db_session=db_session, document_set_id=document_set_id - ), - ) # casting since we "know" a document set with this ID exists - if not document_set.connector_credential_pairs: - delete_document_set( - document_set_row=document_set, db_session=db_session - ) - logger.info( - f"Successfully deleted document set with ID: '{document_set_id}'!" - ) - else: - mark_document_set_as_synced( - document_set_id=document_set_id, db_session=db_session - ) - logger.info(f"Document set sync for '{document_set_id}' complete!") + # add tasks to celery and build up the task set to monitor in redis + stale_doc_count = count_documents_by_needs_sync(db_session) + if stale_doc_count == 0: + return None + + task_logger.info( + f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair." + ) + + # rkuo: we could technically sync all stale docs in one big pass. + # but I feel it's more understandable to group the docs by cc_pair + total_tasks_generated = 0 + cc_pairs = get_connector_credential_pairs(db_session) + for cc_pair in cc_pairs: + rc = RedisConnectorCredentialPair(cc_pair.id) + tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat) + + if tasks_generated is None: + continue + + if tasks_generated == 0: + continue + + task_logger.info( + f"RedisConnector.generate_tasks finished. " + f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" + ) + + total_tasks_generated += tasks_generated + + task_logger.info( + f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}" + ) + + r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated) + return total_tasks_generated + + +def try_generate_document_set_sync_tasks( + document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + lock_beat.reacquire() + + rds = RedisDocumentSet(document_set.id) - except Exception: - logger.exception("Failed to sync document set %s", document_set_id) - raise + # don't generate document set sync tasks if tasks are still pending + if r.exists(rds.fence_key): + return None + + # don't generate sync tasks if we're up to date + if document_set.is_up_to_date: + return None + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rds.taskset_key) + + task_logger.info( + f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}" + ) + + # Add all documents that need to be updated into the queue + tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat) + if tasks_generated is None: + return None + + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 + + task_logger.info( + f"RedisDocumentSet.generate_tasks finished. " + f"document_set_id={document_set.id} tasks_generated={tasks_generated}" + ) + + # set this only after all tasks have been added + r.set(rds.fence_key, tasks_generated) + return tasks_generated + + +def try_generate_user_group_sync_tasks( + usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + lock_beat.reacquire() + + rug = RedisUserGroup(usergroup.id) + + # don't generate sync tasks if tasks are still pending + if r.exists(rug.fence_key): + return None + + if usergroup.is_up_to_date: + return None + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rug.taskset_key) + + # Add all documents that need to be updated into the queue + task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}") + tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat) + if tasks_generated is None: + return None + + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 + + task_logger.info( + f"generate_tasks finished. " + f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}" + ) + + # set this only after all tasks have been added + r.set(rug.fence_key, tasks_generated) + return tasks_generated ##### # Periodic Tasks ##### @celery_app.task( - name="check_for_document_sets_sync_task", + name="check_for_vespa_sync_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_document_sets_sync_task() -> None: - """Runs periodically to check if any sync tasks should be run and adds them - to the queue""" - with Session(get_sqlalchemy_engine()) as db_session: - # check if any document sets are not synced - document_set_info = fetch_document_sets( - user_id=None, db_session=db_session, include_outdated=True - ) - for document_set, _ in document_set_info: - if should_sync_doc_set(document_set, db_session): - logger.info(f"Syncing the {document_set.name} document set") - sync_document_set_task.apply_async( - kwargs=dict(document_set_id=document_set.id), +def check_for_vespa_sync_task() -> None: + """Runs periodically to check if any document needs syncing. + Generates sets of tasks for Celery if syncing is needed.""" + + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + with Session(get_sqlalchemy_engine()) as db_session: + try_generate_stale_document_sync_tasks(db_session, r, lock_beat) + + # check if any document sets are not synced + document_set_info = fetch_document_sets( + user_id=None, db_session=db_session, include_outdated=True + ) + for document_set, _ in document_set_info: + try_generate_document_set_sync_tasks( + document_set, db_session, r, lock_beat ) + # check if any user groups are not synced + try: + fetch_user_groups = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_groups" + ) + + user_groups = fetch_user_groups( + db_session=db_session, only_up_to_date=False + ) + for usergroup in user_groups: + try_generate_user_group_sync_tasks( + usergroup, db_session, r, lock_beat + ) + except ModuleNotFoundError: + # Always exceptions on the MIT version, which is expected + pass + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception("Unexpected exception") + finally: + if lock_beat.owned(): + lock_beat.release() + @celery_app.task( name="check_for_cc_pair_deletion_task", @@ -292,11 +403,13 @@ def check_for_document_sets_sync_task() -> None: def check_for_cc_pair_deletion_task() -> None: """Runs periodically to check if any deletion tasks should be run""" with Session(get_sqlalchemy_engine()) as db_session: - # check if any document sets are not synced + # check if any cc pairs are up for deletion cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: if should_kick_off_deletion_of_cc_pair(cc_pair, db_session): - logger.notice(f"Deleting the {cc_pair.name} connector credential pair") + task_logger.info( + f"Deleting the {cc_pair.name} connector credential pair" + ) cleanup_connector_credential_pair_task.apply_async( kwargs=dict( connector_id=cc_pair.connector.id, @@ -343,7 +456,9 @@ def kombu_message_cleanup_task(self: Any) -> int: db_session.commit() if ctx["deleted"] > 0: - logger.info(f"Deleted {ctx['deleted']} orphaned messages from kombu_message.") + task_logger.info( + f"Deleted {ctx['deleted']} orphaned messages from kombu_message." + ) return ctx["deleted"] @@ -417,12 +532,6 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: ) if result.rowcount > 0: # type: ignore ctx["deleted"] += 1 - else: - task_name = payload["headers"]["task"] - logger.warning( - f"Message found for task older than {ctx['cleanup_age']} days. " - f"id={task_id} name={task_name}" - ) ctx["last_processed_id"] = msg[0] @@ -446,7 +555,7 @@ def check_for_prune_task() -> None: credential=cc_pair.credential, db_session=db_session, ): - logger.info(f"Pruning the {cc_pair.connector.name} connector") + task_logger.info(f"Pruning the {cc_pair.connector.name} connector") prune_documents_task.apply_async( kwargs=dict( @@ -456,19 +565,331 @@ def check_for_prune_task() -> None: ) +@celery_app.task( + name="vespa_metadata_sync_task", + bind=True, + soft_time_limit=45, + time_limit=60, + max_retries=3, +) +def vespa_metadata_sync_task(self: Task, document_id: str) -> bool: + task_logger.info(f"document_id={document_id}") + + try: + with Session(get_sqlalchemy_engine()) as db_session: + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + doc = get_document(document_id, db_session) + if not doc: + return False + + # document set sync + doc_sets = fetch_document_set_for_document(document_id, db_session) + update_doc_sets: set[str] = set(doc_sets) + + # User group sync + doc_access = get_access_for_document( + document_id=document_id, db_session=db_session + ) + update_request = UpdateRequest( + document_ids=[document_id], + document_sets=update_doc_sets, + access=doc_access, + boost=doc.boost, + hidden=doc.hidden, + ) + + # update Vespa + document_index.update(update_requests=[update_request]) + + # update db last. Worst case = we crash right before this and + # the sync might repeat again later + mark_document_as_synced(document_id, db_session) + except SoftTimeLimitExceeded: + task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") + except Exception as e: + task_logger.exception("Unexpected exception") + + # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 + countdown = 2 ** (self.request.retries + 4) + self.retry(exc=e, countdown=countdown) + + return True + + +@signals.task_postrun.connect +def celery_task_postrun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + retval: Any | None = None, + state: str | None = None, + **kwds: Any, +) -> None: + """We handle this signal in order to remove completed tasks + from their respective tasksets. This allows us to track the progress of document set + and user group syncs. + + This function runs after any task completes (both success and failure) + Note that this signal does not fire on a task that failed to complete and is going + to be retried. + """ + if not task: + return + + task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") + # logger.debug(f"Result: {retval}") + + if state not in READY_STATES: + return + + if not task_id: + return + + if task_id.startswith(RedisConnectorCredentialPair.PREFIX): + r = redis_pool.get_client() + r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) + return + + if task_id.startswith(RedisDocumentSet.PREFIX): + r = redis_pool.get_client() + document_set_id = RedisDocumentSet.get_id_from_task_id(task_id) + if document_set_id is not None: + rds = RedisDocumentSet(document_set_id) + r.srem(rds.taskset_key, task_id) + return + + if task_id.startswith(RedisUserGroup.PREFIX): + r = redis_pool.get_client() + usergroup_id = RedisUserGroup.get_id_from_task_id(task_id) + if usergroup_id is not None: + rug = RedisUserGroup(usergroup_id) + r.srem(rug.taskset_key, task_id) + return + + +def monitor_connector_taskset(r: Redis) -> None: + fence_value = r.get(RedisConnectorCredentialPair.get_fence_key()) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = r.scard(RedisConnectorCredentialPair.get_taskset_key()) + task_logger.info(f"Stale documents: remaining={count} initial={initial_count}") + if count == 0: + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) + task_logger.info(f"Successfully synced stale documents. count={initial_count}") + + +def monitor_document_set_taskset( + key_bytes: bytes, r: Redis, db_session: Session +) -> None: + fence_key = key_bytes.decode("utf-8") + document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key) + if document_set_id is None: + task_logger.warning("could not parse document set id from {key}") + return + + rds = RedisDocumentSet(document_set_id) + + fence_value = r.get(rds.fence_key) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rds.taskset_key)) + task_logger.info( + f"document_set_id={document_set_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + document_set = cast( + DocumentSet, + get_document_set_by_id(db_session=db_session, document_set_id=document_set_id), + ) # casting since we "know" a document set with this ID exists + if document_set: + if not document_set.connector_credential_pairs: + # if there are no connectors, then delete the document set. + delete_document_set(document_set_row=document_set, db_session=db_session) + task_logger.info( + f"Successfully deleted document set with ID: '{document_set_id}'!" + ) + else: + mark_document_set_as_synced(document_set_id, db_session) + task_logger.info( + f"Successfully synced document set with ID: '{document_set_id}'!" + ) + + r.delete(rds.taskset_key) + r.delete(rds.fence_key) + + +def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None: + key = key_bytes.decode("utf-8") + usergroup_id = RedisUserGroup.get_id_from_fence_key(key) + if not usergroup_id: + task_logger.warning("Could not parse usergroup id from {key}") + return + + rug = RedisUserGroup(usergroup_id) + fence_value = r.get(rug.fence_key) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rug.taskset_key)) + task_logger.info( + f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + try: + fetch_user_group = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_group" + ) + except ModuleNotFoundError: + task_logger.exception( + "fetch_versioned_implementation failed to look up fetch_user_group." + ) + return + + user_group: UserGroup | None = fetch_user_group( + db_session=db_session, user_group_id=usergroup_id + ) + if user_group: + if user_group.is_up_for_deletion: + delete_user_group = fetch_versioned_implementation_with_fallback( + "danswer.db.user_group", "delete_user_group", noop_fallback + ) + + delete_user_group(db_session=db_session, user_group=user_group) + task_logger.info(f" Deleted usergroup. id='{usergroup_id}'") + else: + mark_user_group_as_synced = fetch_versioned_implementation_with_fallback( + "danswer.db.user_group", "mark_user_group_as_synced", noop_fallback + ) + + mark_user_group_as_synced(db_session=db_session, user_group=user_group) + task_logger.info(f"Synced usergroup. id='{usergroup_id}'") + + r.delete(rug.taskset_key) + r.delete(rug.fence_key) + + +@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300) +def monitor_vespa_sync() -> None: + """This is a celery beat task that monitors and finalizes metadata sync tasksets. + It scans for fence values and then gets the counts of any associated tasksets. + If the count is 0, that means all tasks finished and we should clean up. + + This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't + do anything too expensive in this function! + """ + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # prevent overlapping tasks + if not lock_beat.acquire(blocking=False): + return + + with Session(get_sqlalchemy_engine()) as db_session: + if r.exists(RedisConnectorCredentialPair.get_fence_key()): + monitor_connector_taskset(r) + + for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + monitor_document_set_taskset(key_bytes, r, db_session) + + for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + monitor_usergroup_taskset(key_bytes, r, db_session) + + # + # r_celery = celery_app.broker_connection().channel().client + # length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery) + # task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}") + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + finally: + if lock_beat.owned(): + lock_beat.release() + + +@beat_init.connect +def on_beat_init(sender: Any, **kwargs: Any) -> None: + init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME) + + +@worker_init.connect +def on_worker_init(sender: Any, **kwargs: Any) -> None: + init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME) + + # TODO(rkuo): this is singleton work that should be done on startup exactly once + # if we run multiple workers, we'll need to centralize where this cleanup happens + r = redis_pool.get_client() + + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) + r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) + + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) + + for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + r.delete(key) + + ##### # Celery Beat (Periodic Tasks) Settings ##### celery_app.conf.beat_schedule = { - "check-for-document-set-sync": { - "task": "check_for_document_sets_sync_task", + "check-for-vespa-sync": { + "task": "check_for_vespa_sync_task", "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, "check-for-cc-pair-deletion": { "task": "check_for_cc_pair_deletion_task", # don't need to check too often, since we kick off a deletion initially # during the API call that actually marks the CC pair for deletion "schedule": timedelta(minutes=1), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, } celery_app.conf.beat_schedule.update( @@ -476,6 +897,7 @@ def check_for_prune_task() -> None: "check-for-prune": { "task": "check_for_prune_task", "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, } ) @@ -484,6 +906,16 @@ def check_for_prune_task() -> None: "kombu-message-cleanup": { "task": "kombu_message_cleanup_task", "schedule": timedelta(seconds=3600), + "options": {"priority": DanswerCeleryPriority.LOWEST}, + }, + } +) +celery_app.conf.beat_schedule.update( + { + "monitor-vespa-sync": { + "task": "monitor_vespa_sync", + "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, } ) diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py new file mode 100644 index 00000000000..bf82f0a7274 --- /dev/null +++ b/backend/danswer/background/celery/celery_redis.py @@ -0,0 +1,299 @@ +# These are helper objects for tracking the keys we need to write in redis +import time +from abc import ABC +from abc import abstractmethod +from typing import cast +from uuid import uuid4 + +import redis +from celery import Celery +from redis import Redis +from sqlalchemy.orm import Session + +from danswer.background.celery.celeryconfig import CELERY_SEPARATOR +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.document import ( + construct_document_select_for_connector_credential_pair_by_needs_sync, +) +from danswer.db.document_set import construct_document_select_by_docset +from danswer.utils.variable_functionality import fetch_versioned_implementation + + +class RedisObjectHelper(ABC): + PREFIX = "base" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + def __init__(self, id: int): + self._id: int = id + + @property + def task_id_prefix(self) -> str: + return f"{self.PREFIX}_{self._id}" + + @property + def fence_key(self) -> str: + # example: documentset_fence_1 + return f"{self.FENCE_PREFIX}_{self._id}" + + @property + def taskset_key(self) -> str: + # example: documentset_taskset_1 + return f"{self.TASKSET_PREFIX}_{self._id}" + + @staticmethod + def get_id_from_fence_key(key: str) -> int | None: + """ + Extracts the object ID from a fence key in the format `PREFIX_fence_X`. + + Args: + key (str): The fence key string. + + Returns: + Optional[int]: The extracted ID if the key is in the correct format, otherwise None. + """ + parts = key.split("_") + if len(parts) != 3: + return None + + try: + object_id = int(parts[2]) + except ValueError: + return None + + return object_id + + @staticmethod + def get_id_from_task_id(task_id: str) -> int | None: + """ + Extracts the object ID from a task ID string. + + This method assumes the task ID is formatted as `prefix_objectid_suffix`, where: + - `prefix` is an arbitrary string (e.g., the name of the task or entity), + - `objectid` is the ID you want to extract, + - `suffix` is another arbitrary string (e.g., a UUID). + + Example: + If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`, + this method will return the string `"1"`. + + Args: + task_id (str): The task ID string from which to extract the object ID. + + Returns: + str | None: The extracted object ID if the task ID is in the correct format, otherwise None. + """ + # example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc + parts = task_id.split("_") + if len(parts) != 3: + return None + + try: + object_id = int(parts[1]) + except ValueError: + return None + + return object_id + + @abstractmethod + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + pass + + +class RedisDocumentSet(RedisObjectHelper): + PREFIX = "documentset" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + stmt = construct_document_select_by_docset(self._id) + for doc in db_session.scalars(stmt).yield_per(1): + current_time = time.monotonic() + if current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" + # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" + # we prefix the task id so it's easier to keep track of who created the task + # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" + custom_task_id = f"{self.task_id_prefix}_{uuid4()}" + + # add to the set BEFORE creating the task. + redis_client.sadd(self.taskset_key, custom_task_id) + + result = celery_app.send_task( + "vespa_metadata_sync_task", + kwargs=dict(document_id=doc.id), + queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.LOW, + ) + + async_results.append(result) + + return len(async_results) + + +class RedisUserGroup(RedisObjectHelper): + PREFIX = "usergroup" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + + try: + construct_document_select_by_usergroup = fetch_versioned_implementation( + "danswer.db.user_group", + "construct_document_select_by_usergroup", + ) + except ModuleNotFoundError: + return 0 + + stmt = construct_document_select_by_usergroup(self._id) + for doc in db_session.scalars(stmt).yield_per(1): + current_time = time.monotonic() + if current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" + # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" + # we prefix the task id so it's easier to keep track of who created the task + # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" + custom_task_id = f"{self.task_id_prefix}_{uuid4()}" + + # add to the set BEFORE creating the task. + redis_client.sadd(self.taskset_key, custom_task_id) + + result = celery_app.send_task( + "vespa_metadata_sync_task", + kwargs=dict(document_id=doc.id), + queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.LOW, + ) + + async_results.append(result) + + return len(async_results) + + +class RedisConnectorCredentialPair(RedisObjectHelper): + PREFIX = "connectorsync" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + @classmethod + def get_fence_key(cls) -> str: + return RedisConnectorCredentialPair.FENCE_PREFIX + + @classmethod + def get_taskset_key(cls) -> str: + return RedisConnectorCredentialPair.TASKSET_PREFIX + + @property + def taskset_key(self) -> str: + """Notice that this is intentionally reusing the same taskset for all + connector syncs""" + # example: connector_taskset + return f"{self.TASKSET_PREFIX}" + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + cc_pair = get_connector_credential_pair_from_id(self._id, db_session) + if not cc_pair: + return None + + stmt = construct_document_select_for_connector_credential_pair_by_needs_sync( + cc_pair.connector_id, cc_pair.credential_id + ) + for doc in db_session.scalars(stmt).yield_per(1): + current_time = time.monotonic() + if current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" + # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" + # we prefix the task id so it's easier to keep track of who created the task + # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" + custom_task_id = f"{self.task_id_prefix}_{uuid4()}" + + # add to the tracking taskset in redis BEFORE creating the celery task. + # note that for the moment we are using a single taskset key, not differentiated by cc_pair id + redis_client.sadd( + RedisConnectorCredentialPair.get_taskset_key(), custom_task_id + ) + + # Priority on sync's triggered by new indexing should be medium + result = celery_app.send_task( + "vespa_metadata_sync_task", + kwargs=dict(document_id=doc.id), + queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.MEDIUM, + ) + + async_results.append(result) + + return len(async_results) + + +def celery_get_queue_length(queue: str, r: Redis) -> int: + """This is a redis specific way to get the length of a celery queue. + It is priority aware and knows how to count across the multiple redis lists + used to implement task prioritization. + This operation is not atomic.""" + total_length = 0 + for i in range(len(DanswerCeleryPriority)): + queue_name = queue + if i > 0: + queue_name += CELERY_SEPARATOR + queue_name += str(i) + + length = r.llen(queue_name) + total_length += cast(int, length) + + return total_length diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index e4d4d13bb1d..a51bd8cca35 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -5,7 +5,6 @@ from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_cc_prune_task -from danswer.background.task_utils import name_document_set_sync_task from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( @@ -22,7 +21,6 @@ from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential -from danswer.db.models import DocumentSet from danswer.db.models import TaskQueueState from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task @@ -81,21 +79,6 @@ def should_kick_off_deletion_of_cc_pair( return True -def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool: - if document_set.is_up_to_date: - return False - - task_name = name_document_set_sync_task(document_set.id) - latest_sync = get_latest_task(task_name, db_session) - - if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): - logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.") - return False - - logger.info(f"Document set {document_set.id} syncing now.") - return True - - def should_prune_cc_pair( connector: Connector, credential: Credential, db_session: Session ) -> bool: diff --git a/backend/danswer/background/celery/celeryconfig.py b/backend/danswer/background/celery/celeryconfig.py new file mode 100644 index 00000000000..cf7e72719fd --- /dev/null +++ b/backend/danswer/background/celery/celeryconfig.py @@ -0,0 +1,35 @@ +# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html +from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY +from danswer.configs.app_configs import REDIS_HOST +from danswer.configs.app_configs import REDIS_PASSWORD +from danswer.configs.app_configs import REDIS_PORT +from danswer.configs.constants import DanswerCeleryPriority + +CELERY_SEPARATOR = ":" + +CELERY_PASSWORD_PART = "" +if REDIS_PASSWORD: + CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@" + +# example celery_broker_url: "redis://:password@localhost:6379/15" +broker_url = ( + f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}" +) + +result_backend = ( + f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}" +) + +# NOTE: prefetch 4 is significantly faster than prefetch 1 +# however, prefetching is bad when tasks are lengthy as those tasks +# can stall other tasks. +worker_prefetch_multiplier = 4 + +broker_transport_options = { + "priority_steps": list(range(len(DanswerCeleryPriority))), + "sep": CELERY_SEPARATOR, + "queue_order_strategy": "priority", +} + +task_default_priority = DanswerCeleryPriority.MEDIUM +task_acks_late = True diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index eff8ee30a63..e807f381e87 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -61,6 +61,8 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings" KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__" +CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60 + class DocumentSource(str, Enum): # Special case, document passed in via Danswer APIs without specifying a source type @@ -167,3 +169,23 @@ class FileOrigin(str, Enum): class PostgresAdvisoryLocks(Enum): KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto() + + +class DanswerCeleryQueues: + VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator" + VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator" + VESPA_METADATA_SYNC = "vespa_metadata_sync" + CONNECTOR_DELETION = "connector_deletion" + + +class DanswerRedisLocks: + CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat" + MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat" + + +class DanswerCeleryPriority(int, Enum): + HIGHEST = 0 + HIGH = auto() + MEDIUM = auto() + LOW = auto() + LOWEST = auto() diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 77ea4e3dd9d..92b093ab587 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -3,6 +3,7 @@ from collections.abc import Generator from collections.abc import Sequence from datetime import datetime +from datetime import timezone from uuid import UUID from sqlalchemy import and_ @@ -10,6 +11,7 @@ from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import or_ +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine.util import TransactionalContext @@ -38,6 +40,68 @@ def check_docs_exist(db_session: Session) -> bool: return result.scalar() or False +def count_documents_by_needs_sync(session: Session) -> int: + """Get the count of all documents where: + 1. last_modified is newer than last_synced + 2. last_synced is null (meaning we've never synced) + + This function executes the query and returns the count of + documents matching the criteria.""" + + count = ( + session.query(func.count()) + .select_from(DbDocument) + .filter( + or_( + DbDocument.last_modified > DbDocument.last_synced, + DbDocument.last_synced.is_(None), + ) + ) + .scalar() + ) + + return count + + +def construct_document_select_for_connector_credential_pair_by_needs_sync( + connector_id: int, credential_id: int +) -> Select: + initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( + and_( + DocumentByConnectorCredentialPair.connector_id == connector_id, + DocumentByConnectorCredentialPair.credential_id == credential_id, + ) + ) + + stmt = ( + select(DbDocument) + .where( + DbDocument.id.in_(initial_doc_ids_stmt), + or_( + DbDocument.last_modified + > DbDocument.last_synced, # last_modified is newer than last_synced + DbDocument.last_synced.is_(None), # never synced + ), + ) + .distinct() + ) + + return stmt + + +def construct_document_select_for_connector_credential_pair( + connector_id: int, credential_id: int | None = None +) -> Select: + initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( + and_( + DocumentByConnectorCredentialPair.connector_id == connector_id, + DocumentByConnectorCredentialPair.credential_id == credential_id, + ) + ) + stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct() + return stmt + + def get_documents_for_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, limit: int | None = None ) -> Sequence[DbDocument]: @@ -108,7 +172,29 @@ def get_document_cnts_for_cc_pairs( return db_session.execute(stmt).all() # type: ignore -def get_acccess_info_for_documents( +def get_access_info_for_document( + db_session: Session, + document_id: str, +) -> tuple[str, list[UUID | None], bool] | None: + """Gets access info for a single document by calling the get_access_info_for_documents function + and passing a list with a single document ID. + + Args: + db_session (Session): The database session to use. + document_id (str): The document ID to fetch access info for. + + Returns: + Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs, + and a boolean indicating if the document is globally public, or None if no results are found. + """ + results = get_access_info_for_documents(db_session, [document_id]) + if not results: + return None + + return results[0] + + +def get_access_info_for_documents( db_session: Session, document_ids: list[str], ) -> Sequence[tuple[str, list[UUID | None], bool]]: @@ -173,6 +259,7 @@ def upsert_documents( semantic_id=doc.semantic_identifier, link=doc.first_link, doc_updated_at=None, # this is intentional + last_modified=datetime.now(timezone.utc), primary_owners=doc.primary_owners, secondary_owners=doc.secondary_owners, ) @@ -214,7 +301,7 @@ def upsert_document_by_connector_credential_pair( db_session.commit() -def update_docs_updated_at( +def update_docs_updated_at__no_commit( ids_to_new_updated_at: dict[str, datetime], db_session: Session, ) -> None: @@ -226,6 +313,28 @@ def update_docs_updated_at( for document in documents_to_update: document.doc_updated_at = ids_to_new_updated_at[document.id] + +def update_docs_last_modified__no_commit( + document_ids: list[str], + db_session: Session, +) -> None: + documents_to_update = ( + db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all() + ) + + now = datetime.now(timezone.utc) + for doc in documents_to_update: + doc.last_modified = now + + +def mark_document_as_synced(document_id: str, db_session: Session) -> None: + stmt = select(DbDocument).where(DbDocument.id == document_id) + doc = db_session.scalar(stmt) + if doc is None: + raise ValueError(f"No document with ID: {document_id}") + + # update last_synced + doc.last_synced = datetime.now(timezone.utc) db_session.commit() @@ -379,3 +488,12 @@ def get_documents_by_cc_pair( .filter(ConnectorCredentialPair.id == cc_pair_id) .all() ) + + +def get_document( + document_id: str, + db_session: Session, +) -> DbDocument | None: + stmt = select(DbDocument).where(DbDocument.id == document_id) + doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none() + return doc diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index c2900593835..4a37f8bdced 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -248,6 +248,10 @@ def update_document_set( document_set_update_request: DocumentSetUpdateRequest, user: User | None = None, ) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]: + """If successful, this sets document_set_row.is_up_to_date = False. + That will be processed via Celery in check_for_vespa_sync_task + and trigger a long running background sync to Vespa. + """ if not document_set_update_request.cc_pair_ids: # It's cc-pairs in actuality but the UI displays this error raise ValueError("Cannot create a document set with no Connectors") @@ -519,6 +523,70 @@ def fetch_documents_for_document_set_paginated( return documents, documents[-1].id if documents else None +def construct_document_select_by_docset( + document_set_id: int, + current_only: bool = True, +) -> Select: + """This returns a statement that should be executed using + .yield_per() to minimize overhead. The primary consumers of this function + are background processing task generators.""" + + stmt = ( + select(Document) + .join( + DocumentByConnectorCredentialPair, + DocumentByConnectorCredentialPair.id == Document.id, + ) + .join( + ConnectorCredentialPair, + and_( + ConnectorCredentialPair.connector_id + == DocumentByConnectorCredentialPair.connector_id, + ConnectorCredentialPair.credential_id + == DocumentByConnectorCredentialPair.credential_id, + ), + ) + .join( + DocumentSet__ConnectorCredentialPair, + DocumentSet__ConnectorCredentialPair.connector_credential_pair_id + == ConnectorCredentialPair.id, + ) + .join( + DocumentSetDBModel, + DocumentSetDBModel.id + == DocumentSet__ConnectorCredentialPair.document_set_id, + ) + .where(DocumentSetDBModel.id == document_set_id) + .order_by(Document.id) + ) + + if current_only: + stmt = stmt.where( + DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712 + ) + + stmt = stmt.distinct() + return stmt + + +def fetch_document_set_for_document( + document_id: str, + db_session: Session, +) -> list[str]: + """ + Fetches the document set names for a single document ID. + + :param document_id: The ID of the document to fetch sets for. + :param db_session: The SQLAlchemy session to use for the query. + :return: A list of document set names, or None if no result is found. + """ + result = fetch_document_sets_for_documents([document_id], db_session) + if not result: + return [] + + return result[0][1] + + def fetch_document_sets_for_documents( document_ids: list[str], db_session: Session, diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 79557f209dc..6df1f1f5051 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -1,3 +1,5 @@ +from datetime import datetime +from datetime import timezone from uuid import UUID from fastapi import HTTPException @@ -24,7 +26,6 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair from danswer.db.models import UserRole from danswer.document_index.interfaces import DocumentIndex -from danswer.document_index.interfaces import UpdateRequest from danswer.utils.logger import setup_logger logger = setup_logger() @@ -123,12 +124,11 @@ def update_document_boost( db_session: Session, document_id: str, boost: int, - document_index: DocumentIndex, user: User | None = None, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) stmt = _add_user_filters(stmt, user, get_editable=True) - result = db_session.execute(stmt).scalar_one_or_none() + result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none() if result is None: raise HTTPException( status_code=400, detail="Document is not editable by this user" @@ -136,13 +136,9 @@ def update_document_boost( result.boost = boost - update = UpdateRequest( - document_ids=[document_id], - boost=boost, - ) - - document_index.update(update_requests=[update]) - + # updating last_modified triggers sync + # TODO: Should this submit to the queue directly so that the UI can update? + result.last_modified = datetime.now(timezone.utc) db_session.commit() @@ -163,13 +159,9 @@ def update_document_hidden( result.hidden = hidden - update = UpdateRequest( - document_ids=[document_id], - hidden=hidden, - ) - - document_index.update(update_requests=[update]) - + # updating last_modified triggers sync + # TODO: Should this submit to the queue directly so that the UI can update? + result.last_modified = datetime.now(timezone.utc) db_session.commit() @@ -210,11 +202,9 @@ def create_doc_retrieval_feedback( SearchFeedbackType.REJECT, SearchFeedbackType.HIDE, ]: - update = UpdateRequest( - document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden - ) - # Updates are generally batched for efficiency, this case only 1 doc/value is updated - document_index.update(update_requests=[update]) + # updating last_modified triggers sync + # TODO: Should this submit to the queue directly so that the UI can update? + db_doc.last_modified = datetime.now(timezone.utc) db_session.add(retrieval_feedback) db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index ffc12323a52..adceeea17b8 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -428,12 +428,27 @@ class Document(Base): semantic_id: Mapped[str] = mapped_column(String) # First Section's link link: Mapped[str | None] = mapped_column(String, nullable=True) + # The updated time is also used as a measure of the last successful state of the doc # pulled from the source (to help skip reindexing already updated docs in case of # connector retries) + # TODO: rename this column because it conflates the time of the source doc + # with the local last modified time of the doc and any associated metadata + # it should just be the server timestamp of the source doc doc_updated_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) + + # last time any vespa relevant row metadata or the doc changed. + # does not include last_synced + last_modified: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=False, index=True, default=func.now() + ) + + # last successful sync to vespa + last_synced: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, index=True + ) # The following are not attached to User because the account/email may not be known # within Danswer # Something like the document creator diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index d07da5b06bb..32e95a8f061 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -282,7 +282,7 @@ def _update_chunk( raise requests.HTTPError(failure_msg) from e def update(self, update_requests: list[UpdateRequest]) -> None: - logger.info(f"Updating {len(update_requests)} documents in Vespa") + logger.debug(f"Updating {len(update_requests)} documents in Vespa") # Handle Vespa character limitations # Mutating update_requests but it's not used later anyway diff --git a/backend/danswer/document_index/vespa/indexing_utils.py b/backend/danswer/document_index/vespa/indexing_utils.py index 1b16cfc4947..6b6ba8709d5 100644 --- a/backend/danswer/document_index/vespa/indexing_utils.py +++ b/backend/danswer/document_index/vespa/indexing_utils.py @@ -162,14 +162,16 @@ def _index_vespa_chunk( METADATA_SUFFIX: chunk.metadata_suffix_keyword, EMBEDDINGS: embeddings_name_vector_map, TITLE_EMBEDDING: chunk.title_embedding, - BOOST: chunk.boost, DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at), PRIMARY_OWNERS: get_experts_stores_representations(document.primary_owners), SECONDARY_OWNERS: get_experts_stores_representations(document.secondary_owners), # the only `set` vespa has is `weightedset`, so we have to give each # element an arbitrary weight + # rkuo: acl, docset and boost metadata are also updated through the metadata sync queue + # which only calls VespaIndex.update ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()}, DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets}, + BOOST: chunk.boost, } vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}" diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index afe825d11ec..51cd23e7431 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -18,7 +18,8 @@ from danswer.connectors.models import IndexAttemptMetadata from danswer.db.document import get_documents_by_ids from danswer.db.document import prepare_to_modify_documents -from danswer.db.document import update_docs_updated_at +from danswer.db.document import update_docs_last_modified__no_commit +from danswer.db.document import update_docs_updated_at__no_commit from danswer.db.document import upsert_documents_complete from danswer.db.document_set import fetch_document_sets_for_documents from danswer.db.index_attempt import create_index_attempt_error @@ -264,7 +265,7 @@ def index_doc_batch( Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" - no_access = DocumentAccess.build([], [], False) + no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) ctx = index_doc_batch_prepare( document_batch=document_batch, @@ -295,9 +296,6 @@ def index_doc_batch( # NOTE: don't need to acquire till here, since this is when the actual race condition # with Vespa can occur. with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids): - # Attach the latest status from Postgres (source of truth for access) to each - # chunk. This access status will be attached to each chunk in the document index - # TODO: attach document sets to the chunk based on the status of Postgres as well document_id_to_access_info = get_access_for_documents( document_ids=updatable_ids, db_session=db_session ) @@ -307,6 +305,12 @@ def index_doc_batch( document_ids=updatable_ids, db_session=db_session ) } + + # we're concerned about race conditions where multiple simultaneous indexings might result + # in one set of metadata overwriting another one in vespa. + # we still write data here for immediate and most likely correct sync, but + # to resolve this, an update of the last modified field at the end of this loop + # always triggers a final metadata sync access_aware_chunks = [ DocMetadataAwareIndexChunk.from_index_chunk( index_chunk=chunk, @@ -338,17 +342,25 @@ def index_doc_batch( doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids ] - # Update the time of latest version of the doc successfully indexed + last_modified_ids = [] ids_to_new_updated_at = {} for doc in successful_docs: + last_modified_ids.append(doc.id) + # doc_updated_at is the connector source's idea of when the doc was last modified if doc.doc_updated_at is None: continue ids_to_new_updated_at[doc.id] = doc.doc_updated_at - update_docs_updated_at( + update_docs_updated_at__no_commit( ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session ) + update_docs_last_modified__no_commit( + document_ids=last_modified_ids, db_session=db_session + ) + + db_session.commit() + return len([r for r in insertion_records if r.already_existed is False]), len( access_aware_chunks ) diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index 93dc0f7315d..c789a2b351b 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -61,6 +61,8 @@ class IndexChunk(DocAwareChunk): title_embedding: Embedding | None +# TODO(rkuo): currently, this extra metadata sent during indexing is just for speed, +# but full consistency happens on background sync class DocMetadataAwareIndexChunk(IndexChunk): """An `IndexChunk` that contains all necessary metadata to be indexed. This includes the following: diff --git a/backend/danswer/redis/redis_pool.py b/backend/danswer/redis/redis_pool.py new file mode 100644 index 00000000000..edea22fc05b --- /dev/null +++ b/backend/danswer/redis/redis_pool.py @@ -0,0 +1,49 @@ +import threading +from typing import Optional + +import redis +from redis.client import Redis +from redis.connection import ConnectionPool + +from danswer.configs.app_configs import REDIS_DB_NUMBER +from danswer.configs.app_configs import REDIS_HOST +from danswer.configs.app_configs import REDIS_PASSWORD +from danswer.configs.app_configs import REDIS_PORT + +REDIS_POOL_MAX_CONNECTIONS = 10 + + +class RedisPool: + _instance: Optional["RedisPool"] = None + _lock: threading.Lock = threading.Lock() + _pool: ConnectionPool + + def __new__(cls) -> "RedisPool": + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(RedisPool, cls).__new__(cls) + cls._instance._init_pool() + return cls._instance + + def _init_pool(self) -> None: + self._pool = redis.ConnectionPool( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB_NUMBER, + password=REDIS_PASSWORD, + max_connections=REDIS_POOL_MAX_CONNECTIONS, + ) + + def get_client(self) -> Redis: + return redis.Redis(connection_pool=self._pool) + + +# # Usage example +# redis_pool = RedisPool() +# redis_client = redis_pool.get_client() + +# # Example of setting and getting a value +# redis_client.set('key', 'value') +# value = redis_client.get('key') +# print(value.decode()) # Output: 'value' diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index a2d7156892c..f45d5c38529 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -77,16 +77,10 @@ def document_boost_update( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - update_document_boost( db_session=db_session, document_id=boost_update.document_id, boost=boost_update.boost, - document_index=document_index, user=user, ) return StatusResponse(success=True, message="Updated document boost") diff --git a/backend/danswer/utils/variable_functionality.py b/backend/danswer/utils/variable_functionality.py index 97c6592601e..55f296aa8e7 100644 --- a/backend/danswer/utils/variable_functionality.py +++ b/backend/danswer/utils/variable_functionality.py @@ -31,6 +31,28 @@ def set_is_ee_based_on_env_variable() -> None: @functools.lru_cache(maxsize=128) def fetch_versioned_implementation(module: str, attribute: str) -> Any: + """ + Fetches a versioned implementation of a specified attribute from a given module. + This function first checks if the application is running in an Enterprise Edition (EE) + context. If so, it attempts to import the attribute from the EE-specific module. + If the module or attribute is not found, it falls back to the default module or + raises the appropriate exception depending on the context. + + Args: + module (str): The name of the module from which to fetch the attribute. + attribute (str): The name of the attribute to fetch from the module. + + Returns: + Any: The fetched implementation of the attribute. + + Raises: + ModuleNotFoundError: If the module cannot be found and the error is not related to + the Enterprise Edition fallback logic. + + Logs: + Logs debug information about the fetching process and warnings if the versioned + implementation cannot be found or loaded. + """ logger.debug("Fetching versioned implementation for %s.%s", module, attribute) is_ee = global_version.get_is_ee_version() @@ -66,6 +88,19 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any: def fetch_versioned_implementation_with_fallback( module: str, attribute: str, fallback: T ) -> T: + """ + Attempts to fetch a versioned implementation of a specified attribute from a given module. + If the attempt fails (e.g., due to an import error or missing attribute), the function logs + a warning and returns the provided fallback implementation. + + Args: + module (str): The name of the module from which to fetch the attribute. + attribute (str): The name of the attribute to fetch from the module. + fallback (T): The fallback implementation to return if fetching the attribute fails. + + Returns: + T: The fetched implementation if successful, otherwise the provided fallback. + """ try: return fetch_versioned_implementation(module, attribute) except Exception: @@ -73,4 +108,14 @@ def fetch_versioned_implementation_with_fallback( def noop_fallback(*args: Any, **kwargs: Any) -> None: - pass + """ + A no-op (no operation) fallback function that accepts any arguments but does nothing. + This is often used as a default or placeholder callback function. + + Args: + *args (Any): Positional arguments, which are ignored. + **kwargs (Any): Keyword arguments, which are ignored. + + Returns: + None + """ diff --git a/backend/ee/danswer/access/access.py b/backend/ee/danswer/access/access.py index c2b05ee881f..2b3cdb7a9dc 100644 --- a/backend/ee/danswer/access/access.py +++ b/backend/ee/danswer/access/access.py @@ -11,6 +11,17 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user +def _get_access_for_document( + document_id: str, + db_session: Session, +) -> DocumentAccess: + id_to_access = _get_access_for_documents([document_id], db_session) + if len(id_to_access) == 0: + return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + + return next(iter(id_to_access.values())) + + def _get_access_for_documents( document_ids: list[str], db_session: Session, diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py index 403adbd74e1..2b4c96ccb1e 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -1,28 +1,18 @@ from datetime import timedelta -from typing import Any -from celery.signals import beat_init -from celery.signals import worker_init from sqlalchemy.orm import Session from danswer.background.celery.celery_app import celery_app from danswer.background.task_utils import build_celery_task_wrapper from danswer.configs.app_configs import JOB_TIMEOUT -from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME -from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME from danswer.db.chat import delete_chat_sessions_older_than from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.engine import init_sqlalchemy_engine from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from ee.danswer.background.celery_utils import should_perform_chat_ttl_check -from ee.danswer.background.celery_utils import should_sync_user_groups from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import name_user_group_sync_task -from ee.danswer.db.user_group import fetch_user_groups from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report -from ee.danswer.user_groups.sync import sync_user_groups logger = setup_logger() @@ -30,17 +20,6 @@ global_version.set_ee() -@build_celery_task_wrapper(name_user_group_sync_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_user_group_task(user_group_id: int) -> None: - with Session(get_sqlalchemy_engine()) as db_session: - # actual sync logic - try: - sync_user_groups(user_group_id=user_group_id, db_session=db_session) - except Exception as e: - logger.exception(f"Failed to sync user group - {e}") - - @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def perform_ttl_management_task(retention_limit_days: int) -> None: @@ -51,8 +30,6 @@ def perform_ttl_management_task(retention_limit_days: int) -> None: ##### # Periodic Tasks ##### - - @celery_app.task( name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, @@ -69,24 +46,6 @@ def check_ttl_management_task() -> None: ) -@celery_app.task( - name="check_for_user_groups_sync_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_for_user_groups_sync_task() -> None: - """Runs periodically to check if any user groups are out of sync - Creates a task to sync the user group if needed""" - with Session(get_sqlalchemy_engine()) as db_session: - # check if any document sets are not synced - user_groups = fetch_user_groups(db_session=db_session, only_current=False) - for user_group in user_groups: - if should_sync_user_groups(user_group, db_session): - logger.info(f"User Group {user_group.id} is not synced. Syncing now!") - sync_user_group_task.apply_async( - kwargs=dict(user_group_id=user_group.id), - ) - - @celery_app.task( name="autogenerate_usage_report_task", soft_time_limit=JOB_TIMEOUT, @@ -101,25 +60,11 @@ def autogenerate_usage_report_task() -> None: ) -@beat_init.connect -def on_beat_init(sender: Any, **kwargs: Any) -> None: - init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME) - - -@worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: - init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME) - - ##### # Celery Beat (Periodic Tasks) Settings ##### celery_app.conf.beat_schedule = { - "check-for-user-group-sync": { - "task": "check_for_user_groups_sync_task", - "schedule": timedelta(seconds=5), - }, - "autogenerate_usage_report": { + "autogenerate-usage-report": { "task": "autogenerate_usage_report_task", "schedule": timedelta(days=30), # TODO: change this to config flag }, diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index 0134f6642f7..34190255f5a 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -1,27 +1,13 @@ from sqlalchemy.orm import Session -from danswer.db.models import UserGroup from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task from danswer.utils.logger import setup_logger from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import name_user_group_sync_task logger = setup_logger() -def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool: - if user_group.is_up_to_date: - return False - task_name = name_user_group_sync_task(user_group.id) - latest_sync = get_latest_task(task_name, db_session) - - if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): - logger.info("TTL check is already being performed. Skipping.") - return False - return True - - def should_perform_chat_ttl_check( retention_limit_days: int | None, db_session: Session ) -> bool: diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index 00e7d4d5ebd..ab666f747b5 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -5,6 +5,7 @@ from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import func +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session @@ -81,10 +82,25 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non def fetch_user_groups( - db_session: Session, only_current: bool = True + db_session: Session, only_up_to_date: bool = True ) -> Sequence[UserGroup]: + """ + Fetches user groups from the database. + + This function retrieves a sequence of `UserGroup` objects from the database. + If `only_up_to_date` is set to `True`, it filters the user groups to return only those + that are marked as up-to-date (`is_up_to_date` is `True`). + + Args: + db_session (Session): The SQLAlchemy session used to query the database. + only_up_to_date (bool, optional): Flag to determine whether to filter the results + to include only up to date user groups. Defaults to `True`. + + Returns: + Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria. + """ stmt = select(UserGroup) - if only_current: + if only_up_to_date: stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712 return db_session.scalars(stmt).all() @@ -103,6 +119,42 @@ def fetch_user_groups_for_user( return db_session.scalars(stmt).all() +def construct_document_select_by_usergroup( + user_group_id: int, +) -> Select: + """This returns a statement that should be executed using + .yield_per() to minimize overhead. The primary consumers of this function + are background processing task generators.""" + stmt = ( + select(Document) + .join( + DocumentByConnectorCredentialPair, + Document.id == DocumentByConnectorCredentialPair.id, + ) + .join( + ConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .join( + UserGroup__ConnectorCredentialPair, + UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id, + ) + .join( + UserGroup, + UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id, + ) + .where(UserGroup.id == user_group_id) + .order_by(Document.id) + ) + stmt = stmt.distinct() + return stmt + + def fetch_documents_for_user_group_paginated( db_session: Session, user_group_id: int, @@ -361,6 +413,10 @@ def update_user_group( user_group_id: int, user_group_update: UserGroupUpdate, ) -> UserGroup: + """If successful, this can set db_user_group.is_up_to_date = False. + That will be processed by check_for_vespa_user_groups_sync_task and trigger + a long running background sync to Vespa. + """ stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) if db_user_group is None: diff --git a/backend/ee/danswer/server/user_group/api.py b/backend/ee/danswer/server/user_group/api.py index b33daddea64..355e59fff1d 100644 --- a/backend/ee/danswer/server/user_group/api.py +++ b/backend/ee/danswer/server/user_group/api.py @@ -32,7 +32,7 @@ def list_user_groups( db_session: Session = Depends(get_session), ) -> list[UserGroup]: if user is None or user.role == UserRole.ADMIN: - user_groups = fetch_user_groups(db_session, only_current=False) + user_groups = fetch_user_groups(db_session, only_up_to_date=False) else: user_groups = fetch_user_groups_for_user( db_session=db_session, diff --git a/backend/ee/danswer/user_groups/sync.py b/backend/ee/danswer/user_groups/sync.py deleted file mode 100644 index e3bea192670..00000000000 --- a/backend/ee/danswer/user_groups/sync.py +++ /dev/null @@ -1,87 +0,0 @@ -from sqlalchemy.orm import Session - -from danswer.access.access import get_access_for_documents -from danswer.db.document import prepare_to_modify_documents -from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings -from danswer.document_index.factory import get_default_document_index -from danswer.document_index.interfaces import DocumentIndex -from danswer.document_index.interfaces import UpdateRequest -from danswer.utils.logger import setup_logger -from ee.danswer.db.user_group import delete_user_group -from ee.danswer.db.user_group import fetch_documents_for_user_group_paginated -from ee.danswer.db.user_group import fetch_user_group -from ee.danswer.db.user_group import mark_user_group_as_synced - -logger = setup_logger() - -_SYNC_BATCH_SIZE = 100 - - -def _sync_user_group_batch( - document_ids: list[str], document_index: DocumentIndex, db_session: Session -) -> None: - logger.debug(f"Syncing document sets for: {document_ids}") - - # Acquires a lock on the documents so that no other process can modify them - with prepare_to_modify_documents(db_session=db_session, document_ids=document_ids): - # get current state of document sets for these documents - document_id_to_access = get_access_for_documents( - document_ids=document_ids, db_session=db_session - ) - - # update Vespa - document_index.update( - update_requests=[ - UpdateRequest( - document_ids=[document_id], - access=document_id_to_access[document_id], - ) - for document_id in document_ids - ] - ) - - # Finish the transaction and release the locks - db_session.commit() - - -def sync_user_groups(user_group_id: int, db_session: Session) -> None: - """Sync the status of Postgres for the specified user group""" - search_settings = get_current_search_settings(db_session) - secondary_search_settings = get_secondary_search_settings(db_session) - - document_index = get_default_document_index( - primary_index_name=search_settings.index_name, - secondary_index_name=secondary_search_settings.index_name - if secondary_search_settings - else None, - ) - - user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id) - if user_group is None: - raise ValueError(f"User group '{user_group_id}' does not exist") - - cursor = None - while True: - # NOTE: this may miss some documents, but that is okay. Any new documents added - # will be added with the correct group membership - document_batch, cursor = fetch_documents_for_user_group_paginated( - db_session=db_session, - user_group_id=user_group_id, - last_document_id=cursor, - limit=_SYNC_BATCH_SIZE, - ) - - _sync_user_group_batch( - document_ids=[document.id for document in document_batch], - document_index=document_index, - db_session=db_session, - ) - - if cursor is None: - break - - if user_group.is_up_for_deletion: - delete_user_group(db_session=db_session, user_group=user_group) - else: - mark_user_group_as_synced(db_session=db_session, user_group=user_group) diff --git a/backend/supervisord.conf b/backend/supervisord.conf index b56c763b94f..697866b6c0a 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -24,14 +24,21 @@ autorestart=true # relatively compute-light (e.g. they tend to just make a bunch of requests to # Vespa / Postgres) [program:celery_worker] -command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --concurrency=6 --loglevel=INFO --logfile=/var/log/celery_worker_supervisor.log +command=celery -A danswer.background.celery.celery_run:celery_app worker + --pool=threads + --concurrency=6 + --loglevel=INFO + --logfile=/var/log/celery_worker_supervisor.log + -Q celery,vespa_metadata_sync environment=LOG_FILE_NAME=celery_worker redirect_stderr=true autorestart=true # Job scheduler for periodic tasks [program:celery_beat] -command=celery -A danswer.background.celery.celery_run:celery_app beat --loglevel=INFO --logfile=/var/log/celery_beat_supervisor.log +command=celery -A danswer.background.celery.celery_run:celery_app beat + --loglevel=INFO + --logfile=/var/log/celery_beat_supervisor.log environment=LOG_FILE_NAME=celery_beat redirect_stderr=true autorestart=true diff --git a/web/src/app/admin/documents/explorer/Explorer.tsx b/web/src/app/admin/documents/explorer/Explorer.tsx index a773c222484..c1722b01edf 100644 --- a/web/src/app/admin/documents/explorer/Explorer.tsx +++ b/web/src/app/admin/documents/explorer/Explorer.tsx @@ -211,7 +211,7 @@ export function Explorer({ )} {!query && (
- Search for a document above to modify it's boost or hide it from + Search for a document above to modify its boost or hide it from searches.
)}