Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kill indexing #3213

Merged
merged 16 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)

lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
Expand All @@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
Expand All @@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock

from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
Expand All @@ -27,7 +28,7 @@
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
Expand Down Expand Up @@ -138,7 +139,7 @@ def try_creating_permissions_sync_task(

LOCK_TIMEOUT = 30

lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
Expand All @@ -162,7 +163,7 @@ def try_creating_permissions_sync_task(

custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"

app.send_task(
result = app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
Expand All @@ -174,8 +175,8 @@ def try_creating_permissions_sync_task(
)

# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)

redis_connector.permissions.set_fence(payload)
Expand Down Expand Up @@ -247,9 +248,11 @@ def connector_permission_sync_generator_task(

logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")

payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")

payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)

document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock

from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
Expand All @@ -24,6 +25,9 @@
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
Expand Down Expand Up @@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)

for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
Expand All @@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()


def try_creating_permissions_sync_task(
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
Expand Down Expand Up @@ -156,7 +160,7 @@ def try_creating_permissions_sync_task(

custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"

_ = app.send_task(
result = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
Expand All @@ -166,8 +170,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)

payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)

redis_connector.external_group_sync.set_fence(payload)

except Exception:
task_logger.exception(
Expand Down Expand Up @@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task(

r = get_redis_client(tenant_id=tenant_id)

lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
Expand Down Expand Up @@ -253,7 +262,6 @@ def connector_external_group_sync_generator_task(
)

mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)

except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
Expand All @@ -264,6 +272,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()
59 changes: 41 additions & 18 deletions backend/danswer/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
Expand Down Expand Up @@ -209,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:

redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]

# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)

for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
Expand All @@ -237,21 +231,21 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)

search_settings_primary = False
if search_settings_instance.id == primary_search_settings.id:
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True

if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings) > 1,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue

reindex = False
if search_settings_instance.id == primary_search_settings.id:
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
Expand Down Expand Up @@ -284,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1

Expand Down Expand Up @@ -529,8 +523,11 @@ def try_creating_indexing_task(
return index_attempt_id


@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
Expand All @@ -543,6 +540,10 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

if not self.request.id:
task_logger.error("self.request.id is None!")

client = SimpleJobClient()

job = client.submit(
Expand Down Expand Up @@ -571,8 +572,30 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)

redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)

while True:
sleep(10)
sleep(5)

if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)

job.cancel()
break

# do nothing for ongoing jobs that haven't been stopped
if not job.done():
Expand Down
26 changes: 15 additions & 11 deletions backend/danswer/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
Expand All @@ -58,7 +59,7 @@
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
Expand Down Expand Up @@ -588,17 +589,15 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return

payload: RedisConnectorPermissionSyncData | None = (
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None

mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")

redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
redis_connector.permissions.reset()


def monitor_ccpair_indexing_taskset(
Expand Down Expand Up @@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset(

index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)

redis_connector_index.reset()
return
Expand All @@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
Expand Down Expand Up @@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:

# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
Expand Down
Loading
Loading