diff --git a/.envs/.local/.django b/.envs/.local/.django index 29780e680..8eb5610f7 100644 --- a/.envs/.local/.django +++ b/.envs/.local/.django @@ -12,6 +12,9 @@ DJANGO_SUPERUSER_PASSWORD=localadmin # Redis REDIS_URL=redis://redis:6379/0 +# NATS +NATS_URL=nats://nats:4222 + # Celery / Flower CELERY_FLOWER_USER=QSocnxapfMvzLqJXSsXtnEZqRkBtsmKT CELERY_FLOWER_PASSWORD=BEQgmCtgyrFieKNoGTsux9YIye0I7P5Q7vEgfJD2C4jxmtHDetFaE2jhS7K7rxaf diff --git a/README.md b/README.md index acd0d0975..63a217297 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ docker compose -f processing_services/example/docker-compose.yml up -d - Django admin: http://localhost:8000/admin/ - OpenAPI / Swagger documentation: http://localhost:8000/api/v2/docs/ - Minio UI: http://minio:9001, Minio service: http://minio:9000 +- NATS dashboard: https://natsdashboard.com/ (Add localhost) NOTE: If one of these services is not working properly, it could be due another process is using the port. You can check for this with `lsof -i :`. diff --git a/ami/jobs/models.py b/ami/jobs/models.py index b94baa9a2..482d01a58 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -322,15 +322,13 @@ def run(cls, job: "Job"): """ Procedure for an ML pipeline as a job. """ + from ami.ml.orchestration.jobs import queue_images_to_nats + job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None job.save() - # Keep track of sub-tasks for saving results, pair with batch number - save_tasks: list[tuple[int, AsyncResult]] = [] - save_tasks_completed: list[tuple[int, AsyncResult]] = [] - if job.delay: update_interval_seconds = 2 last_update = time.time() @@ -365,7 +363,7 @@ def run(cls, job: "Job"): progress=0, ) - images = list( + images: list[SourceImage] = list( # @TODO return generator plus image count # @TODO pass to celery group chain? job.pipeline.collect_images( @@ -389,8 +387,6 @@ def run(cls, job: "Job"): images = images[: job.limit] image_count = len(images) job.progress.add_stage_param("collect", "Limit", image_count) - else: - image_count = source_image_count job.progress.update_stage( "collect", @@ -401,6 +397,24 @@ def run(cls, job: "Job"): # End image collection stage job.save() + if job.project.feature_flags.async_pipeline_workers: + queued = queue_images_to_nats(job, images) + if not queued: + job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) + job.progress.update_stage("collect", status=JobState.FAILURE) + job.update_status(JobState.FAILURE) + job.finished_at = datetime.datetime.now() + job.save() + return + else: + cls.process_images(job, images) + + @classmethod + def process_images(cls, job, images): + image_count = len(images) + # Keep track of sub-tasks for saving results, pair with batch number + save_tasks: list[tuple[int, AsyncResult]] = [] + save_tasks_completed: list[tuple[int, AsyncResult]] = [] total_captures = 0 total_detections = 0 total_classifications = 0 diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index b12271178..bac6b1236 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,8 +1,17 @@ +import functools import logging +import time +from collections.abc import Callable +from datetime import datetime +from asgiref.sync import async_to_sync from celery.result import AsyncResult from celery.signals import task_failure, task_postrun, task_prerun +from django.db import transaction +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager +from ami.ml.schemas import PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app @@ -30,6 +39,130 @@ def run_job(self, job_id: int) -> None: job.logger.info(f"Finished job {job}") +@celery_app.task( + bind=True, + max_retries=3, + default_retry_delay=60, + autoretry_for=(Exception,), + soft_time_limit=300, # 5 minutes + time_limit=360, # 6 minutes +) +def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> None: + """ + Process a single pipeline result asynchronously. + + This task: + 1. Deserializes the pipeline result + 2. Saves it to the database + 3. Updates progress by removing processed image IDs from Redis + 4. Acknowledges the task via NATS + + Args: + job_id: The job ID + result_data: Dictionary containing the pipeline result + reply_subject: NATS reply subject for acknowledgment + """ + from ami.jobs.models import Job # avoid circular import + + _, t = log_time() + error = result_data.get("error") + pipeline_result = None + if not error: + pipeline_result = PipelineResultsResponse(**result_data) + processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + else: + image_id = result_data.get("image_id") + processed_image_ids = {str(image_id)} if image_id else set() + logger.error(f"Pipeline returned error for job {job_id}, image {image_id}: {error}") + + state_manager = TaskStateManager(job_id) + + progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) + + try: + _update_job_progress(job_id, "process", progress_info.percentage) + + _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") + job = Job.objects.get(pk=job_id) + job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") + job.logger.info( + f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " + f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " + "processed" + ) + + # Save to database (this is the slow operation) + if pipeline_result: + # should never happen since otherwise we could not be processing results here + assert job.pipeline is not None, "Job pipeline is None" + job.pipeline.save_results(results=pipeline_result, job_id=job.pk) + job.logger.info(f"Successfully saved results for job {job_id}") + + _, t = t( + f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" + f", percentage: {progress_info.percentage*100}%" + ) + # Acknowledge the task via NATS + try: + + async def ack_task(): + async with TaskQueueManager() as manager: + return await manager.acknowledge_task(reply_subject) + + ack_success = async_to_sync(ack_task)() + + if ack_success: + job.logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") + else: + job.logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}") + except Exception as ack_error: + job.logger.error(f"Error acknowledging task via NATS: {ack_error}") + # Don't fail the task if ACK fails - data is already saved + + # Update job stage with calculated progress + progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) + + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) + _update_job_progress(job_id, "results", progress_info.percentage) + + except Job.DoesNotExist: + logger.error(f"Job {job_id} not found") + raise + except Exception as e: + logger.error(f"Failed to process pipeline result for job {job_id}: {e}") + # Celery will automatically retry based on autoretry_for + raise + + +def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: + from ami.jobs.models import Job, JobState # avoid circular import + + with transaction.atomic(): + job = Job.objects.select_for_update().get(pk=job_id) + job.progress.update_stage( + stage, + status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + progress=progress_percentage, + ) + if stage == "results" and progress_percentage >= 1.0: + job.status = JobState.SUCCESS + job.progress.summary.status = JobState.SUCCESS + job.finished_at = datetime.now() + job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") + job.save() + + @task_postrun.connect(sender=run_job) @task_prerun.connect(sender=run_job) def update_job_status(sender, task_id, task, *args, **kwargs): @@ -63,3 +196,28 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}') job.save() + + +def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: + """ + Small helper to measure time between calls. + + Returns: elapsed time since the last call, and a partial function to measure from the current call + Usage: + + _, tlog = log_time() + # do something + _, tlog = tlog("Did something") # will log the time taken by 'something' + # do something else + t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't' + """ + + end = time.perf_counter() + if start == 0: + dur = 0.0 + else: + dur = end - start + if msg and start > 0: + logger.info(f"{msg}: {dur:.3f}s") + new_start = time.perf_counter() + return dur, functools.partial(log_time, new_start) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 5fffdb6fd..8f9d10b67 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,18 +1,21 @@ import logging +from asgiref.sync import async_to_sync from django.db.models.query import QuerySet from django.forms import IntegerField from django.utils import timezone +from django_filters import rest_framework as filters from drf_spectacular.utils import extend_schema from rest_framework.decorators import action -from rest_framework.exceptions import PermissionDenied +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.response import Response from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin +from ami.jobs.tasks import process_pipeline_result from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param -from ami.utils.requests import project_id_doc_param +from ami.utils.requests import batch_param, ids_only_param, incomplete_only_param, project_id_doc_param from .models import Job, JobState from .serializers import JobListSerializer, JobSerializer @@ -20,6 +23,26 @@ logger = logging.getLogger(__name__) +class JobFilterSet(filters.FilterSet): + """Custom filterset to enable pipeline name filtering.""" + + pipeline__slug = filters.CharFilter(field_name="pipeline__slug", lookup_expr="exact") + + class Meta: + model = Job + fields = [ + "status", + "project", + "deployment", + "source_image_collection", + "source_image_single", + "pipeline", + "pipeline__name", + "pipeline__slug", + "job_type_key", + ] + + class JobViewSet(DefaultViewSet, ProjectMixin): """ API endpoint that allows jobs to be viewed or edited. @@ -46,15 +69,8 @@ class JobViewSet(DefaultViewSet, ProjectMixin): "source_image_single", ) serializer_class = JobSerializer - filterset_fields = [ - "status", - "project", - "deployment", - "source_image_collection", - "source_image_single", - "pipeline", - "job_type_key", - ] + filterset_class = JobFilterSet + search_fields = ["name", "pipeline__name"] ordering_fields = [ "name", "created_at", @@ -153,6 +169,171 @@ def get_queryset(self) -> QuerySet: updated_at__lt=cutoff_datetime, ) - @extend_schema(parameters=[project_id_doc_param]) + @extend_schema( + parameters=[ + project_id_doc_param, + ids_only_param, + incomplete_only_param, + ] + ) def list(self, request, *args, **kwargs): + # Check if ids_only parameter is set + ids_only = request.query_params.get("ids_only", "false").lower() in ["true", "1", "yes"] + + # Check if incomplete_only parameter is set + incomplete_only = request.query_params.get("incomplete_only", "false").lower() in ["true", "1", "yes"] + + # Get the base queryset + queryset = self.filter_queryset(self.get_queryset()) + + # Filter to incomplete jobs if requested (checks "results" stage status) + if incomplete_only: + from django.db.models import Q + + # Create filters for each final state to exclude + final_states = JobState.final_states() + exclude_conditions = Q() + + # Exclude jobs where the "results" stage has a final state status + for state in final_states: + # JSON path query to check if results stage status is in final states + exclude_conditions |= Q(progress__stages__contains=[{"key": "results", "status": state}]) + + queryset = queryset.exclude(exclude_conditions) + + if ids_only: + # Return only IDs + job_ids = list(queryset.values_list("id", flat=True)) + return Response({"job_ids": job_ids, "count": len(job_ids)}) + + # Override the queryset for the list view + self.queryset = queryset return super().list(request, *args, **kwargs) + + @extend_schema( + parameters=[batch_param], + responses={200: dict}, + ) + @action(detail=True, methods=["get"], name="tasks") + def tasks(self, request, pk=None): + """ + Get tasks from the job queue. + + Returns task data with reply_subject for acknowledgment. External workers should: + 1. Call this endpoint to get tasks + 2. Process the tasks + 3. POST to /jobs/{id}/result/ with the reply_subject to acknowledge + + This stateless approach allows workers to communicate over HTTP without + maintaining persistent connections to the queue system. + """ + job: Job = self.get_object() + batch = IntegerField(required=False, min_value=1).clean(request.query_params.get("batch", 1)) + job_id = f"job{job.pk}" + + # Validate that the job has a pipeline + if not job.pipeline: + raise ValidationError("This job does not have a pipeline configured") + + # Get tasks from NATS JetStream + from ami.ml.orchestration.nats_queue import TaskQueueManager + + async def get_tasks(): + tasks = [] + async with TaskQueueManager() as manager: + for i in range(batch): + task = await manager.reserve_task(job_id, timeout=0.1) + if task: + tasks.append(task) + return tasks + + # Use async_to_sync to properly handle the async call + tasks = async_to_sync(get_tasks)() + + return Response({"tasks": tasks}) + + @action(detail=True, methods=["post"], name="result") + def result(self, request, pk=None): + """ + Submit pipeline results for asynchronous processing. + + This endpoint accepts a list of pipeline results and queues them for + background processing. Each result will be validated, saved to the database, + and acknowledged via NATS in a Celery task. + + The request body should be a list of results: + [ + { + "reply_subject": "string", # Required: from the task response + "result": { # Required: PipelineResultsResponse (kept as JSON) + "pipeline": "string", + "algorithms": {}, + "total_time": 0.0, + "source_images": [...], + "detections": [...], + "errors": null + } + }, + ... + ] + """ + + job_id = pk if pk else self.kwargs.get("pk") + if not job_id: + raise ValidationError("Job ID is required") + job_id = int(job_id) + + # Validate request data is a list + if not isinstance(request.data, list): + raise ValidationError("Request body must be a list of results") + + # Queue each result for background processing + queued_tasks = [] + + for idx, item in enumerate(request.data): + reply_subject = item.get("reply_subject") + result_data = item.get("result") + + if not reply_subject: + raise ValidationError(f"Item {idx}: reply_subject is required") + + if not result_data: + raise ValidationError(f"Item {idx}: result is required") + + try: + # Queue the background task + task = process_pipeline_result.delay( + job_id=job_id, result_data=result_data, reply_subject=reply_subject + ) + + queued_tasks.append( + { + "reply_subject": reply_subject, + "status": "queued", + "task_id": task.id, + } + ) + + logger.info( + f"Queued pipeline result processing for job {job_id}, " + f"task_id: {task.id}, reply_subject: {reply_subject}" + ) + + except Exception as e: + logger.error(f"Failed to queue result {idx} for job {job_id}: {e}") + queued_tasks.append( + { + "reply_subject": reply_subject, + "status": "error", + "error": str(e), + } + ) + + return Response( + { + "status": "accepted", + "job_id": job_id, + "results_queued": len([t for t in queued_tasks if t["status"] == "queued"]), + "tasks": queued_tasks, + } + ) diff --git a/ami/main/models.py b/ami/main/models.py index 4eecd74f0..fb467e1d5 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -218,6 +218,7 @@ class ProjectFeatureFlags(pydantic.BaseModel): default_filters: bool = False # Whether to show default filters form in UI # Feature flag for jobs to reprocess all images in the project, even if already processed reprocess_all_images: bool = False + async_pipeline_workers: bool = False # Whether to use async pipeline workers that pull tasks from a queue def get_default_feature_flags() -> ProjectFeatureFlags: diff --git a/ami/ml/orchestration/__init__.py b/ami/ml/orchestration/__init__.py index d05bbbd82..75c2ec3b5 100644 --- a/ami/ml/orchestration/__init__.py +++ b/ami/ml/orchestration/__init__.py @@ -1 +1,5 @@ -from .processing import * # noqa: F401, F403 +# cgjs: This creates a circular import: +# - ami.jobs.models imports ami.jobs.tasks.run_job +# - ami.jobs.tasks imports ami.ml.orchestration +# -.processing imports ami.jobs.models +# from .processing import * # noqa: F401, F403 diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py new file mode 100644 index 000000000..240000ca9 --- /dev/null +++ b/ami/ml/orchestration/jobs.py @@ -0,0 +1,101 @@ +from asgiref.sync import async_to_sync +from django.utils import timezone + +from ami.jobs.models import Job, JobState, logger +from ami.main.models import SourceImage +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager + + +# TODO CGJS: Call this once a job is fully complete (all images processed and saved) +def cleanup_nats_resources(job: "Job") -> bool: + """ + Clean up NATS JetStream resources (stream and consumer) for a completed job. + + Args: + job: The Job instance + """ + job_id = f"job{job.pk}" + + async def cleanup(): + async with TaskQueueManager() as manager: + success = await manager.cleanup_job_resources(job_id) + return success + + return async_to_sync(cleanup)() + + +def queue_images_to_nats(job: "Job", images: list[SourceImage]): + """ + Queue all images for a job to a NATS JetStream stream for the job. + + Args: + job: The Job instance + images: List of SourceImage instances to queue + + Returns: + bool: True if all images were successfully queued, False otherwise + """ + job_id = f"job{job.pk}" + job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job_id}'") + + # Prepare all messages outside of async context to avoid Django ORM issues + messages = [] + image_ids = [] + for i, image in enumerate(images): + image_id = str(image.pk) + image_ids.append(image_id) + message = { + "job_id": job.pk, + "image_id": image_id, + "image_url": image.url() if hasattr(image, "url") else None, + "timestamp": (image.timestamp.isoformat() if hasattr(image, "timestamp") and image.timestamp else None), + "batch_index": i, + "total_images": len(images), + "queue_timestamp": timezone.now().isoformat(), + } + messages.append((image.pk, message)) + + # Store all image IDs in Redis for progress tracking + state_manager = TaskStateManager(job.pk) + state_manager.initialize_job(image_ids) + job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") + + async def queue_all_images(): + successful_queues = 0 + failed_queues = 0 + + async with TaskQueueManager() as manager: + for image_pk, message in messages: + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job_id}': {message}") + success = await manager.publish_task( + job_id=job_id, + data=message, + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job_id}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 + + return successful_queues, failed_queues + + successful_queues, failed_queues = async_to_sync(queue_all_images)() + + if not images: + job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) + job.save() + + # Log results (back in sync context) + if successful_queues > 0: + job.logger.info(f"Successfully queued {successful_queues}/{len(images)} images to stream for job '{job_id}'") + + if failed_queues > 0: + job.logger.warning(f"Failed to queue {failed_queues}/{len(images)} images to stream for job '{job_id}'") + return False + + return True diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py new file mode 100644 index 000000000..b4f7c13c9 --- /dev/null +++ b/ami/ml/orchestration/nats_queue.py @@ -0,0 +1,296 @@ +""" +NATS JetStream utility for task queue management in the antenna project. + +This module provides a TaskQueueManager that uses NATS JetStream for distributed +task queuing with acknowledgment support via reply subjects. This allows workers +to pull tasks over HTTP and acknowledge them later without maintaining a persistent +connection to NATS. +""" + +import json +import logging +from typing import Any + +import nats +from django.conf import settings +from nats.js import JetStreamContext +from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy + +logger = logging.getLogger(__name__) + + +async def get_connection(nats_url: str): + nc = await nats.connect(nats_url) + js = nc.jetstream() + return nc, js + + +TASK_TTR = 300 # Default Time-To-Run (visibility timeout) in seconds + + +class TaskQueueManager: + """ + Manager for NATS JetStream task queue operations. + + Use as an async context manager: + async with TaskQueueManager() as manager: + await manager.publish_task('job123', {'data': 'value'}) + task = await manager.reserve_task('job123') + await manager.acknowledge_task(task['reply_subject']) + """ + + def __init__(self, nats_url: str | None = None): + self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") + self.nc: nats.NATS | None = None + self.js: JetStreamContext | None = None + + async def __aenter__(self): + """Create connection on enter.""" + self.nc, self.js = await get_connection(self.nats_url) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.js: + self.js = None + if self.nc and not self.nc.is_closed: + await self.nc.close() + self.nc = None + + return False + + def _get_stream_name(self, job_id: str) -> str: + """Get stream name from job_id.""" + return f"job_{job_id}" + + def _get_subject(self, job_id: str) -> str: + """Get subject name from job_id.""" + return f"job.{job_id}.tasks" + + def _get_consumer_name(self, job_id: str) -> str: + """Get consumer name from job_id.""" + return f"job-{job_id}-consumer" + + async def _ensure_stream(self, job_id: str): + """Ensure stream exists for the given job.""" + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + subject = self._get_subject(job_id) + + try: + await self.js.stream_info(stream_name) + logger.debug(f"Stream {stream_name} already exists") + except Exception as e: + logger.warning(f"Stream {stream_name} does not exist: {e}") + # Stream doesn't exist, create it + await self.js.add_stream( + name=stream_name, + subjects=[subject], + max_age=86400, # 24 hours retention + ) + logger.info(f"Created stream {stream_name}") + + async def _ensure_consumer(self, job_id: str): + """Ensure consumer exists for the given job.""" + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + try: + info = await self.js.consumer_info(stream_name, consumer_name) + logger.debug(f"Consumer {consumer_name} already exists: {info}") + except Exception: + # Consumer doesn't exist, create it + await self.js.add_consumer( + stream=stream_name, + config=ConsumerConfig( + durable_name=consumer_name, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=TASK_TTR, # Visibility timeout (TTR) + max_deliver=5, # Max retry attempts + deliver_policy=DeliverPolicy.ALL, + max_ack_pending=100, # Max unacked messages + filter_subject=subject, + ), + ) + logger.info(f"Created consumer {consumer_name}") + + async def publish_task(self, job_id: str, data: dict[str, Any]) -> bool: + """ + Publish a task to it's job queue. + + Args: + job_id: The job ID (e.g., 'job123' or '123') + data: Task data (dict will be JSON-encoded) + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + subject = self._get_subject(job_id) + task_data = json.dumps(data) + + # Publish to JetStream + ack = await self.js.publish(subject, task_data.encode()) + + logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") + return True + + except Exception as e: + logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") + return False + + async def reserve_task(self, job_id: str, timeout: float | None = None) -> dict[str, Any] | None: + """ + Reserve a task from the specified stream. + + Args: + job_id: The job ID to pull tasks from + timeout: Timeout in seconds for reservation (default: 5 seconds) + + Returns: + Dict with task details including 'reply_subject' for acknowledgment, or None if no task available + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + if timeout is None: + timeout = 5 + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + # Create ephemeral subscription for this pull + psub = await self.js.pull_subscribe(subject, consumer_name) + + try: + # Fetch a single message + msgs = await psub.fetch(1, timeout=timeout) + + if msgs: + msg = msgs[0] + task_data = json.loads(msg.data.decode()) + metadata = msg.metadata + + result = { + "id": metadata.sequence.stream, + "body": task_data, + "reply_subject": msg.reply, # For acknowledgment + } + + logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") + return result + + except nats.errors.TimeoutError: + # No messages available + logger.debug(f"No tasks available in stream for job '{job_id}'") + return None + finally: + # Always unsubscribe + await psub.unsubscribe() + + except Exception as e: + logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") + return None + + async def acknowledge_task(self, reply_subject: str) -> bool: + """ + Acknowledge (delete) a completed task using its reply subject. + + Args: + reply_subject: The reply subject from reserve_task + + Returns: + bool: True if successful + """ + if self.nc is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + await self.nc.publish(reply_subject, b"+ACK") + logger.debug(f"Acknowledged task with reply subject {reply_subject}") + return True + except Exception as e: + logger.error(f"Failed to acknowledge task: {e}") + return False + + async def delete_consumer(self, job_id: str) -> bool: + """ + Delete the consumer for a job. + + Args: + job_id: The job ID + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + + await self.js.delete_consumer(stream_name, consumer_name) + logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete consumer for job '{job_id}': {e}") + return False + + async def delete_stream(self, job_id: str) -> bool: + """ + Delete the stream for a job. + + Args: + job_id: The job ID + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + stream_name = self._get_stream_name(job_id) + + await self.js.delete_stream(stream_name) + logger.info(f"Deleted stream {stream_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete stream for job '{job_id}': {e}") + return False + + async def cleanup_job_resources(self, job_id: str) -> bool: + """ + Clean up all NATS resources (consumer and stream) for a job. + + This should be called when a job completes or is cancelled. + + Args: + job_id: The job ID + + Returns: + bool: True if successful, False otherwise + """ + # Delete consumer first, then stream + consumer_deleted = await self.delete_consumer(job_id) + stream_deleted = await self.delete_stream(job_id) + + return consumer_deleted and stream_deleted diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py new file mode 100644 index 000000000..483275453 --- /dev/null +++ b/ami/ml/orchestration/task_state.py @@ -0,0 +1,125 @@ +""" +Task state management for job progress tracking using Redis. +""" + +import logging +from collections import namedtuple + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +# Define a namedtuple for a TaskProgress with the image counts +TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) + + +class TaskStateManager: + """ + Manages job progress tracking state in Redis. + + Tracks pending images for jobs to calculate progress percentages + as workers process images asynchronously. + """ + + TIMEOUT = 86400 * 7 # 7 days in seconds + STAGES = ["process", "results"] + + def __init__(self, job_id: int): + """ + Initialize the task state manager for a specific job. + + Args: + job_id: The job primary key + """ + self.job_id = job_id + self._pending_key = f"job:{job_id}:pending_images" + self._total_key = f"job:{job_id}:pending_images_total" + + def initialize_job(self, image_ids: list[str]) -> None: + """ + Initialize job tracking with a list of image IDs to process. + + Args: + image_ids: List of image IDs that need to be processed + """ + for stage in self.STAGES: + cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) + + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + + def _get_pending_key(self, stage: str) -> str: + return f"{self._pending_key}:{stage}" + + def update_state( + self, + processed_image_ids: set[str], + stage: str, + request_id: str, + ) -> None | TaskProgress: + """ + Update the task state with newly processed images. + + Args: + processed_image_ids: Set of image IDs that have just been processed + """ + # Create a unique lock key for this job + lock_key = f"job:{self.job_id}:process_results_lock" + lock_timeout = 360 # 6 minutes (matches task time_limit) + lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) + if not lock_acquired: + return None + + try: + # Update progress tracking in Redis + progress_info = self._get_progress(processed_image_ids, stage) + return progress_info + finally: + # Always release the lock when done + current_lock_value = cache.get(lock_key) + # Only delete if we still own the lock (prevents race condition) + if current_lock_value == request_id: + cache.delete(lock_key) + logger.debug(f"Released lock for job {self.job_id}, task {request_id}") + + def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: + """ + Get current progress information for the job. + + Returns: + TaskProgress namedtuple with fields: + - remaining: Number of images still pending (or None if not tracked) + - total: Total number of images (or None if not tracked) + - processed: Number of images processed (or None if not tracked) + - percentage: Progress as float 0.0-1.0 (or None if not tracked) + """ + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + assert len(pending_images) >= len(remaining_images) + cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) + + remaining = len(remaining_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + logger.info( + f"Pending images from Redis for job {self.job_id} {stage}: " + f"{remaining}/{total_images}: {percentage*100}%" + ) + + return TaskProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + ) + + def cleanup(self) -> None: + """ + Delete all Redis keys associated with this job. + """ + for stage in self.STAGES: + cache.delete(self._get_pending_key(stage)) + cache.delete(self._total_key) diff --git a/ami/utils/requests.py b/ami/utils/requests.py index aff13209d..e4de57c0f 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -2,6 +2,7 @@ import requests from django.forms import BooleanField, FloatField +from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter from requests.adapters import HTTPAdapter from rest_framework.request import Request @@ -151,3 +152,22 @@ def get_default_classification_threshold(project: "Project | None" = None, reque required=False, type=int, ) + +ids_only_param = OpenApiParameter( + name="ids_only", + description="Return only job IDs instead of full job objects", + required=False, + type=OpenApiTypes.BOOL, +) +incomplete_only_param = OpenApiParameter( + name="incomplete_only", + description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", + required=False, + type=OpenApiTypes.BOOL, +) +batch_param = OpenApiParameter( + name="batch", + description="Number of tasks to pull in the batch", + required=False, + type=OpenApiTypes.INT, +) diff --git a/config/settings/base.py b/config/settings/base.py index 03124d41a..7f635fa4f 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -262,6 +262,10 @@ } REDIS_URL = env("REDIS_URL", default=None) +# NATS +# ------------------------------------------------------------------------------ +NATS_URL = env("NATS_URL", default="nats://localhost:4222") # type: ignore[no-untyped-call] + # ADMIN # ------------------------------------------------------------------------------ # Django Admin URL. diff --git a/docker-compose.yml b/docker-compose.yml index e2ad3a100..703ecea0d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,6 +21,7 @@ services: depends_on: - postgres - redis + - nats - minio-init - ml_backend - rabbitmq @@ -75,7 +76,12 @@ services: volumes: - ./.git:/app/.git:ro - ./ui:/app - entrypoint: ["sh", "-c", "yarn install && yarn start --debug --host 0.0.0.0 --port 4000"] + entrypoint: + [ + "sh", + "-c", + "yarn install && yarn start --debug --host 0.0.0.0 --port 4000", + ] environment: - API_PROXY_TARGET=http://django:8000 - CHOKIDAR_USEPOLLING=true @@ -84,6 +90,20 @@ services: image: redis:6 container_name: ami_local_redis + nats: + image: nats:2.10-alpine + container_name: ami_local_nats + hostname: nats + ports: + - "4222:4222" # Client port + - "8222:8222" # HTTP monitoring port + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s + retries: 3 + celeryworker: <<: *django image: ami_local_celeryworker diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..d6f27a4ec 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,6 +8,7 @@ celery==5.4.0 # pyup: < 6.0 # https://github.com/celery/celery django-celery-beat==2.5.0 # https://github.com/celery/django-celery-beat flower==2.0.1 # https://github.com/mher/flower kombu==5.4.2 +nats-py==2.10.0 # https://github.com/nats-io/nats.py uvicorn[standard]==0.22.0 # https://github.com/encode/uvicorn rich==13.5.0 markdown==3.4.4 @@ -41,7 +42,7 @@ djoser==2.2.0 django-guardian==2.4.0 # Email sending django-sendgrid-v5==1.2.2 -django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail +django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail/ ## Formerly dev-only dependencies # However we cannot run the app without some of these these dependencies @@ -52,6 +53,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing