Skip to content
29 changes: 27 additions & 2 deletions backend/onyx/background/celery/apps/app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,34 @@ def on_task_postrun(
return


def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn

# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
# But something is blocking set_start_method from working in the cloud unless
# force=True. so we use force=True as a fallback.

all_start_methods: list[str] = multiprocessing.get_all_start_methods()
logger.info(f"Multiprocessing all start methods: {all_start_methods}")

try:
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method exceptioned. Trying force=True..."
)
try:
multiprocessing.set_start_method(
"spawn", force=True
) # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method force=True exceptioned even with force=True."
)

logger.info(
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
)


def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Expand Down
9 changes: 4 additions & 5 deletions backend/onyx/background/celery/apps/heavy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any

from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
Expand Down Expand Up @@ -49,17 +49,16 @@ def on_task_postrun(


@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)


@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")

SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore

app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
Expand Down
17 changes: 8 additions & 9 deletions backend/onyx/background/celery/apps/indexing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any

from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
Expand Down Expand Up @@ -50,22 +50,21 @@ def on_task_postrun(


@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)


@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")

SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)

# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore

app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
Expand Down
12 changes: 6 additions & 6 deletions backend/onyx/background/celery/apps/light.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any

from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
Expand All @@ -15,7 +15,6 @@
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT


logger = setup_logger()

celery_app = Celery(__name__)
Expand Down Expand Up @@ -49,17 +48,18 @@ def on_task_postrun(


@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)


@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")

logger.info(f"Concurrency: {sender.concurrency}") # type: ignore

SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore

app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
Expand Down
9 changes: 4 additions & 5 deletions backend/onyx/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import multiprocessing
from typing import Any
from typing import cast

from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.signals import celeryd_init
from celery.signals import worker_init
Expand Down Expand Up @@ -73,14 +73,13 @@ def on_task_postrun(


@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)


@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")

SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
Expand Down Expand Up @@ -135,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
raise WorkerShutdown("Primary worker lock could not be acquired!")

# tacking on our own user data to the sender
sender.primary_worker_lock = lock
sender.primary_worker_lock = lock # type: ignore

# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
Expand Down
8 changes: 6 additions & 2 deletions backend/onyx/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import os
import sys
import time
Expand Down Expand Up @@ -853,11 +854,14 @@ def connector_indexing_proxy_task(
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
"""celery tasks are forked, but forking is unstable.
This is a thread that proxies work to a spawned task."""

task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
f"search_settings={search_settings_id} "
f"mp_start_method={multiprocessing.get_start_method()}"
)

if not self.request.id:
Expand Down
10 changes: 7 additions & 3 deletions backend/onyx/background/indexing/job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing import Process
from multiprocessing.context import SpawnProcess
from typing import Any
from typing import Literal
from typing import Optional
Expand Down Expand Up @@ -63,7 +64,7 @@ class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""

id: int
process: Optional["Process"] = None
process: Optional["SpawnProcess"] = None

def cancel(self) -> bool:
return self.release()
Expand Down Expand Up @@ -131,7 +132,10 @@ def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | N
job_id = self.job_id_counter
self.job_id_counter += 1

process = Process(target=_run_in_process, args=(func, args), daemon=True)
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

Expand Down
Loading