Skip to content

Commit c016d47

Browse files
committed
Use transaction.on_commit with all async celery tasks
1 parent 7c86612 commit c016d47

File tree

4 files changed

+54
-58
lines changed

4 files changed

+54
-58
lines changed

ami/jobs/models.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,8 @@ def update_job_progress(cls, job: "Job"):
579579
failed=num_failed_save_tasks,
580580
)
581581

582-
# The ML job is completed, log general job stags
583-
if job.status != JobState.FAILURE:
584-
# the job might've already been marked as failed because of unsent process pipeline request tasks
585-
job.update_status(JobState.FAILURE if any_failed_tasks else JobState.SUCCESS, save=False)
582+
# The ML job is completed, log general job stats
583+
job.update_status(JobState.FAILURE if any_failed_tasks else JobState.SUCCESS, save=False)
586584

587585
if any_failed_tasks:
588586
failed_save_task_ids = [
@@ -708,23 +706,11 @@ def run(cls, job: "Job"):
708706

709707
except Exception as e:
710708
job.logger.error(f"Failed to submit all images: {e}")
711-
# mark the job as failed
712-
job.progress.update_stage(
713-
"process",
714-
status=JobState.FAILURE,
715-
progress=1,
716-
failed=image_count,
717-
processed=0,
718-
remaining=image_count,
719-
)
720709
job.update_status(JobState.FAILURE)
721710
job.save()
722-
finally:
723-
# Handle the successfully submitted tasks
724-
subtasks = job.ml_task_records.all()
725-
if subtasks:
726-
check_ml_job_status.apply_async([job.pk])
727-
else:
711+
else:
712+
subtasks = job.ml_task_records.filter(created_at__gte=job.started_at)
713+
if not subtasks:
728714
# No tasks were scheduled, mark the job as done
729715
job.logger.info("No subtasks were scheduled, ending the job.")
730716
job.progress.update_stage(
@@ -740,6 +726,16 @@ def run(cls, job: "Job"):
740726
job.update_status(JobState.SUCCESS, save=False)
741727
job.finished_at = datetime.datetime.now()
742728
job.save()
729+
else:
730+
job.logger.info(
731+
f"Continue processing the remaining {subtasks.count()} process image request subtasks."
732+
)
733+
from django.db import transaction
734+
735+
transaction.on_commit(lambda: check_ml_job_status.apply_async([job.pk]))
736+
finally:
737+
# TODO: clean up?
738+
pass
743739

744740

745741
class DataStorageSyncJob(JobType):

ami/jobs/views.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def check_inprogress_subtasks(self, request, pk=None):
168168
has_inprogress_tasks = job.check_inprogress_subtasks()
169169
if has_inprogress_tasks:
170170
# Schedule task to update the job status
171+
from django.db import transaction
172+
171173
from ami.ml.tasks import check_ml_job_status
172174

173-
check_ml_job_status.apply_async((job.pk,))
175+
transaction.on_commit(lambda: check_ml_job_status.apply_async((job.pk,)))
174176
return Response({"inprogress_subtasks": has_inprogress_tasks})

ami/ml/models/pipeline.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import requests
1818
from celery.result import AsyncResult
19-
from django.db import models
19+
from django.db import models, transaction
2020
from django.utils.text import slugify
2121
from django.utils.timezone import now
2222
from django_pydantic_field import SchemaField
@@ -185,8 +185,8 @@ def process_pipeline_request(pipeline_request: dict, project_id: int):
185185
def process_images(
186186
pipeline: Pipeline,
187187
images: typing.Iterable[SourceImage],
188+
project_id: int,
188189
job_id: int | None = None,
189-
project_id: int | None = None,
190190
process_sync: bool = False,
191191
) -> PipelineResultsResponse | None:
192192
"""
@@ -207,13 +207,12 @@ def process_images(
207207
job = Job.objects.get(pk=job_id)
208208
task_logger = job.logger
209209

210-
if project_id:
211-
project = Project.objects.get(pk=project_id)
212-
else:
213-
task_logger.warning(f"Pipeline {pipeline} is not associated with a project")
214-
project = None
210+
# Pipelines must be associated with a project in order to select a processing service
211+
# A processing service is required to send requests to the /process endpoint
212+
project = Project.objects.get(pk=project_id)
213+
task_logger.info(f"Using project: {project}")
215214

216-
pipeline_config = pipeline.get_config(project_id=project_id)
215+
pipeline_config = pipeline.get_config(project_id=project.pk)
217216
task_logger.info(f"Using pipeline config: {pipeline_config}")
218217

219218
prefiltered_images = list(images)
@@ -265,15 +264,16 @@ def process_images(
265264
task_logger.info(f"Found {len(detection_requests)} existing detections.")
266265

267266
if not process_sync:
267+
assert job_id is not None, "job_id is required to process images using async tasks."
268268
handle_async_process_images(
269269
pipeline.slug,
270270
source_image_requests,
271271
images,
272272
pipeline_config,
273273
detection_requests,
274+
project_id,
274275
job_id,
275276
task_logger,
276-
project_id,
277277
)
278278
return
279279
else:
@@ -289,12 +289,11 @@ def handle_async_process_images(
289289
source_images: list[SourceImage],
290290
pipeline_config: PipelineRequestConfigParameters,
291291
detection_requests: list[DetectionRequest],
292-
job_id: int | None = None,
292+
project_id: int,
293+
job_id: int,
293294
task_logger: logging.Logger = logger,
294-
project_id: int | None = None,
295295
):
296296
"""Handle asynchronous processing by submitting tasks to the appropriate pipeline queue."""
297-
task_ids = []
298297
batch_size = pipeline_config.get("batch_size", 1)
299298

300299
# Group source images into batches
@@ -333,14 +332,18 @@ def handle_async_process_images(
333332
detections=detections_batch,
334333
config=pipeline_config,
335334
)
336-
task_result = process_pipeline_request.apply_async(
337-
args=[prediction_request.dict(), project_id],
338-
# TODO: make ml-pipeline an environment variable (i.e. PIPELINE_QUEUE_PREFIX)?
339-
queue=f"ml-pipeline-{pipeline}",
340-
# all pipelines have their own queue beginning with "ml-pipeline-"
341-
# the antenna celeryworker should subscribe to all pipeline queues
335+
336+
task_id = str(uuid.uuid4())
337+
transaction.on_commit(
338+
lambda: process_pipeline_request.apply_async(
339+
args=[prediction_request.dict(), project_id],
340+
task_id=task_id,
341+
# TODO: make ml-pipeline an environment variable (i.e. PIPELINE_QUEUE_PREFIX)?
342+
queue=f"ml-pipeline-{pipeline}",
343+
# all pipelines have their own queue beginning with "ml-pipeline-"
344+
# the antenna celeryworker should subscribe to all pipeline queues
345+
)
342346
)
343-
task_ids.append(task_result.id)
344347

345348
if job_id:
346349
from ami.jobs.models import Job, MLTaskRecord
@@ -349,21 +352,17 @@ def handle_async_process_images(
349352
# Create a new MLTaskRecord for this task
350353
ml_task_record = MLTaskRecord.objects.create(
351354
job=job,
352-
task_id=task_result.id,
355+
task_id=task_id,
353356
task_name="process_pipeline_request",
354357
pipeline_request=prediction_request,
355-
num_captures=len(source_image_batches[idx]),
358+
num_captures=len(source_image_batches[i]),
356359
)
357-
ml_task_record.source_images.set(source_image_batches[idx])
360+
ml_task_record.source_images.set(source_image_batches[i])
358361
ml_task_record.save()
359-
# job.logger.info(
360-
# f"Created MLTaskRecord for job {job_id} with task ID {task_result.id}"
361-
# " and task name process_pipeline_request"
362-
# )
363362
else:
364363
task_logger.warning("No job ID provided, MLTaskRecord will not be created.")
365364

366-
task_logger.info(f"Submitted {len(task_ids)} batch image processing task(s).")
365+
task_logger.info(f"Submitted {len(source_image_request_batches)} batch image processing task(s).")
367366

368367

369368
def handle_sync_process_images(
@@ -377,9 +376,6 @@ def handle_sync_process_images(
377376
job: Job | None,
378377
) -> PipelineResultsResponse:
379378
"""Handle synchronous processing by sending HTTP requests to the processing service."""
380-
if project_id is None:
381-
raise ValueError("Project ID must be provided when syncronously processing images.")
382-
383379
processing_service = pipeline.choose_processing_service_for_pipeline(job_id, pipeline.name, project_id)
384380
if not processing_service.endpoint_url:
385381
raise ValueError(f"No endpoint URL configured for this pipeline's processing service ({processing_service})")
@@ -1207,8 +1203,8 @@ def get_config(self, project_id: int | None = None) -> PipelineRequestConfigPara
12071203
)
12081204
except self.project_pipeline_configs.model.DoesNotExist as e:
12091205
logger.warning(f"No project-pipeline config for Pipeline {self} " f"and Project #{project_id}: {e}")
1210-
1211-
logger.warning("No project_id, no pipeline config is used.")
1206+
else:
1207+
logger.warning("No project_id. No pipeline config is used. Using default empty config instead.")
12121208

12131209
return config
12141210

ami/ml/tasks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@ def process_source_images_async(pipeline_choice: str, image_ids: list[int], job_
2626

2727
images = SourceImage.objects.filter(pk__in=image_ids)
2828
pipeline = Pipeline.objects.get(slug=pipeline_choice)
29+
project = pipeline.projects.first()
30+
assert project, f"Pipeline {pipeline} must be associated with a project."
2931

30-
results = process_images(
31-
pipeline=pipeline,
32-
images=images,
33-
job_id=job_id,
34-
)
32+
results = process_images(pipeline=pipeline, images=images, job_id=job_id, project_id=project.pk)
3533

3634
try:
3735
save_results(results=results, job_id=job_id)
@@ -136,5 +134,9 @@ def check_ml_job_status(ml_job_id: int):
136134
logger.info(f"ML Job {ml_job_id} is complete.")
137135
job.logger.info(f"ML Job {ml_job_id} is complete.")
138136
else:
137+
from django.db import transaction
138+
139139
logger.info(f"ML Job {ml_job_id} still in progress. Checking again for completed tasks.")
140-
check_ml_job_status.apply_async([ml_job_id], countdown=10) # check again in 10 seconds
140+
transaction.on_commit(
141+
lambda: check_ml_job_status.apply_async([ml_job_id], countdown=10)
142+
) # check again in 10 seconds

0 commit comments

Comments
 (0)