Skip to content

Commit

Permalink
Cloud improvements (#3099)
Browse files Browse the repository at this point in the history
* add improved cloud configuration

* fix typing

* finalize slackbot improvements

* minor update

* finalized keda

* moderate slackbot switch

* update some configs

* revert

* include reset engine!
  • Loading branch information
pablonyx authored Nov 13, 2024
1 parent d68f8d6 commit facf1d5
Show file tree
Hide file tree
Showing 20 changed files with 127 additions and 142 deletions.
10 changes: 10 additions & 0 deletions backend/danswer/background/celery/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT

logger = setup_logger(__name__)
Expand Down Expand Up @@ -72,6 +73,15 @@ def _update_tenant_tasks(self) -> None:
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")

for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue

if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")

Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/background/celery/apps/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown

Expand Down Expand Up @@ -81,6 +82,11 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)


@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()


@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
Expand Down
96 changes: 0 additions & 96 deletions backend/danswer/background/celery/apps/scheduler.py

This file was deleted.

6 changes: 3 additions & 3 deletions backend/danswer/background/celery/tasks/beat_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
Expand All @@ -20,13 +20,13 @@
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
Expand Down
14 changes: 11 additions & 3 deletions backend/danswer/background/indexing/job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,26 @@
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Ensure the parent proc's database connections are not touched
in the new connection pool
"""Initialize the child process with a fresh SQLAlchemy Engine.
Based on the recommended approach in the SQLAlchemy docs found:
Based on SQLAlchemy's recommendations to handle multiprocessing:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}

logger.info("Initializing spawned worker child process.")

# Reset the engine in the child process
SqlEngine.reset_engine()

# Optionally set a custom app name for database logging purposes
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)

# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)

# Proceed with executing the target function
return func(*args, **kwargs)


Expand Down
8 changes: 4 additions & 4 deletions backend/danswer/danswerbot/slack/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def validate_channel_names(
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
60 # How often pods send heartbeats to indicate they are still processing a tenant
15 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_ACQUISITION_INTERVAL = (
60 # How often pods attempt to acquire unprocessed tenants
TENANT_HEARTBEAT_EXPIRATION = (
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
)
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens

MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))
32 changes: 26 additions & 6 deletions backend/danswer/danswerbot/slack/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
Expand Down Expand Up @@ -164,9 +165,15 @@ def heartbeat_loop(self) -> None:

def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")

for tenant_id in tenant_ids:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
continue

if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
Expand All @@ -190,6 +197,9 @@ def acquire_tenants(self) -> None:
continue

logger.debug(f"Acquired lock for tenant {tenant_id}")
self.tenant_ids.add(tenant_id)

for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
Expand Down Expand Up @@ -236,14 +246,14 @@ def acquire_tenants(self) -> None:

self.slack_bot_tokens[tenant_id] = slack_bot_tokens

if tenant_id in self.socket_clients:
if self.socket_clients.get(tenant_id):
asyncio.run(self.socket_clients[tenant_id].close())

self.start_socket_client(tenant_id, slack_bot_tokens)

except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
if self.socket_clients.get(tenant_id):
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
Expand Down Expand Up @@ -277,14 +287,14 @@ def start_socket_client(
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")

def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
if client:
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")

def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
Expand All @@ -298,6 +308,16 @@ def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()

# Release locks for all tenants
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")

# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
Expand Down
7 changes: 7 additions & 0 deletions backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ def get_app_name(cls) -> str:
return ""
return cls._app_name

@classmethod
def reset_engine(cls) -> None:
with cls._lock:
if cls._engine:
cls._engine.dispose()
cls._engine = None


def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/redis/redis_connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def generate_tasks(
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)

for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
Expand Down
14 changes: 14 additions & 0 deletions backend/shared_configs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"

ALLOWED_SLACK_BOT_TENANT_IDS = os.environ.get("ALLOWED_SLACK_BOT_TENANT_IDS")
DISALLOWED_SLACK_BOT_TENANT_LIST = (
[tenant.strip() for tenant in ALLOWED_SLACK_BOT_TENANT_IDS.split(",")]
if ALLOWED_SLACK_BOT_TENANT_IDS
else None
)

IGNORED_SYNCING_TENANT_IDS = os.environ.get("IGNORED_SYNCING_TENANT_ID")
IGNORED_SYNCING_TENANT_LIST = (
[tenant.strip() for tenant in IGNORED_SYNCING_TENANT_IDS.split(",")]
if IGNORED_SYNCING_TENANT_IDS
else None
)

SUPPORTED_EMBEDDING_MODELS = [
# Cloud-based models
SupportedEmbeddingModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@ spec:
scaleTargetRef:
name: celery-worker-indexing
minReplicaCount: 1
maxReplicaCount: 10
maxReplicaCount: 30
triggers:
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:2
Expand All @@ -36,11 +35,19 @@ spec:
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: cpu
metadata:
type: Utilization
value: "70"

- type: memory
metadata:
type: Utilization
value: "70"
Loading

0 comments on commit facf1d5

Please sign in to comment.