Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 26 additions & 46 deletions backend/danswer/background/celery/apps/app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN


Expand Down Expand Up @@ -173,52 +173,38 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:


def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")

# Exit early if multi-tenant since primary worker check not needed
if MULTI_TENANT:
return

# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60

logger.info("Running as a secondary celery worker.")
logger.info("Waiting for all tenant primary workers to be ready...")
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()

logger.info("Waiting for primary worker to be ready...")
while True:
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)

if all_tenants_ready:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break

time_elapsed = time.monotonic() - time_start
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)

logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)

if time_elapsed > WAIT_LIMIT:
msg = (
f"Not all tenant primary workers were ready within the timeout "
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)

time.sleep(WAIT_INTERVAL)

logger.info("All tenant primary workers are ready. Continuing...")
logger.info("Wait for primary worker completed successfully. Continuing...")
return


Expand All @@ -230,26 +216,20 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return

if not hasattr(sender, "primary_worker_locks"):
if not sender.primary_worker_lock:
return

for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")


def on_setup_logging(
Expand Down
3 changes: 2 additions & 1 deletion backend/danswer/background/celery/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation

logger = setup_logger()
# Import the custom scheduler

logger = setup_logger(__name__)

celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
Expand Down
199 changes: 91 additions & 108 deletions backend/danswer/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT


logger = setup_logger()
Expand Down Expand Up @@ -79,91 +79,90 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:

logger.info("Running as the primary celery worker.")

sender.primary_worker_locks = {}
if MULTI_TENANT:
return

# This is singleton work that should be done on startup exactly once
# by the primary worker
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)

# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)

# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)

# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)

# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)

logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")

# tacking on our own user data to the sender
sender.primary_worker_locks[tenant_id] = lock
# tacking on our own user data to the sender
sender.primary_worker_lock = lock

# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)

r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())

for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)


@worker_ready.connect
Expand Down Expand Up @@ -216,52 +215,36 @@ def run_periodic_task(self, worker: Any) -> None:
if not celery_is_worker_primary(worker):
return

if not hasattr(worker, "primary_worker_locks"):
if not hasattr(worker, "primary_worker_lock"):
return

# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()

for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant

r = get_redis_client(tenant_id=tenant_id)

if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
lock = worker.primary_worker_lock

r = get_redis_client(tenant_id=None)

if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)

task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
worker.primary_worker_lock = lock
else:
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)

task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")

except Exception:
task_logger.exception("Periodic task failed.")
Expand Down