16
16
17
17
import requests
18
18
from celery .result import AsyncResult
19
- from django .db import models
19
+ from django .db import models , transaction
20
20
from django .utils .text import slugify
21
21
from django .utils .timezone import now
22
22
from django_pydantic_field import SchemaField
@@ -185,8 +185,8 @@ def process_pipeline_request(pipeline_request: dict, project_id: int):
185
185
def process_images (
186
186
pipeline : Pipeline ,
187
187
images : typing .Iterable [SourceImage ],
188
+ project_id : int ,
188
189
job_id : int | None = None ,
189
- project_id : int | None = None ,
190
190
process_sync : bool = False ,
191
191
) -> PipelineResultsResponse | None :
192
192
"""
@@ -207,13 +207,12 @@ def process_images(
207
207
job = Job .objects .get (pk = job_id )
208
208
task_logger = job .logger
209
209
210
- if project_id :
211
- project = Project .objects .get (pk = project_id )
212
- else :
213
- task_logger .warning (f"Pipeline { pipeline } is not associated with a project" )
214
- project = None
210
+ # Pipelines must be associated with a project in order to select a processing service
211
+ # A processing service is required to send requests to the /process endpoint
212
+ project = Project .objects .get (pk = project_id )
213
+ task_logger .info (f"Using project: { project } " )
215
214
216
- pipeline_config = pipeline .get_config (project_id = project_id )
215
+ pipeline_config = pipeline .get_config (project_id = project . pk )
217
216
task_logger .info (f"Using pipeline config: { pipeline_config } " )
218
217
219
218
prefiltered_images = list (images )
@@ -265,15 +264,16 @@ def process_images(
265
264
task_logger .info (f"Found { len (detection_requests )} existing detections." )
266
265
267
266
if not process_sync :
267
+ assert job_id is not None , "job_id is required to process images using async tasks."
268
268
handle_async_process_images (
269
269
pipeline .slug ,
270
270
source_image_requests ,
271
271
images ,
272
272
pipeline_config ,
273
273
detection_requests ,
274
+ project_id ,
274
275
job_id ,
275
276
task_logger ,
276
- project_id ,
277
277
)
278
278
return
279
279
else :
@@ -289,12 +289,11 @@ def handle_async_process_images(
289
289
source_images : list [SourceImage ],
290
290
pipeline_config : PipelineRequestConfigParameters ,
291
291
detection_requests : list [DetectionRequest ],
292
- job_id : int | None = None ,
292
+ project_id : int ,
293
+ job_id : int ,
293
294
task_logger : logging .Logger = logger ,
294
- project_id : int | None = None ,
295
295
):
296
296
"""Handle asynchronous processing by submitting tasks to the appropriate pipeline queue."""
297
- task_ids = []
298
297
batch_size = pipeline_config .get ("batch_size" , 1 )
299
298
300
299
# Group source images into batches
@@ -333,14 +332,18 @@ def handle_async_process_images(
333
332
detections = detections_batch ,
334
333
config = pipeline_config ,
335
334
)
336
- task_result = process_pipeline_request .apply_async (
337
- args = [prediction_request .dict (), project_id ],
338
- # TODO: make ml-pipeline an environment variable (i.e. PIPELINE_QUEUE_PREFIX)?
339
- queue = f"ml-pipeline-{ pipeline } " ,
340
- # all pipelines have their own queue beginning with "ml-pipeline-"
341
- # the antenna celeryworker should subscribe to all pipeline queues
335
+
336
+ task_id = str (uuid .uuid4 ())
337
+ transaction .on_commit (
338
+ lambda : process_pipeline_request .apply_async (
339
+ args = [prediction_request .dict (), project_id ],
340
+ task_id = task_id ,
341
+ # TODO: make ml-pipeline an environment variable (i.e. PIPELINE_QUEUE_PREFIX)?
342
+ queue = f"ml-pipeline-{ pipeline } " ,
343
+ # all pipelines have their own queue beginning with "ml-pipeline-"
344
+ # the antenna celeryworker should subscribe to all pipeline queues
345
+ )
342
346
)
343
- task_ids .append (task_result .id )
344
347
345
348
if job_id :
346
349
from ami .jobs .models import Job , MLTaskRecord
@@ -349,21 +352,17 @@ def handle_async_process_images(
349
352
# Create a new MLTaskRecord for this task
350
353
ml_task_record = MLTaskRecord .objects .create (
351
354
job = job ,
352
- task_id = task_result . id ,
355
+ task_id = task_id ,
353
356
task_name = "process_pipeline_request" ,
354
357
pipeline_request = prediction_request ,
355
- num_captures = len (source_image_batches [idx ]),
358
+ num_captures = len (source_image_batches [i ]),
356
359
)
357
- ml_task_record .source_images .set (source_image_batches [idx ])
360
+ ml_task_record .source_images .set (source_image_batches [i ])
358
361
ml_task_record .save ()
359
- # job.logger.info(
360
- # f"Created MLTaskRecord for job {job_id} with task ID {task_result.id}"
361
- # " and task name process_pipeline_request"
362
- # )
363
362
else :
364
363
task_logger .warning ("No job ID provided, MLTaskRecord will not be created." )
365
364
366
- task_logger .info (f"Submitted { len (task_ids )} batch image processing task(s)." )
365
+ task_logger .info (f"Submitted { len (source_image_request_batches )} batch image processing task(s)." )
367
366
368
367
369
368
def handle_sync_process_images (
@@ -377,9 +376,6 @@ def handle_sync_process_images(
377
376
job : Job | None ,
378
377
) -> PipelineResultsResponse :
379
378
"""Handle synchronous processing by sending HTTP requests to the processing service."""
380
- if project_id is None :
381
- raise ValueError ("Project ID must be provided when syncronously processing images." )
382
-
383
379
processing_service = pipeline .choose_processing_service_for_pipeline (job_id , pipeline .name , project_id )
384
380
if not processing_service .endpoint_url :
385
381
raise ValueError (f"No endpoint URL configured for this pipeline's processing service ({ processing_service } )" )
@@ -1207,8 +1203,8 @@ def get_config(self, project_id: int | None = None) -> PipelineRequestConfigPara
1207
1203
)
1208
1204
except self .project_pipeline_configs .model .DoesNotExist as e :
1209
1205
logger .warning (f"No project-pipeline config for Pipeline { self } " f"and Project #{ project_id } : { e } " )
1210
-
1211
- logger .warning ("No project_id, no pipeline config is used." )
1206
+ else :
1207
+ logger .warning ("No project_id. No pipeline config is used. Using default empty config instead ." )
1212
1208
1213
1209
return config
1214
1210
0 commit comments