Skip to content

Commit eaa8ae7

Browse files
authored
Bugfix/connector deletion lockout (#2901)
* first cut at deletion hardening * clean up logging * remove commented code
1 parent b9781c4 commit eaa8ae7

File tree

11 files changed

+281
-93
lines changed

11 files changed

+281
-93
lines changed

backend/danswer/background/celery/apps/primary.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from danswer.background.celery.celery_redis import RedisConnectorDeletion
1818
from danswer.background.celery.celery_redis import RedisConnectorIndexing
1919
from danswer.background.celery.celery_redis import RedisConnectorPruning
20+
from danswer.background.celery.celery_redis import RedisConnectorStop
2021
from danswer.background.celery.celery_redis import RedisDocumentSet
2122
from danswer.background.celery.celery_redis import RedisUserGroup
2223
from danswer.background.celery.celery_utils import celery_is_worker_primary
@@ -161,6 +162,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
161162
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
162163
r.delete(key)
163164

165+
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
166+
r.delete(key)
167+
164168

165169
# @worker_process_init.connect
166170
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:

backend/danswer/background/celery/celery_redis.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ def generate_tasks(
313313
lock: redis.lock.Lock,
314314
tenant_id: str | None,
315315
) -> int | None:
316+
"""Returns None if the cc_pair doesn't exist.
317+
Otherwise, returns an int with the number of generated tasks."""
316318
last_lock_time = time.monotonic()
317319

318320
async_results = []
@@ -540,6 +542,29 @@ def is_indexing(self, redis_client: Redis) -> bool:
540542
return False
541543

542544

545+
class RedisConnectorStop(RedisObjectHelper):
546+
"""Used to signal any running tasks for a connector to stop. We should refactor
547+
connector related redis helpers into a single class.
548+
"""
549+
550+
PREFIX = "connectorstop"
551+
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
552+
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
553+
554+
def __init__(self, id: int) -> None:
555+
super().__init__(str(id))
556+
557+
def generate_tasks(
558+
self,
559+
celery_app: Celery,
560+
db_session: Session,
561+
redis_client: Redis,
562+
lock: redis.lock.Lock | None,
563+
tenant_id: str | None,
564+
) -> int | None:
565+
return None
566+
567+
543568
def celery_get_queue_length(queue: str, r: Redis) -> int:
544569
"""This is a redis specific way to get the length of a celery queue.
545570
It is priority aware and knows how to count across the multiple redis lists

backend/danswer/background/celery/celery_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from collections.abc import Callable
21
from datetime import datetime
32
from datetime import timezone
43
from typing import Any
54

65
from sqlalchemy.orm import Session
76

87
from danswer.background.celery.celery_redis import RedisConnectorDeletion
8+
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
99
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
1010
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
1111
rate_limit_builder,
@@ -79,7 +79,7 @@ def document_batch_to_ids(
7979

8080
def extract_ids_from_runnable_connector(
8181
runnable_connector: BaseConnector,
82-
progress_callback: Callable[[int], None] | None = None,
82+
callback: RunIndexingCallbackInterface | None = None,
8383
) -> set[str]:
8484
"""
8585
If the PruneConnector hasnt been implemented for the given connector, just pull
@@ -110,8 +110,10 @@ def extract_ids_from_runnable_connector(
110110
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
111111
)(document_batch_to_ids)
112112
for doc_batch in doc_batch_generator:
113-
if progress_callback:
114-
progress_callback(len(doc_batch))
113+
if callback:
114+
if callback.should_stop():
115+
raise RuntimeError("Stop signal received")
116+
callback.progress(len(doc_batch))
115117
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
116118

117119
return all_connector_doc_ids

backend/danswer/background/celery/tasks/connector_deletion/tasks.py

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from datetime import datetime
2+
from datetime import timezone
3+
14
import redis
25
from celery import Celery
36
from celery import shared_task
@@ -8,16 +11,28 @@
811

912
from danswer.background.celery.apps.app_base import task_logger
1013
from danswer.background.celery.celery_redis import RedisConnectorDeletion
14+
from danswer.background.celery.celery_redis import RedisConnectorIndexing
15+
from danswer.background.celery.celery_redis import RedisConnectorPruning
16+
from danswer.background.celery.celery_redis import RedisConnectorStop
17+
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
18+
RedisConnectorDeletionFenceData,
19+
)
1120
from danswer.configs.app_configs import JOB_TIMEOUT
1221
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
1322
from danswer.configs.constants import DanswerRedisLocks
1423
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
1524
from danswer.db.connector_credential_pair import get_connector_credential_pairs
1625
from danswer.db.engine import get_session_with_tenant
1726
from danswer.db.enums import ConnectorCredentialPairStatus
27+
from danswer.db.search_settings import get_all_search_settings
1828
from danswer.redis.redis_pool import get_redis_client
1929

2030

31+
class TaskDependencyError(RuntimeError):
32+
"""Raised to the caller to indicate dependent tasks are running that would interfere
33+
with connector deletion."""
34+
35+
2136
@shared_task(
2237
name="check_for_connector_deletion_task",
2338
soft_time_limit=JOB_TIMEOUT,
@@ -37,17 +52,30 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
3752
if not lock_beat.acquire(blocking=False):
3853
return
3954

55+
# collect cc_pair_ids
4056
cc_pair_ids: list[int] = []
4157
with get_session_with_tenant(tenant_id) as db_session:
4258
cc_pairs = get_connector_credential_pairs(db_session)
4359
for cc_pair in cc_pairs:
4460
cc_pair_ids.append(cc_pair.id)
4561

62+
# try running cleanup on the cc_pair_ids
4663
for cc_pair_id in cc_pair_ids:
4764
with get_session_with_tenant(tenant_id) as db_session:
48-
try_generate_document_cc_pair_cleanup_tasks(
49-
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
50-
)
65+
rcs = RedisConnectorStop(cc_pair_id)
66+
try:
67+
try_generate_document_cc_pair_cleanup_tasks(
68+
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
69+
)
70+
except TaskDependencyError as e:
71+
# this means we wanted to start deleting but dependent tasks were running
72+
# Leave a stop signal to clear indexing and pruning tasks more quickly
73+
task_logger.info(str(e))
74+
r.set(rcs.fence_key, cc_pair_id)
75+
else:
76+
# clear the stop signal if it exists ... no longer needed
77+
r.delete(rcs.fence_key)
78+
5179
except SoftTimeLimitExceeded:
5280
task_logger.info(
5381
"Soft time limit exceeded, task is being terminated gracefully."
@@ -70,6 +98,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
7098
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
7199
Note that syncing can still be required even if the number of sync tasks generated is zero.
72100
Returns None if no syncing is required.
101+
102+
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
103+
still running. In our case, the caller reacts by setting a stop signal in Redis to
104+
exit those tasks as quickly as possible.
73105
"""
74106

75107
lock_beat.reacquire()
@@ -90,28 +122,63 @@ def try_generate_document_cc_pair_cleanup_tasks(
90122
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
91123
return None
92124

93-
# add tasks to celery and build up the task set to monitor in redis
94-
r.delete(rcd.taskset_key)
95-
96-
# Add all documents that need to be updated into the queue
97-
task_logger.info(
98-
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
125+
# set a basic fence to start
126+
fence_value = RedisConnectorDeletionFenceData(
127+
num_tasks=None,
128+
submitted=datetime.now(timezone.utc),
99129
)
100-
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
101-
if tasks_generated is None:
130+
r.set(rcd.fence_key, fence_value.model_dump_json())
131+
132+
try:
133+
# do not proceed if connector indexing or connector pruning are running
134+
search_settings_list = get_all_search_settings(db_session)
135+
for search_settings in search_settings_list:
136+
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
137+
if r.get(rci.fence_key):
138+
raise TaskDependencyError(
139+
f"Connector deletion - Delayed (indexing in progress): "
140+
f"cc_pair={cc_pair_id} "
141+
f"search_settings={search_settings.id}"
142+
)
143+
144+
rcp = RedisConnectorPruning(cc_pair_id)
145+
if r.get(rcp.fence_key):
146+
raise TaskDependencyError(
147+
f"Connector deletion - Delayed (pruning in progress): "
148+
f"cc_pair={cc_pair_id}"
149+
)
150+
151+
# add tasks to celery and build up the task set to monitor in redis
152+
r.delete(rcd.taskset_key)
153+
154+
# Add all documents that need to be updated into the queue
155+
task_logger.info(
156+
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
157+
)
158+
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
159+
if tasks_generated is None:
160+
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
161+
except TaskDependencyError:
162+
r.delete(rcd.fence_key)
163+
raise
164+
except Exception:
165+
task_logger.exception("Unexpected exception")
166+
r.delete(rcd.fence_key)
102167
return None
168+
else:
169+
# Currently we are allowing the sync to proceed with 0 tasks.
170+
# It's possible for sets/groups to be generated initially with no entries
171+
# and they still need to be marked as up to date.
172+
# if tasks_generated == 0:
173+
# return 0
103174

104-
# Currently we are allowing the sync to proceed with 0 tasks.
105-
# It's possible for sets/groups to be generated initially with no entries
106-
# and they still need to be marked as up to date.
107-
# if tasks_generated == 0:
108-
# return 0
175+
task_logger.info(
176+
f"RedisConnectorDeletion.generate_tasks finished. "
177+
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
178+
)
109179

110-
task_logger.info(
111-
f"RedisConnectorDeletion.generate_tasks finished. "
112-
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
113-
)
180+
# set this only after all tasks have been added
181+
fence_value.num_tasks = tasks_generated
182+
r.set(rcd.fence_key, fence_value.model_dump_json())
114183

115-
# set this only after all tasks have been added
116-
r.set(rcd.fence_key, tasks_generated)
117184
return tasks_generated

0 commit comments

Comments
 (0)