Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions backend/model_server/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,51 @@ def get_embedding_model(
model_name: str,
max_context_length: int,
) -> "SentenceTransformer":
"""
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
"""
from sentence_transformers import SentenceTransformer # type: ignore

global _GLOBAL_MODELS_DICT # A dictionary to store models
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
"""
Build RoPE cos/sin caches once on the final device/dtype so later forwards only read.
Works by calling the underlying HF model directly with dummy IDs/attention.
"""
try:
# ensure > max seq after tokenization
# Ideally we would use the saved tokenizer, but whatever it's ok
# we'll make an assumption about tokenization here
long_text = "x " * (target_len * 2)
_ = st_model.encode(
[long_text],
batch_size=1,
convert_to_tensor=True,
show_progress_bar=False,
normalize_embeddings=False,
)
logger.info("RoPE pre-warm successful")
except Exception as e:
logger.warning(f"RoPE pre-warm skipped/failed: {e}")

global _GLOBAL_MODELS_DICT

if model_name not in _GLOBAL_MODELS_DICT:
logger.notice(f"Loading {model_name}")
# Some model architectures that aren't built into the Transformers or Sentence
# Transformer need to be downloaded to be loaded locally. This does not mean
# data is sent to remote servers for inference, however the remote code can
# be fairly arbitrary so only use trusted models
model = SentenceTransformer(
model_name_or_path=model_name,
trust_remote_code=True,
)
model.max_seq_length = max_context_length
_prewarm_rope(model, max_context_length)
_GLOBAL_MODELS_DICT[model_name] = model
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
_GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length
else:
model = _GLOBAL_MODELS_DICT[model_name]
if max_context_length != model.max_seq_length:
model.max_seq_length = max_context_length
prev = getattr(model, "_rope_prewarmed_to", 0)
if max_context_length > int(prev or 0):
_prewarm_rope(model, max_context_length)

return _GLOBAL_MODELS_DICT[model_name]

Expand Down
7 changes: 4 additions & 3 deletions backend/onyx/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_POOL_OVERFLOW
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
Expand Down Expand Up @@ -83,11 +84,11 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")

EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits

SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
SqlEngine.init_engine(
pool_size=pool_size, max_overflow=CELERY_WORKER_PRIMARY_POOL_OVERFLOW
)

app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/background/celery/configs/primary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_CONCURRENCY

broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
Expand All @@ -15,6 +16,6 @@
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late

worker_concurrency = 4
worker_concurrency = CELERY_WORKER_PRIMARY_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1
8 changes: 8 additions & 0 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)

CELERY_WORKER_PRIMARY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4
)

CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)

# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 8192

Expand Down
Loading