@@ -116,10 +116,11 @@ class JobProgress(pydantic.BaseModel):
116
116
logs : list [str ] = []
117
117
118
118
def get_stage_key (self , name : str ) -> str :
119
+ """Generate a key for a stage or param based on its name"""
119
120
return python_slugify (name )
120
121
121
- def add_stage (self , name : str ) -> JobProgressStageDetail :
122
- key = self .get_stage_key (name )
122
+ def add_stage (self , name : str , key : str | None = None ) -> JobProgressStageDetail :
123
+ key = key or self .get_stage_key (name )
123
124
try :
124
125
return self .get_stage (key )
125
126
except ValueError :
@@ -188,6 +189,16 @@ def update_stage(self, stage_key_or_name: str, **stage_parameters) -> JobProgres
188
189
self .add_or_update_stage_param (stage_key , k , v )
189
190
return stage
190
191
192
+ def reset (self , status : JobState = JobState .CREATED ):
193
+ """
194
+ Set the progress of summary and all stages to 0.
195
+ """
196
+ self .summary .progress = 0
197
+ self .summary .status = status
198
+ for stage in self .stages :
199
+ stage .progress = 0
200
+ stage .status = status
201
+
191
202
class Config :
192
203
use_enum_values = True
193
204
as_dict = True
@@ -265,6 +276,12 @@ def emit(self, record):
265
276
266
277
@dataclass
267
278
class JobType :
279
+ """
280
+ The run method of a job is specific to the job type.
281
+
282
+ Job types must be defined as classes because they define code, not just configuration.
283
+ """
284
+
268
285
name : str
269
286
key : str
270
287
@@ -273,10 +290,7 @@ def run(cls, job: "Job"):
273
290
"""
274
291
Execute the run function specific to this job type.
275
292
"""
276
- pass
277
-
278
-
279
- AnyJobType = typing .TypeVar ("AnyJobType" , bound = JobType )
293
+ raise NotImplementedError ("Job type has not implemented the run method" )
280
294
281
295
282
296
class MLJob (JobType ):
@@ -411,14 +425,6 @@ class DataStorageSyncJob(JobType):
411
425
name = "Data storage sync"
412
426
key = "data_storage_sync"
413
427
414
- @classmethod
415
- def setup (cls , job : "Job" , save = True ):
416
- job .progress = job .progress or default_job_progress
417
- job .progress .add_stage (name = cls .name )
418
-
419
- if save :
420
- job .save ()
421
-
422
428
@classmethod
423
429
def run (cls , job : "Job" ):
424
430
"""
@@ -427,7 +433,8 @@ def run(cls, job: "Job"):
427
433
This is meant to be called by an async task, not directly.
428
434
"""
429
435
430
- job .progress .add_stage_param (cls .key , "Total Files" , "" )
436
+ job .progress .add_stage (cls .name )
437
+ job .progress .add_stage_param (cls .key , "Total files" , "" )
431
438
job .update_status (JobState .STARTED )
432
439
job .started_at = datetime .datetime .now ()
433
440
job .finished_at = None
@@ -461,6 +468,62 @@ def run(cls, job: "Job"):
461
468
job .save ()
462
469
463
470
471
+ class SourceImageCollectionPopulateJob (JobType ):
472
+ name = "Populate captures collection"
473
+ key = "populate_captures_collection"
474
+
475
+ @classmethod
476
+ def run (cls , job : "Job" ):
477
+ """
478
+ Run the populate source image collection job.
479
+
480
+ This is meant to be called by an async task, not directly.
481
+ """
482
+ job .progress .add_stage (cls .name , key = cls .key )
483
+ job .progress .add_stage_param (cls .key , "Captures added" , "" )
484
+ job .update_status (JobState .STARTED )
485
+ job .started_at = datetime .datetime .now ()
486
+ job .finished_at = None
487
+ job .save ()
488
+
489
+ if not job .source_image_collection :
490
+ job .logger .error ("No source image collection provided" )
491
+ job .update_status (JobState .FAILURE )
492
+ job .finished_at = datetime .datetime .now ()
493
+ job .save ()
494
+ return
495
+
496
+ job .logger .info (f"Populating source image collection { job .source_image_collection } " )
497
+ job .update_status (JobState .STARTED )
498
+ job .started_at = datetime .datetime .now ()
499
+ job .finished_at = None
500
+ job .progress .update_stage (
501
+ cls .key ,
502
+ status = JobState .STARTED ,
503
+ progress = 0.10 ,
504
+ captures_added = 0 ,
505
+ )
506
+ job .update_progress (save = True )
507
+
508
+ job .source_image_collection .populate_sample (job = job )
509
+ job .logger .info (f"Finished populating source image collection { job .source_image_collection } " )
510
+ job .save ()
511
+
512
+ captures_added = job .source_image_collection .images .count ()
513
+ job .logger .info (f"Added { captures_added } captures to source image collection { job .source_image_collection } " )
514
+
515
+ job .progress .update_stage (
516
+ cls .key ,
517
+ status = JobState .SUCCESS ,
518
+ progress = 1 ,
519
+ captures_added = captures_added ,
520
+ )
521
+ job .finished_at = datetime .datetime .now ()
522
+ job .update_status (JobState .SUCCESS , save = False )
523
+ job .update_progress (save = False )
524
+ job .save ()
525
+
526
+
464
527
class UnknownJobType (JobType ):
465
528
name = "Unknown"
466
529
key = "unknown"
@@ -472,6 +535,32 @@ def run(cls, job: "Job"):
472
535
job .save ()
473
536
474
537
538
+ VALID_JOB_TYPES = [MLJob , SourceImageCollectionPopulateJob , DataStorageSyncJob , UnknownJobType ]
539
+
540
+
541
+ def get_job_type_by_key (key : str ) -> type [JobType ] | None :
542
+ for job_type in VALID_JOB_TYPES :
543
+ if job_type .key == key :
544
+ return job_type
545
+
546
+
547
+ def get_job_type_by_inferred_key (job : "Job" ) -> type [JobType ] | None :
548
+ """
549
+ Infer the job type from the job's attributes.
550
+
551
+ This is used for a data migration to set the job type of existing jobs
552
+ before the job type field was added to the model.
553
+ """
554
+
555
+ if job .pipeline :
556
+ return MLJob
557
+ # Check the key of the first stage in the job progress
558
+ if job .progress .stages :
559
+ job_type = get_job_type_by_key (job .progress .stages [0 ].key )
560
+ if job_type :
561
+ return job_type
562
+
563
+
475
564
class Job (BaseModel ):
476
565
"""A job to be run by the scheduler"""
477
566
@@ -493,6 +582,9 @@ class Job(BaseModel):
493
582
"Limit" , null = True , blank = True , default = None , help_text = "Limit the number of images to process"
494
583
)
495
584
shuffle = models .BooleanField ("Shuffle" , default = True , help_text = "Process images in a random order" )
585
+ job_type_key = models .CharField (
586
+ "Job Type" , max_length = 255 , default = UnknownJobType .key , choices = [(t .key , t .name ) for t in VALID_JOB_TYPES ]
587
+ )
496
588
497
589
project = models .ForeignKey (
498
590
Project ,
@@ -532,20 +624,15 @@ def __str__(self) -> str:
532
624
return f'#{ self .pk } "{ self .name } " ({ self .status } )'
533
625
534
626
def job_type (self ) -> type [JobType ]:
535
- """
536
- This is a temporary way to determine the type of job.
537
- @TODO rework Job classes and background tasks.
538
- """
539
- if self .pipeline :
540
- return MLJob
541
-
542
- try :
543
- self .progress .get_stage (DataStorageSyncJob .key )
544
- return DataStorageSyncJob
545
- except ValueError :
546
- pass
547
-
548
- return UnknownJobType
627
+ job_type_class = get_job_type_by_key (self .job_type_key )
628
+ if job_type_class :
629
+ return job_type_class
630
+ else :
631
+ inferred_job_type = get_job_type_by_inferred_key (self )
632
+ msg = f"Could not determine job type for job { self .pk } with job_type_key '{ self .job_type_key } '. "
633
+ if inferred_job_type :
634
+ msg += f"Inferred job type as '{ inferred_job_type .name } '"
635
+ raise ValueError (msg )
549
636
550
637
def enqueue (self ):
551
638
"""
@@ -603,6 +690,19 @@ def run(self):
603
690
job_type .run (job = self )
604
691
return None
605
692
693
+ def retry (self , async_task = True ):
694
+ """
695
+ Retry the job.
696
+ """
697
+ self .logger .info (f"Re-running job { self } " )
698
+ self .progress .reset ()
699
+ self .status = JobState .RETRY
700
+ self .save ()
701
+ if async_task :
702
+ self .enqueue ()
703
+ else :
704
+ self .run ()
705
+
606
706
def cancel (self ):
607
707
"""
608
708
Terminate the celery task.
@@ -613,7 +713,6 @@ def cancel(self):
613
713
task = run_job .AsyncResult (self .task_id )
614
714
if task :
615
715
task .revoke (terminate = True )
616
- self .status = task .status
617
716
self .save ()
618
717
else :
619
718
self .status = JobState .REVOKED
@@ -646,7 +745,8 @@ def update_progress(self, save=True):
646
745
Update the total aggregate progress from the progress of each stage.
647
746
"""
648
747
if not len (self .progress .stages ):
649
- total_progress = 1
748
+ # Need at least one stage to calculate progress
749
+ total_progress = 0
650
750
else :
651
751
for stage in self .progress .stages :
652
752
if stage .progress > 0 and stage .status == JobState .CREATED :
@@ -674,11 +774,14 @@ def save(self, *args, **kwargs):
674
774
"""
675
775
Create the job stages if they don't exist.
676
776
"""
677
- if self .progress .stages :
777
+ if self .pk and self . progress .stages :
678
778
self .update_progress (save = False )
679
779
else :
680
780
self .setup (save = False )
681
781
super ().save (* args , ** kwargs )
782
+ logger .debug (f"Saved job { self } " )
783
+ if self .progress .summary .status != self .status :
784
+ logger .warning (f"Job { self } status mismatches progress: { self .progress .summary .status } != { self .status } " )
682
785
683
786
@classmethod
684
787
def default_progress (cls ) -> JobProgress :
@@ -698,4 +801,3 @@ class Meta:
698
801
# permissions = [
699
802
# ("run_job", "Can run a job"),
700
803
# ("cancel_job", "Can cancel a job"),
701
- # ]
0 commit comments