From ae02d2eaa10e4e6992c6738f19dc82713cea4ba6 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 30 Sep 2025 14:36:07 -0700 Subject: [PATCH 01/29] fix: Job serialization overhead --- .gitignore | 3 + object_model_diagram.md | 238 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 object_model_diagram.md diff --git a/.gitignore b/.gitignore index e5274663f..8ea095114 100644 --- a/.gitignore +++ b/.gitignore @@ -276,3 +276,6 @@ sandbox/ # Other flower + +# huggingface cache +huggingface_cache/ diff --git a/object_model_diagram.md b/object_model_diagram.md new file mode 100644 index 000000000..003821948 --- /dev/null +++ b/object_model_diagram.md @@ -0,0 +1,238 @@ +# Object Model Diagram: ML Pipeline System + +```mermaid +classDiagram + %% Core ML Pipeline Classes + class Pipeline { + +string slug (unique) + +string name + +string description + +int version + +string version_name + +stages[] PipelineStage + +default_config PipelineRequestConfigParameters + -- + +get_config(project_id) PipelineRequestConfigParameters + +collect_images() Iterable~SourceImage~ + +process_images() PipelineResultsResponse + +choose_processing_service_for_pipeline() ProcessingService + } + + class Algorithm { + +string key (unique) + +string name + +AlgorithmTaskType task_type + +string description + +int version + +string version_name + +string uri + -- + +detection_task_types[] AlgorithmTaskType + +classification_task_types[] AlgorithmTaskType + +has_valid_category_map() boolean + } + + class AlgorithmCategoryMap { + +data JSONField + +labels[] string + +int labels_hash + +string version + +string description + +string uri + -- + +make_labels_hash() int + +get_category() int + +with_taxa() dict[] + } + + class PipelineStage { + +string key + +string name + +string description + +boolean enabled + +params[] ConfigurableStageParam + } + + class ProcessingService { + +string name + +string description + +string endpoint_url + +datetime last_checked + +boolean last_checked_live + +float last_checked_latency + -- + +create_pipelines() PipelineRegistrationResponse + +get_status() ProcessingServiceStatusResponse + +get_pipeline_configs() PipelineConfigResponse[] + } + + %% Job System Classes + class Job { + +string name + +string queue + +datetime scheduled_at + +datetime started_at + +datetime finished_at + +JobState status + +JobProgress progress + +JobLogs logs + +params JSONField + +result JSONField + +string task_id + +int delay + +int limit + +boolean shuffle + +string job_type_key + -- + +job_type() JobType + +update_status() void + +logger JobLogger + } + + class JobProgress { + +JobProgressSummary summary + +stages[] JobProgressStageDetail + +errors[] string + -- + +add_stage() JobProgressStageDetail + +update_stage() void + +add_stage_param() void + } + + class JobProgressSummary { + +JobState status + +float progress + -- + +status_label string + } + + class JobProgressStageDetail { + +string key + +string name + +string description + +boolean enabled + +params[] ConfigurableStageParam + +JobState status + +float progress + } + + %% Configuration Classes + class ProjectPipelineConfig { + +boolean enabled + +config JSONField + -- + +get_config() dict + } + + class Project { + +string name + +string slug + +feature_flags JSONField + -- + +default_processing_pipeline Pipeline + } + + %% Enums + class AlgorithmTaskType { + <> + DETECTION + LOCALIZATION + SEGMENTATION + CLASSIFICATION + EMBEDDING + TRACKING + TAGGING + REGRESSION + CAPTIONING + GENERATION + TRANSLATION + SUMMARIZATION + QUESTION_ANSWERING + DEPTH_ESTIMATION + POSE_ESTIMATION + SIZE_ESTIMATION + OTHER + UNKNOWN + } + + class JobState { + <> + CREATED + PENDING + STARTED + SUCCESS + FAILURE + RETRY + CANCELING + REVOKED + RECEIVED + UNKNOWN + } + + %% Relationships + + %% Pipeline Relationships + Pipeline }|--|| Algorithm : "algorithms (M2M)" + Pipeline ||--o{ PipelineStage : "stages (SchemaField)" + Pipeline }|--|| ProcessingService : "processing_services (M2M)" + Pipeline ||--o{ Job : "jobs (FK)" + Pipeline ||--o{ ProjectPipelineConfig : "project_pipeline_configs (FK)" + + %% Algorithm Relationships + Algorithm ||--o| AlgorithmCategoryMap : "category_map (FK)" + Algorithm ||--|| AlgorithmTaskType : "task_type" + + %% Job Relationships + Job ||--o| Pipeline : "pipeline (FK, nullable)" + Job ||--|| Project : "project (FK)" + Job ||--|| JobProgress : "progress (SchemaField)" + Job ||--|| JobState : "status" + + %% Job Progress Relationships + JobProgress ||--|| JobProgressSummary : "summary" + JobProgress ||--o{ JobProgressStageDetail : "stages" + JobProgressSummary ||--|| JobState : "status" + JobProgressStageDetail ||--|| JobState : "status" + + %% Processing Service Relationships + ProcessingService }|--|| Project : "projects (M2M)" + + %% Project Configuration Relationships + Project ||--o{ ProjectPipelineConfig : "project_pipeline_configs (FK)" + ProjectPipelineConfig ||--|| Pipeline : "pipeline (FK)" + ProjectPipelineConfig ||--|| Project : "project (FK)" + + %% Notes + note for Pipeline "Identified by unique slug\nAuto-generated from name + version + UUID" + note for Algorithm "Identified by unique key\nAuto-generated from name + version" + note for Job "MLJob is the primary job type\nfor running ML pipelines" + note for ProcessingService "External ML services that\nexecute pipeline algorithms" +``` + +## Key Relationships Summary + +### Core ML Pipeline Flow: +1. **ProcessingService** → registers → **Pipeline** → contains → **Algorithm** +2. **Project** → configures → **Pipeline** through **ProjectPipelineConfig** +3. **Job** → executes → **Pipeline** → uses → **ProcessingService** + +### Model Identification: +- **Pipeline**: Identified by unique `slug` (string) - auto-generated from `name + version + UUID` +- **Algorithm**: Identified by unique `key` (string) - auto-generated from `name + version` +- **Job**: Uses standard Django `id` but also has `task_id` for Celery integration + +### Stage Management: +- **Pipeline** contains **PipelineStage** objects (for configuration display) +- **Job** tracks execution through **JobProgressStageDetail** objects (for runtime progress) +- Both share the same base **ConfigurableStage** schema + +### Algorithm Classification: +- **Algorithm** has task types (detection, classification, etc.) +- Classification algorithms require **AlgorithmCategoryMap** for label mapping +- Detection algorithms don't require category maps + +### Job Execution Flow: +1. **Job** is created with a **Pipeline** reference +2. **Pipeline** selects appropriate **ProcessingService** +3. **ProcessingService** executes algorithms and returns results +4. **Job** tracks progress through **JobProgress** and **JobProgressStageDetail** From 24a15af0862e2964c1f9909c0ab5e70b214cdcd7 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 30 Sep 2025 14:50:42 -0700 Subject: [PATCH 02/29] syntax --- object_model_diagram.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/object_model_diagram.md b/object_model_diagram.md index 003821948..2b17a0143 100644 --- a/object_model_diagram.md +++ b/object_model_diagram.md @@ -172,9 +172,9 @@ classDiagram %% Relationships %% Pipeline Relationships - Pipeline }|--|| Algorithm : "algorithms (M2M)" + Pipeline }|--|{ Algorithm : "algorithms (M2M)" Pipeline ||--o{ PipelineStage : "stages (SchemaField)" - Pipeline }|--|| ProcessingService : "processing_services (M2M)" + Pipeline }|--|{ ProcessingService : "processing_services (M2M)" Pipeline ||--o{ Job : "jobs (FK)" Pipeline ||--o{ ProjectPipelineConfig : "project_pipeline_configs (FK)" @@ -195,7 +195,7 @@ classDiagram JobProgressStageDetail ||--|| JobState : "status" %% Processing Service Relationships - ProcessingService }|--|| Project : "projects (M2M)" + ProcessingService }|--|{ Project : "projects (M2M)" %% Project Configuration Relationships Project ||--o{ ProjectPipelineConfig : "project_pipeline_configs (FK)" From 0da97a604f2c1da667823cd2599a254367358b07 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 30 Sep 2025 14:58:07 -0700 Subject: [PATCH 03/29] fix syntax --- object_model_diagram.md | 59 ++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/object_model_diagram.md b/object_model_diagram.md index 2b17a0143..f337c2849 100644 --- a/object_model_diagram.md +++ b/object_model_diagram.md @@ -4,7 +4,7 @@ classDiagram %% Core ML Pipeline Classes class Pipeline { - +string slug (unique) + +string slug +string name +string description +int version @@ -19,7 +19,7 @@ classDiagram } class Algorithm { - +string key (unique) + +string key +string name +AlgorithmTaskType task_type +string description @@ -170,37 +170,30 @@ classDiagram } %% Relationships - - %% Pipeline Relationships - Pipeline }|--|{ Algorithm : "algorithms (M2M)" - Pipeline ||--o{ PipelineStage : "stages (SchemaField)" - Pipeline }|--|{ ProcessingService : "processing_services (M2M)" - Pipeline ||--o{ Job : "jobs (FK)" - Pipeline ||--o{ ProjectPipelineConfig : "project_pipeline_configs (FK)" - - %% Algorithm Relationships - Algorithm ||--o| AlgorithmCategoryMap : "category_map (FK)" - Algorithm ||--|| AlgorithmTaskType : "task_type" - - %% Job Relationships - Job ||--o| Pipeline : "pipeline (FK, nullable)" - Job ||--|| Project : "project (FK)" - Job ||--|| JobProgress : "progress (SchemaField)" - Job ||--|| JobState : "status" - - %% Job Progress Relationships - JobProgress ||--|| JobProgressSummary : "summary" - JobProgress ||--o{ JobProgressStageDetail : "stages" - JobProgressSummary ||--|| JobState : "status" - JobProgressStageDetail ||--|| JobState : "status" - - %% Processing Service Relationships - ProcessingService }|--|{ Project : "projects (M2M)" - - %% Project Configuration Relationships - Project ||--o{ ProjectPipelineConfig : "project_pipeline_configs (FK)" - ProjectPipelineConfig ||--|| Pipeline : "pipeline (FK)" - ProjectPipelineConfig ||--|| Project : "project (FK)" + Pipeline "M" -- "M" Algorithm : algorithms + Pipeline "1" -- "many" PipelineStage : stages + Pipeline "M" -- "M" ProcessingService : processing_services + Pipeline "1" -- "many" Job : jobs + Pipeline "1" -- "many" ProjectPipelineConfig : project_pipeline_configs + + Algorithm "1" -- "0..1" AlgorithmCategoryMap : category_map + Algorithm "1" -- "1" AlgorithmTaskType : task_type + + Job "0..1" -- "1" Pipeline : pipeline + Job "1" -- "1" Project : project + Job "1" -- "1" JobProgress : progress + Job "1" -- "1" JobState : status + + JobProgress "1" -- "1" JobProgressSummary : summary + JobProgress "1" -- "many" JobProgressStageDetail : stages + JobProgressSummary "1" -- "1" JobState : status + JobProgressStageDetail "1" -- "1" JobState : status + + ProcessingService "M" -- "M" Project : projects + + Project "1" -- "many" ProjectPipelineConfig : project_pipeline_configs + ProjectPipelineConfig "1" -- "1" Pipeline : pipeline + ProjectPipelineConfig "1" -- "1" Project : project %% Notes note for Pipeline "Identified by unique slug\nAuto-generated from name + version + UUID" From 2db7d66f4d7e1d48944408c11a7767710223dcca Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 1 Oct 2025 10:34:54 -0700 Subject: [PATCH 04/29] Simplify diagram --- object_model_diagram.md | 64 ----------------------------------------- 1 file changed, 64 deletions(-) diff --git a/object_model_diagram.md b/object_model_diagram.md index f337c2849..0acdfdf61 100644 --- a/object_model_diagram.md +++ b/object_model_diagram.md @@ -53,19 +53,6 @@ classDiagram +params[] ConfigurableStageParam } - class ProcessingService { - +string name - +string description - +string endpoint_url - +datetime last_checked - +boolean last_checked_live - +float last_checked_latency - -- - +create_pipelines() PipelineRegistrationResponse - +get_status() ProcessingServiceStatusResponse - +get_pipeline_configs() PipelineConfigResponse[] - } - %% Job System Classes class Job { +string name @@ -89,33 +76,6 @@ classDiagram +logger JobLogger } - class JobProgress { - +JobProgressSummary summary - +stages[] JobProgressStageDetail - +errors[] string - -- - +add_stage() JobProgressStageDetail - +update_stage() void - +add_stage_param() void - } - - class JobProgressSummary { - +JobState status - +float progress - -- - +status_label string - } - - class JobProgressStageDetail { - +string key - +string name - +string description - +boolean enabled - +params[] ConfigurableStageParam - +JobState status - +float progress - } - %% Configuration Classes class ProjectPipelineConfig { +boolean enabled @@ -155,24 +115,10 @@ classDiagram UNKNOWN } - class JobState { - <> - CREATED - PENDING - STARTED - SUCCESS - FAILURE - RETRY - CANCELING - REVOKED - RECEIVED - UNKNOWN - } %% Relationships Pipeline "M" -- "M" Algorithm : algorithms Pipeline "1" -- "many" PipelineStage : stages - Pipeline "M" -- "M" ProcessingService : processing_services Pipeline "1" -- "many" Job : jobs Pipeline "1" -- "many" ProjectPipelineConfig : project_pipeline_configs @@ -181,15 +127,6 @@ classDiagram Job "0..1" -- "1" Pipeline : pipeline Job "1" -- "1" Project : project - Job "1" -- "1" JobProgress : progress - Job "1" -- "1" JobState : status - - JobProgress "1" -- "1" JobProgressSummary : summary - JobProgress "1" -- "many" JobProgressStageDetail : stages - JobProgressSummary "1" -- "1" JobState : status - JobProgressStageDetail "1" -- "1" JobState : status - - ProcessingService "M" -- "M" Project : projects Project "1" -- "many" ProjectPipelineConfig : project_pipeline_configs ProjectPipelineConfig "1" -- "1" Pipeline : pipeline @@ -199,7 +136,6 @@ classDiagram note for Pipeline "Identified by unique slug\nAuto-generated from name + version + UUID" note for Algorithm "Identified by unique key\nAuto-generated from name + version" note for Job "MLJob is the primary job type\nfor running ML pipelines" - note for ProcessingService "External ML services that\nexecute pipeline algorithms" ``` ## Key Relationships Summary From 8a714cd494fe70c315f71c581d7d991db7086641 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 3 Oct 2025 14:06:42 -0700 Subject: [PATCH 05/29] Add RabbitMQ --- .dockerignore | 64 +++++++- .envs/.local/.django | 12 ++ ami/utils/rabbitmq.py | 187 ++++++++++++++++++++++++ compose/local/rabbitmq/definitions.json | 38 +++++ compose/local/rabbitmq/rabbitmq.conf | 19 +++ config/settings/base.py | 13 +- docker-compose.yml | 27 +++- requirements/base.txt | 4 +- 8 files changed, 354 insertions(+), 10 deletions(-) create mode 100644 ami/utils/rabbitmq.py create mode 100644 compose/local/rabbitmq/definitions.json create mode 100644 compose/local/rabbitmq/rabbitmq.conf diff --git a/.dockerignore b/.dockerignore index 8b71212d1..ea4fa3ffc 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,20 +1,72 @@ .editorconfig .gitattributes .github -.gitignore .gitlab-ci.yml .idea .pre-commit-config.yaml .readthedocs.yml .travis.yml -.git ui ami/media backups -venv -.venv -.env -.envs +venv/ +.venv/ +.env/ +.envs/ .envs/* node_modules data + +# Python cache / bytecode +__pycache__/ +*.py[cod] +*.pyo +*.pyd +*.pdb +*.egg-info/ +*.egg +*.whl + + +# Django / runtime artifacts +*.log +*.pot +*.pyc +db.sqlite3 +media/ +staticfiles/ # collected static files (use collectstatic inside container) + +# Node / UI dependencies (if using React/Vue in your UI service) +node_modules/ +npm-debug.log +yarn-error.log +.pnpm-debug.log + +# Docs build artifacts +/docs/_build/ + +# Git / VCS +.git/ +.gitignore +.gitattributes +*.swp +*.swo + +# IDE / editor +.vscode/ +.idea/ +*.iml + +# OS cruft +.DS_Store +Thumbs.db + +# Docker itself +.dockerignore +Dockerfile +docker-compose*.yml + +# Build / dist +build/ +dist/ +.eggs/ diff --git a/.envs/.local/.django b/.envs/.local/.django index 76a69e68e..5b53cbb49 100644 --- a/.envs/.local/.django +++ b/.envs/.local/.django @@ -12,6 +12,18 @@ DJANGO_SUPERUSER_PASSWORD=localadmin # Redis REDIS_URL=redis://redis:6379/0 +# RabbitMQ Admin (for management only) +RABBITMQ_ADMIN_USER=admin +RABBITMQ_ADMIN_PASS=admin123 + +# RabbitMQ Django App User (for application use) +RABBITMQ_DJANGO_USER=django_app +RABBITMQ_DJANGO_PASS=django_secure_pass +RABBITMQ_DEFAULT_VHOST=/ +RABBITMQ_HOST=rabbitmq +RABBITMQ_PORT=5672 +RABBITMQ_URL=amqp://django_app:django_secure_pass@rabbitmq:5672/ + # Celery / Flower CELERY_FLOWER_USER=QSocnxapfMvzLqJXSsXtnEZqRkBtsmKT CELERY_FLOWER_PASSWORD=BEQgmCtgyrFieKNoGTsux9YIye0I7P5Q7vEgfJD2C4jxmtHDetFaE2jhS7K7rxaf diff --git a/ami/utils/rabbitmq.py b/ami/utils/rabbitmq.py new file mode 100644 index 000000000..f46276016 --- /dev/null +++ b/ami/utils/rabbitmq.py @@ -0,0 +1,187 @@ +""" +RabbitMQ utilities for the Antenna application. + +This module provides a simple interface for interacting with RabbitMQ +using the pika library. +""" + +import json +import logging +import os +from collections.abc import Callable +from typing import Any + +import pika +from django.conf import settings + +logger = logging.getLogger(__name__) + + +class RabbitMQConnection: + """ + A context manager for RabbitMQ connections. + """ + + def __init__(self, connection_url: str = ""): + self.connection_url: str + self.connection_url = connection_url or getattr(settings, "RABBITMQ_URL", "") + if not self.connection_url: + # Fallback to Django settings or environment variables + host = getattr(settings, "RABBITMQ_HOST", os.getenv("RABBITMQ_HOST", "localhost")) + port = getattr(settings, "RABBITMQ_PORT", int(os.getenv("RABBITMQ_PORT", "5672"))) + user = getattr(settings, "RABBITMQ_DJANGO_USER", os.getenv("RABBITMQ_DJANGO_USER", "guest")) + password = getattr(settings, "RABBITMQ_DJANGO_PASS", os.getenv("RABBITMQ_DJANGO_PASS", "guest")) + vhost = getattr(settings, "RABBITMQ_DEFAULT_VHOST", os.getenv("RABBITMQ_DEFAULT_VHOST", "/")) + self.connection_url = f"amqp://{user}:{password}@{host}:{port}{vhost}" # noqa: E231 + + self.connection = None + self.channel = None + + def __enter__(self): + try: + parameters = pika.URLParameters(self.connection_url) + self.connection = pika.BlockingConnection(parameters) + self.channel = self.connection.channel() + return self.channel + except Exception as e: + logger.error(f"Failed to connect to RabbitMQ: {e}") + raise + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.connection and not self.connection.is_closed: + self.connection.close() + + +class RabbitMQPublisher: + """ + A simple publisher for RabbitMQ messages. + """ + + def __init__(self, connection_url: str | None = None): + self.connection_url = connection_url + + def publish_message( + self, + queue_name: str, + message: dict[str, Any], + exchange: str = "", + routing_key: str | None = None, + durable: bool = True, + ) -> bool: + """ + Publish a message to a RabbitMQ queue. + + Args: + queue_name: Name of the queue to publish to + message: Message data (will be JSON serialized) + exchange: Exchange name (default: '') + routing_key: Routing key (default: queue_name) + durable: Whether the queue should be durable + + Returns: + bool: True if message was published successfully + """ + if routing_key is None: + routing_key = queue_name + + try: + with RabbitMQConnection(self.connection_url) as channel: + # Declare the queue + channel.queue_declare(queue=queue_name, durable=durable) + + # Publish the message + channel.basic_publish( + exchange=exchange, + routing_key=routing_key, + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2 if durable else 1, # Make message persistent if durable + content_type="application/json", + ), + ) + + logger.info(f"Published message to queue '{queue_name}': {message}") + return True + + except Exception as e: + logger.error(f"Failed to publish message to queue '{queue_name}': {e}") + return False + + +class RabbitMQConsumer: + """ + A simple consumer for RabbitMQ messages. + """ + + def __init__(self, connection_url: str | None = None): + self.connection_url = connection_url + + def consume_messages( + self, queue_name: str, callback: Callable[[dict[str, Any]], None], durable: bool = True, auto_ack: bool = False + ): + """ + Consume messages from a RabbitMQ queue. + + Args: + queue_name: Name of the queue to consume from + callback: Function to call for each message + durable: Whether the queue should be durable + auto_ack: Whether to automatically acknowledge messages + """ + + def message_callback(ch, method, properties, body): + try: + message = json.loads(body.decode("utf-8")) + callback(message) + + if not auto_ack: + ch.basic_ack(delivery_tag=method.delivery_tag) + + except Exception as e: + logger.error(f"Error processing message from queue '{queue_name}': {e}") + if not auto_ack: + ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) + + try: + with RabbitMQConnection(self.connection_url) as channel: + # Declare the queue + channel.queue_declare(queue=queue_name, durable=durable) + + # Set up the consumer + channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=auto_ack) + + logger.info(f"Starting to consume messages from queue '{queue_name}'") + channel.start_consuming() + + except KeyboardInterrupt: + logger.info("Stopping consumer...") + if "channel" in locals(): + channel.stop_consuming() + except Exception as e: + logger.error(f"Error consuming from queue '{queue_name}': {e}") + raise + + +# Convenience functions +def publish_to_queue(queue_name: str, message: dict[str, Any], **kwargs) -> bool: + """ + Convenience function to publish a message to a queue. + """ + publisher = RabbitMQPublisher() + return publisher.publish_message(queue_name, message, **kwargs) + + +def test_connection() -> bool: + """ + Test the RabbitMQ connection. + + Returns: + bool: True if connection is successful + """ + try: + with RabbitMQConnection() as _: + logger.info("RabbitMQ connection test successful") + return True + except Exception as e: + logger.error(f"RabbitMQ connection test failed: {e}") + return False diff --git a/compose/local/rabbitmq/definitions.json b/compose/local/rabbitmq/definitions.json new file mode 100644 index 000000000..74fe80fe1 --- /dev/null +++ b/compose/local/rabbitmq/definitions.json @@ -0,0 +1,38 @@ +{ + "users": [ + { + "name": "admin", + "password": "admin123", + "tags": "administrator" + }, + { + "name": "django_app", + "password": "django_secure_pass", + "tags": "management" + } + ], + "vhosts": [ + { + "name": "/" + } + ], + "permissions": [ + { + "user": "admin", + "vhost": "/", + "configure": ".*", + "write": ".*", + "read": ".*" + }, + { + "user": "django_app", + "vhost": "/", + "configure": ".*", + "write": ".*", + "read": ".*" + } + ], + "exchanges": [], + "queues": [], + "bindings": [] +} diff --git a/compose/local/rabbitmq/rabbitmq.conf b/compose/local/rabbitmq/rabbitmq.conf new file mode 100644 index 000000000..99e737298 --- /dev/null +++ b/compose/local/rabbitmq/rabbitmq.conf @@ -0,0 +1,19 @@ +# Enable management plugin +management.tcp.port = 15672 + +# Load user definitions from JSON file +management.load_definitions = /etc/rabbitmq/definitions.json + +# Default virtual host +default_vhost = / + +# Logging +log.console = true +log.console.level = info + +# Memory and disk limits +vm_memory_high_watermark.relative = 0.6 +disk_free_limit.relative = 2.0 + +# Queue master locator +queue_master_locator = min-masters diff --git a/config/settings/base.py b/config/settings/base.py index aec01b715..38aa2d86d 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -24,7 +24,9 @@ # GENERAL # ------------------------------------------------------------------------------ EXTERNAL_HOSTNAME = env("EXTERNAL_HOSTNAME", default="localhost:8000") # type: ignore[no-untyped-call] -EXTERNAL_BASE_URL = env("EXTERNAL_BASE_URL", default=f"http://{EXTERNAL_HOSTNAME}") # type: ignore[no-untyped-call] +EXTERNAL_BASE_URL = env( + "EXTERNAL_BASE_URL", default=f"http://{EXTERNAL_HOSTNAME}" # noqa: E231, E501 # type: ignore[no-untyped-call] +) # https://docs.djangoproject.com/en/dev/ref/settings/#debug DEBUG = env.bool("DJANGO_DEBUG", False) # type: ignore[no-untyped-call] @@ -261,6 +263,15 @@ } } +# RABBITMQ +# ------------------------------------------------------------------------------ +RABBITMQ_URL = env("RABBITMQ_URL", default="amqp://guest:guest@localhost:5672/") # type: ignore[no-untyped-call] +RABBITMQ_HOST = env("RABBITMQ_HOST", default="localhost") # type: ignore[no-untyped-call] +RABBITMQ_PORT = env.int("RABBITMQ_PORT", default=5672) # type: ignore[no-untyped-call] +RABBITMQ_DJANGO_USER = env("RABBITMQ_DJANGO_USER", default="guest") # type: ignore[no-untyped-call] +RABBITMQ_DJANGO_PASS = env("RABBITMQ_DJANGO_PASS", default="guest") # type: ignore[no-untyped-call] +RABBITMQ_DEFAULT_VHOST = env("RABBITMQ_DEFAULT_VHOST", default="/") # type: ignore[no-untyped-call] + # ADMIN # ------------------------------------------------------------------------------ # Django Admin URL. diff --git a/docker-compose.yml b/docker-compose.yml index ff9d125f0..f9567604c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,9 +8,10 @@ volumes: o: bind minio_data: driver: local + rabbitmq_data: + driver: local node_modules: - services: django: &django build: @@ -22,6 +23,7 @@ services: depends_on: - postgres - redis + - rabbitmq - minio-init - ml_backend volumes: @@ -86,6 +88,27 @@ services: image: redis:6 container_name: ami_local_redis + rabbitmq: + image: rabbitmq:3.12-management + container_name: ami_local_rabbitmq + hostname: rabbitmq + ports: + - "5672:5672" # AMQP port + - "15672:15672" # Management UI port + environment: + RABBITMQ_DEFAULT_USER: ${RABBITMQ_ADMIN_USER:-admin} + RABBITMQ_DEFAULT_PASS: ${RABBITMQ_ADMIN_PASS:-admin123} + RABBITMQ_DEFAULT_VHOST: ${RABBITMQ_DEFAULT_VHOST:-/} + volumes: + - rabbitmq_data:/var/lib/rabbitmq + - ./compose/local/rabbitmq/definitions.json:/etc/rabbitmq/definitions.json:ro + - ./compose/local/rabbitmq/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf:ro + healthcheck: + test: rabbitmq-diagnostics -q ping + interval: 30s + timeout: 30s + retries: 3 + celeryworker: <<: *django image: ami_local_celeryworker @@ -123,7 +146,7 @@ services: ports: - "9001:9001" healthcheck: - test: [ "CMD", "mc", "ready", "local" ] + test: ["CMD", "mc", "ready", "local"] interval: 5s timeout: 5s retries: 5 diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..09b332d66 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,6 +8,7 @@ celery==5.4.0 # pyup: < 6.0 # https://github.com/celery/celery django-celery-beat==2.5.0 # https://github.com/celery/django-celery-beat flower==2.0.1 # https://github.com/mher/flower kombu==5.4.2 +pika==1.3.2 # https://github.com/pika/pika uvicorn[standard]==0.22.0 # https://github.com/encode/uvicorn rich==13.5.0 markdown==3.4.4 @@ -41,7 +42,7 @@ djoser==2.2.0 django-guardian==2.4.0 # Email sending django-sendgrid-v5==1.2.2 -django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail +django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail/ ## Formerly dev-only dependencies # However we cannot run the app without some of these these dependencies @@ -52,6 +53,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing From 700f594540e48f6c70431775b7a19eb1d6fc1849 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 8 Oct 2025 16:37:13 -0700 Subject: [PATCH 06/29] WIP: Use NATS JetStream for queuing --- .envs/.local/.django | 13 +- .vscode/launch.json | 24 +- README.md | 1 + ami/jobs/models.py | 144 +++++++++++- ami/jobs/views.py | 118 +++++++++- ami/utils/nats_queue.py | 288 ++++++++++++++++++++++++ ami/utils/rabbitmq.py | 187 --------------- compose/local/rabbitmq/definitions.json | 38 ---- compose/local/rabbitmq/rabbitmq.conf | 19 -- config/settings/base.py | 9 +- docker-compose.yml | 34 ++- requirements/base.txt | 2 +- 12 files changed, 578 insertions(+), 299 deletions(-) create mode 100644 ami/utils/nats_queue.py delete mode 100644 ami/utils/rabbitmq.py delete mode 100644 compose/local/rabbitmq/definitions.json delete mode 100644 compose/local/rabbitmq/rabbitmq.conf diff --git a/.envs/.local/.django b/.envs/.local/.django index 5b53cbb49..d38d383b9 100644 --- a/.envs/.local/.django +++ b/.envs/.local/.django @@ -12,17 +12,8 @@ DJANGO_SUPERUSER_PASSWORD=localadmin # Redis REDIS_URL=redis://redis:6379/0 -# RabbitMQ Admin (for management only) -RABBITMQ_ADMIN_USER=admin -RABBITMQ_ADMIN_PASS=admin123 - -# RabbitMQ Django App User (for application use) -RABBITMQ_DJANGO_USER=django_app -RABBITMQ_DJANGO_PASS=django_secure_pass -RABBITMQ_DEFAULT_VHOST=/ -RABBITMQ_HOST=rabbitmq -RABBITMQ_PORT=5672 -RABBITMQ_URL=amqp://django_app:django_secure_pass@rabbitmq:5672/ +# NATS +NATS_URL=nats://nats:4222 # Celery / Flower CELERY_FLOWER_USER=QSocnxapfMvzLqJXSsXtnEZqRkBtsmKT diff --git a/.vscode/launch.json b/.vscode/launch.json index f85af6b31..c5c79f15e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2,7 +2,29 @@ "version": "0.2.0", "configurations": [ { - "name": "Python Debugger: Remote Attach", + "name": "Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, + { + "name": "Django attach", + "type": "debugpy", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5679 + }, + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "." + } + ] + }, + { + "name": "Celery worker attach", "type": "debugpy", "request": "attach", "connect": { diff --git a/README.md b/README.md index 7d8a26eff..17da13ea1 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ docker compose -f processing_services/example/docker-compose.yml up -d - Django admin: http://localhost:8000/admin/ - OpenAPI / Swagger documentation: http://localhost:8000/api/v2/docs/ - Minio UI: http://minio:9001, Minio service: http://minio:9000 +- NATS dashboard: https://natsdashboard.com/ (Add localhost) NOTE: If one of these services is not working properly, it could be due another process is using the port. You can check for this with `lsof -i :`. diff --git a/ami/jobs/models.py b/ami/jobs/models.py index f7b85283b..844669696 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -71,7 +71,7 @@ def get_status_label(status: JobState, progress: float) -> str: if status in [JobState.CREATED, JobState.PENDING, JobState.RECEIVED]: return "Waiting to start" elif status in [JobState.STARTED, JobState.RETRY, JobState.SUCCESS]: - return f"{progress:.0%} complete" + return f"{progress:.0%} complete" # noqa E231 else: return f"{status.name}" @@ -132,14 +132,14 @@ def get_stage(self, stage_key: str) -> JobProgressStageDetail: for stage in self.stages: if stage.key == stage_key: return stage - raise ValueError(f"Job stage with key '{stage_key}' not found in progress") + raise ValueError(f"Job stage with key '{stage_key}' not found in progress") # noqa E713 def get_stage_param(self, stage_key: str, param_key: str) -> ConfigurableStageParam: stage = self.get_stage(stage_key) for param in stage.params: if param.key == param_key: return param - raise ValueError(f"Job stage parameter with key '{param_key}' not found in stage '{stage_key}'") + raise ValueError(f"Job stage parameter with key '{param_key}' not found in stage '{stage_key}'") # noqa E713 def add_stage_param(self, stage_key: str, param_name: str, value: typing.Any = None) -> ConfigurableStageParam: stage = self.get_stage(stage_key) @@ -326,10 +326,6 @@ def run(cls, job: "Job"): job.finished_at = None job.save() - # Keep track of sub-tasks for saving results, pair with batch number - save_tasks: list[tuple[int, AsyncResult]] = [] - save_tasks_completed: list[tuple[int, AsyncResult]] = [] - if job.delay: update_interval_seconds = 2 last_update = time.time() @@ -372,7 +368,7 @@ def run(cls, job: "Job"): deployment=job.deployment, source_images=[job.source_image_single] if job.source_image_single else None, job_id=job.pk, - skip_processed=True, + skip_processed=False, # WIP don't commit # shuffle=job.shuffle, ) ) @@ -388,8 +384,6 @@ def run(cls, job: "Job"): images = images[: job.limit] image_count = len(images) job.progress.add_stage_param("collect", "Limit", image_count) - else: - image_count = source_image_count job.progress.update_stage( "collect", @@ -400,6 +394,17 @@ def run(cls, job: "Job"): # End image collection stage job.save() + # WIP: don't commit + # TODO: do this conditionally based on the type of processing service this job is using + # cls.process_images(job, images) + cls.queue_images_to_nats(job, images) + + @classmethod + def process_images(cls, job, images): + image_count = len(images) + # Keep track of sub-tasks for saving results, pair with batch number + save_tasks: list[tuple[int, AsyncResult]] = [] + save_tasks_completed: list[tuple[int, AsyncResult]] = [] total_captures = 0 total_detections = 0 total_classifications = 0 @@ -419,7 +424,7 @@ def run(cls, job: "Job"): job_id=job.pk, project_id=job.project.pk, ) - job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s") + job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s") # noqa E231 except Exception as e: # Log error about image batch and continue job.logger.error(f"Failed to process image batch {i+1}: {e}") @@ -471,7 +476,7 @@ def run(cls, job: "Job"): if image_count: percent_successful = 1 - len(request_failed_images) / image_count if image_count else 0 - job.logger.info(f"Processed {percent_successful:.0%} of images successfully.") + job.logger.info(f"Processed {percent_successful:.0%} of images successfully.") # noqa E231 # Check all Celery sub-tasks if they have completed saving results save_tasks_remaining = set(save_tasks) - set(save_tasks_completed) @@ -511,6 +516,121 @@ def run(cls, job: "Job"): job.finished_at = datetime.datetime.now() job.save() + # TODO: This needs to happen once a job is done + @classmethod + def cleanup_nats_resources(cls, job: "Job"): + """ + Clean up NATS JetStream resources (stream and consumer) for a completed job. + + Args: + job: The Job instance + """ + import asyncio + + from ami.utils.nats_queue import TaskQueueManager + + job_id = f"job{job.pk}" + + async def cleanup(): + async with TaskQueueManager() as manager: + success = await manager.cleanup_job_resources(job_id) + return success + + # Run cleanup in a new event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + success = loop.run_until_complete(cleanup()) + if success: + job.logger.info(f"Cleaned up NATS resources for job '{job_id}'") + else: + job.logger.warning(f"Failed to fully clean up NATS resources for job '{job_id}'") + except Exception as e: + job.logger.error(f"Error cleaning up NATS resources for job '{job_id}': {e}") + finally: + loop.close() + + @classmethod + def queue_images_to_nats(cls, job: "Job", images: list): + """ + Queue all images for a job to a NATS JetStream stream for the job. + + Args: + job: The Job instance + images: List of SourceImage instances to queue + + Returns: + bool: True if all images were successfully queued, False otherwise + """ + import asyncio + + from ami.utils.nats_queue import TaskQueueManager + + job_id = f"job{job.pk}" + job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job_id}'") + + # Prepare all messages outside of async context to avoid Django ORM issues + messages = [] + for i, image in enumerate(images): + message = { + "job_id": job.pk, + "image_id": image.id if hasattr(image, "id") else image.pk, + "image_url": image.url() if hasattr(image, "url") else None, + "timestamp": ( + image.timestamp.isoformat() if hasattr(image, "timestamp") and image.timestamp else None + ), + "batch_index": i, + "total_images": len(images), + "queue_timestamp": datetime.datetime.now().isoformat(), + } + messages.append((image.pk, message)) + + async def queue_all_images(): + successful_queues = 0 + failed_queues = 0 + + async with TaskQueueManager() as manager: + for i, (image_pk, message) in enumerate(messages): + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job_id}': {message}") + # Use TTR of 300 seconds (5 minutes) for image processing + success = await manager.publish_job( + job_id=job_id, + data=message, + ttr=300, # 5 minutes visibility timeout + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job_id}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 + + return successful_queues, failed_queues + + # Run the async function in a new event loop to avoid conflicts with Django + # Use new_event_loop() to ensure we're not mixing with Django's async context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + successful_queues, failed_queues = loop.run_until_complete(queue_all_images()) + finally: + loop.close() + + # Log results (back in sync context) + if successful_queues > 0: + job.logger.info( + f"Successfully queued {successful_queues}/{len(images)} images to stream for job '{job_id}'" + ) + + if failed_queues > 0: + job.logger.warning(f"Failed to queue {failed_queues}/{len(images)} images to stream for job '{job_id}'") + return False + + return True + class DataStorageSyncJob(JobType): name = "Data storage sync" diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 5fffdb6fd..1f200e5a0 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,11 +1,13 @@ +import asyncio import logging from django.db.models.query import QuerySet from django.forms import IntegerField from django.utils import timezone -from drf_spectacular.utils import extend_schema +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter, extend_schema from rest_framework.decorators import action -from rest_framework.exceptions import PermissionDenied +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.response import Response from ami.base.permissions import ObjectPermission @@ -55,6 +57,7 @@ class JobViewSet(DefaultViewSet, ProjectMixin): "pipeline", "job_type_key", ] + search_fields = ["name", "pipeline__name"] ordering_fields = [ "name", "created_at", @@ -153,6 +156,115 @@ def get_queryset(self) -> QuerySet: updated_at__lt=cutoff_datetime, ) - @extend_schema(parameters=[project_id_doc_param]) + @extend_schema( + parameters=[ + project_id_doc_param, + OpenApiParameter( + name="pipeline", + description="Filter jobs by pipeline ID", + required=False, + type=OpenApiTypes.INT, + ), + OpenApiParameter( + name="ids_only", + description="Return only job IDs instead of full job objects", + required=False, + type=OpenApiTypes.BOOL, + ), + ] + ) def list(self, request, *args, **kwargs): + # Check if ids_only parameter is set + ids_only = request.query_params.get("ids_only", "false").lower() in ["true", "1", "yes"] + + if ids_only: + # Get filtered queryset and return only IDs + queryset = self.filter_queryset(self.get_queryset()) + job_ids = list(queryset.values_list("id", flat=True)) + return Response({"job_ids": job_ids, "count": len(job_ids)}) + return super().list(request, *args, **kwargs) + + @extend_schema( + parameters=[ + OpenApiParameter( + name="batch", + description="Number of tasks to pull in the batch", + required=False, + type=OpenApiTypes.INT, + ), + ], + responses={200: dict}, + ) + @action(detail=True, methods=["get"], name="tasks") + def tasks(self, request, pk=None): + """ + Get tasks from the job queue. + + Returns task data with reply_subject for acknowledgment. External workers should: + 1. Call this endpoint to get tasks + 2. Process the tasks + 3. POST to /jobs/{id}/result/ with the reply_subject to acknowledge + + This stateless approach allows workers to communicate over HTTP without + maintaining persistent connections to the queue system. + """ + job: Job = self.get_object() + batch = IntegerField(required=False, min_value=1).clean(request.query_params.get("batch", 1)) + job_id = f"job{job.pk}" + + # Validate that the job has a pipeline + if not job.pipeline: + raise ValidationError("This job does not have a pipeline configured") + + # Get tasks from NATS JetStream + from ami.utils.nats_queue import TaskQueueManager + + async def get_tasks(): + tasks = [] + async with TaskQueueManager() as manager: + for i in range(batch): + task = await manager.reserve_job(job_id) + if task: + tasks.append(task) + return tasks + + tasks = asyncio.run(get_tasks()) + return Response({"tasks": tasks}) + + @action(detail=True, methods=["post"], name="result") + def result(self, request, pk=None): + """ + Acknowledge task completion. + + External services should POST task results with the reply_subject received + from the /tasks endpoint to acknowledge task completion. + + The request body should contain: + { + "reply_subject": "string", # Required: from the task response + "status": "completed" | "failed", # Optional + "result_data": {...}, # Optional + "error_message": "Error details...", # Optional for failed tasks + } + """ + from ami.utils.nats_queue import TaskQueueManager + + reply_subject = request.data.get("reply_subject") + + if reply_subject is None: + raise ValidationError("reply_subject is required") + + # Acknowledge the task via NATS + async def ack_task(): + async with TaskQueueManager() as manager: + return await manager.acknowledge_job(reply_subject) + + success = asyncio.run(ack_task()) + + # TODO: Record the job results + + if success: + return Response({"status": "acknowledged"}) + else: + return Response({"status": "failed to acknowledge"}, status=500) diff --git a/ami/utils/nats_queue.py b/ami/utils/nats_queue.py new file mode 100644 index 000000000..3de515d00 --- /dev/null +++ b/ami/utils/nats_queue.py @@ -0,0 +1,288 @@ +""" +NATS JetStream utility for task queue management in the antenna project. + +This module provides a TaskQueueManager that uses NATS JetStream for distributed +task queuing with acknowledgment support via reply subjects. This allows workers +to pull tasks over HTTP and acknowledge them later without maintaining a persistent +connection to NATS. +""" + +import json +import logging +from typing import Any + +import nats +from django.conf import settings +from nats.js import JetStreamContext +from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy + +logger = logging.getLogger(__name__) + + +async def get_connection(nats_url: str): + nc = await nats.connect(nats_url) + js = nc.jetstream() + return nc, js + + +class TaskQueueManager: + """ + Manager for NATS JetStream task queue operations. + + Use as an async context manager: + async with TaskQueueManager() as manager: + await manager.publish_job('job123', {'data': 'value'}) + task = await manager.reserve_job('job123') + await manager.acknowledge_job(task['reply_subject']) + """ + + def __init__(self, nats_url: str = None): + self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") + self.nc: nats.NATS | None = None + self.js: JetStreamContext | None = None + + async def __aenter__(self): + """Create connection on enter.""" + self.nc, self.js = await get_connection(self.nats_url) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.js: + self.js = None + if self.nc and not self.nc.is_closed: + await self.nc.close() + self.nc = None + + return False + + def _get_stream_name(self, job_id: str) -> str: + """Get stream name from job_id.""" + return f"job_{job_id}" + + def _get_subject(self, job_id: str) -> str: + """Get subject name from job_id.""" + return f"job.{job_id}.tasks" + + def _get_consumer_name(self, job_id: str) -> str: + """Get consumer name from job_id.""" + return f"job-{job_id}-consumer" + + async def _ensure_stream(self, job_id: str, ttr: int = 30): + """Ensure stream exists for the given job.""" + assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + + stream_name = self._get_stream_name(job_id) + subject = self._get_subject(job_id) + + try: + await self.js.stream_info(stream_name) + logger.debug(f"Stream {stream_name} already exists") + except Exception as e: + logger.warning(f"Stream {stream_name} does not exist: {e}") + # Stream doesn't exist, create it + await self.js.add_stream( + name=stream_name, + subjects=[subject], + max_age=86400, # 24 hours retention + ) + logger.info(f"Created stream {stream_name}") + + async def _ensure_consumer(self, job_id: str, ttr: int = 30): + """Ensure consumer exists for the given job.""" + assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + try: + info = await self.js.consumer_info(stream_name, consumer_name) + logger.debug(f"Consumer {consumer_name} already exists: {info}") + except Exception: + # Consumer doesn't exist, create it + await self.js.add_consumer( + stream=stream_name, + config=ConsumerConfig( + durable_name=consumer_name, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=ttr, # Visibility timeout (TTR) + max_deliver=5, # Max retry attempts + deliver_policy=DeliverPolicy.ALL, + max_ack_pending=100, # Max unacked messages + filter_subject=subject, + ), + ) + logger.info(f"Created consumer {consumer_name}") + + async def publish_job(self, job_id: str, data: dict[str, Any], ttr: int = 30) -> bool: + """ + Publish a job to a stream. + + Args: + job_id: The job ID (e.g., 'job123' or '123') + data: Job data (dict will be JSON-encoded) + ttr: Time-to-run in seconds (visibility timeout, default 30) + + Returns: + bool: True if successful, False otherwise + """ + assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id, ttr) + await self._ensure_consumer(job_id, ttr) + + subject = self._get_subject(job_id) + job_data = json.dumps(data) + + # Publish to JetStream + ack = await self.js.publish(subject, job_data.encode()) + + logger.info(f"Published job to stream for job '{job_id}', sequence {ack.seq}") + return True + + except Exception as e: + logger.error(f"Failed to publish job to stream for job '{job_id}': {e}") + return False + + async def reserve_job(self, job_id: str, timeout: int | None = None) -> dict[str, Any] | None: + """ + Reserve a job from the specified stream. + + Args: + job_id: The job ID to pull tasks from + timeout: Timeout in seconds for reservation (default: 5 seconds) + + Returns: + Dict with job details including 'reply_subject' for acknowledgment, or None if no job available + """ + assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + + if timeout is None: + timeout = 5 + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + # stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + # Create ephemeral subscription for this pull + psub = await self.js.pull_subscribe(subject, consumer_name) + + try: + # Fetch a single message + msgs = await psub.fetch(1, timeout=timeout) + + if msgs: + msg = msgs[0] + job_data = json.loads(msg.data.decode()) + metadata = msg.metadata + + result = { + "id": metadata.sequence.stream, + "body": job_data, + "reply_subject": msg.reply, # For acknowledgment + } + + logger.debug(f"Reserved job from stream for job '{job_id}', sequence {metadata.sequence.stream}") + return result + + except nats.errors.TimeoutError: + # No messages available + logger.debug(f"No jobs available in stream for job '{job_id}'") + return None + finally: + # Always unsubscribe + await psub.unsubscribe() + + except Exception as e: + logger.error(f"Failed to reserve job from stream for job '{job_id}': {e}") + return None + + async def acknowledge_job(self, reply_subject: str) -> bool: + """ + Acknowledge (delete) a completed job using its reply subject. + + Args: + reply_subject: The reply subject from reserve_job + + Returns: + bool: True if successful + """ + assert self.nc is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + + try: + await self.nc.publish(reply_subject, b"+ACK") + logger.debug(f"Acknowledged job with reply subject {reply_subject}") + return True + except Exception as e: + logger.error(f"Failed to acknowledge job: {e}") + return False + + async def delete_consumer(self, job_id: str) -> bool: + """ + Delete the consumer for a job. + + Args: + job_id: The job ID + + Returns: + bool: True if successful, False otherwise + """ + assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + + try: + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + + await self.js.delete_consumer(stream_name, consumer_name) + logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete consumer for job '{job_id}': {e}") + return False + + async def delete_stream(self, job_id: str) -> bool: + """ + Delete the stream for a job. + + Args: + job_id: The job ID + + Returns: + bool: True if successful, False otherwise + """ + assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + + try: + stream_name = self._get_stream_name(job_id) + + await self.js.delete_stream(stream_name) + logger.info(f"Deleted stream {stream_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete stream for job '{job_id}': {e}") + return False + + async def cleanup_job_resources(self, job_id: str) -> bool: + """ + Clean up all NATS resources (consumer and stream) for a job. + + This should be called when a job completes or is cancelled. + + Args: + job_id: The job ID + + Returns: + bool: True if successful, False otherwise + """ + # Delete consumer first, then stream + consumer_deleted = await self.delete_consumer(job_id) + stream_deleted = await self.delete_stream(job_id) + + return consumer_deleted and stream_deleted diff --git a/ami/utils/rabbitmq.py b/ami/utils/rabbitmq.py deleted file mode 100644 index f46276016..000000000 --- a/ami/utils/rabbitmq.py +++ /dev/null @@ -1,187 +0,0 @@ -""" -RabbitMQ utilities for the Antenna application. - -This module provides a simple interface for interacting with RabbitMQ -using the pika library. -""" - -import json -import logging -import os -from collections.abc import Callable -from typing import Any - -import pika -from django.conf import settings - -logger = logging.getLogger(__name__) - - -class RabbitMQConnection: - """ - A context manager for RabbitMQ connections. - """ - - def __init__(self, connection_url: str = ""): - self.connection_url: str - self.connection_url = connection_url or getattr(settings, "RABBITMQ_URL", "") - if not self.connection_url: - # Fallback to Django settings or environment variables - host = getattr(settings, "RABBITMQ_HOST", os.getenv("RABBITMQ_HOST", "localhost")) - port = getattr(settings, "RABBITMQ_PORT", int(os.getenv("RABBITMQ_PORT", "5672"))) - user = getattr(settings, "RABBITMQ_DJANGO_USER", os.getenv("RABBITMQ_DJANGO_USER", "guest")) - password = getattr(settings, "RABBITMQ_DJANGO_PASS", os.getenv("RABBITMQ_DJANGO_PASS", "guest")) - vhost = getattr(settings, "RABBITMQ_DEFAULT_VHOST", os.getenv("RABBITMQ_DEFAULT_VHOST", "/")) - self.connection_url = f"amqp://{user}:{password}@{host}:{port}{vhost}" # noqa: E231 - - self.connection = None - self.channel = None - - def __enter__(self): - try: - parameters = pika.URLParameters(self.connection_url) - self.connection = pika.BlockingConnection(parameters) - self.channel = self.connection.channel() - return self.channel - except Exception as e: - logger.error(f"Failed to connect to RabbitMQ: {e}") - raise - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.connection and not self.connection.is_closed: - self.connection.close() - - -class RabbitMQPublisher: - """ - A simple publisher for RabbitMQ messages. - """ - - def __init__(self, connection_url: str | None = None): - self.connection_url = connection_url - - def publish_message( - self, - queue_name: str, - message: dict[str, Any], - exchange: str = "", - routing_key: str | None = None, - durable: bool = True, - ) -> bool: - """ - Publish a message to a RabbitMQ queue. - - Args: - queue_name: Name of the queue to publish to - message: Message data (will be JSON serialized) - exchange: Exchange name (default: '') - routing_key: Routing key (default: queue_name) - durable: Whether the queue should be durable - - Returns: - bool: True if message was published successfully - """ - if routing_key is None: - routing_key = queue_name - - try: - with RabbitMQConnection(self.connection_url) as channel: - # Declare the queue - channel.queue_declare(queue=queue_name, durable=durable) - - # Publish the message - channel.basic_publish( - exchange=exchange, - routing_key=routing_key, - body=json.dumps(message), - properties=pika.BasicProperties( - delivery_mode=2 if durable else 1, # Make message persistent if durable - content_type="application/json", - ), - ) - - logger.info(f"Published message to queue '{queue_name}': {message}") - return True - - except Exception as e: - logger.error(f"Failed to publish message to queue '{queue_name}': {e}") - return False - - -class RabbitMQConsumer: - """ - A simple consumer for RabbitMQ messages. - """ - - def __init__(self, connection_url: str | None = None): - self.connection_url = connection_url - - def consume_messages( - self, queue_name: str, callback: Callable[[dict[str, Any]], None], durable: bool = True, auto_ack: bool = False - ): - """ - Consume messages from a RabbitMQ queue. - - Args: - queue_name: Name of the queue to consume from - callback: Function to call for each message - durable: Whether the queue should be durable - auto_ack: Whether to automatically acknowledge messages - """ - - def message_callback(ch, method, properties, body): - try: - message = json.loads(body.decode("utf-8")) - callback(message) - - if not auto_ack: - ch.basic_ack(delivery_tag=method.delivery_tag) - - except Exception as e: - logger.error(f"Error processing message from queue '{queue_name}': {e}") - if not auto_ack: - ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) - - try: - with RabbitMQConnection(self.connection_url) as channel: - # Declare the queue - channel.queue_declare(queue=queue_name, durable=durable) - - # Set up the consumer - channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=auto_ack) - - logger.info(f"Starting to consume messages from queue '{queue_name}'") - channel.start_consuming() - - except KeyboardInterrupt: - logger.info("Stopping consumer...") - if "channel" in locals(): - channel.stop_consuming() - except Exception as e: - logger.error(f"Error consuming from queue '{queue_name}': {e}") - raise - - -# Convenience functions -def publish_to_queue(queue_name: str, message: dict[str, Any], **kwargs) -> bool: - """ - Convenience function to publish a message to a queue. - """ - publisher = RabbitMQPublisher() - return publisher.publish_message(queue_name, message, **kwargs) - - -def test_connection() -> bool: - """ - Test the RabbitMQ connection. - - Returns: - bool: True if connection is successful - """ - try: - with RabbitMQConnection() as _: - logger.info("RabbitMQ connection test successful") - return True - except Exception as e: - logger.error(f"RabbitMQ connection test failed: {e}") - return False diff --git a/compose/local/rabbitmq/definitions.json b/compose/local/rabbitmq/definitions.json deleted file mode 100644 index 74fe80fe1..000000000 --- a/compose/local/rabbitmq/definitions.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "users": [ - { - "name": "admin", - "password": "admin123", - "tags": "administrator" - }, - { - "name": "django_app", - "password": "django_secure_pass", - "tags": "management" - } - ], - "vhosts": [ - { - "name": "/" - } - ], - "permissions": [ - { - "user": "admin", - "vhost": "/", - "configure": ".*", - "write": ".*", - "read": ".*" - }, - { - "user": "django_app", - "vhost": "/", - "configure": ".*", - "write": ".*", - "read": ".*" - } - ], - "exchanges": [], - "queues": [], - "bindings": [] -} diff --git a/compose/local/rabbitmq/rabbitmq.conf b/compose/local/rabbitmq/rabbitmq.conf deleted file mode 100644 index 99e737298..000000000 --- a/compose/local/rabbitmq/rabbitmq.conf +++ /dev/null @@ -1,19 +0,0 @@ -# Enable management plugin -management.tcp.port = 15672 - -# Load user definitions from JSON file -management.load_definitions = /etc/rabbitmq/definitions.json - -# Default virtual host -default_vhost = / - -# Logging -log.console = true -log.console.level = info - -# Memory and disk limits -vm_memory_high_watermark.relative = 0.6 -disk_free_limit.relative = 2.0 - -# Queue master locator -queue_master_locator = min-masters diff --git a/config/settings/base.py b/config/settings/base.py index 38aa2d86d..f9c59d9c9 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -263,14 +263,9 @@ } } -# RABBITMQ +# NATS # ------------------------------------------------------------------------------ -RABBITMQ_URL = env("RABBITMQ_URL", default="amqp://guest:guest@localhost:5672/") # type: ignore[no-untyped-call] -RABBITMQ_HOST = env("RABBITMQ_HOST", default="localhost") # type: ignore[no-untyped-call] -RABBITMQ_PORT = env.int("RABBITMQ_PORT", default=5672) # type: ignore[no-untyped-call] -RABBITMQ_DJANGO_USER = env("RABBITMQ_DJANGO_USER", default="guest") # type: ignore[no-untyped-call] -RABBITMQ_DJANGO_PASS = env("RABBITMQ_DJANGO_PASS", default="guest") # type: ignore[no-untyped-call] -RABBITMQ_DEFAULT_VHOST = env("RABBITMQ_DEFAULT_VHOST", default="/") # type: ignore[no-untyped-call] +NATS_URL = env("NATS_URL", default="nats://localhost:4222") # type: ignore[no-untyped-call] # ADMIN # ------------------------------------------------------------------------------ diff --git a/docker-compose.yml b/docker-compose.yml index f9567604c..33b4f29b4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,8 +8,6 @@ volumes: o: bind minio_data: driver: local - rabbitmq_data: - driver: local node_modules: services: @@ -23,7 +21,7 @@ services: depends_on: - postgres - redis - - rabbitmq + - nats - minio-init - ml_backend volumes: @@ -35,7 +33,10 @@ services: required: false ports: - "8000:8000" + - "5679:5679" command: /start + # for debugging with debugpy: + # command: python -m debugpy --listen 0.0.0.0:5679 -m django runserver 0.0.0.0:8000 networks: - default - antenna_network @@ -88,25 +89,18 @@ services: image: redis:6 container_name: ami_local_redis - rabbitmq: - image: rabbitmq:3.12-management - container_name: ami_local_rabbitmq - hostname: rabbitmq + nats: + image: nats:2.10-alpine + container_name: ami_local_nats + hostname: nats ports: - - "5672:5672" # AMQP port - - "15672:15672" # Management UI port - environment: - RABBITMQ_DEFAULT_USER: ${RABBITMQ_ADMIN_USER:-admin} - RABBITMQ_DEFAULT_PASS: ${RABBITMQ_ADMIN_PASS:-admin123} - RABBITMQ_DEFAULT_VHOST: ${RABBITMQ_DEFAULT_VHOST:-/} - volumes: - - rabbitmq_data:/var/lib/rabbitmq - - ./compose/local/rabbitmq/definitions.json:/etc/rabbitmq/definitions.json:ro - - ./compose/local/rabbitmq/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf:ro + - "4222:4222" # Client port + - "8222:8222" # HTTP monitoring port + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring healthcheck: - test: rabbitmq-diagnostics -q ping - interval: 30s - timeout: 30s + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s retries: 3 celeryworker: diff --git a/requirements/base.txt b/requirements/base.txt index 09b332d66..a0b1be14a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,7 +8,7 @@ celery==5.4.0 # pyup: < 6.0 # https://github.com/celery/celery django-celery-beat==2.5.0 # https://github.com/celery/django-celery-beat flower==2.0.1 # https://github.com/mher/flower kombu==5.4.2 -pika==1.3.2 # https://github.com/pika/pika +nats-py==2.10.0 # https://github.com/nats-io/nats.py uvicorn[standard]==0.22.0 # https://github.com/encode/uvicorn rich==13.5.0 markdown==3.4.4 From 8ea5d7d8ce7f2fbf3e7f97d377b2d99d9abf911f Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 17 Oct 2025 12:19:32 -0700 Subject: [PATCH 07/29] Saving of results --- ami/base/views.py | 4 +- ami/jobs/models.py | 2 +- ami/jobs/tasks.py | 79 +++++++++++++++++++++++++++ ami/jobs/views.py | 118 +++++++++++++++++++++++++++++----------- ami/utils/nats_queue.py | 2 +- 5 files changed, 170 insertions(+), 35 deletions(-) diff --git a/ami/base/views.py b/ami/base/views.py index aa8862ef5..bfda35e60 100644 --- a/ami/base/views.py +++ b/ami/base/views.py @@ -33,7 +33,9 @@ def get_active_project(self) -> Project | None: if not project_id: # Look for project_id in GET query parameters or POST data # POST data returns a list of ints, but QueryDict.get() returns a single value - project_id = self.request.query_params.get(param) or self.request.data.get(param) + project_id = self.request.query_params.get(param) or ( + self.request.data if isinstance(self.request.data, dict) else {} + ).get(param) project_id = SingleParamSerializer[int].clean( param_name=param, diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 844669696..339c44ab9 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -597,7 +597,7 @@ async def queue_all_images(): success = await manager.publish_job( job_id=job_id, data=message, - ttr=300, # 5 minutes visibility timeout + ttr=120, # visibility timeout in seconds ) except Exception as e: logger.error(f"Failed to queue image {image_pk} to stream for job '{job_id}': {e}") diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index b12271178..2993bbcdb 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,9 +1,11 @@ +import asyncio import logging from celery.result import AsyncResult from celery.signals import task_failure, task_postrun, task_prerun from ami.tasks import default_soft_time_limit, default_time_limit +from ami.utils.nats_queue import TaskQueueManager from config import celery_app logger = logging.getLogger(__name__) @@ -30,6 +32,83 @@ def run_job(self, job_id: int) -> None: job.logger.info(f"Finished job {job}") +@celery_app.task( + bind=True, + max_retries=3, + default_retry_delay=60, + autoretry_for=(Exception,), + soft_time_limit=300, # 5 minutes + time_limit=360, # 6 minutes +) +def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> dict: + """ + Process a single pipeline result asynchronously. + + This task: + 1. Deserializes the pipeline result + 2. Saves it to the database + 3. Acknowledges the task via NATS + + Args: + job_id: The job ID + result_json: JSON string of the pipeline result + reply_subject: NATS reply subject for acknowledgment + + Returns: + dict with status information + """ + + from ami.jobs.models import Job + from ami.ml.schemas import PipelineResultsResponse + + try: + job = Job.objects.get(pk=job_id) + job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") + + # Deserialize the result + pipeline_result = PipelineResultsResponse(**result_data) + + # Save to database (this is the slow operation) + if job.pipeline: + job.pipeline.save_results(results=pipeline_result, job_id=job.pk) + job.logger.info(f"Successfully saved results for job {job_id}") + else: + job.logger.warning(f"Job {job_id} has no pipeline, skipping save_results") + + # Acknowledge the task via NATS + try: + + async def ack_task(): + async with TaskQueueManager() as manager: + return await manager.acknowledge_job(reply_subject) + + ack_success = asyncio.run(ack_task()) + + if ack_success: + if ack_success: + job.logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") + else: + job.logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}") + except Exception as ack_error: + job.logger.error(f"Error acknowledging task via NATS: {ack_error}") + # Don't fail the task if ACK fails - data is already saved + + return { + "status": "success", + "job_id": job_id, + "reply_subject": reply_subject, + "acknowledged": ack_success if "ack_success" in locals() else False, + } + + except Job.DoesNotExist: + logger.error(f"Job {job_id} not found") + raise + except Exception as e: + logger.error(f"Failed to process pipeline result for job {job_id}: {e}") + # Celery will automatically retry based on autoretry_for + raise + + @task_postrun.connect(sender=run_job) @task_prerun.connect(sender=run_job) def update_job_status(sender, task_id, task, *args, **kwargs): diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 1f200e5a0..0246253d8 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -224,7 +224,7 @@ async def get_tasks(): tasks = [] async with TaskQueueManager() as manager: for i in range(batch): - task = await manager.reserve_job(job_id) + task = await manager.reserve_job(job_id, timeout=0.1) if task: tasks.append(task) return tasks @@ -235,36 +235,90 @@ async def get_tasks(): @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ - Acknowledge task completion. - - External services should POST task results with the reply_subject received - from the /tasks endpoint to acknowledge task completion. - - The request body should contain: - { - "reply_subject": "string", # Required: from the task response - "status": "completed" | "failed", # Optional - "result_data": {...}, # Optional - "error_message": "Error details...", # Optional for failed tasks - } + Submit pipeline results for asynchronous processing. + + This endpoint accepts a list of pipeline results and queues them for + background processing. Each result will be validated, saved to the database, + and acknowledged via NATS in a Celery task. + + The request body should be a list of results: + [ + { + "reply_subject": "string", # Required: from the task response + "result": { # Required: PipelineResultsResponse (kept as JSON) + "pipeline": "string", + "algorithms": {}, + "total_time": 0.0, + "source_images": [...], + "detections": [...], + "errors": null + } + }, + ... + ] """ - from ami.utils.nats_queue import TaskQueueManager - - reply_subject = request.data.get("reply_subject") - - if reply_subject is None: - raise ValidationError("reply_subject is required") - - # Acknowledge the task via NATS - async def ack_task(): - async with TaskQueueManager() as manager: - return await manager.acknowledge_job(reply_subject) - - success = asyncio.run(ack_task()) - - # TODO: Record the job results - if success: - return Response({"status": "acknowledged"}) - else: - return Response({"status": "failed to acknowledge"}, status=500) + from ami.jobs.tasks import process_pipeline_result + + job_id = pk if pk else self.kwargs.get("pk") + if not job_id: + raise ValidationError("Job ID is required") + job_id = int(job_id) + + # Validate request data is a list + if not isinstance(request.data, list): + raise ValidationError("Request body must be a list of results") + + if not request.data: + raise ValidationError("Request body cannot be empty") + + # Queue each result for background processing + queued_tasks = [] + + for idx, item in enumerate(request.data): + reply_subject = item.get("reply_subject") + result_data = item.get("result") + + if not reply_subject: + raise ValidationError(f"Item {idx}: reply_subject is required") + + if not result_data: + raise ValidationError(f"Item {idx}: result is required") + + try: + # Queue the background task + task = process_pipeline_result.delay( + job_id=job_id, result_data=result_data, reply_subject=reply_subject + ) + + queued_tasks.append( + { + "reply_subject": reply_subject, + "status": "queued", + "task_id": task.id, + } + ) + + logger.info( + f"Queued pipeline result processing for job {job_id}, " + f"task_id: {task.id}, reply_subject: {reply_subject}" + ) + + except Exception as e: + logger.error(f"Failed to queue result {idx} for job {job_id}: {e}") + queued_tasks.append( + { + "reply_subject": reply_subject, + "status": "error", + "error": str(e), + } + ) + + return Response( + { + "status": "accepted", + "job_id": job_id, + "results_queued": len([t for t in queued_tasks if t["status"] == "queued"]), + "tasks": queued_tasks, + } + ) diff --git a/ami/utils/nats_queue.py b/ami/utils/nats_queue.py index 3de515d00..430e361e8 100644 --- a/ami/utils/nats_queue.py +++ b/ami/utils/nats_queue.py @@ -146,7 +146,7 @@ async def publish_job(self, job_id: str, data: dict[str, Any], ttr: int = 30) -> logger.error(f"Failed to publish job to stream for job '{job_id}': {e}") return False - async def reserve_job(self, job_id: str, timeout: int | None = None) -> dict[str, Any] | None: + async def reserve_job(self, job_id: str, timeout: float | None = None) -> dict[str, Any] | None: """ Reserve a job from the specified stream. From 61fc2c5af653a75044b1375e17fa14d69f62fe7c Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 17 Oct 2025 16:06:51 -0700 Subject: [PATCH 08/29] Update progress --- ami/jobs/models.py | 24 +++++++++++-- ami/jobs/tasks.py | 89 ++++++++++++++++++++++++++++++++++++++-------- ami/jobs/views.py | 42 ++++++++++++++++++++-- 3 files changed, 134 insertions(+), 21 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 339c44ab9..efc8bacb0 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -360,7 +360,7 @@ def run(cls, job: "Job"): progress=0, ) - images = list( + images: list[SourceImage] = list( # @TODO return generator plus image count # @TODO pass to celery group chain? job.pipeline.collect_images( @@ -551,7 +551,7 @@ async def cleanup(): loop.close() @classmethod - def queue_images_to_nats(cls, job: "Job", images: list): + def queue_images_to_nats(cls, job: "Job", images: list[SourceImage]): """ Queue all images for a job to a NATS JetStream stream for the job. @@ -564,6 +564,8 @@ def queue_images_to_nats(cls, job: "Job", images: list): """ import asyncio + from django.core.cache import cache + from ami.utils.nats_queue import TaskQueueManager job_id = f"job{job.pk}" @@ -571,10 +573,13 @@ def queue_images_to_nats(cls, job: "Job", images: list): # Prepare all messages outside of async context to avoid Django ORM issues messages = [] + image_ids = [] for i, image in enumerate(images): + image_id = str(image.pk) + image_ids.append(image_id) message = { "job_id": job.pk, - "image_id": image.id if hasattr(image, "id") else image.pk, + "image_id": image_id, "image_url": image.url() if hasattr(image, "url") else None, "timestamp": ( image.timestamp.isoformat() if hasattr(image, "timestamp") and image.timestamp else None @@ -585,6 +590,15 @@ def queue_images_to_nats(cls, job: "Job", images: list): } messages.append((image.pk, message)) + # Store all image IDs in Redis for progress tracking + # TODO CGJS: put these formats in a common place + redis_key = f"job:{job.pk}:pending_images" # noqa E231 + redis_key_total = f"job:{job.pk}:pending_images_total" # noqa E231 + # TODO CGJS: Make the timeout proportional to the expected job duration, e.g. the number of images + cache.set(redis_key, image_ids, timeout=86400 * 7) # 7 days timeout + cache.set(redis_key_total, len(image_ids), timeout=86400 * 7) # 7 days timeout + job.logger.info(f"Stored {len(image_ids)} image IDs in Redis at key '{redis_key}'") + async def queue_all_images(): successful_queues = 0 failed_queues = 0 @@ -619,6 +633,10 @@ async def queue_all_images(): finally: loop.close() + if not images: + job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) + job.save() + # Log results (back in sync context) if successful_queues > 0: job.logger.info( diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 2993bbcdb..c7202bcd4 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -40,25 +40,28 @@ def run_job(self, job_id: int) -> None: soft_time_limit=300, # 5 minutes time_limit=360, # 6 minutes ) -def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> dict: +def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> None: """ Process a single pipeline result asynchronously. This task: 1. Deserializes the pipeline result 2. Saves it to the database - 3. Acknowledges the task via NATS + 3. Updates progress by removing processed image IDs from Redis + 4. Acknowledges the task via NATS Args: job_id: The job ID - result_json: JSON string of the pipeline result + result_data: Dictionary containing the pipeline result reply_subject: NATS reply subject for acknowledgment Returns: dict with status information """ - from ami.jobs.models import Job + from django.core.cache import cache + + from ami.jobs.models import Job, JobState from ami.ml.schemas import PipelineResultsResponse try: @@ -66,15 +69,70 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") # Deserialize the result - pipeline_result = PipelineResultsResponse(**result_data) # Save to database (this is the slow operation) - if job.pipeline: - job.pipeline.save_results(results=pipeline_result, job_id=job.pk) - job.logger.info(f"Successfully saved results for job {job_id}") - else: + if not job.pipeline: job.logger.warning(f"Job {job_id} has no pipeline, skipping save_results") + return + + # TODO CGJS: do we need this? it was for jobs that got in a bad state + # if len(result_data) == 0: + # job.logger.warning(f"Job {job_id} received empty result_data, skipping save_results") + # job.progress.update_stage( + # "results", + # status=JobState.SUCCESS, + # progress=1.0, + # ) + # job.progress.update_stage( + # "process", + # status=JobState.SUCCESS, + # progress=1.0, + # ) + # job.save() + # return + + job.logger.info(f"Successfully saved results for job {job_id}") + + # Update progress tracking in Redis + redis_key = f"job:{job.pk}:pending_images" # noqa E231 + redis_key_total = f"job:{job.pk}:pending_images_total" # noqa E231 + pending_images = cache.get(redis_key) + total_images = cache.get(redis_key_total) + logger.info(f"Pending images from Redis for job {job_id}: {len(pending_images)}/{total_images}") + progress_percentage = 1.0 + pipeline_result = PipelineResultsResponse(**result_data) + if pending_images is not None: + # Extract processed image IDs from the result + processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + + # Update Redis with remaining images + if remaining_images: + cache.set(redis_key, remaining_images, timeout=86400 * 7) # 7 days + else: + cache.delete(redis_key) + + # Calculate progress percentage + images_processed = total_images - len(remaining_images) + progress_percentage = float(images_processed) / total_images if total_images > 0 else 1.0 + + job.logger.info( + f"Job {job_id} progress: {images_processed}/{total_images} images processed " + f"({progress_percentage*100}%), {len(remaining_images)} remaining" + ) + + else: + job.logger.warning(f"No pending images found in Redis for job {job_id}, setting progress to 100%") + + job.progress.update_stage( + "process", + status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + progress=progress_percentage, + ) + job.save() + job.pipeline.save_results(results=pipeline_result, job_id=job.pk) # Acknowledge the task via NATS try: @@ -93,12 +151,13 @@ async def ack_task(): job.logger.error(f"Error acknowledging task via NATS: {ack_error}") # Don't fail the task if ACK fails - data is already saved - return { - "status": "success", - "job_id": job_id, - "reply_subject": reply_subject, - "acknowledged": ack_success if "ack_success" in locals() else False, - } + # Update job stage with calculated progress + job.progress.update_stage( + "results", + status=JobState.STARTED if remaining_images else JobState.SUCCESS, + progress=progress_percentage, + ) + job.save() except Job.DoesNotExist: logger.error(f"Job {job_id} not found") diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 0246253d8..356587fa6 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -171,18 +171,46 @@ def get_queryset(self) -> QuerySet: required=False, type=OpenApiTypes.BOOL, ), + OpenApiParameter( + name="incomplete_only", + description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", + required=False, + type=OpenApiTypes.BOOL, + ), ] ) def list(self, request, *args, **kwargs): # Check if ids_only parameter is set ids_only = request.query_params.get("ids_only", "false").lower() in ["true", "1", "yes"] + # Check if incomplete_only parameter is set + incomplete_only = request.query_params.get("incomplete_only", "false").lower() in ["true", "1", "yes"] + + # Get the base queryset + queryset = self.filter_queryset(self.get_queryset()) + + # Filter to incomplete jobs if requested (checks "results" stage status) + if incomplete_only: + from django.db.models import Q + + # Create filters for each final state to exclude + final_states = JobState.final_states() + exclude_conditions = Q() + + # Exclude jobs where the "results" stage has a final state status + for state in final_states: + # JSON path query to check if results stage status is in final states + exclude_conditions |= Q(progress__stages__contains=[{"key": "results", "status": state}]) + + queryset = queryset.exclude(exclude_conditions) + if ids_only: - # Get filtered queryset and return only IDs - queryset = self.filter_queryset(self.get_queryset()) + # Return only IDs job_ids = list(queryset.values_list("id", flat=True)) return Response({"job_ids": job_ids, "count": len(job_ids)}) + # Override the queryset for the list view + self.queryset = queryset return super().list(request, *args, **kwargs) @extend_schema( @@ -270,7 +298,15 @@ def result(self, request, pk=None): raise ValidationError("Request body must be a list of results") if not request.data: - raise ValidationError("Request body cannot be empty") + task = process_pipeline_result.delay(job_id=job_id, result_data={}, reply_subject="") + return Response( + { + "status": "accepted", + "job_id": job_id, + "tasks": [{"reply_subject": "", "status": "queued", "task_id": task.id}], + } + ) + # raise ValidationError("Request body cannot be empty") # Queue each result for background processing queued_tasks = [] From 9af597c18538dbd8ddf485cc5e3f07a573e64b0f Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 24 Oct 2025 11:19:39 -0700 Subject: [PATCH 09/29] Clean up and refactor task state mgmt --- ami/jobs/models.py | 70 ++++++++++++----------------- ami/jobs/task_state.py | 100 +++++++++++++++++++++++++++++++++++++++++ ami/jobs/tasks.py | 63 +++++++------------------- ami/jobs/views.py | 20 +++------ 4 files changed, 150 insertions(+), 103 deletions(-) create mode 100644 ami/jobs/task_state.py diff --git a/ami/jobs/models.py b/ami/jobs/models.py index efc8bacb0..99a141a63 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -1,3 +1,4 @@ +import asyncio import datetime import logging import random @@ -15,9 +16,11 @@ from ami.base.models import BaseModel from ami.base.schemas import ConfigurableStage, ConfigurableStageParam +from ami.jobs.task_state import TaskStateManager from ami.jobs.tasks import run_job from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.utils.nats_queue import TaskQueueManager from ami.utils.schemas import OrderedEnum logger = logging.getLogger(__name__) @@ -394,10 +397,9 @@ def run(cls, job: "Job"): # End image collection stage job.save() - # WIP: don't commit - # TODO: do this conditionally based on the type of processing service this job is using - # cls.process_images(job, images) - cls.queue_images_to_nats(job, images) + cls.process_images(job, images) + # TODO CGJS: do this conditionally based on the type of processing service this job is using + # cls.queue_images_to_nats(job, images) @classmethod def process_images(cls, job, images): @@ -525,10 +527,6 @@ def cleanup_nats_resources(cls, job: "Job"): Args: job: The Job instance """ - import asyncio - - from ami.utils.nats_queue import TaskQueueManager - job_id = f"job{job.pk}" async def cleanup(): @@ -536,19 +534,7 @@ async def cleanup(): success = await manager.cleanup_job_resources(job_id) return success - # Run cleanup in a new event loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - success = loop.run_until_complete(cleanup()) - if success: - job.logger.info(f"Cleaned up NATS resources for job '{job_id}'") - else: - job.logger.warning(f"Failed to fully clean up NATS resources for job '{job_id}'") - except Exception as e: - job.logger.error(f"Error cleaning up NATS resources for job '{job_id}': {e}") - finally: - loop.close() + _run_in_async_loop(cleanup, f"cleaning up NATS resources for job '{job_id}'") @classmethod def queue_images_to_nats(cls, job: "Job", images: list[SourceImage]): @@ -562,12 +548,6 @@ def queue_images_to_nats(cls, job: "Job", images: list[SourceImage]): Returns: bool: True if all images were successfully queued, False otherwise """ - import asyncio - - from django.core.cache import cache - - from ami.utils.nats_queue import TaskQueueManager - job_id = f"job{job.pk}" job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job_id}'") @@ -591,13 +571,9 @@ def queue_images_to_nats(cls, job: "Job", images: list[SourceImage]): messages.append((image.pk, message)) # Store all image IDs in Redis for progress tracking - # TODO CGJS: put these formats in a common place - redis_key = f"job:{job.pk}:pending_images" # noqa E231 - redis_key_total = f"job:{job.pk}:pending_images_total" # noqa E231 - # TODO CGJS: Make the timeout proportional to the expected job duration, e.g. the number of images - cache.set(redis_key, image_ids, timeout=86400 * 7) # 7 days timeout - cache.set(redis_key_total, len(image_ids), timeout=86400 * 7) # 7 days timeout - job.logger.info(f"Stored {len(image_ids)} image IDs in Redis at key '{redis_key}'") + state_manager = TaskStateManager(job.pk) + state_manager.initialize_job(image_ids) + job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") async def queue_all_images(): successful_queues = 0 @@ -624,14 +600,11 @@ async def queue_all_images(): return successful_queues, failed_queues - # Run the async function in a new event loop to avoid conflicts with Django - # Use new_event_loop() to ensure we're not mixing with Django's async context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - successful_queues, failed_queues = loop.run_until_complete(queue_all_images()) - finally: - loop.close() + result = _run_in_async_loop(queue_all_images, f"queuing images to NATS for job '{job_id}'") + if result is None: + job.logger.error(f"Failed to queue images to NATS for job '{job_id}'") + return False + successful_queues, failed_queues = result if not images: job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) @@ -1104,3 +1077,16 @@ class Meta: # permissions = [ # ("run_job", "Can run a job"), # ("cancel_job", "Can cancel a job"), + + +def _run_in_async_loop(func: typing.Callable, error_msg: str) -> typing.Any: + # helper to use new_event_loop() to ensure we're not mixing with Django's async context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(func()) + except Exception as e: + logger.error(f"Error in async loop - {error_msg}: {e}") + return None + finally: + loop.close() diff --git a/ami/jobs/task_state.py b/ami/jobs/task_state.py new file mode 100644 index 000000000..07b5415a0 --- /dev/null +++ b/ami/jobs/task_state.py @@ -0,0 +1,100 @@ +""" +Task state management for job progress tracking using Redis. +""" + +import logging +from collections import namedtuple + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +# Define a namedtuple for a TaskProgress with x and y coordinates +TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) + + +class TaskStateManager: + """ + Manages job progress tracking state in Redis. + + Tracks pending images for jobs to calculate progress percentages + as workers process images asynchronously. + """ + + TIMEOUT = 86400 * 7 # 7 days in seconds + + def __init__(self, job_id: int): + """ + Initialize the task state manager for a specific job. + + Args: + job_id: The job primary key + """ + self.job_id = job_id + self._pending_key = f"job:{job_id}:pending_images" # noqa E231 + self._total_key = f"job:{job_id}:pending_images_total" # noqa E231 + + def initialize_job(self, image_ids: list[str]) -> None: + """ + Initialize job tracking with a list of image IDs to process. + + Args: + image_ids: List of image IDs that need to be processed + """ + cache.set(self._pending_key, image_ids, timeout=self.TIMEOUT) + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + + def mark_images_processed(self, processed_image_ids: set[str]) -> None: + """ + Mark a set of images as processed by removing them from pending list. + + Args: + processed_image_ids: Set of image IDs that have been processed + """ + pending_images = cache.get(self._pending_key) + if pending_images is None: + return + + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + + if remaining_images: + cache.set(self._pending_key, remaining_images, timeout=self.TIMEOUT) + else: + cache.delete(self._pending_key) + + def get_progress(self) -> TaskProgress | None: + """ + Get current progress information for the job. + + Returns: + TaskProgress namedtuple with fields: + - remaining: Number of images still pending (or None if not tracked) + - total: Total number of images (or None if not tracked) + - processed: Number of images processed (or None if not tracked) + - percentage: Progress as float 0.0-1.0 (or None if not tracked) + """ + pending_images = cache.get(self._pending_key) + total_images = cache.get(self._total_key) + + if pending_images is None or total_images is None: + return None + + remaining = len(pending_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + logger.info(f"Pending images from Redis for job {self.job_id}: " f"{remaining}/{total_images}") + + return TaskProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + ) + + def cleanup(self) -> None: + """ + Delete all Redis keys associated with this job. + """ + cache.delete(self._pending_key) + cache.delete(self._total_key) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index c7202bcd4..c397a58bc 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -4,6 +4,9 @@ from celery.result import AsyncResult from celery.signals import task_failure, task_postrun, task_prerun +from ami.jobs.models import Job, JobState +from ami.jobs.task_state import TaskStateManager +from ami.ml.schemas import PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from ami.utils.nats_queue import TaskQueueManager from config import celery_app @@ -59,71 +62,39 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: dict with status information """ - from django.core.cache import cache - - from ami.jobs.models import Job, JobState - from ami.ml.schemas import PipelineResultsResponse - try: job = Job.objects.get(pk=job_id) job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") - # Deserialize the result - # Save to database (this is the slow operation) if not job.pipeline: job.logger.warning(f"Job {job_id} has no pipeline, skipping save_results") return - # TODO CGJS: do we need this? it was for jobs that got in a bad state - # if len(result_data) == 0: - # job.logger.warning(f"Job {job_id} received empty result_data, skipping save_results") - # job.progress.update_stage( - # "results", - # status=JobState.SUCCESS, - # progress=1.0, - # ) - # job.progress.update_stage( - # "process", - # status=JobState.SUCCESS, - # progress=1.0, - # ) - # job.save() - # return - job.logger.info(f"Successfully saved results for job {job_id}") - # Update progress tracking in Redis - redis_key = f"job:{job.pk}:pending_images" # noqa E231 - redis_key_total = f"job:{job.pk}:pending_images_total" # noqa E231 - pending_images = cache.get(redis_key) - total_images = cache.get(redis_key_total) - logger.info(f"Pending images from Redis for job {job_id}: {len(pending_images)}/{total_images}") - progress_percentage = 1.0 + # Deserialize the result pipeline_result = PipelineResultsResponse(**result_data) - if pending_images is not None: - # Extract processed image IDs from the result - processed_image_ids = {str(img.id) for img in pipeline_result.source_images} - remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + # Update progress tracking in Redis + state_manager = TaskStateManager(job.pk) + processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + state_manager.mark_images_processed(processed_image_ids) - # Update Redis with remaining images - if remaining_images: - cache.set(redis_key, remaining_images, timeout=86400 * 7) # 7 days - else: - cache.delete(redis_key) + progress_info = state_manager.get_progress() + progress_percentage = 0.0 - # Calculate progress percentage - images_processed = total_images - len(remaining_images) - progress_percentage = float(images_processed) / total_images if total_images > 0 else 1.0 + if progress_info is not None: + # Get updated progress + progress_percentage = progress_info.percentage job.logger.info( - f"Job {job_id} progress: {images_processed}/{total_images} images processed " - f"({progress_percentage*100}%), {len(remaining_images)} remaining" + f"Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " + f"({progress_percentage*100}%), {progress_info.remaining} remaining" ) - else: job.logger.warning(f"No pending images found in Redis for job {job_id}, setting progress to 100%") + progress_percentage = 1.0 job.progress.update_stage( "process", @@ -154,7 +125,7 @@ async def ack_task(): # Update job stage with calculated progress job.progress.update_stage( "results", - status=JobState.STARTED if remaining_images else JobState.SUCCESS, + status=JobState.STARTED if progress_percentage < 1.0 else JobState.SUCCESS, progress=progress_percentage, ) job.save() diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 356587fa6..64ed212e7 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,6 +1,6 @@ -import asyncio import logging +from asgiref.sync import async_to_sync from django.db.models.query import QuerySet from django.forms import IntegerField from django.utils import timezone @@ -12,6 +12,7 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin +from ami.jobs.tasks import process_pipeline_result from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param from ami.utils.requests import project_id_doc_param @@ -257,7 +258,9 @@ async def get_tasks(): tasks.append(task) return tasks - tasks = asyncio.run(get_tasks()) + # Use async_to_sync to properly handle the async call + tasks = async_to_sync(get_tasks)() + return Response({"tasks": tasks}) @action(detail=True, methods=["post"], name="result") @@ -286,8 +289,6 @@ def result(self, request, pk=None): ] """ - from ami.jobs.tasks import process_pipeline_result - job_id = pk if pk else self.kwargs.get("pk") if not job_id: raise ValidationError("Job ID is required") @@ -297,17 +298,6 @@ def result(self, request, pk=None): if not isinstance(request.data, list): raise ValidationError("Request body must be a list of results") - if not request.data: - task = process_pipeline_result.delay(job_id=job_id, result_data={}, reply_subject="") - return Response( - { - "status": "accepted", - "job_id": job_id, - "tasks": [{"reply_subject": "", "status": "queued", "task_id": task.id}], - } - ) - # raise ValidationError("Request body cannot be empty") - # Queue each result for background processing queued_tasks = [] From 7ff88652e67bc9bf07e16c3746ca2532c5e53837 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 24 Oct 2025 11:29:49 -0700 Subject: [PATCH 10/29] fix async use --- ami/jobs/models.py | 17 ++--------------- ami/jobs/tasks.py | 11 +++++------ ami/jobs/utils.py | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 21 deletions(-) create mode 100644 ami/jobs/utils.py diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 99a141a63..86c04a7b4 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -1,4 +1,3 @@ -import asyncio import datetime import logging import random @@ -18,6 +17,7 @@ from ami.base.schemas import ConfigurableStage, ConfigurableStageParam from ami.jobs.task_state import TaskStateManager from ami.jobs.tasks import run_job +from ami.jobs.utils import _run_in_async_loop from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.utils.nats_queue import TaskQueueManager @@ -371,7 +371,7 @@ def run(cls, job: "Job"): deployment=job.deployment, source_images=[job.source_image_single] if job.source_image_single else None, job_id=job.pk, - skip_processed=False, # WIP don't commit + skip_processed=True, # shuffle=job.shuffle, ) ) @@ -1077,16 +1077,3 @@ class Meta: # permissions = [ # ("run_job", "Can run a job"), # ("cancel_job", "Can cancel a job"), - - -def _run_in_async_loop(func: typing.Callable, error_msg: str) -> typing.Any: - # helper to use new_event_loop() to ensure we're not mixing with Django's async context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(func()) - except Exception as e: - logger.error(f"Error in async loop - {error_msg}: {e}") - return None - finally: - loop.close() diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index c397a58bc..69dabfeae 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,4 +1,3 @@ -import asyncio import logging from celery.result import AsyncResult @@ -6,6 +5,7 @@ from ami.jobs.models import Job, JobState from ami.jobs.task_state import TaskStateManager +from ami.jobs.utils import _run_in_async_loop from ami.ml.schemas import PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from ami.utils.nats_queue import TaskQueueManager @@ -111,13 +111,12 @@ async def ack_task(): async with TaskQueueManager() as manager: return await manager.acknowledge_job(reply_subject) - ack_success = asyncio.run(ack_task()) + ack_success = _run_in_async_loop(ack_task, f"acknowledging job {job.pk} via NATS") if ack_success: - if ack_success: - job.logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") - else: - job.logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}") + job.logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") + else: + job.logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}") except Exception as ack_error: job.logger.error(f"Error acknowledging task via NATS: {ack_error}") # Don't fail the task if ACK fails - data is already saved diff --git a/ami/jobs/utils.py b/ami/jobs/utils.py new file mode 100644 index 000000000..c9c71052d --- /dev/null +++ b/ami/jobs/utils.py @@ -0,0 +1,17 @@ +import asyncio +import typing + +from ami.jobs.models import logger + + +def _run_in_async_loop(func: typing.Callable, error_msg: str) -> typing.Any: + # helper to use new_event_loop() to ensure we're not mixing with Django's async context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(func()) + except Exception as e: + logger.error(f"Error in async loop - {error_msg}: {e}") + return None + finally: + loop.close() From 7899fc5c62a701585fd44c9c91d50b8cf1265290 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 24 Oct 2025 15:46:58 -0700 Subject: [PATCH 11/29] Fix circular dependency, jobset query by pipeline slug --- ami/jobs/task_state.py | 5 +--- ami/jobs/tasks.py | 2 +- ami/jobs/utils.py | 3 +- ami/jobs/views.py | 65 ++++++++++++++++++------------------------ ami/utils/requests.py | 20 +++++++++++++ 5 files changed, 51 insertions(+), 44 deletions(-) diff --git a/ami/jobs/task_state.py b/ami/jobs/task_state.py index 07b5415a0..76932b3d3 100644 --- a/ami/jobs/task_state.py +++ b/ami/jobs/task_state.py @@ -58,10 +58,7 @@ def mark_images_processed(self, processed_image_ids: set[str]) -> None: remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] - if remaining_images: - cache.set(self._pending_key, remaining_images, timeout=self.TIMEOUT) - else: - cache.delete(self._pending_key) + cache.set(self._pending_key, remaining_images, timeout=self.TIMEOUT) def get_progress(self) -> TaskProgress | None: """ diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 69dabfeae..b5bed3018 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -3,7 +3,6 @@ from celery.result import AsyncResult from celery.signals import task_failure, task_postrun, task_prerun -from ami.jobs.models import Job, JobState from ami.jobs.task_state import TaskStateManager from ami.jobs.utils import _run_in_async_loop from ami.ml.schemas import PipelineResultsResponse @@ -61,6 +60,7 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: Returns: dict with status information """ + from ami.jobs.models import Job, JobState # avoid circular import try: job = Job.objects.get(pk=job_id) diff --git a/ami/jobs/utils.py b/ami/jobs/utils.py index c9c71052d..3dbceb5de 100644 --- a/ami/jobs/utils.py +++ b/ami/jobs/utils.py @@ -1,7 +1,8 @@ import asyncio +import logging import typing -from ami.jobs.models import logger +logger = logging.getLogger(__name__) def _run_in_async_loop(func: typing.Callable, error_msg: str) -> typing.Any: diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 64ed212e7..ae10308be 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -4,8 +4,8 @@ from django.db.models.query import QuerySet from django.forms import IntegerField from django.utils import timezone -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiParameter, extend_schema +from django_filters import rest_framework as filters +from drf_spectacular.utils import extend_schema from rest_framework.decorators import action from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.response import Response @@ -15,7 +15,7 @@ from ami.jobs.tasks import process_pipeline_result from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param -from ami.utils.requests import project_id_doc_param +from ami.utils.requests import batch_param, ids_only_param, incomplete_only_param, project_id_doc_param from .models import Job, JobState from .serializers import JobListSerializer, JobSerializer @@ -23,6 +23,26 @@ logger = logging.getLogger(__name__) +class JobFilterSet(filters.FilterSet): + """Custom filterset to enable pipeline name filtering.""" + + pipeline__slug = filters.CharFilter(field_name="pipeline__slug", lookup_expr="exact") + + class Meta: + model = Job + fields = [ + "status", + "project", + "deployment", + "source_image_collection", + "source_image_single", + "pipeline", + "pipeline__name", + "pipeline__slug", + "job_type_key", + ] + + class JobViewSet(DefaultViewSet, ProjectMixin): """ API endpoint that allows jobs to be viewed or edited. @@ -49,15 +69,7 @@ class JobViewSet(DefaultViewSet, ProjectMixin): "source_image_single", ) serializer_class = JobSerializer - filterset_fields = [ - "status", - "project", - "deployment", - "source_image_collection", - "source_image_single", - "pipeline", - "job_type_key", - ] + filterset_class = JobFilterSet search_fields = ["name", "pipeline__name"] ordering_fields = [ "name", @@ -160,24 +172,8 @@ def get_queryset(self) -> QuerySet: @extend_schema( parameters=[ project_id_doc_param, - OpenApiParameter( - name="pipeline", - description="Filter jobs by pipeline ID", - required=False, - type=OpenApiTypes.INT, - ), - OpenApiParameter( - name="ids_only", - description="Return only job IDs instead of full job objects", - required=False, - type=OpenApiTypes.BOOL, - ), - OpenApiParameter( - name="incomplete_only", - description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", - required=False, - type=OpenApiTypes.BOOL, - ), + ids_only_param, + incomplete_only_param, ] ) def list(self, request, *args, **kwargs): @@ -215,14 +211,7 @@ def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) @extend_schema( - parameters=[ - OpenApiParameter( - name="batch", - description="Number of tasks to pull in the batch", - required=False, - type=OpenApiTypes.INT, - ), - ], + parameters=[batch_param], responses={200: dict}, ) @action(detail=True, methods=["get"], name="tasks") diff --git a/ami/utils/requests.py b/ami/utils/requests.py index dca9c5c43..35c3fa597 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -2,6 +2,7 @@ import requests from django.forms import BooleanField, FloatField +from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter from requests.adapters import HTTPAdapter from rest_framework.request import Request @@ -110,3 +111,22 @@ def get_default_classification_threshold(project: "Project | None" = None, reque required=False, type=int, ) + +ids_only_param = OpenApiParameter( + name="ids_only", + description="Return only job IDs instead of full job objects", + required=False, + type=OpenApiTypes.BOOL, +) +incomplete_only_param = OpenApiParameter( + name="incomplete_only", + description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", + required=False, + type=OpenApiTypes.BOOL, +) +batch_param = OpenApiParameter( + name="batch", + description="Number of tasks to pull in the batch", + required=False, + type=OpenApiTypes.INT, +) From d9f8ffd135649dbe658fb78cca0631f3160bd279 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 24 Oct 2025 16:01:04 -0700 Subject: [PATCH 12/29] GH review comments --- ami/base/views.py | 7 +++---- ami/jobs/task_state.py | 2 +- ami/utils/nats_queue.py | 23 +++++++++++++++-------- requirements/base.txt | 2 +- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/ami/base/views.py b/ami/base/views.py index ced0316e2..0763d9832 100644 --- a/ami/base/views.py +++ b/ami/base/views.py @@ -37,10 +37,9 @@ def get_active_project( # If not in URL, try query parameters if not project_id: # Look for project_id in GET query parameters or POST data - # POST data returns a list of ints, but QueryDict.get() returns a single value - project_id = request.query_params.get(param) or (request.data if isinstance(request.data, dict) else {}).get( - param - ) + # request.data may not always be a dict (e.g., for non-POST requests), so we check its type + post_data = request.data if isinstance(request.data, dict) else {} + project_id = request.query_params.get(param) or post_data.get(param) project_id = SingleParamSerializer[int].clean( param_name=param, diff --git a/ami/jobs/task_state.py b/ami/jobs/task_state.py index 76932b3d3..55edcc1b0 100644 --- a/ami/jobs/task_state.py +++ b/ami/jobs/task_state.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -# Define a namedtuple for a TaskProgress with x and y coordinates +# Define a namedtuple for a TaskProgress with the image counts TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) diff --git a/ami/utils/nats_queue.py b/ami/utils/nats_queue.py index 430e361e8..51b1c385b 100644 --- a/ami/utils/nats_queue.py +++ b/ami/utils/nats_queue.py @@ -36,7 +36,7 @@ class TaskQueueManager: await manager.acknowledge_job(task['reply_subject']) """ - def __init__(self, nats_url: str = None): + def __init__(self, nats_url: str | None = None): self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") self.nc: nats.NATS | None = None self.js: JetStreamContext | None = None @@ -69,7 +69,8 @@ def _get_consumer_name(self, job_id: str) -> str: async def _ensure_stream(self, job_id: str, ttr: int = 30): """Ensure stream exists for the given job.""" - assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") stream_name = self._get_stream_name(job_id) subject = self._get_subject(job_id) @@ -89,7 +90,8 @@ async def _ensure_stream(self, job_id: str, ttr: int = 30): async def _ensure_consumer(self, job_id: str, ttr: int = 30): """Ensure consumer exists for the given job.""" - assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") stream_name = self._get_stream_name(job_id) consumer_name = self._get_consumer_name(job_id) @@ -126,7 +128,8 @@ async def publish_job(self, job_id: str, data: dict[str, Any], ttr: int = 30) -> Returns: bool: True if successful, False otherwise """ - assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") try: # Ensure stream and consumer exist @@ -157,7 +160,8 @@ async def reserve_job(self, job_id: str, timeout: float | None = None) -> dict[s Returns: Dict with job details including 'reply_subject' for acknowledgment, or None if no job available """ - assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") if timeout is None: timeout = 5 @@ -214,7 +218,8 @@ async def acknowledge_job(self, reply_subject: str) -> bool: Returns: bool: True if successful """ - assert self.nc is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + if self.nc is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") try: await self.nc.publish(reply_subject, b"+ACK") @@ -234,7 +239,8 @@ async def delete_consumer(self, job_id: str) -> bool: Returns: bool: True if successful, False otherwise """ - assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") try: stream_name = self._get_stream_name(job_id) @@ -257,7 +263,8 @@ async def delete_stream(self, job_id: str) -> bool: Returns: bool: True if successful, False otherwise """ - assert self.js is not None, "Connection is not open. Use TaskQueueManager as an async context manager." + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") try: stream_name = self._get_stream_name(job_id) diff --git a/requirements/base.txt b/requirements/base.txt index a0b1be14a..d6f27a4ec 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -53,7 +53,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail/ Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg -# psycopg==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing From edad552eb6030af777c3cddb21a5844b0bd61099 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 29 Oct 2025 14:06:17 -0700 Subject: [PATCH 13/29] Add feature flag, rename "job" to "task" --- ami/jobs/models.py | 9 ++++---- ami/jobs/tasks.py | 2 +- ami/jobs/views.py | 2 +- ami/main/models.py | 1 + ami/utils/nats_queue.py | 47 ++++++++++++++++++++--------------------- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index e2bd058f3..a29a2507a 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -398,9 +398,10 @@ def run(cls, job: "Job"): # End image collection stage job.save() - cls.process_images(job, images) - # TODO CGJS: do this conditionally based on the type of processing service this job is using - # cls.queue_images_to_nats(job, images) + if job.project.feature_flags.async_pipeline_workers: + cls.queue_images_to_nats(job, images) + else: + cls.process_images(job, images) @classmethod def process_images(cls, job, images): @@ -585,7 +586,7 @@ async def queue_all_images(): try: logger.info(f"Queueing image {image_pk} to stream for job '{job_id}': {message}") # Use TTR of 300 seconds (5 minutes) for image processing - success = await manager.publish_job( + success = await manager.publish_task( job_id=job_id, data=message, ttr=120, # visibility timeout in seconds diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index b5bed3018..2b385ae49 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -109,7 +109,7 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: async def ack_task(): async with TaskQueueManager() as manager: - return await manager.acknowledge_job(reply_subject) + return await manager.acknowledge_task(reply_subject) ack_success = _run_in_async_loop(ack_task, f"acknowledging job {job.pk} via NATS") diff --git a/ami/jobs/views.py b/ami/jobs/views.py index ae10308be..3ec068b31 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -242,7 +242,7 @@ async def get_tasks(): tasks = [] async with TaskQueueManager() as manager: for i in range(batch): - task = await manager.reserve_job(job_id, timeout=0.1) + task = await manager.reserve_task(job_id, timeout=0.1) if task: tasks.append(task) return tasks diff --git a/ami/main/models.py b/ami/main/models.py index 231fef6f6..123bedbdd 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -215,6 +215,7 @@ class ProjectFeatureFlags(pydantic.BaseModel): tags: bool = False # Whether the project supports tagging taxa reprocess_existing_detections: bool = False # Whether to reprocess existing detections default_filters: bool = False # Whether to show default filters form in UI + async_pipeline_workers: bool = False # Whether to use async pipeline workers that pull tasks from a queue def get_default_feature_flags() -> ProjectFeatureFlags: diff --git a/ami/utils/nats_queue.py b/ami/utils/nats_queue.py index 51b1c385b..c566c6235 100644 --- a/ami/utils/nats_queue.py +++ b/ami/utils/nats_queue.py @@ -31,9 +31,9 @@ class TaskQueueManager: Use as an async context manager: async with TaskQueueManager() as manager: - await manager.publish_job('job123', {'data': 'value'}) - task = await manager.reserve_job('job123') - await manager.acknowledge_job(task['reply_subject']) + await manager.publish_task('job123', {'data': 'value'}) + task = await manager.reserve_task('job123') + await manager.acknowledge_task(task['reply_subject']) """ def __init__(self, nats_url: str | None = None): @@ -116,13 +116,13 @@ async def _ensure_consumer(self, job_id: str, ttr: int = 30): ) logger.info(f"Created consumer {consumer_name}") - async def publish_job(self, job_id: str, data: dict[str, Any], ttr: int = 30) -> bool: + async def publish_task(self, job_id: str, data: dict[str, Any], ttr: int = 30) -> bool: """ - Publish a job to a stream. + Publish a task to it's job queue. Args: job_id: The job ID (e.g., 'job123' or '123') - data: Job data (dict will be JSON-encoded) + data: Task data (dict will be JSON-encoded) ttr: Time-to-run in seconds (visibility timeout, default 30) Returns: @@ -137,28 +137,28 @@ async def publish_job(self, job_id: str, data: dict[str, Any], ttr: int = 30) -> await self._ensure_consumer(job_id, ttr) subject = self._get_subject(job_id) - job_data = json.dumps(data) + task_data = json.dumps(data) # Publish to JetStream - ack = await self.js.publish(subject, job_data.encode()) + ack = await self.js.publish(subject, task_data.encode()) - logger.info(f"Published job to stream for job '{job_id}', sequence {ack.seq}") + logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") return True except Exception as e: - logger.error(f"Failed to publish job to stream for job '{job_id}': {e}") + logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") return False - async def reserve_job(self, job_id: str, timeout: float | None = None) -> dict[str, Any] | None: + async def reserve_task(self, job_id: str, timeout: float | None = None) -> dict[str, Any] | None: """ - Reserve a job from the specified stream. + Reserve a task from the specified stream. Args: job_id: The job ID to pull tasks from timeout: Timeout in seconds for reservation (default: 5 seconds) Returns: - Dict with job details including 'reply_subject' for acknowledgment, or None if no job available + Dict with task details including 'reply_subject' for acknowledgment, or None if no task available """ if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") @@ -171,7 +171,6 @@ async def reserve_job(self, job_id: str, timeout: float | None = None) -> dict[s await self._ensure_stream(job_id) await self._ensure_consumer(job_id) - # stream_name = self._get_stream_name(job_id) consumer_name = self._get_consumer_name(job_id) subject = self._get_subject(job_id) @@ -184,36 +183,36 @@ async def reserve_job(self, job_id: str, timeout: float | None = None) -> dict[s if msgs: msg = msgs[0] - job_data = json.loads(msg.data.decode()) + task_data = json.loads(msg.data.decode()) metadata = msg.metadata result = { "id": metadata.sequence.stream, - "body": job_data, + "body": task_data, "reply_subject": msg.reply, # For acknowledgment } - logger.debug(f"Reserved job from stream for job '{job_id}', sequence {metadata.sequence.stream}") + logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") return result except nats.errors.TimeoutError: # No messages available - logger.debug(f"No jobs available in stream for job '{job_id}'") + logger.debug(f"No tasks available in stream for job '{job_id}'") return None finally: # Always unsubscribe await psub.unsubscribe() except Exception as e: - logger.error(f"Failed to reserve job from stream for job '{job_id}': {e}") + logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") return None - async def acknowledge_job(self, reply_subject: str) -> bool: + async def acknowledge_task(self, reply_subject: str) -> bool: """ - Acknowledge (delete) a completed job using its reply subject. + Acknowledge (delete) a completed task using its reply subject. Args: - reply_subject: The reply subject from reserve_job + reply_subject: The reply subject from reserve_task Returns: bool: True if successful @@ -223,10 +222,10 @@ async def acknowledge_job(self, reply_subject: str) -> bool: try: await self.nc.publish(reply_subject, b"+ACK") - logger.debug(f"Acknowledged job with reply subject {reply_subject}") + logger.debug(f"Acknowledged task with reply subject {reply_subject}") return True except Exception as e: - logger.error(f"Failed to acknowledge job: {e}") + logger.error(f"Failed to acknowledge task: {e}") return False async def delete_consumer(self, job_id: str) -> bool: From d2548670e385d97f2724e3940689325eca4692d9 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 31 Oct 2025 14:09:11 -0700 Subject: [PATCH 14/29] Code reorganization --- ami/jobs/models.py | 112 +----------------- ami/jobs/tasks.py | 8 +- ami/jobs/views.py | 2 +- ami/ml/orchestration/jobs.py | 107 +++++++++++++++++ ami/{utils => ml/orchestration}/nats_queue.py | 0 ami/{jobs => ml/orchestration}/task_state.py | 0 ami/{jobs => ml/orchestration}/utils.py | 2 +- 7 files changed, 117 insertions(+), 114 deletions(-) create mode 100644 ami/ml/orchestration/jobs.py rename ami/{utils => ml/orchestration}/nats_queue.py (100%) rename ami/{jobs => ml/orchestration}/task_state.py (100%) rename ami/{jobs => ml/orchestration}/utils.py (85%) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index a29a2507a..d3cff336e 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -15,13 +15,10 @@ from ami.base.models import BaseModel from ami.base.schemas import ConfigurableStage, ConfigurableStageParam -from ami.jobs.task_state import TaskStateManager from ami.jobs.tasks import run_job -from ami.jobs.utils import _run_in_async_loop from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.ml.post_processing.registry import get_postprocessing_task -from ami.utils.nats_queue import TaskQueueManager from ami.utils.schemas import OrderedEnum logger = logging.getLogger(__name__) @@ -325,6 +322,8 @@ def run(cls, job: "Job"): """ Procedure for an ML pipeline as a job. """ + from ami.ml.orchestration.jobs import queue_images_to_nats + job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None @@ -399,7 +398,7 @@ def run(cls, job: "Job"): job.save() if job.project.feature_flags.async_pipeline_workers: - cls.queue_images_to_nats(job, images) + queue_images_to_nats(job, images) else: cls.process_images(job, images) @@ -520,111 +519,8 @@ def process_images(cls, job, images): job.finished_at = datetime.datetime.now() job.save() - # TODO: This needs to happen once a job is done - @classmethod - def cleanup_nats_resources(cls, job: "Job"): - """ - Clean up NATS JetStream resources (stream and consumer) for a completed job. - - Args: - job: The Job instance - """ - job_id = f"job{job.pk}" - - async def cleanup(): - async with TaskQueueManager() as manager: - success = await manager.cleanup_job_resources(job_id) - return success - - _run_in_async_loop(cleanup, f"cleaning up NATS resources for job '{job_id}'") - - @classmethod - def queue_images_to_nats(cls, job: "Job", images: list[SourceImage]): - """ - Queue all images for a job to a NATS JetStream stream for the job. - - Args: - job: The Job instance - images: List of SourceImage instances to queue - - Returns: - bool: True if all images were successfully queued, False otherwise - """ - job_id = f"job{job.pk}" - job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job_id}'") - - # Prepare all messages outside of async context to avoid Django ORM issues - messages = [] - image_ids = [] - for i, image in enumerate(images): - image_id = str(image.pk) - image_ids.append(image_id) - message = { - "job_id": job.pk, - "image_id": image_id, - "image_url": image.url() if hasattr(image, "url") else None, - "timestamp": ( - image.timestamp.isoformat() if hasattr(image, "timestamp") and image.timestamp else None - ), - "batch_index": i, - "total_images": len(images), - "queue_timestamp": datetime.datetime.now().isoformat(), - } - messages.append((image.pk, message)) - - # Store all image IDs in Redis for progress tracking - state_manager = TaskStateManager(job.pk) - state_manager.initialize_job(image_ids) - job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") - - async def queue_all_images(): - successful_queues = 0 - failed_queues = 0 - - async with TaskQueueManager() as manager: - for i, (image_pk, message) in enumerate(messages): - try: - logger.info(f"Queueing image {image_pk} to stream for job '{job_id}': {message}") - # Use TTR of 300 seconds (5 minutes) for image processing - success = await manager.publish_task( - job_id=job_id, - data=message, - ttr=120, # visibility timeout in seconds - ) - except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job_id}': {e}") - success = False - - if success: - successful_queues += 1 - else: - failed_queues += 1 - - return successful_queues, failed_queues - - result = _run_in_async_loop(queue_all_images, f"queuing images to NATS for job '{job_id}'") - if result is None: - job.logger.error(f"Failed to queue images to NATS for job '{job_id}'") - return False - successful_queues, failed_queues = result - - if not images: - job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) - job.save() - - # Log results (back in sync context) - if successful_queues > 0: - job.logger.info( - f"Successfully queued {successful_queues}/{len(images)} images to stream for job '{job_id}'" - ) - - if failed_queues > 0: - job.logger.warning(f"Failed to queue {failed_queues}/{len(images)} images to stream for job '{job_id}'") - return False - - return True - +# TODO: This needs to happen once a job is done class DataStorageSyncJob(JobType): name = "Data storage sync" key = "data_storage_sync" diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 2b385ae49..d1a340638 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -3,11 +3,11 @@ from celery.result import AsyncResult from celery.signals import task_failure, task_postrun, task_prerun -from ami.jobs.task_state import TaskStateManager -from ami.jobs.utils import _run_in_async_loop +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager +from ami.ml.orchestration.utils import run_in_async_loop from ami.ml.schemas import PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit -from ami.utils.nats_queue import TaskQueueManager from config import celery_app logger = logging.getLogger(__name__) @@ -111,7 +111,7 @@ async def ack_task(): async with TaskQueueManager() as manager: return await manager.acknowledge_task(reply_subject) - ack_success = _run_in_async_loop(ack_task, f"acknowledging job {job.pk} via NATS") + ack_success = run_in_async_loop(ack_task, f"acknowledging job {job.pk} via NATS") if ack_success: job.logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 3ec068b31..8f9d10b67 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -236,7 +236,7 @@ def tasks(self, request, pk=None): raise ValidationError("This job does not have a pipeline configured") # Get tasks from NATS JetStream - from ami.utils.nats_queue import TaskQueueManager + from ami.ml.orchestration.nats_queue import TaskQueueManager async def get_tasks(): tasks = [] diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py new file mode 100644 index 000000000..6a8898eeb --- /dev/null +++ b/ami/ml/orchestration/jobs.py @@ -0,0 +1,107 @@ +import datetime + +from ami.jobs.models import Job, JobState, logger +from ami.main.models import SourceImage +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager +from ami.ml.orchestration.utils import run_in_async_loop + + +# TODO CGJS: Call this once a job is fully complete (all images processed and saved) +def cleanup_nats_resources(job: "Job"): + """ + Clean up NATS JetStream resources (stream and consumer) for a completed job. + + Args: + job: The Job instance + """ + job_id = f"job{job.pk}" + + async def cleanup(): + async with TaskQueueManager() as manager: + success = await manager.cleanup_job_resources(job_id) + return success + + run_in_async_loop(cleanup, f"cleaning up NATS resources for job '{job_id}'") + + +def queue_images_to_nats(job: "Job", images: list[SourceImage]): + """ + Queue all images for a job to a NATS JetStream stream for the job. + + Args: + job: The Job instance + images: List of SourceImage instances to queue + + Returns: + bool: True if all images were successfully queued, False otherwise + """ + job_id = f"job{job.pk}" + job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job_id}'") + + # Prepare all messages outside of async context to avoid Django ORM issues + messages = [] + image_ids = [] + for i, image in enumerate(images): + image_id = str(image.pk) + image_ids.append(image_id) + message = { + "job_id": job.pk, + "image_id": image_id, + "image_url": image.url() if hasattr(image, "url") else None, + "timestamp": (image.timestamp.isoformat() if hasattr(image, "timestamp") and image.timestamp else None), + "batch_index": i, + "total_images": len(images), + "queue_timestamp": datetime.datetime.now().isoformat(), + } + messages.append((image.pk, message)) + + # Store all image IDs in Redis for progress tracking + state_manager = TaskStateManager(job.pk) + state_manager.initialize_job(image_ids) + job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") + + async def queue_all_images(): + successful_queues = 0 + failed_queues = 0 + + async with TaskQueueManager() as manager: + for i, (image_pk, message) in enumerate(messages): + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job_id}': {message}") + # Use TTR of 300 seconds (5 minutes) for image processing + success = await manager.publish_task( + job_id=job_id, + data=message, + ttr=120, # visibility timeout in seconds + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job_id}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 + + return successful_queues, failed_queues + + result = run_in_async_loop(queue_all_images, f"queuing images to NATS for job '{job_id}'") + if result is None: + job.logger.error(f"Failed to queue images to NATS for job '{job_id}'") + return False + successful_queues, failed_queues = result + + if not images: + job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) + job.save() + + # Log results (back in sync context) + if successful_queues > 0: + job.logger.info(f"Successfully queued {successful_queues}/{len(images)} images to stream for job '{job_id}'") + + if failed_queues > 0: + job.logger.warning(f"Failed to queue {failed_queues}/{len(images)} images to stream for job '{job_id}'") + return False + + return True diff --git a/ami/utils/nats_queue.py b/ami/ml/orchestration/nats_queue.py similarity index 100% rename from ami/utils/nats_queue.py rename to ami/ml/orchestration/nats_queue.py diff --git a/ami/jobs/task_state.py b/ami/ml/orchestration/task_state.py similarity index 100% rename from ami/jobs/task_state.py rename to ami/ml/orchestration/task_state.py diff --git a/ami/jobs/utils.py b/ami/ml/orchestration/utils.py similarity index 85% rename from ami/jobs/utils.py rename to ami/ml/orchestration/utils.py index 3dbceb5de..752e591dd 100644 --- a/ami/jobs/utils.py +++ b/ami/ml/orchestration/utils.py @@ -5,7 +5,7 @@ logger = logging.getLogger(__name__) -def _run_in_async_loop(func: typing.Callable, error_msg: str) -> typing.Any: +def run_in_async_loop(func: typing.Callable, error_msg: str) -> typing.Any: # helper to use new_event_loop() to ensure we're not mixing with Django's async context loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) From 1cc890ed7713744ab3084ba87b5dcc9910e8ae69 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 31 Oct 2025 14:33:51 -0700 Subject: [PATCH 15/29] Resolve circular deps --- ami/jobs/models.py | 1 - ami/ml/orchestration/__init__.py | 7 ++++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index d3cff336e..847190dbc 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -520,7 +520,6 @@ def process_images(cls, job, images): job.save() -# TODO: This needs to happen once a job is done class DataStorageSyncJob(JobType): name = "Data storage sync" key = "data_storage_sync" diff --git a/ami/ml/orchestration/__init__.py b/ami/ml/orchestration/__init__.py index d05bbbd82..d503079ff 100644 --- a/ami/ml/orchestration/__init__.py +++ b/ami/ml/orchestration/__init__.py @@ -1 +1,6 @@ -from .processing import * # noqa: F401, F403 +# cgjs: This creates a circular import: +# - ami.jobs.models imports ami.jobs.tasks.run_job +# - ami.jobs.tasks imports ami.ml.orchestration +# -.processing imports ami.jobs.models + +# from .processing import * # noqa: F401, F403 From 84ee5a2dafce97de9e755ed5ed4697121b07ac70 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Fri, 31 Oct 2025 14:36:44 -0700 Subject: [PATCH 16/29] Update ami/jobs/models.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- ami/jobs/models.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 847190dbc..5be9e8a21 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -398,7 +398,15 @@ def run(cls, job: "Job"): job.save() if job.project.feature_flags.async_pipeline_workers: - queue_images_to_nats(job, images) + if job.project.feature_flags.async_pipeline_workers: + queued = queue_images_to_nats(job, images) + if not queued: + job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) + job.progress.update_stage("collect", status=JobState.FAILURE) + job.update_status(JobState.FAILURE) + job.finished_at = datetime.datetime.now() + job.save() + return else: cls.process_images(job, images) From 09fee924bbd30332c095385bfa28c0192ccf9b54 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 31 Oct 2025 14:39:32 -0700 Subject: [PATCH 17/29] cleanup --- ami/ml/orchestration/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ami/ml/orchestration/__init__.py b/ami/ml/orchestration/__init__.py index d503079ff..75c2ec3b5 100644 --- a/ami/ml/orchestration/__init__.py +++ b/ami/ml/orchestration/__init__.py @@ -2,5 +2,4 @@ # - ami.jobs.models imports ami.jobs.tasks.run_job # - ami.jobs.tasks imports ami.ml.orchestration # -.processing imports ami.jobs.models - # from .processing import * # noqa: F401, F403 From 4480b0dc97205dedea49164ff48f237958248f7a Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 4 Nov 2025 14:13:02 -0800 Subject: [PATCH 18/29] Consistent progress updates, single image job command --- ami/jobs/management/commands/README.md | 11 ++ .../commands/process_single_image.py | 166 ++++++++++++++++++ ami/jobs/tasks.py | 126 +++++++++---- ami/jobs/utils.py | 78 ++++++++ ami/ml/orchestration/jobs.py | 4 +- ami/ml/orchestration/task_state.py | 65 +++++-- 6 files changed, 392 insertions(+), 58 deletions(-) create mode 100644 ami/jobs/management/commands/README.md create mode 100644 ami/jobs/management/commands/process_single_image.py create mode 100644 ami/jobs/utils.py diff --git a/ami/jobs/management/commands/README.md b/ami/jobs/management/commands/README.md new file mode 100644 index 000000000..2901e229c --- /dev/null +++ b/ami/jobs/management/commands/README.md @@ -0,0 +1,11 @@ +# Process Single Image Command + +A Django management command for processing a single image through a pipeline. Useful for testing, debugging, and reprocessing individual images. + +## Usage + +### With Wait Flag (Monitor Progress) + +```bash +docker compose run --rm web python manage.py process_single_image 12345 --pipeline 1 --wait +``` diff --git a/ami/jobs/management/commands/process_single_image.py b/ami/jobs/management/commands/process_single_image.py new file mode 100644 index 000000000..12cf28b1d --- /dev/null +++ b/ami/jobs/management/commands/process_single_image.py @@ -0,0 +1,166 @@ +"""Management command to process a single image through a pipeline for testing/debugging.""" + +import logging +import time + +from django.core.management.base import BaseCommand, CommandError + +from ami.jobs.utils import submit_single_image_job +from ami.main.models import Detection, SourceImage +from ami.ml.models import Pipeline + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Submit a job to process a single image through a pipeline (for testing/debugging)" + + def add_arguments(self, parser): + parser.add_argument("image_id", type=int, help="SourceImage ID to process") + parser.add_argument( + "--pipeline", + type=int, + required=True, + help="Pipeline ID to use for processing", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Custom job name (optional)", + ) + parser.add_argument( + "--wait", + action="store_true", + help="Wait for the job to complete and show results", + ) + parser.add_argument( + "--poll-interval", + type=int, + default=2, + help="Polling interval in seconds when using --wait (default: 2)", + ) + + def handle(self, *args, **options): + image_id = options["image_id"] + pipeline_id = options["pipeline"] + job_name = options["name"] + wait = options["wait"] + poll_interval = options["poll_interval"] + + # Validate image exists + try: + image = SourceImage.objects.select_related("deployment__project").get(pk=image_id) + self.stdout.write(self.style.SUCCESS(f"✓ Found image: {image.path}")) + self.stdout.write(f" Project: {image.deployment.project.name}") + self.stdout.write(f" Deployment: {image.deployment.name}") + except SourceImage.DoesNotExist: + raise CommandError(f"SourceImage with id {image_id} does not exist") + + # Validate pipeline exists + try: + pipeline = Pipeline.objects.get(pk=pipeline_id) + self.stdout.write(self.style.SUCCESS(f"✓ Using pipeline: {pipeline.name} (v{pipeline.version})")) + except Pipeline.DoesNotExist: + raise CommandError(f"Pipeline with id {pipeline_id} does not exist") + + # Submit the job + self.stdout.write("") + self.stdout.write(self.style.WARNING("Submitting job...")) + + try: + job = submit_single_image_job( + image_id=image_id, + pipeline_id=pipeline_id, + job_name=job_name, + ) + except Exception as e: + raise CommandError(f"Failed to submit job: {str(e)}") + + self.stdout.write( + self.style.SUCCESS( + f"✓ Job {job.pk} created and enqueued\n" + f" Task ID: {job.task_id}\n" + f" Status: {job.status}\n" + f" Name: {job.name}" + ) + ) + + if not wait: + self.stdout.write("") + self.stdout.write("To check job status, run:") + self.stdout.write(f" Job.objects.get(pk={job.pk}).status") + return + + # Wait for job completion + self.stdout.write("") + self.stdout.write(self.style.WARNING("Waiting for job to complete...")) + self.stdout.write("(Press Ctrl+C to stop waiting)\n") + + try: + start_time = time.time() + last_status = None + last_progress = None + + while True: + job.refresh_from_db() + progress = job.progress.summary.progress * 100 + status = job.status + + # Only update display if something changed + if status != last_status or abs(progress - (last_progress or 0)) > 0.1: + elapsed = time.time() - start_time + self.stdout.write( + f" Status: {status:15s} | Progress: {progress:5.1f}% | Elapsed: {elapsed:6.1f}s", + ending="\r", + ) + last_status = status + last_progress = progress + + # Check if job is done + if job.status in ["SUCCESS", "FAILURE", "REVOKED", "REJECTED"]: + self.stdout.write("") # New line after progress updates + break + + time.sleep(poll_interval) + + except KeyboardInterrupt: + self.stdout.write("") + self.stdout.write(self.style.WARNING("\n⚠ Stopped waiting (job is still running)")) + self.stdout.write(f" Job ID: {job.pk}") + return + + # Show results + self.stdout.write("") + elapsed_total = time.time() - start_time + + if job.status == "SUCCESS": + self.stdout.write(self.style.SUCCESS(f"✓ Job completed successfully in {elapsed_total:.1f}s")) + + # Show results summary + detection_count = Detection.objects.filter(source_image_id=image_id).count() + self.stdout.write("\nResults:") + self.stdout.write(f" Detections created: {detection_count}") + + # Show classifications if any + if detection_count > 0: + from ami.main.models import Classification + + classification_count = Classification.objects.filter(detection__source_image_id=image_id).count() + self.stdout.write(f" Classifications created: {classification_count}") + + elif job.status == "FAILURE": + self.stdout.write(self.style.ERROR(f"✗ Job failed after {elapsed_total:.1f}s")) + self.stdout.write("\nCheck job logs for details:") + self.stdout.write(f" Job.objects.get(pk={job.pk}).logs") + + # Show any error messages + if job.progress.errors: + self.stdout.write("\nErrors:") + for error in job.progress.errors[-5:]: # Last 5 errors + self.stdout.write(f" - {error}") # noqa: E221 + + else: + self.stdout.write(self.style.WARNING(f"⚠ Job ended with status: {job.status}")) + + self.stdout.write(f"\nJob ID: {job.pk}") diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index d1a340638..7e1f6a4f3 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,7 +1,11 @@ +import functools import logging +import time +from collections.abc import Callable from celery.result import AsyncResult from celery.signals import task_failure, task_postrun, task_prerun +from django.db import transaction from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.orchestration.task_state import TaskStateManager @@ -60,50 +64,54 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: Returns: dict with status information """ - from ami.jobs.models import Job, JobState # avoid circular import + from ami.jobs.models import Job # avoid circular import + + _, t = log_time() + error = result_data.get("error") + pipeline_result = None + if not error: + pipeline_result = PipelineResultsResponse(**result_data) + processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + else: + processed_image_ids = set() + image_id = result_data.get("image_id") + logger.error(f"Pipeline returned error for job {job_id}, image {image_id}: {error}") + + state_manager = TaskStateManager(job_id) + + progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) try: + _update_job_progress(job_id, "process", progress_info.percentage) + + _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") job = Job.objects.get(pk=job_id) job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") + job.logger.info( + f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " + f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " + "processed" + ) # Save to database (this is the slow operation) if not job.pipeline: job.logger.warning(f"Job {job_id} has no pipeline, skipping save_results") return - job.logger.info(f"Successfully saved results for job {job_id}") - - # Deserialize the result - pipeline_result = PipelineResultsResponse(**result_data) - - # Update progress tracking in Redis - state_manager = TaskStateManager(job.pk) - processed_image_ids = {str(img.id) for img in pipeline_result.source_images} - state_manager.mark_images_processed(processed_image_ids) - - progress_info = state_manager.get_progress() - progress_percentage = 0.0 + if pipeline_result: + job.pipeline.save_results(results=pipeline_result, job_id=job.pk) + job.logger.info(f"Successfully saved results for job {job_id}") - if progress_info is not None: - # Get updated progress - progress_percentage = progress_info.percentage - - job.logger.info( - f"Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " - f"({progress_percentage*100}%), {progress_info.remaining} remaining" + _, t = t( + f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" + f", percentage: {progress_info.percentage*100}%" ) - else: - job.logger.warning(f"No pending images found in Redis for job {job_id}, setting progress to 100%") - progress_percentage = 1.0 - - job.progress.update_stage( - "process", - status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, - progress=progress_percentage, - ) - job.save() - - job.pipeline.save_results(results=pipeline_result, job_id=job.pk) # Acknowledge the task via NATS try: @@ -122,12 +130,15 @@ async def ack_task(): # Don't fail the task if ACK fails - data is already saved # Update job stage with calculated progress - job.progress.update_stage( - "results", - status=JobState.STARTED if progress_percentage < 1.0 else JobState.SUCCESS, - progress=progress_percentage, - ) - job.save() + progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) + + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) + _update_job_progress(job_id, "results", progress_info.percentage) except Job.DoesNotExist: logger.error(f"Job {job_id} not found") @@ -138,6 +149,20 @@ async def ack_task(): raise +def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: + from ami.jobs.models import Job, JobState # avoid circular import + + with transaction.atomic(): + job = Job.objects.select_for_update().get(pk=job_id) + job.progress.update_stage( + stage, + status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + progress=progress_percentage, + ) + job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") + job.save() + + @task_postrun.connect(sender=run_job) @task_prerun.connect(sender=run_job) def update_job_status(sender, task_id, task, *args, **kwargs): @@ -171,3 +196,28 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}') job.save() + + +def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: + """ + Small helper to measure time between calls. + + Returns: elapsed time since the last call, and a partial function to measure from the current call + Usage: + + _, tlog = log_time() + # do something + _, tlog = tlog("Did something") # will log the time taken by 'something' + # do something else + t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't' + """ + + end = time.perf_counter() + if start == 0: + dur = 0.0 + else: + dur = end - start + if msg and start > 0: + logger.info(f"{msg}: {dur:.3f}s") + new_start = time.perf_counter() + return dur, functools.partial(log_time, new_start) diff --git a/ami/jobs/utils.py b/ami/jobs/utils.py new file mode 100644 index 000000000..87cb5d12e --- /dev/null +++ b/ami/jobs/utils.py @@ -0,0 +1,78 @@ +"""Utility functions for job management and testing.""" + +import logging + +from ami.jobs.models import Job, MLJob +from ami.main.models import Project, SourceImage +from ami.ml.models import Pipeline + +logger = logging.getLogger(__name__) + + +def submit_single_image_job( + image_id: int, + pipeline_id: int, + project_id: int | None = None, + job_name: str | None = None, +) -> Job: + """ + Submit a job to process a single image through a pipeline. + + This is useful for testing, debugging, or reprocessing individual images. + + Args: + image_id: The SourceImage ID to process + pipeline_id: The Pipeline ID to use for processing + project_id: Optional project ID (will be inferred from image if not provided) + job_name: Optional custom job name (will be auto-generated if not provided) + + Returns: + The created Job instance + + Raises: + SourceImage.DoesNotExist: If the image doesn't exist + Pipeline.DoesNotExist: If the pipeline doesn't exist + """ + # Fetch the image and validate it exists + try: + image = SourceImage.objects.select_related("deployment__project").get(pk=image_id) + except SourceImage.DoesNotExist: + logger.error(f"SourceImage with id {image_id} does not exist") + raise + + # Fetch the pipeline and validate it exists + try: + pipeline = Pipeline.objects.get(pk=pipeline_id) + except Pipeline.DoesNotExist: + logger.error(f"Pipeline with id {pipeline_id} does not exist") + raise + + # Infer project from image if not provided + if project_id is None: + project = image.deployment.project + else: + project = Project.objects.get(pk=project_id) + + # Generate job name if not provided + if job_name is None: + job_name = f"Single image {image_id} - {pipeline.name}" + + # Create the job + job = Job.objects.create( + name=job_name, + project=project, + pipeline=pipeline, + job_type_key=MLJob.key, + source_image_single=image, + ) + + logger.info( + f"Created job {job.pk} for single image {image_id} " f"with pipeline {pipeline.name} (id: {pipeline_id})" + ) + + # Enqueue the job (starts the Celery task) + job.enqueue() + + logger.info(f"Job {job.pk} enqueued with task_id: {job.task_id}") + + return job diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 6a8898eeb..bcfca58f2 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -58,7 +58,7 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): # Store all image IDs in Redis for progress tracking state_manager = TaskStateManager(job.pk) - state_manager.initialize_job(image_ids) + state_manager.initialize_job(image_ids, stages=["process", "results"]) job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") async def queue_all_images(): @@ -73,7 +73,7 @@ async def queue_all_images(): success = await manager.publish_task( job_id=job_id, data=message, - ttr=120, # visibility timeout in seconds + ttr=300, # visibility timeout in seconds ) except Exception as e: logger.error(f"Failed to queue image {image_pk} to stream for job '{job_id}': {e}") diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index 55edcc1b0..073d3f650 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -35,32 +35,55 @@ def __init__(self, job_id: int): self._pending_key = f"job:{job_id}:pending_images" # noqa E231 self._total_key = f"job:{job_id}:pending_images_total" # noqa E231 - def initialize_job(self, image_ids: list[str]) -> None: + def initialize_job(self, image_ids: list[str], stages: list[str]) -> None: """ Initialize job tracking with a list of image IDs to process. Args: image_ids: List of image IDs that need to be processed + stages: List of stages to track for each image """ - cache.set(self._pending_key, image_ids, timeout=self.TIMEOUT) + self.stages = stages + for stage in stages: + cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) - def mark_images_processed(self, processed_image_ids: set[str]) -> None: + def _get_pending_key(self, stage: str) -> str: + return f"{self._pending_key}:{stage}" # noqa E231 + + def update_state( + self, + processed_image_ids: set[str], + stage: str, + request_id: str, + ) -> None | TaskProgress: """ - Mark a set of images as processed by removing them from pending list. + Update the task state with newly processed images. Args: - processed_image_ids: Set of image IDs that have been processed + processed_image_ids: Set of image IDs that have just been processed """ - pending_images = cache.get(self._pending_key) - if pending_images is None: - return - - remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] - - cache.set(self._pending_key, remaining_images, timeout=self.TIMEOUT) + # Create a unique lock key for this job + lock_key = f"job:{self.job_id}:process_results_lock" + lock_timeout = 360 # 6 minutes (matches task time_limit) + lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) + if not lock_acquired: + return None - def get_progress(self) -> TaskProgress | None: + try: + # Update progress tracking in Redis + progress_info = self._get_progress(processed_image_ids, stage) + return progress_info + finally: + # Always release the lock when done + current_lock_value = cache.get(lock_key) + # Only delete if we still own the lock (prevents race condition) + if current_lock_value == request_id: + cache.delete(lock_key) + logger.debug(f"Released lock for job {self.job_id}, task {request_id}") + + def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: """ Get current progress information for the job. @@ -71,16 +94,21 @@ def get_progress(self) -> TaskProgress | None: - processed: Number of images processed (or None if not tracked) - percentage: Progress as float 0.0-1.0 (or None if not tracked) """ - pending_images = cache.get(self._pending_key) + pending_images = cache.get(self._get_pending_key(stage)) total_images = cache.get(self._total_key) - if pending_images is None or total_images is None: return None + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + assert len(pending_images) >= len(remaining_images) + cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) - remaining = len(pending_images) + remaining = len(remaining_images) processed = total_images - remaining percentage = float(processed) / total_images if total_images > 0 else 1.0 - logger.info(f"Pending images from Redis for job {self.job_id}: " f"{remaining}/{total_images}") + logger.info( + f"Pending images from Redis for job {self.job_id} {stage}: " + f"{remaining}/{total_images}: {percentage*100}%" + ) return TaskProgress( remaining=remaining, @@ -93,5 +121,6 @@ def cleanup(self) -> None: """ Delete all Redis keys associated with this job. """ - cache.delete(self._pending_key) + for stage in self.stages: + cache.delete(self._get_pending_key(stage)) cache.delete(self._total_key) From 303270953dd6b6546490ac308ba37079deef6b7f Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 4 Nov 2025 14:17:03 -0800 Subject: [PATCH 19/29] Fix typo --- ami/jobs/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 5be9e8a21..6f9a022fb 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -397,7 +397,6 @@ def run(cls, job: "Job"): # End image collection stage job.save() - if job.project.feature_flags.async_pipeline_workers: if job.project.feature_flags.async_pipeline_workers: queued = queue_images_to_nats(job, images) if not queued: From a8b94e38f0d949d73d5f4ce4637a94d52e37f9b1 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 18 Nov 2025 10:01:11 -0800 Subject: [PATCH 20/29] Remove unnecesary file --- ami/jobs/utils.py | 78 ----------------------------------------------- 1 file changed, 78 deletions(-) delete mode 100644 ami/jobs/utils.py diff --git a/ami/jobs/utils.py b/ami/jobs/utils.py deleted file mode 100644 index 87cb5d12e..000000000 --- a/ami/jobs/utils.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Utility functions for job management and testing.""" - -import logging - -from ami.jobs.models import Job, MLJob -from ami.main.models import Project, SourceImage -from ami.ml.models import Pipeline - -logger = logging.getLogger(__name__) - - -def submit_single_image_job( - image_id: int, - pipeline_id: int, - project_id: int | None = None, - job_name: str | None = None, -) -> Job: - """ - Submit a job to process a single image through a pipeline. - - This is useful for testing, debugging, or reprocessing individual images. - - Args: - image_id: The SourceImage ID to process - pipeline_id: The Pipeline ID to use for processing - project_id: Optional project ID (will be inferred from image if not provided) - job_name: Optional custom job name (will be auto-generated if not provided) - - Returns: - The created Job instance - - Raises: - SourceImage.DoesNotExist: If the image doesn't exist - Pipeline.DoesNotExist: If the pipeline doesn't exist - """ - # Fetch the image and validate it exists - try: - image = SourceImage.objects.select_related("deployment__project").get(pk=image_id) - except SourceImage.DoesNotExist: - logger.error(f"SourceImage with id {image_id} does not exist") - raise - - # Fetch the pipeline and validate it exists - try: - pipeline = Pipeline.objects.get(pk=pipeline_id) - except Pipeline.DoesNotExist: - logger.error(f"Pipeline with id {pipeline_id} does not exist") - raise - - # Infer project from image if not provided - if project_id is None: - project = image.deployment.project - else: - project = Project.objects.get(pk=project_id) - - # Generate job name if not provided - if job_name is None: - job_name = f"Single image {image_id} - {pipeline.name}" - - # Create the job - job = Job.objects.create( - name=job_name, - project=project, - pipeline=pipeline, - job_type_key=MLJob.key, - source_image_single=image, - ) - - logger.info( - f"Created job {job.pk} for single image {image_id} " f"with pipeline {pipeline.name} (id: {pipeline_id})" - ) - - # Enqueue the job (starts the Celery task) - job.enqueue() - - logger.info(f"Job {job.pk} enqueued with task_id: {job.task_id}") - - return job From 0a5c89ed74e76874be0ce276efcec522d267bc7a Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 11:01:30 -0800 Subject: [PATCH 21/29] Remove diagram, fix flakes --- ami/jobs/models.py | 6 +- ami/ml/orchestration/task_state.py | 6 +- config/settings/base.py | 4 +- object_model_diagram.md | 167 ----------------------------- 4 files changed, 7 insertions(+), 176 deletions(-) delete mode 100644 object_model_diagram.md diff --git a/ami/jobs/models.py b/ami/jobs/models.py index b4c57f0e0..628fa8b23 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -72,7 +72,7 @@ def get_status_label(status: JobState, progress: float) -> str: if status in [JobState.CREATED, JobState.PENDING, JobState.RECEIVED]: return "Waiting to start" elif status in [JobState.STARTED, JobState.RETRY, JobState.SUCCESS]: - return f"{progress:.0%} complete" # noqa E231 + return f"{progress:.0%} complete" else: return f"{status.name}" @@ -435,7 +435,7 @@ def process_images(cls, job, images): project_id=job.project.pk, reprocess_all_images=job.project.feature_flags.reprocess_all_images, ) - job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s") # noqa E231 + job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s") except Exception as e: # Log error about image batch and continue job.logger.error(f"Failed to process image batch {i+1}: {e}") @@ -487,7 +487,7 @@ def process_images(cls, job, images): if image_count: percent_successful = 1 - len(request_failed_images) / image_count if image_count else 0 - job.logger.info(f"Processed {percent_successful:.0%} of images successfully.") # noqa E231 + job.logger.info(f"Processed {percent_successful:.0%} of images successfully.") # Check all Celery sub-tasks if they have completed saving results save_tasks_remaining = set(save_tasks) - set(save_tasks_completed) diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index 073d3f650..b8bfeec2d 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -32,8 +32,8 @@ def __init__(self, job_id: int): job_id: The job primary key """ self.job_id = job_id - self._pending_key = f"job:{job_id}:pending_images" # noqa E231 - self._total_key = f"job:{job_id}:pending_images_total" # noqa E231 + self._pending_key = f"job:{job_id}:pending_images" + self._total_key = f"job:{job_id}:pending_images_total" def initialize_job(self, image_ids: list[str], stages: list[str]) -> None: """ @@ -50,7 +50,7 @@ def initialize_job(self, image_ids: list[str], stages: list[str]) -> None: cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) def _get_pending_key(self, stage: str) -> str: - return f"{self._pending_key}:{stage}" # noqa E231 + return f"{self._pending_key}:{stage}" def update_state( self, diff --git a/config/settings/base.py b/config/settings/base.py index 7fec5158a..7f635fa4f 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -24,9 +24,7 @@ # GENERAL # ------------------------------------------------------------------------------ EXTERNAL_HOSTNAME = env("EXTERNAL_HOSTNAME", default="localhost:8000") # type: ignore[no-untyped-call] -EXTERNAL_BASE_URL = env( - "EXTERNAL_BASE_URL", default=f"http://{EXTERNAL_HOSTNAME}" # noqa: E231, E501 # type: ignore[no-untyped-call] -) +EXTERNAL_BASE_URL = env("EXTERNAL_BASE_URL", default=f"http://{EXTERNAL_HOSTNAME}") # type: ignore[no-untyped-call] # https://docs.djangoproject.com/en/dev/ref/settings/#debug DEBUG = env.bool("DJANGO_DEBUG", False) # type: ignore[no-untyped-call] diff --git a/object_model_diagram.md b/object_model_diagram.md deleted file mode 100644 index 0acdfdf61..000000000 --- a/object_model_diagram.md +++ /dev/null @@ -1,167 +0,0 @@ -# Object Model Diagram: ML Pipeline System - -```mermaid -classDiagram - %% Core ML Pipeline Classes - class Pipeline { - +string slug - +string name - +string description - +int version - +string version_name - +stages[] PipelineStage - +default_config PipelineRequestConfigParameters - -- - +get_config(project_id) PipelineRequestConfigParameters - +collect_images() Iterable~SourceImage~ - +process_images() PipelineResultsResponse - +choose_processing_service_for_pipeline() ProcessingService - } - - class Algorithm { - +string key - +string name - +AlgorithmTaskType task_type - +string description - +int version - +string version_name - +string uri - -- - +detection_task_types[] AlgorithmTaskType - +classification_task_types[] AlgorithmTaskType - +has_valid_category_map() boolean - } - - class AlgorithmCategoryMap { - +data JSONField - +labels[] string - +int labels_hash - +string version - +string description - +string uri - -- - +make_labels_hash() int - +get_category() int - +with_taxa() dict[] - } - - class PipelineStage { - +string key - +string name - +string description - +boolean enabled - +params[] ConfigurableStageParam - } - - %% Job System Classes - class Job { - +string name - +string queue - +datetime scheduled_at - +datetime started_at - +datetime finished_at - +JobState status - +JobProgress progress - +JobLogs logs - +params JSONField - +result JSONField - +string task_id - +int delay - +int limit - +boolean shuffle - +string job_type_key - -- - +job_type() JobType - +update_status() void - +logger JobLogger - } - - %% Configuration Classes - class ProjectPipelineConfig { - +boolean enabled - +config JSONField - -- - +get_config() dict - } - - class Project { - +string name - +string slug - +feature_flags JSONField - -- - +default_processing_pipeline Pipeline - } - - %% Enums - class AlgorithmTaskType { - <> - DETECTION - LOCALIZATION - SEGMENTATION - CLASSIFICATION - EMBEDDING - TRACKING - TAGGING - REGRESSION - CAPTIONING - GENERATION - TRANSLATION - SUMMARIZATION - QUESTION_ANSWERING - DEPTH_ESTIMATION - POSE_ESTIMATION - SIZE_ESTIMATION - OTHER - UNKNOWN - } - - - %% Relationships - Pipeline "M" -- "M" Algorithm : algorithms - Pipeline "1" -- "many" PipelineStage : stages - Pipeline "1" -- "many" Job : jobs - Pipeline "1" -- "many" ProjectPipelineConfig : project_pipeline_configs - - Algorithm "1" -- "0..1" AlgorithmCategoryMap : category_map - Algorithm "1" -- "1" AlgorithmTaskType : task_type - - Job "0..1" -- "1" Pipeline : pipeline - Job "1" -- "1" Project : project - - Project "1" -- "many" ProjectPipelineConfig : project_pipeline_configs - ProjectPipelineConfig "1" -- "1" Pipeline : pipeline - ProjectPipelineConfig "1" -- "1" Project : project - - %% Notes - note for Pipeline "Identified by unique slug\nAuto-generated from name + version + UUID" - note for Algorithm "Identified by unique key\nAuto-generated from name + version" - note for Job "MLJob is the primary job type\nfor running ML pipelines" -``` - -## Key Relationships Summary - -### Core ML Pipeline Flow: -1. **ProcessingService** → registers → **Pipeline** → contains → **Algorithm** -2. **Project** → configures → **Pipeline** through **ProjectPipelineConfig** -3. **Job** → executes → **Pipeline** → uses → **ProcessingService** - -### Model Identification: -- **Pipeline**: Identified by unique `slug` (string) - auto-generated from `name + version + UUID` -- **Algorithm**: Identified by unique `key` (string) - auto-generated from `name + version` -- **Job**: Uses standard Django `id` but also has `task_id` for Celery integration - -### Stage Management: -- **Pipeline** contains **PipelineStage** objects (for configuration display) -- **Job** tracks execution through **JobProgressStageDetail** objects (for runtime progress) -- Both share the same base **ConfigurableStage** schema - -### Algorithm Classification: -- **Algorithm** has task types (detection, classification, etc.) -- Classification algorithms require **AlgorithmCategoryMap** for label mapping -- Detection algorithms don't require category maps - -### Job Execution Flow: -1. **Job** is created with a **Pipeline** reference -2. **Pipeline** selects appropriate **ProcessingService** -3. **ProcessingService** executes algorithms and returns results -4. **Job** tracks progress through **JobProgress** and **JobProgressStageDetail** From 344f8839ae3a289c2f790cd9e8c69e94b148bbeb Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 11:10:48 -0800 Subject: [PATCH 22/29] Use async_to_sync --- ami/jobs/models.py | 4 ++-- ami/jobs/tasks.py | 4 ++-- ami/ml/orchestration/jobs.py | 7 ++++--- ami/ml/orchestration/utils.py | 18 ------------------ 4 files changed, 8 insertions(+), 25 deletions(-) delete mode 100644 ami/ml/orchestration/utils.py diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 628fa8b23..482d01a58 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -133,14 +133,14 @@ def get_stage(self, stage_key: str) -> JobProgressStageDetail: for stage in self.stages: if stage.key == stage_key: return stage - raise ValueError(f"Job stage with key '{stage_key}' not found in progress") # noqa E713 + raise ValueError(f"Job stage with key '{stage_key}' not found in progress") def get_stage_param(self, stage_key: str, param_key: str) -> ConfigurableStageParam: stage = self.get_stage(stage_key) for param in stage.params: if param.key == param_key: return param - raise ValueError(f"Job stage parameter with key '{param_key}' not found in stage '{stage_key}'") # noqa E713 + raise ValueError(f"Job stage parameter with key '{param_key}' not found in stage '{stage_key}'") def add_stage_param(self, stage_key: str, param_name: str, value: typing.Any = None) -> ConfigurableStageParam: stage = self.get_stage(stage_key) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 7e1f6a4f3..95f7fd56c 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -3,13 +3,13 @@ import time from collections.abc import Callable +from asgiref.sync import async_to_sync from celery.result import AsyncResult from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.orchestration.task_state import TaskStateManager -from ami.ml.orchestration.utils import run_in_async_loop from ami.ml.schemas import PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app @@ -119,7 +119,7 @@ async def ack_task(): async with TaskQueueManager() as manager: return await manager.acknowledge_task(reply_subject) - ack_success = run_in_async_loop(ack_task, f"acknowledging job {job.pk} via NATS") + ack_success = async_to_sync(ack_task)() if ack_success: job.logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index bcfca58f2..b79edb924 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -1,10 +1,11 @@ import datetime +from asgiref.sync import async_to_sync + from ami.jobs.models import Job, JobState, logger from ami.main.models import SourceImage from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.orchestration.task_state import TaskStateManager -from ami.ml.orchestration.utils import run_in_async_loop # TODO CGJS: Call this once a job is fully complete (all images processed and saved) @@ -22,7 +23,7 @@ async def cleanup(): success = await manager.cleanup_job_resources(job_id) return success - run_in_async_loop(cleanup, f"cleaning up NATS resources for job '{job_id}'") + async_to_sync(cleanup)() def queue_images_to_nats(job: "Job", images: list[SourceImage]): @@ -86,7 +87,7 @@ async def queue_all_images(): return successful_queues, failed_queues - result = run_in_async_loop(queue_all_images, f"queuing images to NATS for job '{job_id}'") + result = async_to_sync(queue_all_images)() if result is None: job.logger.error(f"Failed to queue images to NATS for job '{job_id}'") return False diff --git a/ami/ml/orchestration/utils.py b/ami/ml/orchestration/utils.py deleted file mode 100644 index 752e591dd..000000000 --- a/ami/ml/orchestration/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import asyncio -import logging -import typing - -logger = logging.getLogger(__name__) - - -def run_in_async_loop(func: typing.Callable, error_msg: str) -> typing.Any: - # helper to use new_event_loop() to ensure we're not mixing with Django's async context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(func()) - except Exception as e: - logger.error(f"Error in async loop - {error_msg}: {e}") - return None - finally: - loop.close() From df7eaa3ed65d0b52004a6c15bbcc4f2ddd06d835 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 11:31:16 -0800 Subject: [PATCH 23/29] CR feedback --- ami/ml/orchestration/jobs.py | 3 +-- ami/ml/orchestration/nats_queue.py | 16 +++++++++------- ami/ml/orchestration/task_state.py | 9 ++++----- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index b79edb924..b7010291b 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -59,7 +59,7 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): # Store all image IDs in Redis for progress tracking state_manager = TaskStateManager(job.pk) - state_manager.initialize_job(image_ids, stages=["process", "results"]) + state_manager.initialize_job(image_ids) job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") async def queue_all_images(): @@ -74,7 +74,6 @@ async def queue_all_images(): success = await manager.publish_task( job_id=job_id, data=message, - ttr=300, # visibility timeout in seconds ) except Exception as e: logger.error(f"Failed to queue image {image_pk} to stream for job '{job_id}': {e}") diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index c566c6235..b4f7c13c9 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -25,6 +25,9 @@ async def get_connection(nats_url: str): return nc, js +TASK_TTR = 300 # Default Time-To-Run (visibility timeout) in seconds + + class TaskQueueManager: """ Manager for NATS JetStream task queue operations. @@ -67,7 +70,7 @@ def _get_consumer_name(self, job_id: str) -> str: """Get consumer name from job_id.""" return f"job-{job_id}-consumer" - async def _ensure_stream(self, job_id: str, ttr: int = 30): + async def _ensure_stream(self, job_id: str): """Ensure stream exists for the given job.""" if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") @@ -88,7 +91,7 @@ async def _ensure_stream(self, job_id: str, ttr: int = 30): ) logger.info(f"Created stream {stream_name}") - async def _ensure_consumer(self, job_id: str, ttr: int = 30): + async def _ensure_consumer(self, job_id: str): """Ensure consumer exists for the given job.""" if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") @@ -107,7 +110,7 @@ async def _ensure_consumer(self, job_id: str, ttr: int = 30): config=ConsumerConfig( durable_name=consumer_name, ack_policy=AckPolicy.EXPLICIT, - ack_wait=ttr, # Visibility timeout (TTR) + ack_wait=TASK_TTR, # Visibility timeout (TTR) max_deliver=5, # Max retry attempts deliver_policy=DeliverPolicy.ALL, max_ack_pending=100, # Max unacked messages @@ -116,14 +119,13 @@ async def _ensure_consumer(self, job_id: str, ttr: int = 30): ) logger.info(f"Created consumer {consumer_name}") - async def publish_task(self, job_id: str, data: dict[str, Any], ttr: int = 30) -> bool: + async def publish_task(self, job_id: str, data: dict[str, Any]) -> bool: """ Publish a task to it's job queue. Args: job_id: The job ID (e.g., 'job123' or '123') data: Task data (dict will be JSON-encoded) - ttr: Time-to-run in seconds (visibility timeout, default 30) Returns: bool: True if successful, False otherwise @@ -133,8 +135,8 @@ async def publish_task(self, job_id: str, data: dict[str, Any], ttr: int = 30) - try: # Ensure stream and consumer exist - await self._ensure_stream(job_id, ttr) - await self._ensure_consumer(job_id, ttr) + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) subject = self._get_subject(job_id) task_data = json.dumps(data) diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index b8bfeec2d..483275453 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -23,6 +23,7 @@ class TaskStateManager: """ TIMEOUT = 86400 * 7 # 7 days in seconds + STAGES = ["process", "results"] def __init__(self, job_id: int): """ @@ -35,16 +36,14 @@ def __init__(self, job_id: int): self._pending_key = f"job:{job_id}:pending_images" self._total_key = f"job:{job_id}:pending_images_total" - def initialize_job(self, image_ids: list[str], stages: list[str]) -> None: + def initialize_job(self, image_ids: list[str]) -> None: """ Initialize job tracking with a list of image IDs to process. Args: image_ids: List of image IDs that need to be processed - stages: List of stages to track for each image """ - self.stages = stages - for stage in stages: + for stage in self.STAGES: cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) @@ -121,6 +120,6 @@ def cleanup(self) -> None: """ Delete all Redis keys associated with this job. """ - for stage in self.stages: + for stage in self.STAGES: cache.delete(self._get_pending_key(stage)) cache.delete(self._total_key) From 0391642c55efa101a7c393bba73c964efe935a77 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 12:30:43 -0800 Subject: [PATCH 24/29] clean up --- ami/jobs/tasks.py | 8 +++++--- ami/ml/orchestration/jobs.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 95f7fd56c..9fa5c7639 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -2,6 +2,7 @@ import logging import time from collections.abc import Callable +from datetime import datetime from asgiref.sync import async_to_sync from celery.result import AsyncResult @@ -60,9 +61,6 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: job_id: The job ID result_data: Dictionary containing the pipeline result reply_subject: NATS reply subject for acknowledgment - - Returns: - dict with status information """ from ami.jobs.models import Job # avoid circular import @@ -159,6 +157,10 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, ) + if stage == "results" and progress_percentage >= 1.0: + job.status = JobState.SUCCESS + job.progress.summary.status = JobState.SUCCESS + job.finished_at = datetime.now() job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") job.save() diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index b7010291b..5b4f2e9d6 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -9,7 +9,7 @@ # TODO CGJS: Call this once a job is fully complete (all images processed and saved) -def cleanup_nats_resources(job: "Job"): +def cleanup_nats_resources(job: "Job") -> bool: """ Clean up NATS JetStream resources (stream and consumer) for a completed job. @@ -23,7 +23,7 @@ async def cleanup(): success = await manager.cleanup_job_resources(job_id) return success - async_to_sync(cleanup)() + return async_to_sync(cleanup)() def queue_images_to_nats(job: "Job", images: list[SourceImage]): From 4ae27b00ae29b2ba31e45023fd0f3ee409b27d98 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 14:58:01 -0800 Subject: [PATCH 25/29] more cleanup --- ami/ml/orchestration/jobs.py | 9 ++------- docker-compose.yml | 8 ++++++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 5b4f2e9d6..b5682a7e4 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -67,10 +67,9 @@ async def queue_all_images(): failed_queues = 0 async with TaskQueueManager() as manager: - for i, (image_pk, message) in enumerate(messages): + for image_pk, message in messages: try: logger.info(f"Queueing image {image_pk} to stream for job '{job_id}': {message}") - # Use TTR of 300 seconds (5 minutes) for image processing success = await manager.publish_task( job_id=job_id, data=message, @@ -86,11 +85,7 @@ async def queue_all_images(): return successful_queues, failed_queues - result = async_to_sync(queue_all_images)() - if result is None: - job.logger.error(f"Failed to queue images to NATS for job '{job_id}'") - return False - successful_queues, failed_queues = result + successful_queues, failed_queues = async_to_sync(queue_all_images)() if not images: job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) diff --git a/docker-compose.yml b/docker-compose.yml index 566f4ab62..6f2bd2c76 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -34,7 +34,6 @@ services: required: false ports: - "8000:8000" - - "5679:5679" command: /start # for debugging with debugpy: # command: python -m debugpy --listen 0.0.0.0:5679 -m django runserver 0.0.0.0:8000 @@ -79,7 +78,12 @@ services: volumes: - ./.git:/app/.git:ro - ./ui:/app - entrypoint: ["sh", "-c", "yarn install && yarn start --debug --host 0.0.0.0 --port 4000"] + entrypoint: + [ + "sh", + "-c", + "yarn install && yarn start --debug --host 0.0.0.0 --port 4000", + ] environment: - API_PROXY_TARGET=http://django:8000 - CHOKIDAR_USEPOLLING=true From 4f50b3db17e682f028ef9f94000019246c8e4de9 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Fri, 21 Nov 2025 15:19:06 -0800 Subject: [PATCH 26/29] Apply suggestions from code review Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- ami/ml/orchestration/jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index b5682a7e4..06173059c 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -53,7 +53,7 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): "timestamp": (image.timestamp.isoformat() if hasattr(image, "timestamp") and image.timestamp else None), "batch_index": i, "total_images": len(images), - "queue_timestamp": datetime.datetime.now().isoformat(), + "queue_timestamp": timezone.now().isoformat(), } messages.append((image.pk, message)) From a8fc79a413a30fcea704576004fa2ff439856440 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 15:19:44 -0800 Subject: [PATCH 27/29] Remove old comments --- docker-compose.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 6f2bd2c76..703ecea0d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -35,8 +35,6 @@ services: ports: - "8000:8000" command: /start - # for debugging with debugpy: - # command: python -m debugpy --listen 0.0.0.0:5679 -m django runserver 0.0.0.0:8000 networks: - default - antenna_network From 4efdf0700664fc70f229a74ba8c61f8085da42fa Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 15:23:54 -0800 Subject: [PATCH 28/29] Fix processing error cases --- ami/jobs/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 9fa5c7639..00d7472a2 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -71,8 +71,8 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: pipeline_result = PipelineResultsResponse(**result_data) processed_image_ids = {str(img.id) for img in pipeline_result.source_images} else: - processed_image_ids = set() image_id = result_data.get("image_id") + processed_image_ids = {str(image_id)} if image_id else set() logger.error(f"Pipeline returned error for job {job_id}, image {image_id}: {error}") state_manager = TaskStateManager(job_id) From f221a1ada095d1b292b9bdfed35b8da9c17fc44a Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 21 Nov 2025 15:31:36 -0800 Subject: [PATCH 29/29] updates --- ami/jobs/tasks.py | 6 ++---- ami/ml/orchestration/jobs.py | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 00d7472a2..bac6b1236 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -98,11 +98,9 @@ def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: ) # Save to database (this is the slow operation) - if not job.pipeline: - job.logger.warning(f"Job {job_id} has no pipeline, skipping save_results") - return - if pipeline_result: + # should never happen since otherwise we could not be processing results here + assert job.pipeline is not None, "Job pipeline is None" job.pipeline.save_results(results=pipeline_result, job_id=job.pk) job.logger.info(f"Successfully saved results for job {job_id}") diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 06173059c..240000ca9 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -1,6 +1,5 @@ -import datetime - from asgiref.sync import async_to_sync +from django.utils import timezone from ami.jobs.models import Job, JobState, logger from ami.main.models import SourceImage