Skip to content

Commit 3ce2b29

Browse files
authored
Use job system to populate capture collections (#612)
* feat: use job system for populating collections * fix: migrate old job types * fix: update job type choices * fix: add outstanding migration for job limit default * feat: add method to retry / re-run jobs * feat: add test for retries, attempt to fix others * fix: reset progress of job and stages on retry * feat: use a simple job type in tests, attempt to fix progress failure again
1 parent 18bc563 commit 3ce2b29

File tree

14 files changed

+399
-74
lines changed

14 files changed

+399
-74
lines changed

ami/jobs/admin.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ami.main.admin import AdminBase
66

7-
from .models import Job
7+
from .models import Job, get_job_type_by_inferred_key
88

99

1010
@admin.register(Job)
@@ -19,7 +19,8 @@ class JobAdmin(AdminBase):
1919
"started_at",
2020
"finished_at",
2121
"duration",
22-
"get_job_type_display",
22+
"job_type_key",
23+
"inferred_job_type",
2324
)
2425

2526
@admin.action()
@@ -28,9 +29,15 @@ def enqueue_jobs(self, request: HttpRequest, queryset: QuerySet[Job]) -> None:
2829
job.enqueue()
2930
self.message_user(request, f"Queued {queryset.count()} job(s).")
3031

31-
@admin.display(description="Job Type")
32-
def get_job_type_display(self, obj: Job) -> str:
33-
return obj.job_type().name
32+
@admin.display(description="Inferred Job Type")
33+
def inferred_job_type(self, obj: Job) -> str:
34+
"""
35+
@TODO Remove this after running migration 0011_job_job_type_key.py and troubleshooting.
36+
"""
37+
job_type = get_job_type_by_inferred_key(obj)
38+
return job_type.name if job_type else "Could not infer"
39+
40+
# return obj.job_type().name
3441

3542
actions = [enqueue_jobs]
3643

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Generated by Django 4.2.10 on 2024-11-11 15:17
2+
3+
from django.db import migrations, models
4+
5+
6+
# Add method to set job_type_key based on inferred job type
7+
def set_job_type_key(apps, schema_editor):
8+
from ami.jobs.models import get_job_type_by_inferred_key, UnknownJobType
9+
10+
Job = apps.get_model("jobs", "Job")
11+
for job in Job.objects.all():
12+
inferred_key = get_job_type_by_inferred_key(job)
13+
if inferred_key:
14+
job.job_type_key = inferred_key.key
15+
else:
16+
job.job_type_key = UnknownJobType.key
17+
job.save()
18+
19+
20+
class Migration(migrations.Migration):
21+
dependencies = [
22+
("jobs", "0010_job_limit_job_shuffle"),
23+
]
24+
25+
operations = [
26+
migrations.AddField(
27+
model_name="job",
28+
name="job_type_key",
29+
field=models.CharField(
30+
choices=[
31+
("ml", "ML pipeline"),
32+
("populate_captures_collection", "Populate captures collection"),
33+
("data_storage_sync", "Data storage sync"),
34+
("unknown", "Unknown"),
35+
],
36+
default="unknown",
37+
max_length=255,
38+
verbose_name="Job Type",
39+
),
40+
),
41+
migrations.RunPython(set_job_type_key, reverse_code=migrations.RunPython.noop),
42+
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Generated by Django 4.2.10 on 2024-11-11 17:42
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("jobs", "0011_job_job_type_key"),
9+
]
10+
11+
operations = [
12+
migrations.AlterField(
13+
model_name="job",
14+
name="limit",
15+
field=models.IntegerField(
16+
blank=True,
17+
default=None,
18+
help_text="Limit the number of images to process",
19+
null=True,
20+
verbose_name="Limit",
21+
),
22+
),
23+
]

ami/jobs/models.py

Lines changed: 135 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,11 @@ class JobProgress(pydantic.BaseModel):
116116
logs: list[str] = []
117117

118118
def get_stage_key(self, name: str) -> str:
119+
"""Generate a key for a stage or param based on its name"""
119120
return python_slugify(name)
120121

121-
def add_stage(self, name: str) -> JobProgressStageDetail:
122-
key = self.get_stage_key(name)
122+
def add_stage(self, name: str, key: str | None = None) -> JobProgressStageDetail:
123+
key = key or self.get_stage_key(name)
123124
try:
124125
return self.get_stage(key)
125126
except ValueError:
@@ -188,6 +189,16 @@ def update_stage(self, stage_key_or_name: str, **stage_parameters) -> JobProgres
188189
self.add_or_update_stage_param(stage_key, k, v)
189190
return stage
190191

192+
def reset(self, status: JobState = JobState.CREATED):
193+
"""
194+
Set the progress of summary and all stages to 0.
195+
"""
196+
self.summary.progress = 0
197+
self.summary.status = status
198+
for stage in self.stages:
199+
stage.progress = 0
200+
stage.status = status
201+
191202
class Config:
192203
use_enum_values = True
193204
as_dict = True
@@ -265,6 +276,12 @@ def emit(self, record):
265276

266277
@dataclass
267278
class JobType:
279+
"""
280+
The run method of a job is specific to the job type.
281+
282+
Job types must be defined as classes because they define code, not just configuration.
283+
"""
284+
268285
name: str
269286
key: str
270287

@@ -273,10 +290,7 @@ def run(cls, job: "Job"):
273290
"""
274291
Execute the run function specific to this job type.
275292
"""
276-
pass
277-
278-
279-
AnyJobType = typing.TypeVar("AnyJobType", bound=JobType)
293+
raise NotImplementedError("Job type has not implemented the run method")
280294

281295

282296
class MLJob(JobType):
@@ -411,14 +425,6 @@ class DataStorageSyncJob(JobType):
411425
name = "Data storage sync"
412426
key = "data_storage_sync"
413427

414-
@classmethod
415-
def setup(cls, job: "Job", save=True):
416-
job.progress = job.progress or default_job_progress
417-
job.progress.add_stage(name=cls.name)
418-
419-
if save:
420-
job.save()
421-
422428
@classmethod
423429
def run(cls, job: "Job"):
424430
"""
@@ -427,7 +433,8 @@ def run(cls, job: "Job"):
427433
This is meant to be called by an async task, not directly.
428434
"""
429435

430-
job.progress.add_stage_param(cls.key, "Total Files", "")
436+
job.progress.add_stage(cls.name)
437+
job.progress.add_stage_param(cls.key, "Total files", "")
431438
job.update_status(JobState.STARTED)
432439
job.started_at = datetime.datetime.now()
433440
job.finished_at = None
@@ -461,6 +468,62 @@ def run(cls, job: "Job"):
461468
job.save()
462469

463470

471+
class SourceImageCollectionPopulateJob(JobType):
472+
name = "Populate captures collection"
473+
key = "populate_captures_collection"
474+
475+
@classmethod
476+
def run(cls, job: "Job"):
477+
"""
478+
Run the populate source image collection job.
479+
480+
This is meant to be called by an async task, not directly.
481+
"""
482+
job.progress.add_stage(cls.name, key=cls.key)
483+
job.progress.add_stage_param(cls.key, "Captures added", "")
484+
job.update_status(JobState.STARTED)
485+
job.started_at = datetime.datetime.now()
486+
job.finished_at = None
487+
job.save()
488+
489+
if not job.source_image_collection:
490+
job.logger.error("No source image collection provided")
491+
job.update_status(JobState.FAILURE)
492+
job.finished_at = datetime.datetime.now()
493+
job.save()
494+
return
495+
496+
job.logger.info(f"Populating source image collection {job.source_image_collection}")
497+
job.update_status(JobState.STARTED)
498+
job.started_at = datetime.datetime.now()
499+
job.finished_at = None
500+
job.progress.update_stage(
501+
cls.key,
502+
status=JobState.STARTED,
503+
progress=0.10,
504+
captures_added=0,
505+
)
506+
job.update_progress(save=True)
507+
508+
job.source_image_collection.populate_sample(job=job)
509+
job.logger.info(f"Finished populating source image collection {job.source_image_collection}")
510+
job.save()
511+
512+
captures_added = job.source_image_collection.images.count()
513+
job.logger.info(f"Added {captures_added} captures to source image collection {job.source_image_collection}")
514+
515+
job.progress.update_stage(
516+
cls.key,
517+
status=JobState.SUCCESS,
518+
progress=1,
519+
captures_added=captures_added,
520+
)
521+
job.finished_at = datetime.datetime.now()
522+
job.update_status(JobState.SUCCESS, save=False)
523+
job.update_progress(save=False)
524+
job.save()
525+
526+
464527
class UnknownJobType(JobType):
465528
name = "Unknown"
466529
key = "unknown"
@@ -472,6 +535,32 @@ def run(cls, job: "Job"):
472535
job.save()
473536

474537

538+
VALID_JOB_TYPES = [MLJob, SourceImageCollectionPopulateJob, DataStorageSyncJob, UnknownJobType]
539+
540+
541+
def get_job_type_by_key(key: str) -> type[JobType] | None:
542+
for job_type in VALID_JOB_TYPES:
543+
if job_type.key == key:
544+
return job_type
545+
546+
547+
def get_job_type_by_inferred_key(job: "Job") -> type[JobType] | None:
548+
"""
549+
Infer the job type from the job's attributes.
550+
551+
This is used for a data migration to set the job type of existing jobs
552+
before the job type field was added to the model.
553+
"""
554+
555+
if job.pipeline:
556+
return MLJob
557+
# Check the key of the first stage in the job progress
558+
if job.progress.stages:
559+
job_type = get_job_type_by_key(job.progress.stages[0].key)
560+
if job_type:
561+
return job_type
562+
563+
475564
class Job(BaseModel):
476565
"""A job to be run by the scheduler"""
477566

@@ -493,6 +582,9 @@ class Job(BaseModel):
493582
"Limit", null=True, blank=True, default=None, help_text="Limit the number of images to process"
494583
)
495584
shuffle = models.BooleanField("Shuffle", default=True, help_text="Process images in a random order")
585+
job_type_key = models.CharField(
586+
"Job Type", max_length=255, default=UnknownJobType.key, choices=[(t.key, t.name) for t in VALID_JOB_TYPES]
587+
)
496588

497589
project = models.ForeignKey(
498590
Project,
@@ -532,20 +624,15 @@ def __str__(self) -> str:
532624
return f'#{self.pk} "{self.name}" ({self.status})'
533625

534626
def job_type(self) -> type[JobType]:
535-
"""
536-
This is a temporary way to determine the type of job.
537-
@TODO rework Job classes and background tasks.
538-
"""
539-
if self.pipeline:
540-
return MLJob
541-
542-
try:
543-
self.progress.get_stage(DataStorageSyncJob.key)
544-
return DataStorageSyncJob
545-
except ValueError:
546-
pass
547-
548-
return UnknownJobType
627+
job_type_class = get_job_type_by_key(self.job_type_key)
628+
if job_type_class:
629+
return job_type_class
630+
else:
631+
inferred_job_type = get_job_type_by_inferred_key(self)
632+
msg = f"Could not determine job type for job {self.pk} with job_type_key '{self.job_type_key}'. "
633+
if inferred_job_type:
634+
msg += f"Inferred job type as '{inferred_job_type.name}'"
635+
raise ValueError(msg)
549636

550637
def enqueue(self):
551638
"""
@@ -603,6 +690,19 @@ def run(self):
603690
job_type.run(job=self)
604691
return None
605692

693+
def retry(self, async_task=True):
694+
"""
695+
Retry the job.
696+
"""
697+
self.logger.info(f"Re-running job {self}")
698+
self.progress.reset()
699+
self.status = JobState.RETRY
700+
self.save()
701+
if async_task:
702+
self.enqueue()
703+
else:
704+
self.run()
705+
606706
def cancel(self):
607707
"""
608708
Terminate the celery task.
@@ -613,7 +713,6 @@ def cancel(self):
613713
task = run_job.AsyncResult(self.task_id)
614714
if task:
615715
task.revoke(terminate=True)
616-
self.status = task.status
617716
self.save()
618717
else:
619718
self.status = JobState.REVOKED
@@ -646,7 +745,8 @@ def update_progress(self, save=True):
646745
Update the total aggregate progress from the progress of each stage.
647746
"""
648747
if not len(self.progress.stages):
649-
total_progress = 1
748+
# Need at least one stage to calculate progress
749+
total_progress = 0
650750
else:
651751
for stage in self.progress.stages:
652752
if stage.progress > 0 and stage.status == JobState.CREATED:
@@ -674,11 +774,14 @@ def save(self, *args, **kwargs):
674774
"""
675775
Create the job stages if they don't exist.
676776
"""
677-
if self.progress.stages:
777+
if self.pk and self.progress.stages:
678778
self.update_progress(save=False)
679779
else:
680780
self.setup(save=False)
681781
super().save(*args, **kwargs)
782+
logger.debug(f"Saved job {self}")
783+
if self.progress.summary.status != self.status:
784+
logger.warning(f"Job {self} status mismatches progress: {self.progress.summary.status} != {self.status}")
682785

683786
@classmethod
684787
def default_progress(cls) -> JobProgress:
@@ -698,4 +801,3 @@ class Meta:
698801
# permissions = [
699802
# ("run_job", "Can run a job"),
700803
# ("cancel_job", "Can cancel a job"),
701-
# ]

0 commit comments

Comments
 (0)